From e4355aeff0b1b18d74af520b3fa032a6178bf045 Mon Sep 17 00:00:00 2001 From: YulNaumenko Date: Wed, 29 May 2024 20:42:15 -0700 Subject: [PATCH 01/23] [Security AI Assistant] Chat complete API --- .../kbn-elastic-assistant-common/constants.ts | 2 + .../chat/post_chat_complete_route.gen.ts | 124 +++++ .../chat/post_chat_complete_route.schema.yaml | 171 +++++++ .../impl/schemas/index.ts | 3 + .../language_models/chat_openai.test.ts | 3 +- .../server/language_models/chat_openai.ts | 22 +- .../server/language_models/llm.test.ts | 3 +- .../server/language_models/llm.ts | 22 +- .../language_models/simple_chat_model.test.ts | 3 +- .../language_models/simple_chat_model.ts | 21 +- .../server/lib/executor.test.ts | 20 +- .../elastic_assistant/server/lib/executor.ts | 12 +- .../execute_custom_llm_chain/index.ts | 5 +- .../executors/openai_functions_executor.ts | 5 +- .../server/lib/langchain/executors/types.ts | 5 +- .../routes/chat/chat_complete_route.test.ts | 143 ++++++ .../server/routes/chat/chat_complete_route.ts | 173 +++++++ .../server/routes/helpers.ts | 449 +++++++++++++++++- .../routes/post_actions_connector_execute.ts | 331 +++---------- .../server/routes/register_routes.ts | 8 +- 20 files changed, 1175 insertions(+), 350 deletions(-) create mode 100644 x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.gen.ts create mode 100644 x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.schema.yaml create mode 100644 x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts create mode 100644 x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts diff --git a/x-pack/packages/kbn-elastic-assistant-common/constants.ts b/x-pack/packages/kbn-elastic-assistant-common/constants.ts index f30cb053d4ce1..8ef915fd95111 100755 --- a/x-pack/packages/kbn-elastic-assistant-common/constants.ts +++ b/x-pack/packages/kbn-elastic-assistant-common/constants.ts @@ -14,6 +14,8 @@ export const ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL = `${ELASTIC_AI_ASSISTANT_IN export const ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL_BY_ID = `${ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL}/{id}`; export const ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL_BY_ID_MESSAGES = `${ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL_BY_ID}/messages`; +export const ELASTIC_AI_ASSISTANT_CHAT_COMPLETE_URL = `${ELASTIC_AI_ASSISTANT_URL}/chat/complete`; + export const ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL_BULK_ACTION = `${ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL}/_bulk_action`; export const ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL_FIND = `${ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL}/_find`; diff --git a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.gen.ts b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.gen.ts new file mode 100644 index 0000000000000..ec17fb447fe5d --- /dev/null +++ b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.gen.ts @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { z } from 'zod'; + +/* + * NOTICE: Do not edit this file manually. + * This file is automatically generated by the OpenAPI Generator, @kbn/openapi-generator. + * + * info: + * title: Chat Complete API endpoint + * version: 2023-10-31 + */ + +import { NonEmptyString } from '../common_attributes.gen'; + +export type RootContext = z.infer; +export const RootContext = z.literal('security'); + +/** + * Message role. + */ +export type ChatMessageRole = z.infer; +export const ChatMessageRole = z.enum(['system', 'user', 'assistant', 'function', 'elastic']); +export type ChatMessageRoleEnum = typeof ChatMessageRole.enum; +export const ChatMessageRoleEnum = ChatMessageRole.enum; + +/** + * Message role. + */ +export type TriggerType = z.infer; +export const TriggerType = z.enum(['user', 'assistant', 'elastic']); +export type TriggerTypeEnum = typeof TriggerType.enum; +export const TriggerTypeEnum = TriggerType.enum; + +export type TriggerArguments = z.infer; +export const TriggerArguments = z.object({}).catchall(z.unknown()); + +export type TriggerData = z.infer; +export const TriggerData = z.object({}).catchall(z.unknown()); + +/** + * AI assistant message. + */ +export type InstructionsObject = z.infer; +export const InstructionsObject = z.object({ + doc_id: z.string().optional(), + text: z.string().optional(), +}); + +/** + * AI assistant message. + */ +export type FunctionCall = z.infer; +export const FunctionCall = z.object({ + /** + * Trigger type. + */ + trigger: TriggerType, + arguments: TriggerArguments.optional(), + data: TriggerData.optional(), +}); + +export type MessageData = z.infer; +export const MessageData = z.object({}).catchall(z.unknown()); + +/** + * AI assistant message. + */ +export type ChatMessage = z.infer; +export const ChatMessage = z.object({ + /** + * Message content. + */ + content: z.string().optional(), + /** + * Message name. + */ + name: z.string().optional(), + /** + * Function definition. + */ + function_call: FunctionCall.optional(), + /** + * Message role. + */ + role: ChatMessageRole, + /** + * The timestamp message was sent or received. + */ + '@timestamp': NonEmptyString, + /** + * ECS objects array to attach to the context of the message. + */ + data: z.array(MessageData).optional(), + fields_to_anonymize: z.array(z.string()).optional(), +}); + +export type ChatCompleteProps = z.infer; +export const ChatCompleteProps = z.object({ + /** + * Solution context. + */ + context: RootContext.optional(), + conversationId: z.string().optional(), + responseLanguage: z.string().optional(), + langSmithProject: z.string().optional(), + langSmithApiKey: z.string().optional(), + disableFunctions: z.boolean().optional(), + connectorId: z.string(), + model: z.string().optional(), + title: z.string().optional(), + persist: z.boolean(), + messages: z.array(ChatMessage), + instructions: z.array(z.union([InstructionsObject, z.string()])).optional(), +}); + +export type ChatCompleteRequestBody = z.infer; +export const ChatCompleteRequestBody = ChatCompleteProps; +export type ChatCompleteRequestBodyInput = z.input; diff --git a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.schema.yaml b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.schema.yaml new file mode 100644 index 0000000000000..3e5f7c90d6a3b --- /dev/null +++ b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.schema.yaml @@ -0,0 +1,171 @@ +openapi: 3.0.0 +info: + title: Chat Complete API endpoint + version: '2023-10-31' +paths: + /api/elastic_assistant/chat/complete: + post: + operationId: ChatComplete + x-codegen-enabled: true + description: Creates a model response for the given chat conversation. + summary: Creates a model response for the given chat conversation. + tags: + - Chat Complete API + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/ChatCompleteProps' + responses: + 200: + description: Indicates a successful call. + content: + application/octet-stream: + schema: + type: string + format: binary + 400: + description: Generic Error + content: + application/json: + schema: + type: object + properties: + statusCode: + type: number + error: + type: string + message: + type: string + +components: + schemas: + RootContext: + type: string + enum: + - security + + ChatMessageRole: + type: string + description: Message role. + enum: + - system + - user + - assistant + - function + - elastic + + TriggerType: + type: string + description: Message role. + enum: + - user + - assistant + - elastic + + TriggerArguments: + type: object + additionalProperties: true + + TriggerData: + type: object + additionalProperties: true + + InstructionsObject: + type: object + description: AI assistant message. + properties: + doc_id: + type: string + text: + type: string + + FunctionCall: + type: object + description: AI assistant message. + required: + - 'trigger' + properties: + trigger: + $ref: '#/components/schemas/TriggerType' + description: Trigger type. + arguments: + $ref: '#/components/schemas/TriggerArguments' + data: + $ref: '#/components/schemas/TriggerData' + + MessageData: + type: object + additionalProperties: true + + ChatMessage: + type: object + description: AI assistant message. + required: + - '@timestamp' + - 'role' + properties: + content: + type: string + description: Message content. + name: + type: string + description: Message name. + function_call: + $ref: '#/components/schemas/FunctionCall' + description: Function definition. + role: + $ref: '#/components/schemas/ChatMessageRole' + description: Message role. + '@timestamp': + $ref: '../common_attributes.schema.yaml#/components/schemas/NonEmptyString' + description: The timestamp message was sent or received. + data: + description: ECS objects array to attach to the context of the message. + type: array + items: + $ref: '#/components/schemas/MessageData' + fields_to_anonymize: + type: array + items: + type: string + + ChatCompleteProps: + type: object + properties: + context: + $ref: '#/components/schemas/RootContext' + description: Solution context. + conversationId: + type: string + responseLanguage: + type: string + langSmithProject: + type: string + langSmithApiKey: + type: string + disableFunctions: + type: boolean + connectorId: + type: string + model: + type: string + title: + type: string + persist: + type: boolean + messages: + type: array + items: + $ref: '#/components/schemas/ChatMessage' + instructions: + type: array + items: + oneOf: + - $ref: '#/components/schemas/InstructionsObject' + - type: string + required: + - messages + - persist + - connectorId diff --git a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/index.ts b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/index.ts index c9c2d2a8be3c0..b745ba56209a1 100644 --- a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/index.ts +++ b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/index.ts @@ -24,6 +24,9 @@ export * from './common_attributes.gen'; // Attack discovery Schemas export * from './attack_discovery/post_attack_discovery_route.gen'; +// Chat Schemas +export * from './chat/post_chat_complete_route.gen'; + // Evaluation Schemas export * from './evaluation/post_evaluate_route.gen'; export * from './evaluation/get_evaluate_route.gen'; diff --git a/x-pack/packages/kbn-langchain/server/language_models/chat_openai.test.ts b/x-pack/packages/kbn-langchain/server/language_models/chat_openai.test.ts index 5c6d389a4ccc4..b73db19853982 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/chat_openai.test.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/chat_openai.test.ts @@ -60,10 +60,9 @@ const mockRequest = { } as ActionsClientChatOpenAIParams['request']; const defaultArgs = { - actions: mockActions, + actionsClient: mockActions, connectorId, logger: mockLogger, - request: mockRequest, streaming: false, signal, timeout: 999999, diff --git a/x-pack/packages/kbn-langchain/server/language_models/chat_openai.ts b/x-pack/packages/kbn-langchain/server/language_models/chat_openai.ts index 7675e2442e598..bd42795d6c1fd 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/chat_openai.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/chat_openai.ts @@ -6,24 +6,24 @@ */ import { v4 as uuidv4 } from 'uuid'; -import { KibanaRequest, Logger } from '@kbn/core/server'; -import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; +import { Logger } from '@kbn/core/server'; +import type { ActionsClient } from '@kbn/actions-plugin/server'; import { get } from 'lodash/fp'; import { ChatOpenAI } from '@langchain/openai'; import { Stream } from 'openai/streaming'; import type OpenAI from 'openai'; +import { PublicMethodsOf } from '@kbn/utility-types'; import { DEFAULT_OPEN_AI_MODEL, DEFAULT_TIMEOUT } from './constants'; import { InvokeAIActionParamsSchema, RunActionParamsSchema } from './types'; const LLM_TYPE = 'ActionsClientChatOpenAI'; export interface ActionsClientChatOpenAIParams { - actions: ActionsPluginStart; + actionsClient: PublicMethodsOf; connectorId: string; llmType?: string; logger: Logger; - request: KibanaRequest; streaming?: boolean; traceId?: string; maxRetries?: number; @@ -53,22 +53,20 @@ export class ActionsClientChatOpenAI extends ChatOpenAI { #temperature?: number; // Kibana variables - #actions: ActionsPluginStart; + #actionsClient: PublicMethodsOf; #connectorId: string; #logger: Logger; - #request: KibanaRequest; #actionResultData: string; #traceId: string; #signal?: AbortSignal; #timeout?: number; constructor({ - actions, + actionsClient, connectorId, traceId = uuidv4(), llmType, logger, - request, maxRetries, model, signal, @@ -89,12 +87,11 @@ export class ActionsClientChatOpenAI extends ChatOpenAI { azureOpenAIApiVersion: 'nothing', openAIApiKey: '', }); - this.#actions = actions; + this.#actionsClient = actionsClient; this.#connectorId = connectorId; this.#traceId = traceId; this.llmType = llmType ?? LLM_TYPE; this.#logger = logger; - this.#request = request; this.#timeout = timeout; this.#actionResultData = ''; this.streaming = streaming; @@ -143,10 +140,7 @@ export class ActionsClientChatOpenAI extends ChatOpenAI { )} ` ); - // create an actions client from the authenticated request context: - const actionsClient = await this.#actions.getActionsClientWithRequest(this.#request); - - const actionResult = await actionsClient.execute(requestBody); + const actionResult = await this.#actionsClient.execute(requestBody); if (actionResult.status === 'error') { throw new Error(`${LLM_TYPE}: ${actionResult?.message} - ${actionResult?.serviceMessage}`); diff --git a/x-pack/packages/kbn-langchain/server/language_models/llm.test.ts b/x-pack/packages/kbn-langchain/server/language_models/llm.test.ts index e0f7b764b625a..db04d3e5d7810 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/llm.test.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/llm.test.ts @@ -46,10 +46,9 @@ describe('ActionsClientLlm', () => { describe('getActionResultData', () => { it('returns the expected data', async () => { const actionsClientLlm = new ActionsClientLlm({ - actions: mockActions, + actionsClient: mockActions, connectorId, logger: mockLogger, - request: mockRequest, }); const result = await actionsClientLlm._call(prompt); // ignore the result diff --git a/x-pack/packages/kbn-langchain/server/language_models/llm.ts b/x-pack/packages/kbn-langchain/server/language_models/llm.ts index 9708a8d3a5d7f..bad538821ff1d 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/llm.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/llm.ts @@ -5,11 +5,12 @@ * 2.0. */ -import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; -import { KibanaRequest, Logger } from '@kbn/core/server'; +import type { ActionsClient } from '@kbn/actions-plugin/server'; +import { Logger } from '@kbn/core/server'; import { LLM } from '@langchain/core/language_models/llms'; import { get } from 'lodash/fp'; import { v4 as uuidv4 } from 'uuid'; +import { PublicMethodsOf } from '@kbn/utility-types'; import { DEFAULT_TIMEOUT, getDefaultArguments } from './constants'; import { getMessageContentAndRole } from './helpers'; @@ -18,11 +19,10 @@ import { TraceOptions } from './types'; const LLM_TYPE = 'ActionsClientLlm'; interface ActionsClientLlmParams { - actions: ActionsPluginStart; + actionsClient: PublicMethodsOf; connectorId: string; llmType?: string; logger: Logger; - request: KibanaRequest; model?: string; temperature?: number; timeout?: number; @@ -31,10 +31,9 @@ interface ActionsClientLlmParams { } export class ActionsClientLlm extends LLM { - #actions: ActionsPluginStart; + #actionsClient: PublicMethodsOf; #connectorId: string; #logger: Logger; - #request: KibanaRequest; #traceId: string; #timeout?: number; @@ -46,13 +45,12 @@ export class ActionsClientLlm extends LLM { temperature?: number; constructor({ - actions, + actionsClient, connectorId, traceId = uuidv4(), llmType, logger, model, - request, temperature, timeout, traceOptions, @@ -61,12 +59,11 @@ export class ActionsClientLlm extends LLM { callbacks: [...(traceOptions?.tracers ?? [])], }); - this.#actions = actions; + this.#actionsClient = actionsClient; this.#connectorId = connectorId; this.#traceId = traceId; this.llmType = llmType ?? LLM_TYPE; this.#logger = logger; - this.#request = request; this.#timeout = timeout; this.model = model; this.temperature = temperature; @@ -107,10 +104,7 @@ export class ActionsClientLlm extends LLM { }, }; - // create an actions client from the authenticated request context: - const actionsClient = await this.#actions.getActionsClientWithRequest(this.#request); - - const actionResult = await actionsClient.execute(requestBody); + const actionResult = await this.#actionsClient.execute(requestBody); if (actionResult.status === 'error') { throw new Error( diff --git a/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.test.ts b/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.test.ts index 6a11466f9faa0..4592c8183098c 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.test.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.test.ts @@ -87,10 +87,9 @@ const mockRequest: CustomChatModelInput['request'] = { } as CustomChatModelInput['request']; const defaultArgs = { - actions: mockActions, + actionsClient: mockActions, connectorId, logger: mockLogger, - request: mockRequest, streaming: false, }; jest.mock('../utils/bedrock'); diff --git a/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.ts b/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.ts index f13b0a53611ef..9ee0fd0d5cc9a 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.ts @@ -11,12 +11,12 @@ import { type BaseChatModelParams, } from '@langchain/core/language_models/chat_models'; import { type BaseMessage } from '@langchain/core/messages'; -import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; +import type { ActionsClient } from '@kbn/actions-plugin/server'; import { Logger } from '@kbn/logging'; -import { KibanaRequest } from '@kbn/core-http-server'; import { v4 as uuidv4 } from 'uuid'; import { get } from 'lodash/fp'; import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'; +import { PublicMethodsOf } from '@kbn/utility-types'; import { parseBedrockStream } from '../utils/bedrock'; import { getDefaultArguments } from './constants'; @@ -26,22 +26,20 @@ export const getMessageContentAndRole = (prompt: string, role = 'user') => ({ }); export interface CustomChatModelInput extends BaseChatModelParams { - actions: ActionsPluginStart; + actionsClient: PublicMethodsOf; connectorId: string; logger: Logger; llmType?: string; signal?: AbortSignal; model?: string; temperature?: number; - request: KibanaRequest; streaming: boolean; } export class ActionsClientSimpleChatModel extends SimpleChatModel { - #actions: ActionsPluginStart; + #actionsClient: PublicMethodsOf; #connectorId: string; #logger: Logger; - #request: KibanaRequest; #traceId: string; #signal?: AbortSignal; llmType: string; @@ -50,24 +48,22 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel { temperature?: number; constructor({ - actions, + actionsClient, connectorId, llmType, logger, model, - request, temperature, signal, streaming, }: CustomChatModelInput) { super({}); - this.#actions = actions; + this.#actionsClient = actionsClient; this.#connectorId = connectorId; this.#traceId = uuidv4(); this.#logger = logger; this.#signal = signal; - this.#request = request; this.llmType = llmType ?? 'ActionsClientSimpleChatModel'; this.model = model; this.temperature = temperature; @@ -126,10 +122,7 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel { }, }; - // create an actions client from the authenticated request context: - const actionsClient = await this.#actions.getActionsClientWithRequest(this.#request); - - const actionResult = await actionsClient.execute(requestBody); + const actionResult = await this.#actionsClient.execute(requestBody); if (actionResult.status === 'error') { throw new Error( diff --git a/x-pack/plugins/elastic_assistant/server/lib/executor.test.ts b/x-pack/plugins/elastic_assistant/server/lib/executor.test.ts index bacdd6cac1b49..4e9cb7ffaec67 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/executor.test.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/executor.test.ts @@ -15,6 +15,7 @@ import { executeAction, Props } from './executor'; import { PassThrough } from 'stream'; import { KibanaRequest } from '@kbn/core-http-server'; import { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; +import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock'; import { ExecuteConnectorRequestBody } from '@kbn/elastic-assistant-common'; import { loggerMock } from '@kbn/logging-mocks'; import * as ParseStream from './parse_stream'; @@ -33,8 +34,8 @@ const testProps: Omit = { subActionParams: { messages: [{ content: 'hello', role: 'user' }] }, }, actionTypeId: '.bedrock', - request, connectorId, + actionsClient: actionsClientMock.create(), onLlmResponse, logger: mockLogger, }; @@ -46,7 +47,7 @@ describe('executeAction', () => { jest.clearAllMocks(); }); it('should execute an action and return a StaticResponse when the response from the actions framework is a string', async () => { - const actions = { + /* const actions = { getActionsClientWithRequest: jest.fn().mockResolvedValue({ execute: jest.fn().mockResolvedValue({ data: { @@ -54,9 +55,14 @@ describe('executeAction', () => { }, }), }), - } as unknown as Props['actions']; + } as unknown as Props['actions'];*/ + testProps.actionsClient.execute = jest.fn().mockResolvedValue({ + data: { + message: 'Test message', + }, + }); - const result = await executeAction({ ...testProps, actions }); + const result = await executeAction({ ...testProps }); expect(result).toEqual({ connector_id: connectorId, @@ -68,15 +74,15 @@ describe('executeAction', () => { it('should execute an action and return a Readable object when the response from the actions framework is a stream', async () => { const readableStream = new PassThrough(); - const actions = { + const actionsClient = { getActionsClientWithRequest: jest.fn().mockResolvedValue({ execute: jest.fn().mockResolvedValue({ data: readableStream, }), }), - } as unknown as Props['actions']; + } as unknown as Props['actionsClient']; - const result = await executeAction({ ...testProps, actions }); + const result = await executeAction({ ...testProps, actionsClient }); expect(JSON.stringify(result)).toStrictEqual( JSON.stringify(readableStream.pipe(new PassThrough())) diff --git a/x-pack/plugins/elastic_assistant/server/lib/executor.ts b/x-pack/plugins/elastic_assistant/server/lib/executor.ts index e7797805854ec..bd25a77808dbe 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/executor.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/executor.ts @@ -6,20 +6,18 @@ */ import { get } from 'lodash/fp'; -import { KibanaRequest } from '@kbn/core-http-server'; -import { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; +import { ActionsClient } from '@kbn/actions-plugin/server'; import { PassThrough, Readable } from 'stream'; -import { ExecuteConnectorRequestBody } from '@kbn/elastic-assistant-common'; import { Logger } from '@kbn/core/server'; +import { PublicMethodsOf } from '@kbn/utility-types'; import { handleStreamStorage } from './parse_stream'; export interface Props { onLlmResponse?: (content: string) => Promise; abortSignal?: AbortSignal; - actions: ActionsPluginStart; + actionsClient: PublicMethodsOf; connectorId: string; params: InvokeAIActionsParams; - request: KibanaRequest; actionTypeId: string; logger: Logger; } @@ -43,15 +41,13 @@ interface InvokeAIActionsParams { export const executeAction = async ({ onLlmResponse, - actions, + actionsClient, params, connectorId, actionTypeId, - request, logger, abortSignal, }: Props): Promise => { - const actionsClient = await actions.getActionsClientWithRequest(request); const actionResult = await actionsClient.execute({ actionId: connectorId, params: { diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts index 8323712c50aa7..2b7625b255ecd 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts @@ -30,7 +30,7 @@ export const DEFAULT_AGENT_EXECUTOR_ID = 'Elastic AI Assistant Agent Executor'; */ export const callAgentExecutor: AgentExecutor = async ({ abortSignal, - actions, + actionsClient, alertsIndexPattern, anonymizationFields, isEnabledKnowledgeBase, @@ -53,9 +53,8 @@ export const callAgentExecutor: AgentExecutor = async ({ const llmClass = isOpenAI ? ActionsClientChatOpenAI : ActionsClientSimpleChatModel; const llm = new llmClass({ - actions, + actionsClient, connectorId, - request, llmType, logger, // possible client model override, diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/openai_functions_executor.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/openai_functions_executor.ts index 62b1d7ac7814f..6aa1aa3ce7890 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/openai_functions_executor.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/openai_functions_executor.ts @@ -25,7 +25,7 @@ export const OPEN_AI_FUNCTIONS_AGENT_EXECUTOR_ID = * NOTE: This is not to be used in production as-is, and must be used with an OpenAI ConnectorId */ export const callOpenAIFunctionsExecutor: AgentExecutor = async ({ - actions, + actionsClient, connectorId, esClient, esStore, @@ -36,9 +36,8 @@ export const callOpenAIFunctionsExecutor: AgentExecutor = async ({ traceOptions, }) => { const llm = new ActionsClientLlm({ - actions, + actionsClient, connectorId, - request, llmType, logger, model: request.body.model, diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts index 8acd7f4fcdde2..797945a453242 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts @@ -5,7 +5,7 @@ * 2.0. */ -import { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; +import { ActionsClient } from '@kbn/actions-plugin/server'; import { ElasticsearchClient } from '@kbn/core-elasticsearch-server'; import { BaseMessage } from '@langchain/core/messages'; import { Logger } from '@kbn/logging'; @@ -14,6 +14,7 @@ import type { LangChainTracer } from '@langchain/core/tracers/tracer_langchain'; import { ExecuteConnectorRequestBody, Message, Replacements } from '@kbn/elastic-assistant-common'; import { StreamResponseWithHeaders } from '@kbn/ml-response-stream/server'; import { AnonymizationFieldResponse } from '@kbn/elastic-assistant-common/impl/schemas/anonymization_fields/bulk_crud_anonymization_fields_route.gen'; +import { PublicMethodsOf } from '@kbn/utility-types'; import { ResponseBody } from '../types'; import type { AssistantTool } from '../../../types'; import { ElasticsearchStore } from '../elasticsearch_store/elasticsearch_store'; @@ -21,7 +22,7 @@ import { ElasticsearchStore } from '../elasticsearch_store/elasticsearch_store'; export interface AgentExecutorParams { abortSignal?: AbortSignal; alertsIndexPattern?: string; - actions: ActionsPluginStart; + actionsClient: PublicMethodsOf; anonymizationFields?: AnonymizationFieldResponse[]; isEnabledKnowledgeBase: boolean; assistantTools?: AssistantTool[]; diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts new file mode 100644 index 0000000000000..5b2737436fa8c --- /dev/null +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts @@ -0,0 +1,143 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { elasticsearchClientMock } from '@kbn/core-elasticsearch-client-server-mocks'; +import { requestContextMock } from '../../__mocks__/request_context'; +import { serverMock } from '../../__mocks__/server'; +import { createConversationRoute } from './create_route'; +import { getBasicEmptySearchResponse, getEmptyFindResult } from '../../__mocks__/response'; +import { getCreateConversationRequest, requestMock } from '../../__mocks__/request'; +import { + getCreateConversationSchemaMock, + getConversationMock, + getQueryConversationParams, +} from '../../__mocks__/conversations_schema.mock'; +import { ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL } from '@kbn/elastic-assistant-common'; +import { AuthenticatedUser } from '@kbn/security-plugin-types-common'; + +describe('Create conversation route', () => { + let server: ReturnType; + let { clients, context } = requestContextMock.createTools(); + const mockUser1 = { + username: 'my_username', + authentication_realm: { + type: 'my_realm_type', + name: 'my_realm_name', + }, + } as AuthenticatedUser; + + beforeEach(() => { + server = serverMock.create(); + ({ clients, context } = requestContextMock.createTools()); + + clients.elasticAssistant.getAIAssistantConversationsDataClient.findDocuments.mockResolvedValue( + Promise.resolve(getEmptyFindResult()) + ); // no current conversations + clients.elasticAssistant.getAIAssistantConversationsDataClient.createConversation.mockResolvedValue( + getConversationMock(getQueryConversationParams()) + ); // creation succeeds + + context.core.elasticsearch.client.asCurrentUser.search.mockResolvedValue( + elasticsearchClientMock.createSuccessTransportRequestPromise(getBasicEmptySearchResponse()) + ); + context.elasticAssistant.getCurrentUser.mockReturnValue(mockUser1); + createConversationRoute(server.router); + }); + + describe('status codes', () => { + test('returns 200 with a conversation created via AIAssistantConversationsDataClient', async () => { + const response = await server.inject( + getCreateConversationRequest(), + requestContextMock.convertContext(context) + ); + expect(response.status).toEqual(200); + }); + + test('returns 401 Unauthorized when request context getCurrentUser is not defined', async () => { + context.elasticAssistant.getCurrentUser.mockReturnValueOnce(null); + const response = await server.inject( + getCreateConversationRequest(), + requestContextMock.convertContext(context) + ); + expect(response.status).toEqual(401); + }); + }); + + describe('unhappy paths', () => { + test('catches error if creation throws', async () => { + clients.elasticAssistant.getAIAssistantConversationsDataClient.createConversation.mockImplementation( + async () => { + throw new Error('Test error'); + } + ); + const response = await server.inject( + getCreateConversationRequest(), + requestContextMock.convertContext(context) + ); + expect(response.status).toEqual(500); + expect(response.body).toEqual({ + message: 'Test error', + status_code: 500, + }); + }); + }); + + describe('request validation', () => { + test('disallows unknown title', async () => { + const request = requestMock.create({ + method: 'post', + path: ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL, + body: { + ...getCreateConversationSchemaMock(), + title: true, + }, + }); + const result = server.validate(request); + + expect(result.badRequest).toHaveBeenCalled(); + }); + }); + describe('conversation containing messages', () => { + const getMessage = (role: string = 'user') => ({ + role, + content: 'test content', + timestamp: '2019-12-13T16:40:33.400Z', + }); + const defaultMessage = getMessage(); + + test('is successful', async () => { + const request = requestMock.create({ + method: 'post', + path: ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL, + body: { + ...getCreateConversationSchemaMock(), + messages: [defaultMessage], + }, + }); + + const response = await server.inject(request, requestContextMock.convertContext(context)); + expect(response.status).toEqual(200); + }); + + test('fails when provided with an unsupported message role', async () => { + const wrongMessage = getMessage('test_thing'); + + const request = requestMock.create({ + method: 'post', + path: ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL, + body: { + ...getCreateConversationSchemaMock(), + messages: [wrongMessage], + }, + }); + const result = await server.validate(request); + expect(result.badRequest).toHaveBeenCalledWith( + `messages.0.role: Invalid enum value. Expected 'system' | 'user' | 'assistant', received 'test_thing'` + ); + }); + }); +}); diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts new file mode 100644 index 0000000000000..be110c82261e1 --- /dev/null +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts @@ -0,0 +1,173 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { transformError } from '@kbn/securitysolution-es-utils'; +import { Logger } from '@kbn/core/server'; +import { + ELASTIC_AI_ASSISTANT_CHAT_COMPLETE_URL, + ChatCompleteProps, + API_VERSIONS, + Message, + Replacements, +} from '@kbn/elastic-assistant-common'; +import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; +import { getRequestAbortedSignal } from '@kbn/data-plugin/server'; +import { ElasticAssistantPluginRouter, GetElser } from '../../types'; +import { buildResponse } from '../utils'; +import { + UPGRADE_LICENSE_MESSAGE, + appendAssistantMessageToConversation, + hasAIAssistantLicense, + langChainExecute, + updateConversationWithUserInput, +} from '../helpers'; + +export const chatCompleteRoute = ( + router: ElasticAssistantPluginRouter, + getElser: GetElser +): void => { + router.versioned + .post({ + access: 'public', + path: ELASTIC_AI_ASSISTANT_CHAT_COMPLETE_URL, + + options: { + tags: ['access:elasticAssistant'], + }, + }) + .addVersion( + { + version: API_VERSIONS.public.v1, + validate: { + request: { + body: buildRouteValidationWithZod(ChatCompleteProps), + }, + }, + }, + async (context, request, response) => { + const abortSignal = getRequestAbortedSignal(request.events.aborted$); + const assistantResponse = buildResponse(response); + try { + const ctx = await context.resolve(['core', 'elasticAssistant', 'licensing']); + const logger: Logger = ctx.elasticAssistant.logger; + const telemetry = ctx.elasticAssistant.telemetry; + let onLlmResponse; + + const license = ctx.licensing.license; + if (!hasAIAssistantLicense(license)) { + return response.forbidden({ + body: { + message: UPGRADE_LICENSE_MESSAGE, + }, + }); + } + const authenticatedUser = ctx.elasticAssistant.getCurrentUser(); + if (authenticatedUser == null) { + return assistantResponse.error({ + body: `Authenticated user not found`, + statusCode: 401, + }); + } + + const conversationsDataClient = + await ctx.elasticAssistant.getAIAssistantConversationsDataClient(); + + const anonymizationFieldsDataClient = + await ctx.elasticAssistant.getAIAssistantAnonymizationFieldsDataClient(); + + let messages; + const conversationId = request.body.conversationId; + const connectorId = request.body.connectorId; + + let latestReplacements: Replacements = {}; + const onNewReplacements = (newReplacements: Replacements) => { + latestReplacements = { ...latestReplacements, ...newReplacements }; + }; + + // get the actions plugin start contract from the request context: + const actions = (await context.elasticAssistant).actions; + const actionsClient = await actions.getActionsClientWithRequest(request); + const actionTypeId = + (await actionsClient.getAllSystemConnectors()).find((c) => c.id === connectorId) + ?.actionTypeId ?? '.gen-ai'; + + // replacements + + if (request.body.persist && conversationId && conversationsDataClient) { + const updatedConversation = await updateConversationWithUserInput({ + actionsClient, + actionTypeId, + authenticatedUser, + connectorId, + conversationId, + conversationsDataClient, + logger, + replacements: latestReplacements, + newMessages: request.body.messages + .filter((f) => f.role === 'assistant' || f.role === 'user') + .map((m) => ({ + role: m.role, + content: m.content ?? '', + })) as Message[], + model: request.body.model, + }); + if (updatedConversation == null) { + return response.badRequest({ + body: `conversation id: "${conversationId}" not updated`, + }); + } + // messages are anonymized by conversationsDataClient + messages = updatedConversation?.messages?.map((c) => ({ + role: c.role, + content: c.content, + })); + + onLlmResponse = async ( + content: string, + traceData: Message['traceData'] = {}, + isError = false + ): Promise => { + if (updatedConversation && conversationsDataClient) { + await appendAssistantMessageToConversation({ + conversation: updatedConversation, + conversationsDataClient, + messageContent: content, + replacements: latestReplacements, + isError, + traceData, + }); + } + }; + } + + return await langChainExecute({ + abortSignal, + actionsClient, + actionTypeId, + assistantContext: ctx.elasticAssistant, + connectorId, + context, + getElser, + logger, + messages: messages ?? [], + onLlmResponse, + onNewReplacements, + replacements: latestReplacements, + request, + response, + telemetry, + }); + } catch (err) { + const error = transformError(err as Error); + return assistantResponse.error({ + body: error.message, + statusCode: error.statusCode, + }); + } + } + ); +}; diff --git a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts index 932eafb4a549b..a18a796065e35 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts @@ -5,11 +5,46 @@ * 2.0. */ -import { KibanaRequest } from '@kbn/core-http-server'; -import { Logger } from '@kbn/core/server'; -import { Message, TraceData } from '@kbn/elastic-assistant-common'; +import { + AnalyticsServiceSetup, + AuthenticatedUser, + KibanaRequest, + KibanaResponseFactory, + Logger, +} from '@kbn/core/server'; +import { StreamResponseWithHeaders } from '@kbn/ml-response-stream/server'; + +import { + TraceData, + ConversationResponse, + ExecuteConnectorRequestBody, + Message, + Replacements, + replaceAnonymizedValuesWithOriginalValues, +} from '@kbn/elastic-assistant-common'; import { ILicense } from '@kbn/licensing-plugin/server'; +import { i18n } from '@kbn/i18n'; +import { PublicMethodsOf } from '@kbn/utility-types'; +import { ActionsClient } from '@kbn/actions-plugin/server'; import { MINIMUM_AI_ASSISTANT_LICENSE } from '../../common/constants'; +import { ESQL_RESOURCE, KNOWLEDGE_BASE_INDEX_PATTERN } from './knowledge_base/constants'; +import { callAgentExecutor } from '../lib/langchain/execute_custom_llm_chain'; +import { getLlmType } from './utils'; +import { StaticReturnType } from '../lib/langchain/executors/types'; +import { executeAction, StaticResponse } from '../lib/executor'; +import { getLangChainMessages } from '../lib/langchain/helpers'; + +import { getLangSmithTracer } from './evaluate/utils'; +import { EsAnonymizationFieldsSchema } from '../ai_assistant_data_clients/anonymization_fields/types'; +import { transformESSearchToAnonymizationFields } from '../ai_assistant_data_clients/anonymization_fields/helpers'; +import { ElasticsearchStore } from '../lib/langchain/elasticsearch_store/elasticsearch_store'; +import { AIAssistantConversationsDataClient } from '../ai_assistant_data_clients/conversations'; +import { INVOKE_ASSISTANT_SUCCESS_EVENT } from '../lib/telemetry/event_based_telemetry'; +import { + ElasticAssistantApiRequestHandlerContext, + ElasticAssistantRequestHandlerContext, + GetElser, +} from '../types'; interface GetPluginNameFromRequestParams { request: KibanaRequest; @@ -87,3 +122,411 @@ export const hasAIAssistantLicense = (license: ILicense): boolean => export const UPGRADE_LICENSE_MESSAGE = 'Your license does not support AI Assistant. Please upgrade your license.'; + +export interface GenerateTitleForNewChatConversationParams { + conversationsDataClient: AIAssistantConversationsDataClient; + conversationId: string; + message: Pick; + model?: string; + actionTypeId: string; + connectorId: string; + logger: Logger; + actionsClient: PublicMethodsOf; +} +export const generateTitleForNewChatConversation = async ({ + conversationsDataClient, + conversationId, + message, + model, + actionTypeId, + connectorId, + logger, + actionsClient, +}: GenerateTitleForNewChatConversationParams) => { + try { + const autoTitle = (await executeAction({ + actionsClient, + connectorId, + actionTypeId, + params: { + subAction: 'invokeAI', + subActionParams: { + model, + messages: [ + { + role: 'assistant', + content: i18n.translate( + 'xpack.elasticAssistantPlugin.server.autoTitlePromptDescription', + { + defaultMessage: + 'You are a helpful assistant for Elastic Security. Assume the following message is the start of a conversation between you and a user; give this conversation a title based on the content below. DO NOT UNDER ANY CIRCUMSTANCES wrap this title in single or double quotes. This title is shown in a list of conversations to the user, so title it for the user, not for you.', + } + ), + }, + message, + ], + ...(actionTypeId === '.gen-ai' + ? { n: 1, stop: null, temperature: 0.2 } + : { temperature: 0, stopSequences: [] }), + }, + }, + logger, + })) as unknown as StaticResponse; // TODO: Use function overloads in executeAction to avoid this cast when sending subAction: 'invokeAI', + if (autoTitle.status === 'ok') { + try { + // This regular expression captures a string enclosed in single or double quotes. + // It extracts the string content without the quotes. + // Example matches: + // - "Hello, World!" => Captures: Hello, World! + // - 'Another Example' => Captures: Another Example + // - JustTextWithoutQuotes => Captures: JustTextWithoutQuotes + const match = autoTitle.data.match(/^["']?([^"']+)["']?$/); + const title = match ? match[1] : autoTitle.data; + + return await conversationsDataClient.updateConversation({ + conversationUpdateProps: { + id: conversationId, + title, + }, + }); + } catch (e) { + logger.warn(`Failed to update conversation with generated title: ${e.message}`); + } + } + } catch (e) { + /* empty */ + } +}; + +export interface AppendMessageToConversationParams { + conversationsDataClient: AIAssistantConversationsDataClient; + messages: Array>; + replacements: Replacements; + conversation: ConversationResponse; +} +export const appendMessageToConversation = async ({ + conversationsDataClient, + messages, + replacements, + conversation, +}: AppendMessageToConversationParams) => { + const updatedConversation = await conversationsDataClient?.appendConversationMessages({ + existingConversation: conversation, + messages: messages.map((m) => ({ + ...{ + content: replaceAnonymizedValuesWithOriginalValues({ + messageContent: m.content, + replacements, + }), + role: m.role ?? 'user', + }, + timestamp: m.timestamp ?? new Date().toISOString(), + })), + }); + return updatedConversation; +}; + +export interface AppendAssistantMessageToConversationParams { + conversationsDataClient: AIAssistantConversationsDataClient; + messageContent: string; + replacements: Replacements; + conversation: ConversationResponse; + isError?: boolean; + traceData?: Message['traceData']; +} +export const appendAssistantMessageToConversation = async ({ + conversationsDataClient, + messageContent, + replacements, + conversation, + isError = false, + traceData = {}, +}: AppendAssistantMessageToConversationParams) => { + await conversationsDataClient?.appendConversationMessages({ + existingConversation: conversation, + messages: [ + getMessageFromRawResponse({ + rawContent: replaceAnonymizedValuesWithOriginalValues({ + messageContent, + replacements, + }), + traceData, + isError, + }), + ], + }); + if (Object.keys(replacements).length > 0) { + await conversationsDataClient?.updateConversation({ + conversationUpdateProps: { + id: conversation.id, + replacements, + }, + }); + } +}; + +export interface NonLangChainExecuteParams { + request: KibanaRequest; + messages: Array>; + abortSignal: AbortSignal; + actionTypeId: string; + connectorId: string; + logger: Logger; + actionsClient: PublicMethodsOf; + onLlmResponse?: ( + content: string, + traceData?: Message['traceData'], + isError?: boolean + ) => Promise; + response: KibanaResponseFactory; + telemetry: AnalyticsServiceSetup; +} +export const nonLangChainExecute = async ({ + messages, + abortSignal, + actionTypeId, + connectorId, + logger, + actionsClient, + onLlmResponse, + response, + request, + telemetry, +}: NonLangChainExecuteParams) => { + logger.debug('Executing via actions framework directly'); + const result = await executeAction({ + abortSignal, + onLlmResponse, + actionsClient, + connectorId, + actionTypeId, + params: { + subAction: request.body.subAction, + subActionParams: { + model: request.body.model, + messages, + ...(actionTypeId === '.gen-ai' + ? { n: 1, stop: null, temperature: 0.2 } + : { temperature: 0, stopSequences: [] }), + }, + }, + logger, + }); + + telemetry.reportEvent(INVOKE_ASSISTANT_SUCCESS_EVENT.eventType, { + actionTypeId, + isEnabledKnowledgeBase: request.body.isEnabledKnowledgeBase, + isEnabledRAGAlerts: request.body.isEnabledRAGAlerts, + model: request.body.model, + assistantStreamingEnabled: request.body.subAction !== 'invokeAI', + }); + return response.ok({ + body: result, + }); +}; + +export interface LangChainExecuteParams { + messages: Array>; + replacements: Replacements; + onNewReplacements: (newReplacements: Replacements) => void; + abortSignal: AbortSignal; + telemetry: AnalyticsServiceSetup; + actionTypeId: string; + connectorId: string; + assistantContext: ElasticAssistantApiRequestHandlerContext; + context: ElasticAssistantRequestHandlerContext; + actionsClient: PublicMethodsOf; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + request: KibanaRequest; + logger: Logger; + onLlmResponse?: ( + content: string, + traceData?: Message['traceData'], + isError?: boolean + ) => Promise; + getElser: GetElser; + response: KibanaResponseFactory; +} +export const langChainExecute = async ({ + messages, + replacements, + onNewReplacements, + abortSignal, + telemetry, + actionTypeId, + connectorId, + assistantContext, + context, + actionsClient, + request, + logger, + onLlmResponse, + getElser, + response, +}: LangChainExecuteParams) => { + // TODO: Add `traceId` to actions request when calling via langchain + logger.debug( + `Executing via langchain, isEnabledKnowledgeBase: ${request.body.isEnabledKnowledgeBase}, isEnabledRAGAlerts: ${request.body.isEnabledRAGAlerts}` + ); + // Fetch any tools registered by the request's originating plugin + const pluginName = getPluginNameFromRequest({ + request, + defaultPluginName: DEFAULT_PLUGIN_NAME, + logger, + }); + const assistantTools = assistantContext + .getRegisteredTools(pluginName) + .filter((x) => x.id !== 'attack-discovery'); // We don't (yet) support asking the assistant for NEW attack discoveries from a conversation + + // get a scoped esClient for assistant memory + const esClient = (await context.core).elasticsearch.client.asCurrentUser; + + // convert the assistant messages to LangChain messages: + const langChainMessages = getLangChainMessages(messages); + + const elserId = await getElser(); + + const anonymizationFieldsDataClient = + await assistantContext.getAIAssistantAnonymizationFieldsDataClient(); + const anonymizationFieldsRes = + await anonymizationFieldsDataClient?.findDocuments({ + perPage: 1000, + page: 1, + }); + + // Create an ElasticsearchStore for KB interactions + // Setup with kbDataClient if `enableKnowledgeBaseByDefault` FF is enabled + const enableKnowledgeBaseByDefault = + assistantContext.getRegisteredFeatures(pluginName).assistantKnowledgeBaseByDefault; + const kbDataClient = enableKnowledgeBaseByDefault + ? (await assistantContext.getAIAssistantKnowledgeBaseDataClient(false)) ?? undefined + : undefined; + const kbIndex = + enableKnowledgeBaseByDefault && kbDataClient != null + ? kbDataClient.indexTemplateAndPattern.alias + : KNOWLEDGE_BASE_INDEX_PATTERN; + const esStore = new ElasticsearchStore( + esClient, + kbIndex, + logger, + telemetry, + elserId, + ESQL_RESOURCE, + kbDataClient + ); + + const result: StreamResponseWithHeaders | StaticReturnType = await callAgentExecutor({ + abortSignal, + alertsIndexPattern: request.body.alertsIndexPattern, + anonymizationFields: anonymizationFieldsRes + ? transformESSearchToAnonymizationFields(anonymizationFieldsRes.data) + : undefined, + actionsClient, + isEnabledKnowledgeBase: request.body.isEnabledKnowledgeBase ?? false, + assistantTools, + connectorId, + esClient, + esStore, + isStream: request.body.subAction !== 'invokeAI', + llmType: getLlmType(actionTypeId), + langChainMessages, + logger, + onNewReplacements, + onLlmResponse, + request, + replacements, + size: request.body.size, + traceOptions: { + projectName: request.body.langSmithProject, + tracers: getLangSmithTracer({ + apiKey: request.body.langSmithApiKey, + projectName: request.body.langSmithProject, + logger, + }), + }, + }); + + telemetry.reportEvent(INVOKE_ASSISTANT_SUCCESS_EVENT.eventType, { + actionTypeId, + isEnabledKnowledgeBase: request.body.isEnabledKnowledgeBase, + isEnabledRAGAlerts: request.body.isEnabledRAGAlerts, + model: request.body.model, + // TODO rm actionTypeId check when llmClass for bedrock streaming is implemented + // tracked here: https://github.com/elastic/security-team/issues/7363 + assistantStreamingEnabled: request.body.subAction !== 'invokeAI' && actionTypeId === '.gen-ai', + }); + + return response.ok(result); +}; + +export interface UpdateConversationWithParams { + logger: Logger; + conversationsDataClient: AIAssistantConversationsDataClient; + replacements: Replacements; + conversationId: string; + actionTypeId: string; + connectorId: string; + actionsClient: PublicMethodsOf; + newMessages?: Array>; + model?: string; + authenticatedUser: AuthenticatedUser; +} +export const updateConversationWithUserInput = async ({ + logger, + conversationsDataClient, + replacements, + conversationId, + actionTypeId, + connectorId, + actionsClient, + newMessages, + model, + authenticatedUser, +}: UpdateConversationWithParams) => { + const conversation = await conversationsDataClient?.getConversation({ + id: conversationId, + authenticatedUser, + }); + if (conversation == null) { + throw new Error(`conversation id: "${conversationId}" not found`); + } + let updatedConversation = conversation; + + const NEW_CHAT = i18n.translate('xpack.elasticAssistantPlugin.server.newChat', { + defaultMessage: 'New chat', + }); + + const messages = updatedConversation?.messages?.map((c) => ({ + role: c.role, + content: c.content, + timestamp: c.timestamp, + })); + + const lastMessage = newMessages?.[0] ?? messages?.[0]; + + if (conversation?.title === NEW_CHAT && lastMessage) { + const res = await generateTitleForNewChatConversation({ + message: lastMessage, + actionsClient, + actionTypeId, + connectorId, + conversationId, + conversationsDataClient, + logger, + model, + }); + if (res) { + updatedConversation = res; + } + } + + if (newMessages) { + return appendMessageToConversation({ + conversation: updatedConversation, + conversationsDataClient, + messages: newMessages, + replacements, + }); + } + return updatedConversation; +}; diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index 2d53106bacf13..446b9b8c73d9b 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -8,7 +8,6 @@ import { IRouter, Logger } from '@kbn/core/server'; import { transformError } from '@kbn/securitysolution-es-utils'; import { getRequestAbortedSignal } from '@kbn/data-plugin/server'; -import { StreamResponseWithHeaders } from '@kbn/ml-response-stream/server'; import { schema } from '@kbn/config-schema'; import { @@ -16,32 +15,18 @@ import { ExecuteConnectorRequestBody, Message, Replacements, - replaceAnonymizedValuesWithOriginalValues, } from '@kbn/elastic-assistant-common'; import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; -import { i18n } from '@kbn/i18n'; -import { getLlmType } from './utils'; -import { StaticReturnType } from '../lib/langchain/executors/types'; -import { - INVOKE_ASSISTANT_ERROR_EVENT, - INVOKE_ASSISTANT_SUCCESS_EVENT, -} from '../lib/telemetry/event_based_telemetry'; -import { executeAction, StaticResponse } from '../lib/executor'; +import { INVOKE_ASSISTANT_ERROR_EVENT } from '../lib/telemetry/event_based_telemetry'; import { POST_ACTIONS_CONNECTOR_EXECUTE } from '../../common/constants'; -import { getLangChainMessages } from '../lib/langchain/helpers'; import { buildResponse } from '../lib/build_response'; import { ElasticAssistantRequestHandlerContext, GetElser } from '../types'; -import { ESQL_RESOURCE, KNOWLEDGE_BASE_INDEX_PATTERN } from './knowledge_base/constants'; -import { callAgentExecutor } from '../lib/langchain/execute_custom_llm_chain'; import { - DEFAULT_PLUGIN_NAME, - getMessageFromRawResponse, - getPluginNameFromRequest, + appendAssistantMessageToConversation, + langChainExecute, + nonLangChainExecute, + updateConversationWithUserInput, } from './helpers'; -import { getLangSmithTracer } from './evaluate/utils'; -import { EsAnonymizationFieldsSchema } from '../ai_assistant_data_clients/anonymization_fields/types'; -import { transformESSearchToAnonymizationFields } from '../ai_assistant_data_clients/anonymization_fields/helpers'; -import { ElasticsearchStore } from '../lib/langchain/elasticsearch_store/elasticsearch_store'; export const postActionsConnectorExecuteRoute = ( router: IRouter, @@ -83,311 +68,107 @@ export const postActionsConnectorExecuteRoute = ( body: `Authenticated user not found`, }); } - const conversationsDataClient = - await assistantContext.getAIAssistantConversationsDataClient(); - - const anonymizationFieldsDataClient = - await assistantContext.getAIAssistantAnonymizationFieldsDataClient(); - let latestReplacements: Replacements = request.body.replacements; const onNewReplacements = (newReplacements: Replacements) => { latestReplacements = { ...latestReplacements, ...newReplacements }; }; - let prevMessages; - let newMessage: Pick | undefined; + let messages; + let newMessage: Pick | undefined; const conversationId = request.body.conversationId; const actionTypeId = request.body.actionTypeId; - const langSmithProject = request.body.langSmithProject; - const langSmithApiKey = request.body.langSmithApiKey; + const connectorId = decodeURIComponent(request.params.connectorId); // if message is undefined, it means the user is regenerating a message from the stored conversation if (request.body.message) { newMessage = { content: request.body.message, role: 'user', + timestamp: new Date().toISOString(), }; } - const connectorId = decodeURIComponent(request.params.connectorId); - // get the actions plugin start contract from the request context: const actions = (await context.elasticAssistant).actions; + const actionsClient = await actions.getActionsClientWithRequest(request); - if (conversationId) { - const conversation = await conversationsDataClient?.getConversation({ - id: conversationId, + const conversationsDataClient = + await assistantContext.getAIAssistantConversationsDataClient(); + + if (conversationId && conversationsDataClient) { + const updatedConversation = await updateConversationWithUserInput({ + actionsClient, + actionTypeId, authenticatedUser, + connectorId, + conversationId, + conversationsDataClient, + logger, + replacements: latestReplacements, + newMessages: newMessage ? [newMessage] : [], + model: request.body.model, }); - if (conversation == null) { - return response.notFound({ - body: `conversation id: "${conversationId}" not found`, + if (updatedConversation == null) { + return response.badRequest({ + body: `conversation id: "${conversationId}" not updated`, }); } - // messages are anonymized by conversationsDataClient - prevMessages = conversation?.messages?.map((c) => ({ + messages = updatedConversation?.messages?.map((c) => ({ role: c.role, content: c.content, })); - if (request.body.message) { - const res = await conversationsDataClient?.appendConversationMessages({ - existingConversation: conversation, - messages: [ - { - ...{ - content: replaceAnonymizedValuesWithOriginalValues({ - messageContent: request.body.message, - replacements: request.body.replacements, - }), - role: 'user', - }, - timestamp: new Date().toISOString(), - }, - ], - }); - - if (res == null) { - return response.badRequest({ - body: `conversation id: "${conversationId}" not updated`, - }); - } - } - const updatedConversation = await conversationsDataClient?.getConversation({ - id: conversationId, - authenticatedUser, - }); - - if (updatedConversation == null) { - return response.notFound({ - body: `conversation id: "${conversationId}" not found`, - }); - } - - const NEW_CHAT = i18n.translate('xpack.elasticAssistantPlugin.server.newChat', { - defaultMessage: 'New chat', - }); - if (conversation?.title === NEW_CHAT && prevMessages) { - try { - const autoTitle = (await executeAction({ - actions, - request, - connectorId, - actionTypeId, - params: { - subAction: 'invokeAI', - subActionParams: { - model: request.body.model, - messages: [ - { - role: 'assistant', - content: i18n.translate( - 'xpack.elasticAssistantPlugin.server.autoTitlePromptDescription', - { - defaultMessage: - 'You are a helpful assistant for Elastic Security. Assume the following message is the start of a conversation between you and a user; give this conversation a title based on the content below. DO NOT UNDER ANY CIRCUMSTANCES wrap this title in single or double quotes. This title is shown in a list of conversations to the user, so title it for the user, not for you.', - } - ), - }, - newMessage ?? prevMessages?.[0], - ], - ...(actionTypeId === '.gen-ai' - ? { n: 1, stop: null, temperature: 0.2 } - : { temperature: 0, stopSequences: [] }), - }, - }, - logger, - })) as unknown as StaticResponse; // TODO: Use function overloads in executeAction to avoid this cast when sending subAction: 'invokeAI', - if (autoTitle.status === 'ok') { - try { - // This regular expression captures a string enclosed in single or double quotes. - // It extracts the string content without the quotes. - // Example matches: - // - "Hello, World!" => Captures: Hello, World! - // - 'Another Example' => Captures: Another Example - // - JustTextWithoutQuotes => Captures: JustTextWithoutQuotes - const match = autoTitle.data.match(/^["']?([^"']+)["']?$/); - const title = match ? match[1] : autoTitle.data; - - await conversationsDataClient?.updateConversation({ - conversationUpdateProps: { - id: conversationId, - title, - }, - }); - } catch (e) { - logger.warn(`Failed to update conversation with generated title: ${e.message}`); - } - } - } catch (e) { - /* empty */ - } - } - onLlmResponse = async ( content: string, traceData: Message['traceData'] = {}, isError = false ): Promise => { - if (updatedConversation) { - await conversationsDataClient?.appendConversationMessages({ - existingConversation: updatedConversation, - messages: [ - getMessageFromRawResponse({ - rawContent: replaceAnonymizedValuesWithOriginalValues({ - messageContent: content, - replacements: latestReplacements, - }), - traceData, - isError, - }), - ], - }); - } - if (Object.keys(latestReplacements).length > 0) { - await conversationsDataClient?.updateConversation({ - conversationUpdateProps: { - id: conversationId, - replacements: latestReplacements, - }, + if (updatedConversation && conversationsDataClient) { + await appendAssistantMessageToConversation({ + conversation: updatedConversation, + conversationsDataClient, + messageContent: content, + replacements: latestReplacements, + isError, + traceData, }); } }; } - // if not langchain, call execute action directly and return the response: if (!request.body.isEnabledKnowledgeBase && !request.body.isEnabledRAGAlerts) { - logger.debug('Executing via actions framework directly'); - - const result = await executeAction({ + // if not langchain, call execute action directly and return the response: + return await nonLangChainExecute({ abortSignal, - onLlmResponse, - actions, - request, - connectorId, + actionsClient, actionTypeId, - params: { - subAction: request.body.subAction, - subActionParams: { - model: request.body.model, - messages: [...(prevMessages ?? []), ...(newMessage ? [newMessage] : [])], - ...(actionTypeId === '.gen-ai' - ? { n: 1, stop: null, temperature: 0.2 } - : { temperature: 0, stopSequences: [] }), - }, - }, + connectorId, logger, - }); - - telemetry.reportEvent(INVOKE_ASSISTANT_SUCCESS_EVENT.eventType, { - actionTypeId, - isEnabledKnowledgeBase: request.body.isEnabledKnowledgeBase, - isEnabledRAGAlerts: request.body.isEnabledRAGAlerts, - model: request.body.model, - assistantStreamingEnabled: request.body.subAction !== 'invokeAI', - }); - return response.ok({ - body: result, + messages: messages ?? [], + onLlmResponse, + request, + response, + telemetry, }); } - - // TODO: Add `traceId` to actions request when calling via langchain - logger.debug( - `Executing via langchain, isEnabledKnowledgeBase: ${request.body.isEnabledKnowledgeBase}, isEnabledRAGAlerts: ${request.body.isEnabledRAGAlerts}` - ); - - // Fetch any tools registered by the request's originating plugin - const pluginName = getPluginNameFromRequest({ - request, - defaultPluginName: DEFAULT_PLUGIN_NAME, - logger, - }); - const assistantTools = (await context.elasticAssistant) - .getRegisteredTools(pluginName) - .filter((x) => x.id !== 'attack-discovery'); // We don't (yet) support asking the assistant for NEW attack discoveries from a conversation - - // get a scoped esClient for assistant memory - const esClient = (await context.core).elasticsearch.client.asCurrentUser; - - // convert the assistant messages to LangChain messages: - const langChainMessages = getLangChainMessages( - ([...(prevMessages ?? []), ...(newMessage ? [newMessage] : [])] ?? - []) as unknown as Array> - ); - - const elserId = await getElser(); - - const anonymizationFieldsRes = - await anonymizationFieldsDataClient?.findDocuments({ - perPage: 1000, - page: 1, - }); - - // Create an ElasticsearchStore for KB interactions - // Setup with kbDataClient if `enableKnowledgeBaseByDefault` FF is enabled - const enableKnowledgeBaseByDefault = - assistantContext.getRegisteredFeatures(pluginName).assistantKnowledgeBaseByDefault; - const kbDataClient = enableKnowledgeBaseByDefault - ? (await assistantContext.getAIAssistantKnowledgeBaseDataClient(false)) ?? undefined - : undefined; - const kbIndex = - enableKnowledgeBaseByDefault && kbDataClient != null - ? kbDataClient.indexTemplateAndPattern.alias - : KNOWLEDGE_BASE_INDEX_PATTERN; - const esStore = new ElasticsearchStore( - esClient, - kbIndex, - logger, - telemetry, - elserId, - ESQL_RESOURCE, - kbDataClient - ); - - const result: StreamResponseWithHeaders | StaticReturnType = await callAgentExecutor({ + return await langChainExecute({ abortSignal, - alertsIndexPattern: request.body.alertsIndexPattern, - anonymizationFields: anonymizationFieldsRes - ? transformESSearchToAnonymizationFields(anonymizationFieldsRes.data) - : undefined, - actions, - isEnabledKnowledgeBase: request.body.isEnabledKnowledgeBase ?? false, - assistantTools, + actionsClient, + actionTypeId, + assistantContext, connectorId, - esClient, - esStore, - isStream: request.body.subAction !== 'invokeAI', - llmType: getLlmType(actionTypeId), - langChainMessages, + context, + getElser, logger, - onNewReplacements, + messages: messages ?? [], onLlmResponse, + onNewReplacements, + replacements: latestReplacements, request, - replacements: request.body.replacements, - size: request.body.size, - traceOptions: { - projectName: langSmithProject, - tracers: getLangSmithTracer({ - apiKey: langSmithApiKey, - projectName: langSmithProject, - logger, - }), - }, - }); - - telemetry.reportEvent(INVOKE_ASSISTANT_SUCCESS_EVENT.eventType, { - actionTypeId, - isEnabledKnowledgeBase: request.body.isEnabledKnowledgeBase, - isEnabledRAGAlerts: request.body.isEnabledRAGAlerts, - model: request.body.model, - // TODO rm actionTypeId check when llmClass for bedrock streaming is implemented - // tracked here: https://github.com/elastic/security-team/issues/7363 - assistantStreamingEnabled: - request.body.subAction !== 'invokeAI' && actionTypeId === '.gen-ai', + response, + telemetry, }); - - return response.ok(result); } catch (err) { logger.error(err); const error = transformError(err); diff --git a/x-pack/plugins/elastic_assistant/server/routes/register_routes.ts b/x-pack/plugins/elastic_assistant/server/routes/register_routes.ts index fc0e30f4a925c..632a92a7806f8 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/register_routes.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/register_routes.ts @@ -21,18 +21,24 @@ import { getKnowledgeBaseStatusRoute } from './knowledge_base/get_knowledge_base import { postKnowledgeBaseRoute } from './knowledge_base/post_knowledge_base'; import { getEvaluateRoute } from './evaluate/get_evaluate'; import { postEvaluateRoute } from './evaluate/post_evaluate'; -import { postActionsConnectorExecuteRoute } from './post_actions_connector_execute'; import { getCapabilitiesRoute } from './capabilities/get_capabilities_route'; import { bulkPromptsRoute } from './prompts/bulk_actions_route'; import { findPromptsRoute } from './prompts/find_route'; import { bulkActionAnonymizationFieldsRoute } from './anonymization_fields/bulk_actions_route'; import { findAnonymizationFieldsRoute } from './anonymization_fields/find_route'; +import { chatCompleteRoute } from './chat/chat_complete_route'; +import { postActionsConnectorExecuteRoute } from './post_actions_connector_execute'; export const registerRoutes = ( router: ElasticAssistantPluginRouter, logger: Logger, getElserId: GetElser ) => { + /** PUBLIC */ + // Chat + chatCompleteRoute(router); + + /** INTERNAL */ // Capabilities getCapabilitiesRoute(router); From 36b2a5ff07d877f51ea3b0bc56ae6b1774150c84 Mon Sep 17 00:00:00 2001 From: YulNaumenko Date: Thu, 30 May 2024 19:44:48 -0700 Subject: [PATCH 02/23] reduced params to keep security only --- .../chat/post_chat_complete_route.gen.ts | 54 +----- .../chat/post_chat_complete_route.schema.yaml | 62 +------ .../server/routes/chat/chat_complete_route.ts | 7 +- .../server/routes/evaluate/post_evaluate.ts | 5 +- .../server/routes/helpers.ts | 160 +++++++++++++----- 5 files changed, 132 insertions(+), 156 deletions(-) diff --git a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.gen.ts b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.gen.ts index ec17fb447fe5d..acc984d06caaf 100644 --- a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.gen.ts +++ b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.gen.ts @@ -25,46 +25,10 @@ export const RootContext = z.literal('security'); * Message role. */ export type ChatMessageRole = z.infer; -export const ChatMessageRole = z.enum(['system', 'user', 'assistant', 'function', 'elastic']); +export const ChatMessageRole = z.enum(['system', 'user', 'assistant']); export type ChatMessageRoleEnum = typeof ChatMessageRole.enum; export const ChatMessageRoleEnum = ChatMessageRole.enum; -/** - * Message role. - */ -export type TriggerType = z.infer; -export const TriggerType = z.enum(['user', 'assistant', 'elastic']); -export type TriggerTypeEnum = typeof TriggerType.enum; -export const TriggerTypeEnum = TriggerType.enum; - -export type TriggerArguments = z.infer; -export const TriggerArguments = z.object({}).catchall(z.unknown()); - -export type TriggerData = z.infer; -export const TriggerData = z.object({}).catchall(z.unknown()); - -/** - * AI assistant message. - */ -export type InstructionsObject = z.infer; -export const InstructionsObject = z.object({ - doc_id: z.string().optional(), - text: z.string().optional(), -}); - -/** - * AI assistant message. - */ -export type FunctionCall = z.infer; -export const FunctionCall = z.object({ - /** - * Trigger type. - */ - trigger: TriggerType, - arguments: TriggerArguments.optional(), - data: TriggerData.optional(), -}); - export type MessageData = z.infer; export const MessageData = z.object({}).catchall(z.unknown()); @@ -77,14 +41,6 @@ export const ChatMessage = z.object({ * Message content. */ content: z.string().optional(), - /** - * Message name. - */ - name: z.string().optional(), - /** - * Function definition. - */ - function_call: FunctionCall.optional(), /** * Message role. */ @@ -102,21 +58,15 @@ export const ChatMessage = z.object({ export type ChatCompleteProps = z.infer; export const ChatCompleteProps = z.object({ - /** - * Solution context. - */ - context: RootContext.optional(), conversationId: z.string().optional(), + promptId: z.string().optional(), responseLanguage: z.string().optional(), langSmithProject: z.string().optional(), langSmithApiKey: z.string().optional(), - disableFunctions: z.boolean().optional(), connectorId: z.string(), model: z.string().optional(), - title: z.string().optional(), persist: z.boolean(), messages: z.array(ChatMessage), - instructions: z.array(z.union([InstructionsObject, z.string()])).optional(), }); export type ChatCompleteRequestBody = z.infer; diff --git a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.schema.yaml b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.schema.yaml index 3e5f7c90d6a3b..a87f0d8b447c3 100644 --- a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.schema.yaml +++ b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.schema.yaml @@ -53,47 +53,6 @@ components: - system - user - assistant - - function - - elastic - - TriggerType: - type: string - description: Message role. - enum: - - user - - assistant - - elastic - - TriggerArguments: - type: object - additionalProperties: true - - TriggerData: - type: object - additionalProperties: true - - InstructionsObject: - type: object - description: AI assistant message. - properties: - doc_id: - type: string - text: - type: string - - FunctionCall: - type: object - description: AI assistant message. - required: - - 'trigger' - properties: - trigger: - $ref: '#/components/schemas/TriggerType' - description: Trigger type. - arguments: - $ref: '#/components/schemas/TriggerArguments' - data: - $ref: '#/components/schemas/TriggerData' MessageData: type: object @@ -109,12 +68,6 @@ components: content: type: string description: Message content. - name: - type: string - description: Message name. - function_call: - $ref: '#/components/schemas/FunctionCall' - description: Function definition. role: $ref: '#/components/schemas/ChatMessageRole' description: Message role. @@ -134,37 +87,26 @@ components: ChatCompleteProps: type: object properties: - context: - $ref: '#/components/schemas/RootContext' - description: Solution context. conversationId: type: string + promptId: + type: string responseLanguage: type: string langSmithProject: type: string langSmithApiKey: type: string - disableFunctions: - type: boolean connectorId: type: string model: type: string - title: - type: string persist: type: boolean messages: type: array items: $ref: '#/components/schemas/ChatMessage' - instructions: - type: array - items: - oneOf: - - $ref: '#/components/schemas/InstructionsObject' - - type: string required: - messages - persist diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts index be110c82261e1..d1f3319137c79 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts @@ -21,9 +21,9 @@ import { buildResponse } from '../utils'; import { UPGRADE_LICENSE_MESSAGE, appendAssistantMessageToConversation, + createOrUpdateConversationWithUserInput, hasAIAssistantLicense, langChainExecute, - updateConversationWithUserInput, } from '../helpers'; export const chatCompleteRoute = ( @@ -97,14 +97,15 @@ export const chatCompleteRoute = ( // replacements - if (request.body.persist && conversationId && conversationsDataClient) { - const updatedConversation = await updateConversationWithUserInput({ + if (request.body.persist && conversationsDataClient) { + const updatedConversation = await createOrUpdateConversationWithUserInput({ actionsClient, actionTypeId, authenticatedUser, connectorId, conversationId, conversationsDataClient, + promptId: request.body.promptId, logger, replacements: latestReplacements, newMessages: request.body.messages diff --git a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts index ef1950b5e90ad..990417b799234 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts @@ -192,7 +192,7 @@ export const postEvaluateRoute = ( agents.push({ agentEvaluator: async (langChainMessages, exampleId) => { const evalResult = await AGENT_EXECUTOR_MAP[agentName]({ - actions, + actionsClient, isEnabledKnowledgeBase: true, assistantTools, connectorId, @@ -237,9 +237,8 @@ export const postEvaluateRoute = ( evalModel == null || evalModel === '' ? undefined : new ActionsClientLlm({ - actions, + actionsClient, connectorId: evalModel, - request: skeletonRequest, logger, model: skeletonRequest.body.model, }); diff --git a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts index a18a796065e35..8ee9e08055607 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts @@ -54,6 +54,10 @@ interface GetPluginNameFromRequestParams { export const DEFAULT_PLUGIN_NAME = 'securitySolutionUI'; +export const NEW_CHAT = i18n.translate('xpack.elasticAssistantPlugin.server.newChat', { + defaultMessage: 'New chat', +}); + /** * Attempts to extract the plugin name the request originated from using the request headers. * @@ -124,24 +128,22 @@ export const UPGRADE_LICENSE_MESSAGE = 'Your license does not support AI Assistant. Please upgrade your license.'; export interface GenerateTitleForNewChatConversationParams { - conversationsDataClient: AIAssistantConversationsDataClient; - conversationId: string; - message: Pick; + message: Pick; model?: string; actionTypeId: string; connectorId: string; logger: Logger; actionsClient: PublicMethodsOf; + responseLanguage?: string; } export const generateTitleForNewChatConversation = async ({ - conversationsDataClient, - conversationId, message, model, actionTypeId, connectorId, logger, actionsClient, + responseLanguage = 'English', }: GenerateTitleForNewChatConversationParams) => { try { const autoTitle = (await executeAction({ @@ -155,13 +157,7 @@ export const generateTitleForNewChatConversation = async ({ messages: [ { role: 'assistant', - content: i18n.translate( - 'xpack.elasticAssistantPlugin.server.autoTitlePromptDescription', - { - defaultMessage: - 'You are a helpful assistant for Elastic Security. Assume the following message is the start of a conversation between you and a user; give this conversation a title based on the content below. DO NOT UNDER ANY CIRCUMSTANCES wrap this title in single or double quotes. This title is shown in a list of conversations to the user, so title it for the user, not for you.', - } - ), + content: `You are a helpful assistant for Elastic Security. Assume the following message is the start of a conversation between you and a user; give this conversation a title based on the content below. DO NOT UNDER ANY CIRCUMSTANCES wrap this title in single or double quotes. This title is shown in a list of conversations to the user, so title it for the user, not for you. Please create the title in ${responseLanguage}.`, }, message, ], @@ -173,25 +169,15 @@ export const generateTitleForNewChatConversation = async ({ logger, })) as unknown as StaticResponse; // TODO: Use function overloads in executeAction to avoid this cast when sending subAction: 'invokeAI', if (autoTitle.status === 'ok') { - try { - // This regular expression captures a string enclosed in single or double quotes. - // It extracts the string content without the quotes. - // Example matches: - // - "Hello, World!" => Captures: Hello, World! - // - 'Another Example' => Captures: Another Example - // - JustTextWithoutQuotes => Captures: JustTextWithoutQuotes - const match = autoTitle.data.match(/^["']?([^"']+)["']?$/); - const title = match ? match[1] : autoTitle.data; - - return await conversationsDataClient.updateConversation({ - conversationUpdateProps: { - id: conversationId, - title, - }, - }); - } catch (e) { - logger.warn(`Failed to update conversation with generated title: ${e.message}`); - } + // This regular expression captures a string enclosed in single or double quotes. + // It extracts the string content without the quotes. + // Example matches: + // - "Hello, World!" => Captures: Hello, World! + // - 'Another Example' => Captures: Another Example + // - JustTextWithoutQuotes => Captures: JustTextWithoutQuotes + const match = autoTitle.data.match(/^["']?([^"']+)["']?$/); + const title = match ? match[1] : autoTitle.data; + return title; } } catch (e) { /* empty */ @@ -459,6 +445,78 @@ export const langChainExecute = async ({ return response.ok(result); }; +export interface CreateOrUpdateConversationWithParams { + logger: Logger; + conversationsDataClient: AIAssistantConversationsDataClient; + replacements: Replacements; + conversationId?: string; + promptId?: string; + actionTypeId: string; + connectorId: string; + actionsClient: PublicMethodsOf; + newMessages?: Array>; + model?: string; + authenticatedUser: AuthenticatedUser; + responseLanguage?: string; +} +export const createOrUpdateConversationWithUserInput = async ({ + logger, + conversationsDataClient, + replacements, + conversationId, + actionTypeId, + promptId, + connectorId, + actionsClient, + newMessages, + model, + authenticatedUser, + responseLanguage, +}: CreateOrUpdateConversationWithParams) => { + if (!conversationId) { + if (newMessages && newMessages.length > 0) { + const title = await generateTitleForNewChatConversation({ + message: newMessages[0], + actionsClient, + actionTypeId, + connectorId, + logger, + model, + responseLanguage, + }); + if (title) { + return conversationsDataClient.createConversation({ + conversation: { + title, + messages: newMessages, + replacements, + apiConfig: { + connectorId, + actionTypeId, + model, + defaultSystemPromptId: promptId, + }, + }, + authenticatedUser, + }); + } + } + return; + } + return updateConversationWithUserInput({ + actionsClient, + actionTypeId, + authenticatedUser, + connectorId, + conversationId, + conversationsDataClient, + logger, + replacements, + newMessages, + model, + }); +}; + export interface UpdateConversationWithParams { logger: Logger; conversationsDataClient: AIAssistantConversationsDataClient; @@ -470,6 +528,7 @@ export interface UpdateConversationWithParams { newMessages?: Array>; model?: string; authenticatedUser: AuthenticatedUser; + responseLanguage?: string; } export const updateConversationWithUserInput = async ({ logger, @@ -482,7 +541,32 @@ export const updateConversationWithUserInput = async ({ newMessages, model, authenticatedUser, + responseLanguage, }: UpdateConversationWithParams) => { + if (!conversationId) { + if (newMessages && newMessages.length > 0) { + const title = await generateTitleForNewChatConversation({ + message: newMessages[0], + actionsClient, + actionTypeId, + connectorId, + logger, + responseLanguage, + model, + }); + if (title) { + return conversationsDataClient.createConversation({ + conversation: { + title, + messages: newMessages, + replacements, + }, + authenticatedUser, + }); + } + } + return; + } const conversation = await conversationsDataClient?.getConversation({ id: conversationId, authenticatedUser, @@ -492,10 +576,6 @@ export const updateConversationWithUserInput = async ({ } let updatedConversation = conversation; - const NEW_CHAT = i18n.translate('xpack.elasticAssistantPlugin.server.newChat', { - defaultMessage: 'New chat', - }); - const messages = updatedConversation?.messages?.map((c) => ({ role: c.role, content: c.content, @@ -505,16 +585,20 @@ export const updateConversationWithUserInput = async ({ const lastMessage = newMessages?.[0] ?? messages?.[0]; if (conversation?.title === NEW_CHAT && lastMessage) { - const res = await generateTitleForNewChatConversation({ + const title = await generateTitleForNewChatConversation({ message: lastMessage, actionsClient, actionTypeId, connectorId, - conversationId, - conversationsDataClient, logger, model, }); + const res = await conversationsDataClient.updateConversation({ + conversationUpdateProps: { + id: conversationId, + title, + }, + }); if (res) { updatedConversation = res; } From a057a093518cdf830d2b262275296ac5ae4cf1a3 Mon Sep 17 00:00:00 2001 From: YulNaumenko Date: Thu, 30 May 2024 21:28:53 -0700 Subject: [PATCH 03/23] - --- .../server/routes/chat/chat_complete_route.ts | 43 ++++++++++++++++--- 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts index d1f3319137c79..2cb1ba286a77b 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts @@ -13,6 +13,8 @@ import { API_VERSIONS, Message, Replacements, + transformRawData, + getAnonymizedValue, } from '@kbn/elastic-assistant-common'; import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; import { getRequestAbortedSignal } from '@kbn/data-plugin/server'; @@ -26,6 +28,10 @@ import { langChainExecute, } from '../helpers'; +export const SYSTEM_PROMPT_CONTEXT_NON_I18N = (context: string) => { + return `CONTEXT:\n"""\n${context}\n"""`; +}; + export const chatCompleteRoute = ( router: ElasticAssistantPluginRouter, getElser: GetElser @@ -95,7 +101,35 @@ export const chatCompleteRoute = ( (await actionsClient.getAllSystemConnectors()).find((c) => c.id === connectorId) ?.actionTypeId ?? '.gen-ai'; - // replacements + if (request.body.messages) { + // replacements + const systemAnonymizationFields = await anonymizationFieldsDataClient?.findDocuments({ + page: 1, + perPage: 1000, + }); + + messages = request.body.messages.map((m) => { + let content = m.content ?? ''; + if (m.data && m.data.length > 0) { + const anonymizedData = transformRawData({ + anonymizationFields: systemAnonymizationFields?.data, + currentReplacements: latestReplacements, + getAnonymizedValue, + onNewReplacements, + rawData: m.data as unknown as Record, + }); + const wr = `${SYSTEM_PROMPT_CONTEXT_NON_I18N(anonymizedData)}\n`; + + content = `${wr}\n${m.content}`; + } + const transformedMessage = { + role: m.role, + content, + timestamp: m['@timestamp'], + }; + return transformedMessage; + }); + } if (request.body.persist && conversationsDataClient) { const updatedConversation = await createOrUpdateConversationWithUserInput({ @@ -108,12 +142,7 @@ export const chatCompleteRoute = ( promptId: request.body.promptId, logger, replacements: latestReplacements, - newMessages: request.body.messages - .filter((f) => f.role === 'assistant' || f.role === 'user') - .map((m) => ({ - role: m.role, - content: m.content ?? '', - })) as Message[], + newMessages: messages, model: request.body.model, }); if (updatedConversation == null) { From 2cb73f5fcc917aab5986071ceaa5e8063c9f71d1 Mon Sep 17 00:00:00 2001 From: YulNaumenko Date: Wed, 19 Jun 2024 06:21:36 -0700 Subject: [PATCH 04/23] - --- .../conversations/index.ts | 11 ++++++----- .../graphs/default_assistant_graph/graph.ts | 9 +++++++++ .../nodes/generate_chat_title.ts | 8 ++++++++ .../server/routes/helpers.ts | 19 +++++++++++++------ .../routes/user_conversations/create_route.ts | 3 +-- .../routes/user_conversations/update_route.ts | 2 +- 6 files changed, 38 insertions(+), 14 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/ai_assistant_data_clients/conversations/index.ts b/x-pack/plugins/elastic_assistant/server/ai_assistant_data_clients/conversations/index.ts index c937e86a982f6..e2b0f35b0be3b 100644 --- a/x-pack/plugins/elastic_assistant/server/ai_assistant_data_clients/conversations/index.ts +++ b/x-pack/plugins/elastic_assistant/server/ai_assistant_data_clients/conversations/index.ts @@ -43,7 +43,7 @@ export class AIAssistantConversationsDataClient extends AIAssistantDataClient { logger: this.options.logger, conversationIndex: this.indexTemplateAndPattern.alias, id, - user: authenticatedUser, + user: authenticatedUser ?? this.options.currentUser, }); }; @@ -84,18 +84,19 @@ export class AIAssistantConversationsDataClient extends AIAssistantDataClient { */ public createConversation = async ({ conversation, - authenticatedUser, }: { conversation: ConversationCreateProps; - authenticatedUser: AuthenticatedUser; }): Promise => { + if (!this.options.currentUser) { + throw new Error('AIAssistantConversationsDataClient currentUser is not defined.'); + } const esClient = await this.options.elasticsearchClientPromise; return createConversation({ esClient, logger: this.options.logger, conversationIndex: this.indexTemplateAndPattern.alias, spaceId: this.spaceId, - user: authenticatedUser, + user: this.options.currentUser, conversation, }); }; @@ -128,7 +129,7 @@ export class AIAssistantConversationsDataClient extends AIAssistantDataClient { conversationIndex: this.indexTemplateAndPattern.alias, conversationUpdateProps, isPatch, - user: authenticatedUser, + user: authenticatedUser ?? this.options.currentUser ?? undefined, }); }; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts index 779bf20a61720..433d8bcac1b4b 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts @@ -19,6 +19,7 @@ import { AssistantDataClients } from '../../executors/types'; import { shouldContinue } from './nodes/should_continue'; import { AGENT_NODE, runAgent } from './nodes/run_agent'; import { executeTools, TOOLS_NODE } from './nodes/execute_tools'; +import { GENERATE_CHAT_TITLE_NODE, generateChatTitle } from './nodes/generate_chat_title'; export const DEFAULT_ASSISTANT_GRAPH_ID = 'Default Security Assistant Graph'; @@ -94,6 +95,13 @@ export const getDefaultAssistantGraph = ({ state, tools, }); + const generateChatTitleNode = (state: AgentState) => + generateChatTitle({ + ...nodeParams, + state, + conversationsDataClient: dataClients?.conversationsDataClient, + conversationId, + }); const shouldContinueEdge = (state: AgentState) => shouldContinue({ ...nodeParams, state }); // Put together a new graph using the nodes and default state from above @@ -101,6 +109,7 @@ export const getDefaultAssistantGraph = ({ channels: graphState, }); // Define the nodes to cycle between + graph.addNode(GENERATE_CHAT_TITLE_NODE, generateChatTitleNode); graph.addNode(AGENT_NODE, runAgentNode); graph.addNode(TOOLS_NODE, executeToolsNode); // Add conditional edge for basic routing diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/generate_chat_title.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/generate_chat_title.ts index bcba25eab0b0d..ba9c76c2e33dc 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/generate_chat_title.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/generate_chat_title.ts @@ -32,11 +32,19 @@ export const GENERATE_CHAT_TITLE_NODE = 'generateChatTitle'; export const generateChatTitle = async ({ conversationsDataClient, + conversationId, logger, model, state, }: GenerateChatTitleParams) => { logger.debug(`Node state:\n ${JSON.stringify(state, null, 2)}`); + + if (!conversationId) { + if (state.messages.length > 0) { + + } + } + const conversation = await conversationsDataClient?.getConversation({ id: conversationId }); if (state.messages.length !== 0) { logger.debug('No need to generate chat title, messages already exist'); return; diff --git a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts index fc16fbfc4259e..c1835109fe580 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts @@ -30,7 +30,7 @@ import { MINIMUM_AI_ASSISTANT_LICENSE } from '../../common/constants'; import { ESQL_RESOURCE, KNOWLEDGE_BASE_INDEX_PATTERN } from './knowledge_base/constants'; import { callAgentExecutor } from '../lib/langchain/execute_custom_llm_chain'; import { getLlmType } from './utils'; -import { StaticReturnType } from '../lib/langchain/executors/types'; +import { AgentExecutorParams, AssistantDataClients, StaticReturnType } from '../lib/langchain/executors/types'; import { executeAction, StaticResponse } from '../lib/executor'; import { getLangChainMessages } from '../lib/langchain/helpers'; @@ -384,8 +384,10 @@ export const langChainExecute = async ({ page: 1, }); + const conversationsDataClient = await assistantContext.getAIAssistantConversationsDataClient(); + // Create an ElasticsearchStore for KB interactions - // Setup with kbDataClient if `enableKnowledgeBaseByDefault` FF is enabled + // Setup with kbDataClient if `assistantKnowledgeBaseByDefault` FF is enabled const enableKnowledgeBaseByDefault = assistantContext.getRegisteredFeatures(pluginName).assistantKnowledgeBaseByDefault; const kbDataClient = enableKnowledgeBaseByDefault @@ -405,8 +407,16 @@ export const langChainExecute = async ({ kbDataClient ); - const executorParams = { + const dataClients: AssistantDataClients = { + anonymizationFieldsDataClient: anonymizationFieldsDataClient ?? undefined, + conversationsDataClient: conversationsDataClient ?? undefined, + kbDataClient, + }; + + // Shared executor params + const executorParams: AgentExecutorParams = { abortSignal, + dataClients, alertsIndexPattern: request.body.alertsIndexPattern, anonymizationFields: anonymizationFieldsRes ? transformESSearchToAnonymizationFields(anonymizationFieldsRes.data) @@ -510,7 +520,6 @@ export const createOrUpdateConversationWithUserInput = async ({ defaultSystemPromptId: promptId, }, }, - authenticatedUser, }); } } @@ -574,7 +583,6 @@ export const updateConversationWithUserInput = async ({ messages: newMessages, replacements, }, - authenticatedUser, }); } } @@ -582,7 +590,6 @@ export const updateConversationWithUserInput = async ({ } const conversation = await conversationsDataClient?.getConversation({ id: conversationId, - authenticatedUser, }); if (conversation == null) { throw new Error(`conversation id: "${conversationId}" not found`); diff --git a/x-pack/plugins/elastic_assistant/server/routes/user_conversations/create_route.ts b/x-pack/plugins/elastic_assistant/server/routes/user_conversations/create_route.ts index e66c83f77510d..2d51939eb528b 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/user_conversations/create_route.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/user_conversations/create_route.ts @@ -49,7 +49,6 @@ export const createConversationRoute = (router: ElasticAssistantPluginRouter): v }, }); } - const dataClient = await ctx.elasticAssistant.getAIAssistantConversationsDataClient(); const authenticatedUser = ctx.elasticAssistant.getCurrentUser(); if (authenticatedUser == null) { return assistantResponse.error({ @@ -57,6 +56,7 @@ export const createConversationRoute = (router: ElasticAssistantPluginRouter): v statusCode: 401, }); } + const dataClient = await ctx.elasticAssistant.getAIAssistantConversationsDataClient(); const result = await dataClient?.findDocuments({ perPage: 100, @@ -73,7 +73,6 @@ export const createConversationRoute = (router: ElasticAssistantPluginRouter): v const createdConversation = await dataClient?.createConversation({ conversation: request.body, - authenticatedUser, }); if (createdConversation == null) { diff --git a/x-pack/plugins/elastic_assistant/server/routes/user_conversations/update_route.ts b/x-pack/plugins/elastic_assistant/server/routes/user_conversations/update_route.ts index 4a7fd5a9d67cb..046785166a109 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/user_conversations/update_route.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/user_conversations/update_route.ts @@ -54,7 +54,6 @@ export const updateConversationRoute = (router: ElasticAssistantPluginRouter) => }); } - const dataClient = await ctx.elasticAssistant.getAIAssistantConversationsDataClient(); const authenticatedUser = ctx.elasticAssistant.getCurrentUser(); if (authenticatedUser == null) { return assistantResponse.error({ @@ -62,6 +61,7 @@ export const updateConversationRoute = (router: ElasticAssistantPluginRouter) => statusCode: 401, }); } + const dataClient = await ctx.elasticAssistant.getAIAssistantConversationsDataClient(); const existingConversation = await dataClient?.getConversation({ id, authenticatedUser }); if (existingConversation == null) { From d0e5a160d3637c0af282b06a472c82affa4f6dc7 Mon Sep 17 00:00:00 2001 From: YulNaumenko Date: Sat, 22 Jun 2024 13:35:57 -0700 Subject: [PATCH 05/23] added conversation persistence to LangGraph --- .../execute_custom_llm_chain/index.ts | 14 +++- .../server/lib/langchain/executors/types.ts | 3 +- .../graphs/default_assistant_graph/graph.ts | 77 +++++++++++++++++-- .../graphs/default_assistant_graph/index.ts | 22 ++++-- .../nodes/generate_chat_title.ts | 33 +++----- .../nodes/get_persisted_conversation.ts | 54 +++++++++++++ .../nodes/persist_conversation_changes.ts | 73 ++++++++++++++++++ .../nodes/should_continue.ts | 29 +++++++ .../graphs/default_assistant_graph/types.ts | 3 + .../attack_discovery/post_attack_discovery.ts | 4 +- .../server/routes/chat/chat_complete_route.ts | 35 +++++++-- .../server/routes/helpers.ts | 73 ++++-------------- .../routes/post_actions_connector_execute.ts | 34 +++++--- .../server/lib/get_chat_params.ts | 6 +- 14 files changed, 342 insertions(+), 118 deletions(-) create mode 100644 x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/get_persisted_conversation.ts create mode 100644 x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/persist_conversation_changes.ts diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts index cc06cc094025f..365ecd386c0c7 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts @@ -17,6 +17,8 @@ import { ActionsClientChatOpenAI, ActionsClientSimpleChatModel, } from '@kbn/langchain/server'; +import { EsAnonymizationFieldsSchema } from '../../../ai_assistant_data_clients/anonymization_fields/types'; +import { transformESSearchToAnonymizationFields } from '../../../ai_assistant_data_clients/anonymization_fields/helpers'; import { AgentExecutor } from '../executors/types'; import { APMTracer } from '../tracers/apm_tracer'; import { AssistantToolParams } from '../../../types'; @@ -32,7 +34,6 @@ export const callAgentExecutor: AgentExecutor = async ({ abortSignal, actionsClient, alertsIndexPattern, - anonymizationFields, isEnabledKnowledgeBase, assistantTools = [], connectorId, @@ -48,6 +49,7 @@ export const callAgentExecutor: AgentExecutor = async ({ request, size, traceOptions, + dataClients, }) => { const isOpenAI = llmType === 'openai'; const llmClass = isOpenAI ? ActionsClientChatOpenAI : ActionsClientSimpleChatModel; @@ -70,6 +72,16 @@ export const callAgentExecutor: AgentExecutor = async ({ maxRetries: 0, }); + const anonymizationFieldsRes = + await dataClients?.anonymizationFieldsDataClient?.findDocuments({ + perPage: 1000, + page: 1, + }); + + const anonymizationFields = anonymizationFieldsRes + ? transformESSearchToAnonymizationFields(anonymizationFieldsRes.data) + : undefined; + const pastMessages = langChainMessages.slice(0, -1); // all but the last message const latestMessage = langChainMessages.slice(-1); // the last message diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts index 3ff8b76e75a10..7af0b459f4bc9 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/executors/types.ts @@ -13,7 +13,6 @@ import { KibanaRequest, KibanaResponseFactory, ResponseHeaders } from '@kbn/core import type { LangChainTracer } from '@langchain/core/tracers/tracer_langchain'; import { ExecuteConnectorRequestBody, Message, Replacements } from '@kbn/elastic-assistant-common'; import { StreamResponseWithHeaders } from '@kbn/ml-response-stream/server'; -import { AnonymizationFieldResponse } from '@kbn/elastic-assistant-common/impl/schemas/anonymization_fields/bulk_crud_anonymization_fields_route.gen'; import { PublicMethodsOf } from '@kbn/utility-types'; import { ResponseBody } from '../types'; import type { AssistantTool } from '../../../types'; @@ -38,7 +37,6 @@ export interface AgentExecutorParams { abortSignal?: AbortSignal; alertsIndexPattern?: string; actionsClient: PublicMethodsOf; - anonymizationFields?: AnonymizationFieldResponse[]; isEnabledKnowledgeBase: boolean; assistantTools?: AssistantTool[]; connectorId: string; @@ -57,6 +55,7 @@ export interface AgentExecutorParams { response?: KibanaResponseFactory; size?: number; traceOptions?: TraceOptions; + responseLanguage?: string; } export interface StaticReturnType { diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts index 433d8bcac1b4b..c9be3e0eb0e48 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts @@ -14,12 +14,25 @@ import type { Logger } from '@kbn/logging'; import { BaseMessage } from '@langchain/core/messages'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import { ConversationResponse, Replacements } from '@kbn/elastic-assistant-common'; import { AgentState, NodeParamsBase } from './types'; import { AssistantDataClients } from '../../executors/types'; -import { shouldContinue } from './nodes/should_continue'; +import { + shouldContinue, + shouldContinueGenerateTitle, + shouldContinueGetConversation, +} from './nodes/should_continue'; import { AGENT_NODE, runAgent } from './nodes/run_agent'; import { executeTools, TOOLS_NODE } from './nodes/execute_tools'; import { GENERATE_CHAT_TITLE_NODE, generateChatTitle } from './nodes/generate_chat_title'; +import { + GET_PERSISTED_CONVERSATION_NODE, + getPersistedConversation, +} from './nodes/get_persisted_conversation'; +import { + PERSIST_CONVERSATION_CHANGES_NODE, + persistConversationChanges, +} from './nodes/persist_conversation_changes'; export const DEFAULT_ASSISTANT_GRAPH_ID = 'Default Security Assistant Graph'; @@ -29,8 +42,9 @@ interface GetDefaultAssistantGraphParams { conversationId?: string; llm: BaseChatModel; logger: Logger; - messages: BaseMessage[]; tools: StructuredTool[]; + responseLanguage: string; + replacements: Replacements; } export type DefaultAssistantGraph = ReturnType; @@ -44,7 +58,7 @@ export const getDefaultAssistantGraph = ({ dataClients, llm, logger, - messages, + responseLanguage, tools, }: GetDefaultAssistantGraphParams) => { try { @@ -67,7 +81,16 @@ export const getDefaultAssistantGraph = ({ }, messages: { value: (x: BaseMessage[], y: BaseMessage[]) => x.concat(y), - default: () => messages, + default: () => [], + }, + chatTitle: { + value: (x: string, y?: string) => y ?? x, + default: () => '', + }, + conversation: { + value: (x: ConversationResponse | undefined, y?: ConversationResponse | undefined) => + y ?? x, + default: () => undefined, }, }; @@ -97,25 +120,65 @@ export const getDefaultAssistantGraph = ({ }); const generateChatTitleNode = (state: AgentState) => generateChatTitle({ + ...nodeParams, + state, + responseLanguage, + }); + + const getPersistedConversationNode = (state: AgentState) => + getPersistedConversation({ + ...nodeParams, + state, + conversationsDataClient: dataClients?.conversationsDataClient, + conversationId, + }); + + const persistConversationChangesNode = (state: AgentState) => + persistConversationChanges({ ...nodeParams, state, conversationsDataClient: dataClients?.conversationsDataClient, conversationId, }); const shouldContinueEdge = (state: AgentState) => shouldContinue({ ...nodeParams, state }); + const shouldContinueGenerateTitleEdge = (state: AgentState) => + shouldContinueGenerateTitle({ ...nodeParams, state }); + const shouldContinueGetConversationEdge = (state: AgentState) => + shouldContinueGetConversation({ ...nodeParams, state, conversationId }); // Put together a new graph using the nodes and default state from above - const graph = new StateGraph, '__start__' | 'agent' | 'tools'>({ + const graph = new StateGraph< + AgentState, + Partial, + | '__start__' + | 'agent' + | 'tools' + | 'generateChatTitle' + | 'getPersistedConversation' + | 'persistConversationChanges' + >({ channels: graphState, }); // Define the nodes to cycle between + graph.addNode(GET_PERSISTED_CONVERSATION_NODE, getPersistedConversationNode); graph.addNode(GENERATE_CHAT_TITLE_NODE, generateChatTitleNode); + graph.addNode(PERSIST_CONVERSATION_CHANGES_NODE, persistConversationChangesNode); graph.addNode(AGENT_NODE, runAgentNode); graph.addNode(TOOLS_NODE, executeToolsNode); + + // Add edges, alternating between agent and action until finished + graph.addConditionalEdges(START, shouldContinueGetConversationEdge, { + continue: GET_PERSISTED_CONVERSATION_NODE, + end: AGENT_NODE, + }); + graph.addConditionalEdges(GET_PERSISTED_CONVERSATION_NODE, shouldContinueGenerateTitleEdge, { + continue: GENERATE_CHAT_TITLE_NODE, + end: PERSIST_CONVERSATION_CHANGES_NODE, + }); + graph.addEdge(GENERATE_CHAT_TITLE_NODE, PERSIST_CONVERSATION_CHANGES_NODE); + graph.addEdge(PERSIST_CONVERSATION_CHANGES_NODE, AGENT_NODE); // Add conditional edge for basic routing graph.addConditionalEdges(AGENT_NODE, shouldContinueEdge, { continue: TOOLS_NODE, end: END }); - // Add edges, alternating between agent and action until finished - graph.addEdge(START, AGENT_NODE); graph.addEdge(TOOLS_NODE, AGENT_NODE); // Compile the graph return graph.compile(); diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index 054ecb08f31f4..0c3a0cccb9f07 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -13,12 +13,14 @@ import { ActionsClientSimpleChatModel, } from '@kbn/langchain/server'; import { createOpenAIFunctionsAgent, createStructuredChatAgent } from 'langchain/agents'; +import { EsAnonymizationFieldsSchema } from '../../../../ai_assistant_data_clients/anonymization_fields/types'; import { AssistantToolParams } from '../../../../types'; import { AgentExecutor } from '../../executors/types'; import { openAIFunctionAgentPrompt, structuredChatAgentPrompt } from './prompts'; import { APMTracer } from '../../tracers/apm_tracer'; import { getDefaultAssistantGraph } from './graph'; import { invokeGraph, streamGraph } from './helpers'; +import { transformESSearchToAnonymizationFields } from '../../../../ai_assistant_data_clients/anonymization_fields/helpers'; /** * Drop in replacement for the existing `callAgentExecutor` that uses LangGraph @@ -27,7 +29,6 @@ export const callAssistantGraph: AgentExecutor = async ({ abortSignal, actionsClient, alertsIndexPattern, - anonymizationFields, isEnabledKnowledgeBase, assistantTools = [], connectorId, @@ -45,6 +46,7 @@ export const callAssistantGraph: AgentExecutor = async ({ request, size, traceOptions, + responseLanguage = 'English', }) => { const logger = parentLogger.get('defaultAssistantGraph'); const isOpenAI = llmType === 'openai'; @@ -67,7 +69,16 @@ export const callAssistantGraph: AgentExecutor = async ({ // failure could be due to bad connector, we should deliver that result to the client asap maxRetries: 0, }); - const model = llm; + + const anonymizationFieldsRes = + await dataClients?.anonymizationFieldsDataClient?.findDocuments({ + perPage: 1000, + page: 1, + }); + + const anonymizationFields = anonymizationFieldsRes + ? transformESSearchToAnonymizationFields(anonymizationFieldsRes.data) + : undefined; const messages = langChainMessages.slice(0, -1); // all but the last message const latestMessage = langChainMessages.slice(-1); // the last message @@ -75,7 +86,7 @@ export const callAssistantGraph: AgentExecutor = async ({ const modelExists = await esStore.isModelInstalled(); // Create a chain that uses the ELSER backed ElasticsearchStore, override k=10 for esql query generation for now - const chain = RetrievalQAChain.fromLLM(model, esStore.asRetriever(10)); + const chain = RetrievalQAChain.fromLLM(llm, esStore.asRetriever(10)); // Fetch any applicable tools that the source plugin may have registered const assistantToolParams: AssistantToolParams = { @@ -85,7 +96,7 @@ export const callAssistantGraph: AgentExecutor = async ({ esClient, isEnabledKnowledgeBase, kbDataClient: dataClients?.kbDataClient, - llm: model, + llm, logger, modelExists, onNewReplacements, @@ -120,8 +131,9 @@ export const callAssistantGraph: AgentExecutor = async ({ dataClients, llm, logger, - messages, tools, + responseLanguage, + replacements, }); const inputs = { input: latestMessage[0].content as string }; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/generate_chat_title.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/generate_chat_title.ts index ba9c76c2e33dc..d1d60e9bed9b4 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/generate_chat_title.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/generate_chat_title.ts @@ -8,54 +8,45 @@ import { StringOutputParser } from '@langchain/core/output_parsers'; import { ChatPromptTemplate } from '@langchain/core/prompts'; import { AgentState, NodeParamsBase } from '../types'; -import { AIAssistantConversationsDataClient } from '../../../../../ai_assistant_data_clients/conversations'; -export const GENERATE_CHAT_TITLE_PROMPT = ChatPromptTemplate.fromMessages([ - [ - 'system', - `You are a helpful assistant for Elastic Security. Assume the following user message is the start of a conversation between you and a user; give this conversation a title based on the content below. DO NOT UNDER ANY CIRCUMSTANCES wrap this title in single or double quotes. This title is shown in a list of conversations to the user, so title it for the user, not for you. As an example, for the given MESSAGE, this is the TITLE: +export const GENERATE_CHAT_TITLE_PROMPT = (responseLanguage: string) => + ChatPromptTemplate.fromMessages([ + [ + 'system', + `You are a helpful assistant for Elastic Security. Assume the following user message is the start of a conversation between you and a user; give this conversation a title based on the content below. DO NOT UNDER ANY CIRCUMSTANCES wrap this title in single or double quotes. This title is shown in a list of conversations to the user, so title it for the user, not for you. Please create the title in ${responseLanguage}. As an example, for the given MESSAGE, this is the TITLE: MESSAGE: I am having trouble with the Elastic Security app. TITLE: Troubleshooting Elastic Security app issues `, - ], - ['human', '{input}'], -]); + ], + ['human', '{input}'], + ]); export interface GenerateChatTitleParams extends NodeParamsBase { - conversationsDataClient?: AIAssistantConversationsDataClient; - conversationId?: string; + responseLanguage: string; state: AgentState; } export const GENERATE_CHAT_TITLE_NODE = 'generateChatTitle'; export const generateChatTitle = async ({ - conversationsDataClient, - conversationId, + responseLanguage, logger, model, state, }: GenerateChatTitleParams) => { logger.debug(`Node state:\n ${JSON.stringify(state, null, 2)}`); - if (!conversationId) { - if (state.messages.length > 0) { - - } - } - const conversation = await conversationsDataClient?.getConversation({ id: conversationId }); if (state.messages.length !== 0) { logger.debug('No need to generate chat title, messages already exist'); - return; + return { chatTitle: '' }; } const outputParser = new StringOutputParser(); - const graph = GENERATE_CHAT_TITLE_PROMPT.pipe(model).pipe(outputParser); + const graph = GENERATE_CHAT_TITLE_PROMPT(responseLanguage).pipe(model).pipe(outputParser); const chatTitle = await graph.invoke({ input: JSON.stringify(state.input, null, 2), }); - logger.debug(`chatTitle: ${chatTitle}`); return { diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/get_persisted_conversation.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/get_persisted_conversation.ts new file mode 100644 index 0000000000000..29bc67097a178 --- /dev/null +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/get_persisted_conversation.ts @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { AgentState, NodeParamsBase } from '../types'; +import { AIAssistantConversationsDataClient } from '../../../../../ai_assistant_data_clients/conversations'; +import { getLangChainMessages } from '../../../helpers'; + +export interface GetPersistedConversationParams extends NodeParamsBase { + conversationsDataClient?: AIAssistantConversationsDataClient; + conversationId?: string; + state: AgentState; +} + +export const GET_PERSISTED_CONVERSATION_NODE = 'getPersistedConversation'; + +export const getPersistedConversation = async ({ + conversationsDataClient, + conversationId, + logger, + state, +}: GetPersistedConversationParams) => { + logger.debug(`Node state:\n ${JSON.stringify(state, null, 2)}`); + if (!conversationId) { + logger.debug('Cannot get conversation, because conversationId is undefined'); + return { + conversation: undefined, + messages: [], + chatTitle: '', + }; + } + + const conversation = await conversationsDataClient?.getConversation({ id: conversationId }); + if (!conversation) { + logger.debug('Requested conversation, because conversation is undefined'); + return { + conversation: undefined, + messages: [], + chatTitle: '', + }; + } + + logger.debug(`conversationId: ${conversationId}`); + + const messages = getLangChainMessages(conversation.messages ?? []); + return { + conversation, + messages, + chatTitle: conversation.title, + }; +}; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/persist_conversation_changes.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/persist_conversation_changes.ts new file mode 100644 index 0000000000000..39331310e0f0e --- /dev/null +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/persist_conversation_changes.ts @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { replaceAnonymizedValuesWithOriginalValues } from '@kbn/elastic-assistant-common'; +import { AgentState, NodeParamsBase } from '../types'; +import { AIAssistantConversationsDataClient } from '../../../../../ai_assistant_data_clients/conversations'; +import { getLangChainMessages } from '../../../helpers'; + +export interface PersistConversationChangesParams extends NodeParamsBase { + conversationsDataClient?: AIAssistantConversationsDataClient; + conversationId?: string; + state: AgentState; +} + +export const PERSIST_CONVERSATION_CHANGES_NODE = 'persistConversationChanges'; + +export const persistConversationChanges = async ({ + conversationsDataClient, + conversationId, + logger, + state, +}: PersistConversationChangesParams) => { + logger.debug(`Node state:\n ${JSON.stringify(state, null, 2)}`); + + if (!state.conversation || !conversationId) { + logger.debug('No need to generate chat title, conversationId is undefined'); + return { + conversation: undefined, + messages: [], + }; + } + + let conversation; + if (state.conversation?.title !== state.chatTitle) { + conversation = await conversationsDataClient?.updateConversation({ + conversationUpdateProps: { + id: conversationId, + title: state.chatTitle, + }, + }); + } + + const updatedConversation = await conversationsDataClient?.appendConversationMessages({ + existingConversation: conversation ? conversation : state.conversation, + messages: [ + { + content: replaceAnonymizedValuesWithOriginalValues({ + messageContent: state.input, + replacements: {}, + }), + role: 'user', + timestamp: new Date().toISOString(), + }, + ], + }); + if (!updatedConversation) { + logger.debug('Not updated conversation'); + return { conversation: undefined, messages: [] }; + } + + logger.debug(`conversationId: ${conversationId}`); + const langChainMessages = getLangChainMessages(updatedConversation.messages ?? []); + const messages = langChainMessages.slice(0, -1); // all but the last message + + return { + conversation: updatedConversation, + messages, + }; +}; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/should_continue.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/should_continue.ts index 281963df363a8..046c4a86d4c7a 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/should_continue.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/should_continue.ts @@ -5,6 +5,7 @@ * 2.0. */ +import { NEW_CHAT } from '../../../../../routes/helpers'; import { AgentState, NodeParamsBase } from '../types'; export interface ShouldContinueParams extends NodeParamsBase { @@ -26,3 +27,31 @@ export const shouldContinue = ({ logger, state }: ShouldContinueParams) => { return 'continue'; }; + +export const shouldContinueGenerateTitle = ({ logger, state }: ShouldContinueParams) => { + logger.debug(`Node state:\n${JSON.stringify(state, null, 2)}`); + + if (state.conversation?.title !== NEW_CHAT) { + return 'end'; + } + + return 'continue'; +}; + +export interface ShouldContinueGetConversation extends NodeParamsBase { + state: AgentState; + conversationId?: string; +} +export const shouldContinueGetConversation = ({ + logger, + state, + conversationId, +}: ShouldContinueGetConversation) => { + logger.debug(`Node state:\n${JSON.stringify(state, null, 2)}`); + + if (!conversationId) { + return 'end'; + } + + return 'continue'; +}; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/types.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/types.ts index 1d19646fb6eb3..4ee4f1ba1b148 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/types.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/types.ts @@ -9,6 +9,7 @@ import { BaseMessage } from '@langchain/core/messages'; import { AgentAction, AgentFinish, AgentStep } from '@langchain/core/agents'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import type { Logger } from '@kbn/logging'; +import { ConversationResponse } from '@kbn/elastic-assistant-common'; export interface AgentStateBase { agentOutcome?: AgentAction | AgentFinish; @@ -18,6 +19,8 @@ export interface AgentStateBase { export interface AgentState extends AgentStateBase { input: string; messages: BaseMessage[]; + chatTitle: string; + conversation: ConversationResponse | undefined; } export interface NodeParamsBase { diff --git a/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/post_attack_discovery.ts b/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/post_attack_discovery.ts index 7859d635ccb30..bd3af4b5513ee 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/post_attack_discovery.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/post_attack_discovery.ts @@ -64,6 +64,7 @@ export const postAttackDiscoveryRoute = ( try { // get the actions plugin start contract from the request context: const actions = (await context.elasticAssistant).actions; + const actionsClient = await actions.getActionsClientWithRequest(request); const pluginName = getPluginNameFromRequest({ request, defaultPluginName: DEFAULT_PLUGIN_NAME, @@ -110,11 +111,10 @@ export const postAttackDiscoveryRoute = ( }; const llm = new ActionsClientLlm({ - actions, + actionsClient, connectorId, llmType: getLlmType(actionTypeId), logger, - request, temperature: 0, // zero temperature for attack discovery, because we want structured JSON output timeout: CONNECTOR_TIMEOUT, traceOptions, diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts index 2cb1ba286a77b..78183adb16e23 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts @@ -21,12 +21,16 @@ import { getRequestAbortedSignal } from '@kbn/data-plugin/server'; import { ElasticAssistantPluginRouter, GetElser } from '../../types'; import { buildResponse } from '../utils'; import { + DEFAULT_PLUGIN_NAME, UPGRADE_LICENSE_MESSAGE, appendAssistantMessageToConversation, createOrUpdateConversationWithUserInput, + getPluginNameFromRequest, hasAIAssistantLicense, langChainExecute, } from '../helpers'; +import { transformESSearchToAnonymizationFields } from '../../ai_assistant_data_clients/anonymization_fields/helpers'; +import { EsAnonymizationFieldsSchema } from '../../ai_assistant_data_clients/anonymization_fields/types'; export const SYSTEM_PROMPT_CONTEXT_NON_I18N = (context: string) => { return `CONTEXT:\n"""\n${context}\n"""`; @@ -103,16 +107,22 @@ export const chatCompleteRoute = ( if (request.body.messages) { // replacements - const systemAnonymizationFields = await anonymizationFieldsDataClient?.findDocuments({ - page: 1, - perPage: 1000, - }); + const anonymizationFieldsRes = + await anonymizationFieldsDataClient?.findDocuments({ + perPage: 1000, + page: 1, + }); + + const anonymizationFields = anonymizationFieldsRes + ? transformESSearchToAnonymizationFields(anonymizationFieldsRes.data) + : undefined; + // anonymize messages before sending to LLM messages = request.body.messages.map((m) => { let content = m.content ?? ''; if (m.data && m.data.length > 0) { const anonymizedData = transformRawData({ - anonymizationFields: systemAnonymizationFields?.data, + anonymizationFields, currentReplacements: latestReplacements, getAnonymizedValue, onNewReplacements, @@ -131,11 +141,20 @@ export const chatCompleteRoute = ( }); } - if (request.body.persist && conversationsDataClient) { + // Fetch any tools registered by the request's originating plugin + const pluginName = getPluginNameFromRequest({ + request, + defaultPluginName: DEFAULT_PLUGIN_NAME, + logger, + }); + const enableKnowledgeBaseByDefault = ( + await context.elasticAssistant + ).getRegisteredFeatures(pluginName).assistantKnowledgeBaseByDefault; + // TODO: remove non-graph persistance when KB will be enabled by default + if (!enableKnowledgeBaseByDefault && request.body.persist && conversationsDataClient) { const updatedConversation = await createOrUpdateConversationWithUserInput({ actionsClient, actionTypeId, - authenticatedUser, connectorId, conversationId, conversationsDataClient, @@ -178,7 +197,6 @@ export const chatCompleteRoute = ( abortSignal, actionsClient, actionTypeId, - assistantContext: ctx.elasticAssistant, connectorId, context, getElser, @@ -190,6 +208,7 @@ export const chatCompleteRoute = ( request, response, telemetry, + responseLanguage: request.body.responseLanguage, }); } catch (err) { const error = transformError(err as Error); diff --git a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts index c1835109fe580..62306d305b7c6 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts @@ -7,7 +7,6 @@ import { AnalyticsServiceSetup, - AuthenticatedUser, KibanaRequest, KibanaResponseFactory, Logger, @@ -30,21 +29,19 @@ import { MINIMUM_AI_ASSISTANT_LICENSE } from '../../common/constants'; import { ESQL_RESOURCE, KNOWLEDGE_BASE_INDEX_PATTERN } from './knowledge_base/constants'; import { callAgentExecutor } from '../lib/langchain/execute_custom_llm_chain'; import { getLlmType } from './utils'; -import { AgentExecutorParams, AssistantDataClients, StaticReturnType } from '../lib/langchain/executors/types'; +import { + AgentExecutorParams, + AssistantDataClients, + StaticReturnType, +} from '../lib/langchain/executors/types'; import { executeAction, StaticResponse } from '../lib/executor'; import { getLangChainMessages } from '../lib/langchain/helpers'; import { getLangSmithTracer } from './evaluate/utils'; -import { EsAnonymizationFieldsSchema } from '../ai_assistant_data_clients/anonymization_fields/types'; -import { transformESSearchToAnonymizationFields } from '../ai_assistant_data_clients/anonymization_fields/helpers'; import { ElasticsearchStore } from '../lib/langchain/elasticsearch_store/elasticsearch_store'; import { AIAssistantConversationsDataClient } from '../ai_assistant_data_clients/conversations'; import { INVOKE_ASSISTANT_SUCCESS_EVENT } from '../lib/telemetry/event_based_telemetry'; -import { - ElasticAssistantApiRequestHandlerContext, - ElasticAssistantRequestHandlerContext, - GetElser, -} from '../types'; +import { ElasticAssistantRequestHandlerContext, GetElser } from '../types'; import { callAssistantGraph } from '../lib/langchain/graphs/default_assistant_graph'; interface GetPluginNameFromRequestParams { @@ -190,7 +187,7 @@ export const generateTitleForNewChatConversation = async ({ export interface AppendMessageToConversationParams { conversationsDataClient: AIAssistantConversationsDataClient; - messages: Array>; + messages: Array>; replacements: Replacements; conversation: ConversationResponse; } @@ -210,7 +207,7 @@ export const appendMessageToConversation = async ({ }), role: m.role ?? 'user', }, - timestamp: m.timestamp ?? new Date().toISOString(), + timestamp: new Date().toISOString(), })), }); return updatedConversation; @@ -323,7 +320,7 @@ export interface LangChainExecuteParams { telemetry: AnalyticsServiceSetup; actionTypeId: string; connectorId: string; - assistantContext: ElasticAssistantApiRequestHandlerContext; + conversationId?: string; context: ElasticAssistantRequestHandlerContext; actionsClient: PublicMethodsOf; // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -336,6 +333,7 @@ export interface LangChainExecuteParams { ) => Promise; getElser: GetElser; response: KibanaResponseFactory; + responseLanguage?: string; } export const langChainExecute = async ({ messages, @@ -345,14 +343,15 @@ export const langChainExecute = async ({ telemetry, actionTypeId, connectorId, - assistantContext, context, actionsClient, request, logger, + conversationId, onLlmResponse, getElser, response, + responseLanguage, }: LangChainExecuteParams) => { // TODO: Add `traceId` to actions request when calling via langchain logger.debug( @@ -364,6 +363,7 @@ export const langChainExecute = async ({ defaultPluginName: DEFAULT_PLUGIN_NAME, logger, }); + const assistantContext = await context.elasticAssistant; const assistantTools = assistantContext .getRegisteredTools(pluginName) .filter((x) => x.id !== 'attack-discovery'); // We don't (yet) support asking the assistant for NEW attack discoveries from a conversation @@ -378,12 +378,6 @@ export const langChainExecute = async ({ const anonymizationFieldsDataClient = await assistantContext.getAIAssistantAnonymizationFieldsDataClient(); - const anonymizationFieldsRes = - await anonymizationFieldsDataClient?.findDocuments({ - perPage: 1000, - page: 1, - }); - const conversationsDataClient = await assistantContext.getAIAssistantConversationsDataClient(); // Create an ElasticsearchStore for KB interactions @@ -418,12 +412,10 @@ export const langChainExecute = async ({ abortSignal, dataClients, alertsIndexPattern: request.body.alertsIndexPattern, - anonymizationFields: anonymizationFieldsRes - ? transformESSearchToAnonymizationFields(anonymizationFieldsRes.data) - : undefined, actionsClient, isEnabledKnowledgeBase: request.body.isEnabledKnowledgeBase ?? false, assistantTools, + conversationId, connectorId, esClient, esStore, @@ -435,6 +427,7 @@ export const langChainExecute = async ({ onLlmResponse, request, replacements, + responseLanguage, size: request.body.size, traceOptions: { projectName: request.body.langSmithProject, @@ -454,7 +447,6 @@ export const langChainExecute = async ({ result = await callAgentExecutor(executorParams); } - telemetry.reportEvent(INVOKE_ASSISTANT_SUCCESS_EVENT.eventType, { actionTypeId, isEnabledKnowledgeBase: request.body.isEnabledKnowledgeBase, @@ -464,7 +456,6 @@ export const langChainExecute = async ({ // tracked here: https://github.com/elastic/security-team/issues/7363 assistantStreamingEnabled: request.body.subAction !== 'invokeAI' && actionTypeId === '.gen-ai', }); - return response.ok(result); }; @@ -479,7 +470,6 @@ export interface CreateOrUpdateConversationWithParams { actionsClient: PublicMethodsOf; newMessages?: Array>; model?: string; - authenticatedUser: AuthenticatedUser; responseLanguage?: string; } export const createOrUpdateConversationWithUserInput = async ({ @@ -493,7 +483,6 @@ export const createOrUpdateConversationWithUserInput = async ({ actionsClient, newMessages, model, - authenticatedUser, responseLanguage, }: CreateOrUpdateConversationWithParams) => { if (!conversationId) { @@ -528,7 +517,6 @@ export const createOrUpdateConversationWithUserInput = async ({ return updateConversationWithUserInput({ actionsClient, actionTypeId, - authenticatedUser, connectorId, conversationId, conversationsDataClient, @@ -547,10 +535,8 @@ export interface UpdateConversationWithParams { actionTypeId: string; connectorId: string; actionsClient: PublicMethodsOf; - newMessages?: Array>; + newMessages?: Array>; model?: string; - authenticatedUser: AuthenticatedUser; - responseLanguage?: string; } export const updateConversationWithUserInput = async ({ logger, @@ -562,32 +548,7 @@ export const updateConversationWithUserInput = async ({ actionsClient, newMessages, model, - authenticatedUser, - responseLanguage, }: UpdateConversationWithParams) => { - if (!conversationId) { - if (newMessages && newMessages.length > 0) { - const title = await generateTitleForNewChatConversation({ - message: newMessages[0], - actionsClient, - actionTypeId, - connectorId, - logger, - responseLanguage, - model, - }); - if (title) { - return conversationsDataClient.createConversation({ - conversation: { - title, - messages: newMessages, - replacements, - }, - }); - } - } - return; - } const conversation = await conversationsDataClient?.getConversation({ id: conversationId, }); @@ -603,8 +564,6 @@ export const updateConversationWithUserInput = async ({ })); const lastMessage = newMessages?.[0] ?? messages?.[0]; - console.log(conversation?.title) - console.log(lastMessage) if (conversation?.title === NEW_CHAT && lastMessage) { const title = await generateTitleForNewChatConversation({ diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index 0504779d5d57e..a1d89928efc8e 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -17,14 +17,18 @@ import { Replacements, } from '@kbn/elastic-assistant-common'; import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; -import { - INVOKE_ASSISTANT_ERROR_EVENT, -} from '../lib/telemetry/event_based_telemetry'; +import { INVOKE_ASSISTANT_ERROR_EVENT } from '../lib/telemetry/event_based_telemetry'; import { POST_ACTIONS_CONNECTOR_EXECUTE } from '../../common/constants'; import { buildResponse } from '../lib/build_response'; import { ElasticAssistantRequestHandlerContext, GetElser } from '../types'; -import { appendAssistantMessageToConversation, langChainExecute, nonLangChainExecute, updateConversationWithUserInput } from './helpers'; - +import { + DEFAULT_PLUGIN_NAME, + appendAssistantMessageToConversation, + getPluginNameFromRequest, + langChainExecute, + nonLangChainExecute, + updateConversationWithUserInput, +} from './helpers'; export const postActionsConnectorExecuteRoute = ( router: IRouter, @@ -72,7 +76,7 @@ export const postActionsConnectorExecuteRoute = ( }; let messages; - let newMessage: Pick | undefined; + let newMessage: Pick | undefined; const conversationId = request.body.conversationId; const actionTypeId = request.body.actionTypeId; const connectorId = decodeURIComponent(request.params.connectorId); @@ -82,7 +86,6 @@ export const postActionsConnectorExecuteRoute = ( newMessage = { content: request.body.message, role: 'user', - timestamp: new Date().toISOString(), }; } @@ -93,11 +96,20 @@ export const postActionsConnectorExecuteRoute = ( const conversationsDataClient = await assistantContext.getAIAssistantConversationsDataClient(); - if (conversationId && conversationsDataClient) { + // Fetch any tools registered by the request's originating plugin + const pluginName = getPluginNameFromRequest({ + request, + defaultPluginName: DEFAULT_PLUGIN_NAME, + logger, + }); + const enableKnowledgeBaseByDefault = + assistantContext.getRegisteredFeatures(pluginName).assistantKnowledgeBaseByDefault; + + // TODO: remove non-graph persistance when KB will be enabled by default + if (!enableKnowledgeBaseByDefault && conversationId && conversationsDataClient) { const updatedConversation = await updateConversationWithUserInput({ actionsClient, actionTypeId, - authenticatedUser, connectorId, conversationId, conversationsDataClient, @@ -155,12 +167,12 @@ export const postActionsConnectorExecuteRoute = ( abortSignal, actionsClient, actionTypeId, - assistantContext, connectorId, + conversationId, context, getElser, logger, - messages: messages ?? [], + messages: (enableKnowledgeBaseByDefault && newMessage ? [newMessage] : messages) ?? [], onLlmResponse, onNewReplacements, replacements: latestReplacements, diff --git a/x-pack/plugins/search_playground/server/lib/get_chat_params.ts b/x-pack/plugins/search_playground/server/lib/get_chat_params.ts index a481309b9277d..9740988c3e1e2 100644 --- a/x-pack/plugins/search_playground/server/lib/get_chat_params.ts +++ b/x-pack/plugins/search_playground/server/lib/get_chat_params.ts @@ -46,9 +46,8 @@ export const getChatParams = async ( switch (connector.actionTypeId) { case OPENAI_CONNECTOR_ID: chatModel = new ActionsClientChatOpenAI({ - actions, + actionsClient, logger, - request, connectorId, model, traceId: uuidv4(), @@ -67,9 +66,8 @@ export const getChatParams = async ( case BEDROCK_CONNECTOR_ID: const llmType = 'bedrock'; chatModel = new ActionsClientLlm({ - actions, + actionsClient, logger, - request, connectorId, model, traceId: uuidv4(), From 4c1f4655bfe5b6c0bba05c9a3365a4b146b9e240 Mon Sep 17 00:00:00 2001 From: YulNaumenko Date: Sun, 23 Jun 2024 18:43:02 -0700 Subject: [PATCH 06/23] removed timestamp --- .../chat/post_chat_complete_route.gen.ts | 8 +---- .../chat/post_chat_complete_route.schema.yaml | 4 --- .../graphs/default_assistant_graph/helpers.ts | 6 ++++ .../server/routes/chat/chat_complete_route.ts | 36 +++++++++---------- .../server/routes/helpers.ts | 23 ++++++++---- .../routes/post_actions_connector_execute.ts | 34 +++++++++--------- .../server/routes/register_routes.ts | 2 +- 7 files changed, 58 insertions(+), 55 deletions(-) diff --git a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.gen.ts b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.gen.ts index acc984d06caaf..d6430e0b9fc89 100644 --- a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.gen.ts +++ b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.gen.ts @@ -5,8 +5,6 @@ * 2.0. */ -import { z } from 'zod'; - /* * NOTICE: Do not edit this file manually. * This file is automatically generated by the OpenAPI Generator, @kbn/openapi-generator. @@ -16,7 +14,7 @@ import { z } from 'zod'; * version: 2023-10-31 */ -import { NonEmptyString } from '../common_attributes.gen'; +import { z } from 'zod'; export type RootContext = z.infer; export const RootContext = z.literal('security'); @@ -45,10 +43,6 @@ export const ChatMessage = z.object({ * Message role. */ role: ChatMessageRole, - /** - * The timestamp message was sent or received. - */ - '@timestamp': NonEmptyString, /** * ECS objects array to attach to the context of the message. */ diff --git a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.schema.yaml b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.schema.yaml index a87f0d8b447c3..adbbf44e2ac73 100644 --- a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.schema.yaml +++ b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.schema.yaml @@ -62,7 +62,6 @@ components: type: object description: AI assistant message. required: - - '@timestamp' - 'role' properties: content: @@ -71,9 +70,6 @@ components: role: $ref: '#/components/schemas/ChatMessageRole' description: Message role. - '@timestamp': - $ref: '../common_attributes.schema.yaml#/components/schemas/NonEmptyString' - description: The timestamp message was sent or received. data: description: ECS objects array to attach to the context of the message. type: array diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index 383b3e9f5cee8..23d1f6710160e 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -109,6 +109,12 @@ export const streamGraph = async ({ } } } + } else if (event.event === 'on_llm_end') { + const generations = event.data.output?.generations[0]; + if (generations && generations[0]?.generationInfo.finish_reason === 'stop') { + console.log(`good.... ${JSON.stringify(finalMessage)}`); + handleStreamEnd(finalMessage); + } } await processEvent(); diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts index 78183adb16e23..be2bf873f089e 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts @@ -65,7 +65,6 @@ export const chatCompleteRoute = ( const ctx = await context.resolve(['core', 'elasticAssistant', 'licensing']); const logger: Logger = ctx.elasticAssistant.logger; const telemetry = ctx.elasticAssistant.telemetry; - let onLlmResponse; const license = ctx.licensing.license; if (!hasAIAssistantLicense(license)) { @@ -135,7 +134,6 @@ export const chatCompleteRoute = ( const transformedMessage = { role: m.role, content, - timestamp: m['@timestamp'], }; return transformedMessage; }); @@ -174,25 +172,25 @@ export const chatCompleteRoute = ( role: c.role, content: c.content, })); - - onLlmResponse = async ( - content: string, - traceData: Message['traceData'] = {}, - isError = false - ): Promise => { - if (updatedConversation && conversationsDataClient) { - await appendAssistantMessageToConversation({ - conversation: updatedConversation, - conversationsDataClient, - messageContent: content, - replacements: latestReplacements, - isError, - traceData, - }); - } - }; } + const onLlmResponse = async ( + content: string, + traceData: Message['traceData'] = {}, + isError = false + ): Promise => { + if (conversationId && conversationsDataClient) { + await appendAssistantMessageToConversation({ + conversationId, + conversationsDataClient, + messageContent: content, + replacements: latestReplacements, + isError, + traceData, + }); + } + }; + return await langChainExecute({ abortSignal, actionsClient, diff --git a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts index 62306d305b7c6..ce7b9d0e0bca0 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts @@ -217,7 +217,7 @@ export interface AppendAssistantMessageToConversationParams { conversationsDataClient: AIAssistantConversationsDataClient; messageContent: string; replacements: Replacements; - conversation: ConversationResponse; + conversationId: string; isError?: boolean; traceData?: Message['traceData']; } @@ -225,11 +225,16 @@ export const appendAssistantMessageToConversation = async ({ conversationsDataClient, messageContent, replacements, - conversation, + conversationId, isError = false, traceData = {}, }: AppendAssistantMessageToConversationParams) => { - await conversationsDataClient?.appendConversationMessages({ + const conversation = await conversationsDataClient.getConversation({ id: conversationId }); + if (!conversation) { + return; + } + + await conversationsDataClient.appendConversationMessages({ existingConversation: conversation, messages: [ getMessageFromRawResponse({ @@ -449,8 +454,8 @@ export const langChainExecute = async ({ telemetry.reportEvent(INVOKE_ASSISTANT_SUCCESS_EVENT.eventType, { actionTypeId, - isEnabledKnowledgeBase: request.body.isEnabledKnowledgeBase, - isEnabledRAGAlerts: request.body.isEnabledRAGAlerts, + isEnabledKnowledgeBase: request.body.isEnabledKnowledgeBase ?? true, + isEnabledRAGAlerts: request.body.isEnabledRAGAlerts ?? true, model: request.body.model, // TODO rm actionTypeId check when llmClass for bedrock streaming is implemented // tracked here: https://github.com/elastic/security-team/issues/7363 @@ -468,7 +473,7 @@ export interface CreateOrUpdateConversationWithParams { actionTypeId: string; connectorId: string; actionsClient: PublicMethodsOf; - newMessages?: Array>; + newMessages?: Array>; model?: string; responseLanguage?: string; } @@ -500,7 +505,11 @@ export const createOrUpdateConversationWithUserInput = async ({ return conversationsDataClient.createConversation({ conversation: { title, - messages: newMessages, + messages: newMessages.map((m) => ({ + content: m.content, + role: m.role, + timestamp: new Date().toISOString(), + })), replacements, apiConfig: { connectorId, diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index a1d89928efc8e..eb86d9798c1e2 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -128,25 +128,25 @@ export const postActionsConnectorExecuteRoute = ( role: c.role, content: c.content, })); - - onLlmResponse = async ( - content: string, - traceData: Message['traceData'] = {}, - isError = false - ): Promise => { - if (updatedConversation && conversationsDataClient) { - await appendAssistantMessageToConversation({ - conversation: updatedConversation, - conversationsDataClient, - messageContent: content, - replacements: latestReplacements, - isError, - traceData, - }); - } - }; } + onLlmResponse = async ( + content: string, + traceData: Message['traceData'] = {}, + isError = false + ): Promise => { + if (conversationsDataClient && conversationId) { + await appendAssistantMessageToConversation({ + conversationId, + conversationsDataClient, + messageContent: content, + replacements: latestReplacements, + isError, + traceData, + }); + } + }; + if (!request.body.isEnabledKnowledgeBase && !request.body.isEnabledRAGAlerts) { // if not langchain, call execute action directly and return the response: return await nonLangChainExecute({ diff --git a/x-pack/plugins/elastic_assistant/server/routes/register_routes.ts b/x-pack/plugins/elastic_assistant/server/routes/register_routes.ts index 81d9ec2911061..4df24a78a9e74 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/register_routes.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/register_routes.ts @@ -39,7 +39,7 @@ export const registerRoutes = ( ) => { /** PUBLIC */ // Chat - chatCompleteRoute(router); + chatCompleteRoute(router, getElserId); /** INTERNAL */ // Capabilities From e48145003182df516305cd2eb2ab2144c611d557 Mon Sep 17 00:00:00 2001 From: YulNaumenko Date: Sun, 23 Jun 2024 21:32:06 -0700 Subject: [PATCH 07/23] fixed streaming --- .../graphs/default_assistant_graph/index.ts | 1 - .../server/routes/chat/chat_complete_route.ts | 83 +++++++++++-------- .../server/routes/helpers.ts | 14 ++-- .../routes/post_actions_connector_execute.ts | 2 + 4 files changed, 59 insertions(+), 41 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index 0c3a0cccb9f07..a7be83f79a627 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -80,7 +80,6 @@ export const callAssistantGraph: AgentExecutor = async ({ ? transformESSearchToAnonymizationFields(anonymizationFieldsRes.data) : undefined; - const messages = langChainMessages.slice(0, -1); // all but the last message const latestMessage = langChainMessages.slice(-1); // the last message const modelExists = await esStore.isModelInstalled(); diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts index be2bf873f089e..dfe1cce165596 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts @@ -100,44 +100,54 @@ export const chatCompleteRoute = ( // get the actions plugin start contract from the request context: const actions = (await context.elasticAssistant).actions; const actionsClient = await actions.getActionsClientWithRequest(request); - const actionTypeId = - (await actionsClient.getAllSystemConnectors()).find((c) => c.id === connectorId) - ?.actionTypeId ?? '.gen-ai'; - - if (request.body.messages) { - // replacements - const anonymizationFieldsRes = - await anonymizationFieldsDataClient?.findDocuments({ - perPage: 1000, - page: 1, - }); + const connectors = await actionsClient.getBulk({ ids: [connectorId] }); + const actionTypeId = connectors.length > 0 ? connectors[0].actionTypeId : '.gen-ai'; + + // replacements + const anonymizationFieldsRes = + await anonymizationFieldsDataClient?.findDocuments({ + perPage: 1000, + page: 1, + }); - const anonymizationFields = anonymizationFieldsRes - ? transformESSearchToAnonymizationFields(anonymizationFieldsRes.data) - : undefined; - - // anonymize messages before sending to LLM - messages = request.body.messages.map((m) => { - let content = m.content ?? ''; - if (m.data && m.data.length > 0) { - const anonymizedData = transformRawData({ - anonymizationFields, - currentReplacements: latestReplacements, - getAnonymizedValue, - onNewReplacements, - rawData: m.data as unknown as Record, + let anonymizationFields = anonymizationFieldsRes + ? transformESSearchToAnonymizationFields(anonymizationFieldsRes.data) + : undefined; + + // anonymize messages before sending to LLM + messages = request.body.messages.map((m) => { + let content = m.content ?? ''; + if (m.data && m.data.length > 0) { + // includes/anonymize fields from the messages data + if (m.fields_to_anonymize && m.fields_to_anonymize.length > 0) { + anonymizationFields = anonymizationFields?.map((a) => { + if (m.fields_to_anonymize?.includes(a.field)) { + return { + ...a, + allowed: true, + anonymized: true, + }; + } + return a; }); - const wr = `${SYSTEM_PROMPT_CONTEXT_NON_I18N(anonymizedData)}\n`; - - content = `${wr}\n${m.content}`; } - const transformedMessage = { - role: m.role, - content, - }; - return transformedMessage; - }); - } + const anonymizedData = transformRawData({ + anonymizationFields, + currentReplacements: latestReplacements, + getAnonymizedValue, + onNewReplacements, + rawData: m.data as unknown as Record, + }); + const wr = `${SYSTEM_PROMPT_CONTEXT_NON_I18N(anonymizedData)}\n`; + + content = `${wr}\n${m.content}`; + } + const transformedMessage = { + role: m.role, + content, + }; + return transformedMessage; + }); // Fetch any tools registered by the request's originating plugin const pluginName = getPluginNameFromRequest({ @@ -193,9 +203,12 @@ export const chatCompleteRoute = ( return await langChainExecute({ abortSignal, + isEnabledKnowledgeBase: true, + isStream: false, actionsClient, actionTypeId, connectorId, + conversationId, context, getElser, logger, diff --git a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts index ce7b9d0e0bca0..5c2aca29ed84a 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts @@ -320,6 +320,8 @@ export const nonLangChainExecute = async ({ export interface LangChainExecuteParams { messages: Array>; replacements: Replacements; + isEnabledKnowledgeBase: boolean; + isStream?: boolean; onNewReplacements: (newReplacements: Replacements) => void; abortSignal: AbortSignal; telemetry: AnalyticsServiceSetup; @@ -344,6 +346,7 @@ export const langChainExecute = async ({ messages, replacements, onNewReplacements, + isEnabledKnowledgeBase, abortSignal, telemetry, actionTypeId, @@ -357,10 +360,11 @@ export const langChainExecute = async ({ getElser, response, responseLanguage, + isStream = true, }: LangChainExecuteParams) => { // TODO: Add `traceId` to actions request when calling via langchain logger.debug( - `Executing via langchain, isEnabledKnowledgeBase: ${request.body.isEnabledKnowledgeBase}, isEnabledRAGAlerts: ${request.body.isEnabledRAGAlerts}` + `Executing via langchain, isEnabledKnowledgeBase: ${isEnabledKnowledgeBase}, isEnabledRAGAlerts: ${request.body.isEnabledRAGAlerts}` ); // Fetch any tools registered by the request's originating plugin const pluginName = getPluginNameFromRequest({ @@ -418,13 +422,13 @@ export const langChainExecute = async ({ dataClients, alertsIndexPattern: request.body.alertsIndexPattern, actionsClient, - isEnabledKnowledgeBase: request.body.isEnabledKnowledgeBase ?? false, + isEnabledKnowledgeBase, assistantTools, conversationId, connectorId, esClient, esStore, - isStream: request.body.subAction !== 'invokeAI', + isStream, llmType: getLlmType(actionTypeId), langChainMessages, logger, @@ -454,12 +458,12 @@ export const langChainExecute = async ({ telemetry.reportEvent(INVOKE_ASSISTANT_SUCCESS_EVENT.eventType, { actionTypeId, - isEnabledKnowledgeBase: request.body.isEnabledKnowledgeBase ?? true, + isEnabledKnowledgeBase, isEnabledRAGAlerts: request.body.isEnabledRAGAlerts ?? true, model: request.body.model, // TODO rm actionTypeId check when llmClass for bedrock streaming is implemented // tracked here: https://github.com/elastic/security-team/issues/7363 - assistantStreamingEnabled: request.body.subAction !== 'invokeAI' && actionTypeId === '.gen-ai', + assistantStreamingEnabled: isStream && actionTypeId === '.gen-ai', }); return response.ok(result); }; diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index eb86d9798c1e2..e89a9fdaca7f5 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -165,6 +165,8 @@ export const postActionsConnectorExecuteRoute = ( return await langChainExecute({ abortSignal, + isStream: request.body.subAction !== 'invokeAI', + isEnabledKnowledgeBase: request.body.isEnabledKnowledgeBase ?? false, actionsClient, actionTypeId, connectorId, From 178762b1f28fac7b6b89f3bb9892f6fa8f8bc31b Mon Sep 17 00:00:00 2001 From: YulNaumenko Date: Sun, 23 Jun 2024 21:38:35 -0700 Subject: [PATCH 08/23] renamed api --- .../kbn-elastic-assistant-common/constants.ts | 2 +- .../routes/chat/chat_complete_route.test.ts | 15 ++++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/x-pack/packages/kbn-elastic-assistant-common/constants.ts b/x-pack/packages/kbn-elastic-assistant-common/constants.ts index d0b5887851e71..96af59095ab87 100755 --- a/x-pack/packages/kbn-elastic-assistant-common/constants.ts +++ b/x-pack/packages/kbn-elastic-assistant-common/constants.ts @@ -7,7 +7,7 @@ export const ELASTIC_AI_ASSISTANT_INTERNAL_API_VERSION = '1'; -export const ELASTIC_AI_ASSISTANT_URL = '/api/elastic_assistant'; +export const ELASTIC_AI_ASSISTANT_URL = '/api/security_ai_assistant'; export const ELASTIC_AI_ASSISTANT_INTERNAL_URL = '/internal/elastic_assistant'; export const ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL = `${ELASTIC_AI_ASSISTANT_INTERNAL_URL}/current_user/conversations`; diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts index 5b2737436fa8c..c46e4d598c351 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts @@ -8,7 +8,6 @@ import { elasticsearchClientMock } from '@kbn/core-elasticsearch-client-server-mocks'; import { requestContextMock } from '../../__mocks__/request_context'; import { serverMock } from '../../__mocks__/server'; -import { createConversationRoute } from './create_route'; import { getBasicEmptySearchResponse, getEmptyFindResult } from '../../__mocks__/response'; import { getCreateConversationRequest, requestMock } from '../../__mocks__/request'; import { @@ -16,10 +15,11 @@ import { getConversationMock, getQueryConversationParams, } from '../../__mocks__/conversations_schema.mock'; -import { ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL } from '@kbn/elastic-assistant-common'; import { AuthenticatedUser } from '@kbn/security-plugin-types-common'; +import { ELASTIC_AI_ASSISTANT_CHAT_COMPLETE_URL } from '@kbn/elastic-assistant-common'; +import { chatCompleteRoute } from './chat_complete_route'; -describe('Create conversation route', () => { +describe('Chat complete route', () => { let server: ReturnType; let { clients, context } = requestContextMock.createTools(); const mockUser1 = { @@ -45,7 +45,8 @@ describe('Create conversation route', () => { elasticsearchClientMock.createSuccessTransportRequestPromise(getBasicEmptySearchResponse()) ); context.elasticAssistant.getCurrentUser.mockReturnValue(mockUser1); - createConversationRoute(server.router); + const mockGetElser = jest.fn().mockResolvedValue('.elser_model_2'); + chatCompleteRoute(server.router, mockGetElser); }); describe('status codes', () => { @@ -90,7 +91,7 @@ describe('Create conversation route', () => { test('disallows unknown title', async () => { const request = requestMock.create({ method: 'post', - path: ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL, + path: ELASTIC_AI_ASSISTANT_CHAT_COMPLETE_URL, body: { ...getCreateConversationSchemaMock(), title: true, @@ -112,7 +113,7 @@ describe('Create conversation route', () => { test('is successful', async () => { const request = requestMock.create({ method: 'post', - path: ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL, + path: ELASTIC_AI_ASSISTANT_CHAT_COMPLETE_URL, body: { ...getCreateConversationSchemaMock(), messages: [defaultMessage], @@ -128,7 +129,7 @@ describe('Create conversation route', () => { const request = requestMock.create({ method: 'post', - path: ELASTIC_AI_ASSISTANT_CONVERSATIONS_URL, + path: ELASTIC_AI_ASSISTANT_CHAT_COMPLETE_URL, body: { ...getCreateConversationSchemaMock(), messages: [wrongMessage], From f74c072ddacdc76fe23309beaf72c8bbab050b66 Mon Sep 17 00:00:00 2001 From: YulNaumenko Date: Mon, 24 Jun 2024 06:29:37 -0700 Subject: [PATCH 09/23] changed data --- .../impl/schemas/chat/post_chat_complete_route.gen.ts | 4 ++-- .../impl/schemas/chat/post_chat_complete_route.schema.yaml | 6 ++---- .../server/routes/chat/chat_complete_route.ts | 7 +++++-- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.gen.ts b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.gen.ts index d6430e0b9fc89..25ba5a68030b3 100644 --- a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.gen.ts +++ b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.gen.ts @@ -44,9 +44,9 @@ export const ChatMessage = z.object({ */ role: ChatMessageRole, /** - * ECS objects array to attach to the context of the message. + * ECS object to attach to the context of the message. */ - data: z.array(MessageData).optional(), + data: MessageData.optional(), fields_to_anonymize: z.array(z.string()).optional(), }); diff --git a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.schema.yaml b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.schema.yaml index adbbf44e2ac73..c8f783b5b5ab8 100644 --- a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.schema.yaml +++ b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.schema.yaml @@ -71,10 +71,8 @@ components: $ref: '#/components/schemas/ChatMessageRole' description: Message role. data: - description: ECS objects array to attach to the context of the message. - type: array - items: - $ref: '#/components/schemas/MessageData' + description: ECS object to attach to the context of the message. + $ref: '#/components/schemas/MessageData' fields_to_anonymize: type: array items: diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts index dfe1cce165596..145842cc7619c 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts @@ -117,7 +117,7 @@ export const chatCompleteRoute = ( // anonymize messages before sending to LLM messages = request.body.messages.map((m) => { let content = m.content ?? ''; - if (m.data && m.data.length > 0) { + if (m.data) { // includes/anonymize fields from the messages data if (m.fields_to_anonymize && m.fields_to_anonymize.length > 0) { anonymizationFields = anonymizationFields?.map((a) => { @@ -136,7 +136,10 @@ export const chatCompleteRoute = ( currentReplacements: latestReplacements, getAnonymizedValue, onNewReplacements, - rawData: m.data as unknown as Record, + rawData: Object.keys(m.data).reduce( + (obj, key) => ({ ...obj, key: [m.data ? m.data[key] : ''] }), + {} + ), }); const wr = `${SYSTEM_PROMPT_CONTEXT_NON_I18N(anonymizedData)}\n`; From fabcea8cf8ff2da6d1bd938e3009fd3c30488ed8 Mon Sep 17 00:00:00 2001 From: kibanamachine <42973632+kibanamachine@users.noreply.github.com> Date: Mon, 24 Jun 2024 13:41:46 +0000 Subject: [PATCH 10/23] [CI] Auto-commit changed files from 'node scripts/lint_ts_projects --fix' --- x-pack/packages/kbn-langchain/tsconfig.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/packages/kbn-langchain/tsconfig.json b/x-pack/packages/kbn-langchain/tsconfig.json index 92dc5ebd33911..949aca47794ec 100644 --- a/x-pack/packages/kbn-langchain/tsconfig.json +++ b/x-pack/packages/kbn-langchain/tsconfig.json @@ -18,6 +18,6 @@ "@kbn/logging", "@kbn/actions-plugin", "@kbn/logging-mocks", - "@kbn/core-http-server" + "@kbn/utility-types" ] } From 004d5c8bc3ecc299b4943a88a35a21606b9d4aa1 Mon Sep 17 00:00:00 2001 From: YulNaumenko Date: Thu, 27 Jun 2024 15:36:59 -0700 Subject: [PATCH 11/23] fixed tests --- .../language_models/chat_openai.test.ts | 38 +- .../server/language_models/llm.test.ts | 43 +- .../language_models/simple_chat_model.test.ts | 49 +- .../graphs/default_assistant_graph/helpers.ts | 3 +- .../graphs/default_assistant_graph/index.ts | 2 +- .../nodes/get_persisted_conversation.ts | 3 + .../server/routes/attack_discovery/helpers.ts | 10 +- .../attack_discovery/post_attack_discovery.ts | 3 +- .../routes/chat/chat_complete_route.test.ts | 871 +++++++++++++++--- .../routes/post_actions_connector_execute.ts | 11 +- 10 files changed, 796 insertions(+), 237 deletions(-) diff --git a/x-pack/packages/kbn-langchain/server/language_models/chat_openai.test.ts b/x-pack/packages/kbn-langchain/server/language_models/chat_openai.test.ts index b73db19853982..92fd210fd7c53 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/chat_openai.test.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/chat_openai.test.ts @@ -7,10 +7,10 @@ import type OpenAI from 'openai'; import { Stream } from 'openai/streaming'; -import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; import { loggerMock } from '@kbn/logging-mocks'; +import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock'; -import { ActionsClientChatOpenAI, ActionsClientChatOpenAIParams } from './chat_openai'; +import { ActionsClientChatOpenAI } from './chat_openai'; import { mockActionResponse, mockChatCompletion } from './mocks'; const connectorId = 'mock-connector-id'; @@ -19,11 +19,8 @@ const mockExecute = jest.fn(); const mockLogger = loggerMock.create(); -const mockActions = { - getActionsClientWithRequest: jest.fn().mockImplementation(() => ({ - execute: mockExecute, - })), -} as unknown as ActionsPluginStart; +const actionsClient = actionsClientMock.create(); + const chunk = { object: 'chat.completion.chunk', choices: [ @@ -40,27 +37,13 @@ export async function* asyncGenerator() { yield chunk; } const mockStreamExecute = jest.fn(); -const mockStreamActions = { - getActionsClientWithRequest: jest.fn().mockImplementation(() => ({ - execute: mockStreamExecute, - })), -} as unknown as ActionsPluginStart; const prompt = 'Do you know my name?'; const { signal } = new AbortController(); -const mockRequest = { - params: { connectorId }, - body: { - message: prompt, - subAction: 'invokeAI', - isEnabledKnowledgeBase: true, - }, -} as ActionsClientChatOpenAIParams['request']; - const defaultArgs = { - actionsClient: mockActions, + actionsClient, connectorId, logger: mockLogger, streaming: false, @@ -118,7 +101,7 @@ describe('ActionsClientChatOpenAI', () => { const actionsClientChatOpenAI = new ActionsClientChatOpenAI({ ...defaultArgs, streaming: true, - actions: mockStreamActions, + actionsClient, }); const result: AsyncIterable = @@ -177,16 +160,11 @@ describe('ActionsClientChatOpenAI', () => { serviceMessage: 'action-result-service-message', status: 'error', // <-- error status })); - - const badActions = { - getActionsClientWithRequest: jest.fn().mockImplementation(() => ({ - execute: hasErrorStatus, - })), - } as unknown as ActionsPluginStart; + actionsClient.execute.mockRejectedValueOnce(hasErrorStatus); const actionsClientChatOpenAI = new ActionsClientChatOpenAI({ ...defaultArgs, - actions: badActions, + actionsClient, }); expect(actionsClientChatOpenAI.completionWithRetry(defaultNonStreamingArgs)) diff --git a/x-pack/packages/kbn-langchain/server/language_models/llm.test.ts b/x-pack/packages/kbn-langchain/server/language_models/llm.test.ts index db04d3e5d7810..a5d92f89a93e5 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/llm.test.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/llm.test.ts @@ -5,15 +5,15 @@ * 2.0. */ -import { KibanaRequest } from '@kbn/core/server'; -import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; import { loggerMock } from '@kbn/logging-mocks'; +import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock'; import { ActionsClientLlm } from './llm'; import { mockActionResponse } from './mocks'; const connectorId = 'mock-connector-id'; +const actionsClient = actionsClientMock.create(); const mockExecute = jest.fn().mockImplementation(() => ({ data: mockActionResponse, status: 'ok', @@ -21,23 +21,8 @@ const mockExecute = jest.fn().mockImplementation(() => ({ const mockLogger = loggerMock.create(); -const mockActions = { - getActionsClientWithRequest: jest.fn().mockImplementation(() => ({ - execute: mockExecute, - })), -} as unknown as ActionsPluginStart; - const prompt = 'Do you know my name?'; -const mockRequest: KibanaRequest = { - params: { connectorId }, - body: { - message: prompt, - subAction: 'invokeAI', - isEnabledKnowledgeBase: true, - }, -} as KibanaRequest; - describe('ActionsClientLlm', () => { beforeEach(() => { jest.clearAllMocks(); @@ -46,7 +31,7 @@ describe('ActionsClientLlm', () => { describe('getActionResultData', () => { it('returns the expected data', async () => { const actionsClientLlm = new ActionsClientLlm({ - actionsClient: mockActions, + actionsClient, connectorId, logger: mockLogger, }); @@ -60,10 +45,9 @@ describe('ActionsClientLlm', () => { describe('_llmType', () => { it('returns the expected LLM type', () => { const actionsClientLlm = new ActionsClientLlm({ - actions: mockActions, + actionsClient, connectorId, logger: mockLogger, - request: mockRequest, }); expect(actionsClientLlm._llmType()).toEqual('ActionsClientLlm'); @@ -71,11 +55,10 @@ describe('ActionsClientLlm', () => { it('returns the expected LLM type when overridden', () => { const actionsClientLlm = new ActionsClientLlm({ - actions: mockActions, + actionsClient, connectorId, llmType: 'special-llm-type', logger: mockLogger, - request: mockRequest, }); expect(actionsClientLlm._llmType()).toEqual('special-llm-type'); @@ -85,10 +68,9 @@ describe('ActionsClientLlm', () => { describe('_call', () => { it('returns the expected content when _call is invoked', async () => { const actionsClientLlm = new ActionsClientLlm({ - actions: mockActions, + actionsClient, connectorId, logger: mockLogger, - request: mockRequest, }); const result = await actionsClientLlm._call(prompt); @@ -103,17 +85,11 @@ describe('ActionsClientLlm', () => { status: 'error', // <-- error status })); - const badActions = { - getActionsClientWithRequest: jest.fn().mockImplementation(() => ({ - execute: hasErrorStatus, - })), - } as unknown as ActionsPluginStart; - + actionsClient.execute.mockRejectedValueOnce(hasErrorStatus); const actionsClientLlm = new ActionsClientLlm({ - actions: badActions, + actionsClient, connectorId, logger: mockLogger, - request: mockRequest, }); await expect(actionsClientLlm._call(prompt)).rejects.toThrowError( @@ -130,10 +106,9 @@ describe('ActionsClientLlm', () => { })); const actionsClientLlm = new ActionsClientLlm({ - actions: mockActions, + actionsClient, connectorId, logger: mockLogger, - request: mockRequest, }); await expect(actionsClientLlm._call(prompt)).rejects.toThrowError( diff --git a/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.test.ts b/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.test.ts index 504f6c9159f8c..40c7e19cfdd80 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.test.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.test.ts @@ -6,10 +6,10 @@ */ import { PassThrough } from 'stream'; -import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; import { loggerMock } from '@kbn/logging-mocks'; +import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock'; -import { ActionsClientSimpleChatModel, CustomChatModelInput } from './simple_chat_model'; +import { ActionsClientSimpleChatModel } from './simple_chat_model'; import { mockActionResponse } from './mocks'; import { BaseMessage } from '@langchain/core/messages'; import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'; @@ -19,26 +19,15 @@ import { parseGeminiStream } from '../utils/gemini'; const connectorId = 'mock-connector-id'; const mockExecute = jest.fn(); +const actionsClient = actionsClientMock.create(); const mockLogger = loggerMock.create(); -const mockActions = { - getActionsClientWithRequest: jest.fn().mockImplementation(() => ({ - execute: mockExecute, - })), -} as unknown as ActionsPluginStart; - const mockStreamExecute = jest.fn().mockImplementation(() => ({ data: new PassThrough(), status: 'ok', })); -const mockStreamActions = { - getActionsClientWithRequest: jest.fn().mockImplementation(() => ({ - execute: mockStreamExecute, - })), -} as unknown as ActionsPluginStart; -const prompt = 'Do you know my name?'; const callMessages = [ { lc_serializable: true, @@ -78,17 +67,8 @@ const callRunManager = { handleLLMNewToken, } as unknown as CallbackManagerForLLMRun; -const mockRequest: CustomChatModelInput['request'] = { - params: { connectorId }, - body: { - message: prompt, - subAction: 'invokeAI', - isEnabledKnowledgeBase: true, - }, -} as CustomChatModelInput['request']; - const defaultArgs = { - actionsClient: mockActions, + actionsClient, connectorId, logger: mockLogger, streaming: false, @@ -158,15 +138,11 @@ describe('ActionsClientSimpleChatModel', () => { status: 'error', // <-- error status })); - const badActions = { - getActionsClientWithRequest: jest.fn().mockImplementation(() => ({ - execute: hasErrorStatus, - })), - } as unknown as ActionsPluginStart; + actionsClient.execute.mockRejectedValueOnce(hasErrorStatus); const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({ ...defaultArgs, - actions: badActions, + actionsClient, }); await expect( @@ -220,9 +196,10 @@ describe('ActionsClientSimpleChatModel', () => { (parseGeminiStream as jest.Mock).mockResolvedValue(mockActionResponse.message); }); it('returns the expected content when _call is invoked with streaming and llmType is Bedrock', async () => { + actionsClient.execute.mockImplementationOnce(mockStreamExecute); const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({ ...defaultArgs, - actions: mockStreamActions, + actionsClient, llmType: 'bedrock', streaming: true, maxTokens: 333, @@ -247,9 +224,10 @@ describe('ActionsClientSimpleChatModel', () => { expect(result).toEqual(mockActionResponse.message); }); it('returns the expected content when _call is invoked with streaming and llmType is Gemini', async () => { + actionsClient.execute.mockImplementationOnce(mockStreamExecute); const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({ ...defaultArgs, - actions: mockStreamActions, + actionsClient, llmType: 'gemini', streaming: true, maxTokens: 333, @@ -282,9 +260,11 @@ describe('ActionsClientSimpleChatModel', () => { handleToken(`, "action_input": "`); handleToken('token6'); }); + actionsClient.execute.mockImplementationOnce(mockStreamExecute); + const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({ ...defaultArgs, - actions: mockStreamActions, + actionsClient, llmType: 'bedrock', streaming: true, }); @@ -302,9 +282,10 @@ describe('ActionsClientSimpleChatModel', () => { handleToken('"'); handleToken('token7'); }); + actionsClient.execute.mockImplementationOnce(mockStreamExecute); const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel({ ...defaultArgs, - actions: mockStreamActions, + actionsClient, llmType: 'bedrock', streaming: true, }); diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index 23d1f6710160e..83c118b130546 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -112,12 +112,11 @@ export const streamGraph = async ({ } else if (event.event === 'on_llm_end') { const generations = event.data.output?.generations[0]; if (generations && generations[0]?.generationInfo.finish_reason === 'stop') { - console.log(`good.... ${JSON.stringify(finalMessage)}`); handleStreamEnd(finalMessage); } } - await processEvent(); + processEvent(); } catch (err) { // if I throw an error here, it crashes the server. Not sure how to get around that. // If I put await on this function the error works properly, but when there is not an error diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index a7be83f79a627..517ac10479461 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -134,7 +134,7 @@ export const callAssistantGraph: AgentExecutor = async ({ responseLanguage, replacements, }); - const inputs = { input: latestMessage[0].content as string }; + const inputs = { input: latestMessage[0]?.content as string }; if (isStream) { return streamGraph({ apmTracer, assistantGraph, inputs, logger, onLlmResponse, request }); diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/get_persisted_conversation.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/get_persisted_conversation.ts index 29bc67097a178..6dbf284e462c4 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/get_persisted_conversation.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/get_persisted_conversation.ts @@ -30,6 +30,7 @@ export const getPersistedConversation = async ({ conversation: undefined, messages: [], chatTitle: '', + input: state.input, }; } @@ -40,6 +41,7 @@ export const getPersistedConversation = async ({ conversation: undefined, messages: [], chatTitle: '', + input: state.input, }; } @@ -50,5 +52,6 @@ export const getPersistedConversation = async ({ conversation, messages, chatTitle: conversation.title, + input: !state.input ? conversation.messages?.slice(-1)[0].content : state.input, }; }; diff --git a/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/helpers.ts b/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/helpers.ts index 9dca7ee46cbda..73a82b448cd4b 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/helpers.ts @@ -23,9 +23,10 @@ import { ActionsClientLlm } from '@kbn/langchain/server'; import { Moment } from 'moment'; import { transformError } from '@kbn/securitysolution-es-utils'; -import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; +import type { ActionsClient } from '@kbn/actions-plugin/server'; import moment from 'moment/moment'; import { uniq } from 'lodash/fp'; +import { PublicMethodsOf } from '@kbn/utility-types'; import { getLangSmithTracer } from '../evaluate/utils'; import { getLlmType } from '../utils'; import type { GetRegisteredTools } from '../../services/app_context'; @@ -52,7 +53,7 @@ export const REQUIRED_FOR_ATTACK_DISCOVERY: AnonymizationFieldResponse[] = [ ]; export const getAssistantToolParams = ({ - actions, + actionsClient, alertsIndexPattern, anonymizationFields, apiConfig, @@ -67,7 +68,7 @@ export const getAssistantToolParams = ({ request, size, }: { - actions: ActionsPluginStart; + actionsClient: PublicMethodsOf; alertsIndexPattern: string; anonymizationFields?: AnonymizationFieldResponse[]; apiConfig: ApiConfig; @@ -98,11 +99,10 @@ export const getAssistantToolParams = ({ }; const llm = new ActionsClientLlm({ - actions, + actionsClient, connectorId: apiConfig.connectorId, llmType: getLlmType(apiConfig.actionTypeId), logger, - request, temperature: 0, // zero temperature for attack discovery, because we want structured JSON output timeout: connectorTimeout, traceOptions, diff --git a/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/post_attack_discovery.ts b/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/post_attack_discovery.ts index 8ff2cd72ee36c..b9c680dde3d1d 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/post_attack_discovery.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/post_attack_discovery.ts @@ -70,6 +70,7 @@ export const postAttackDiscoveryRoute = ( try { // get the actions plugin start contract from the request context: const actions = (await context.elasticAssistant).actions; + const actionsClient = await actions.getActionsClientWithRequest(request); const dataClient = await assistantContext.getAttackDiscoveryDataClient(); const authenticatedUser = assistantContext.getCurrentUser(); if (authenticatedUser == null) { @@ -120,7 +121,7 @@ export const postAttackDiscoveryRoute = ( } const assistantToolParams = getAssistantToolParams({ - actions, + actionsClient, alertsIndexPattern, anonymizationFields, apiConfig, diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts index c46e4d598c351..0ebe03980ec8d 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts @@ -5,140 +5,761 @@ * 2.0. */ -import { elasticsearchClientMock } from '@kbn/core-elasticsearch-client-server-mocks'; -import { requestContextMock } from '../../__mocks__/request_context'; -import { serverMock } from '../../__mocks__/server'; -import { getBasicEmptySearchResponse, getEmptyFindResult } from '../../__mocks__/response'; -import { getCreateConversationRequest, requestMock } from '../../__mocks__/request'; +import { ElasticsearchClient, IRouter, KibanaRequest, Logger } from '@kbn/core/server'; +import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; +import { BaseMessage } from '@langchain/core/messages'; +import { NEVER } from 'rxjs'; +import { mockActionResponse } from '../../__mocks__/action_result_data'; +import { ElasticAssistantRequestHandlerContext } from '../../types'; +import { elasticsearchServiceMock } from '@kbn/core-elasticsearch-server-mocks'; +import { loggingSystemMock } from '@kbn/core-logging-server-mocks'; +import { coreMock } from '@kbn/core/server/mocks'; import { - getCreateConversationSchemaMock, - getConversationMock, - getQueryConversationParams, -} from '../../__mocks__/conversations_schema.mock'; -import { AuthenticatedUser } from '@kbn/security-plugin-types-common'; -import { ELASTIC_AI_ASSISTANT_CHAT_COMPLETE_URL } from '@kbn/elastic-assistant-common'; + INVOKE_ASSISTANT_ERROR_EVENT, + INVOKE_ASSISTANT_SUCCESS_EVENT, +} from '../../lib/telemetry/event_based_telemetry'; +import { PassThrough } from 'stream'; +import { getConversationResponseMock } from '../../ai_assistant_data_clients/conversations/update_conversation.test'; +import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock'; +import { getFindAnonymizationFieldsResultWithSingleHit } from '../../__mocks__/response'; +import { defaultAssistantFeatures } from '@kbn/elastic-assistant-common'; import { chatCompleteRoute } from './chat_complete_route'; -describe('Chat complete route', () => { - let server: ReturnType; - let { clients, context } = requestContextMock.createTools(); - const mockUser1 = { - username: 'my_username', - authentication_realm: { - type: 'my_realm_type', - name: 'my_realm_name', +const actionsClient = actionsClientMock.create(); +jest.mock('../../lib/build_response', () => ({ + buildResponse: jest.fn().mockImplementation((x) => x), +})); +jest.mock('../../lib/executor', () => ({ + executeAction: jest.fn().mockImplementation(async ({ connectorId }) => { + if (connectorId === 'mock-connector-id') { + return { + connector_id: 'mock-connector-id', + data: mockActionResponse, + status: 'ok', + }; + } else { + throw new Error('simulated error'); + } + }), +})); +const mockStream = jest.fn().mockImplementation(() => new PassThrough()); +jest.mock('../../lib/langchain/execute_custom_llm_chain', () => ({ + callAgentExecutor: jest.fn().mockImplementation( + async ({ + connectorId, + isStream, + }: { + actions: ActionsPluginStart; + connectorId: string; + esClient: ElasticsearchClient; + langChainMessages: BaseMessage[]; + logger: Logger; + isStream: boolean; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + request: KibanaRequest; + }) => { + if (!isStream && connectorId === 'mock-connector-id') { + return { + body: { + connector_id: 'mock-connector-id', + data: mockActionResponse, + status: 'ok', + }, + headers: { 'content-type': 'application/json' }, + }; + } else if (isStream && connectorId === 'mock-connector-id') { + return { + body: mockStream, + headers: { + 'Cache-Control': 'no-cache', + Connection: 'keep-alive', + 'Transfer-Encoding': 'chunked', + 'X-Accel-Buffering': 'no', + 'X-Content-Type-Options': 'nosniff', + }, + }; + } else { + throw new Error('simulated error'); + } + } + ), +})); +const existingConversation = getConversationResponseMock(); +const reportEvent = jest.fn(); +const appendConversationMessages = jest.fn(); +const mockContext = { + elasticAssistant: { + actions: { + getActionsClientWithRequest: jest.fn().mockResolvedValue(actionsClient), }, - } as AuthenticatedUser; + getRegisteredTools: jest.fn(() => []), + getRegisteredFeatures: jest.fn(() => defaultAssistantFeatures), + logger: loggingSystemMock.createLogger(), + telemetry: { ...coreMock.createSetup().analytics, reportEvent }, + getCurrentUser: () => ({ + username: 'user', + email: 'email', + fullName: 'full name', + roles: ['user-role'], + enabled: true, + authentication_realm: { name: 'native1', type: 'native' }, + lookup_realm: { name: 'native1', type: 'native' }, + authentication_provider: { type: 'basic', name: 'basic1' }, + authentication_type: 'realm', + elastic_cloud_user: false, + metadata: { _reserved: false }, + }), + getAIAssistantConversationsDataClient: jest.fn().mockResolvedValue({ + getConversation: jest.fn().mockResolvedValue(existingConversation), + updateConversation: jest.fn().mockResolvedValue(existingConversation), + appendConversationMessages: + appendConversationMessages.mockResolvedValue(existingConversation), + }), + getAIAssistantAnonymizationFieldsDataClient: jest.fn().mockResolvedValue({ + findDocuments: jest.fn().mockResolvedValue(getFindAnonymizationFieldsResultWithSingleHit()), + }), + }, + core: { + elasticsearch: { + client: elasticsearchServiceMock.createScopedClusterClient(), + }, + savedObjects: coreMock.createRequestHandlerContext().savedObjects, + }, +}; - beforeEach(() => { - server = serverMock.create(); - ({ clients, context } = requestContextMock.createTools()); - - clients.elasticAssistant.getAIAssistantConversationsDataClient.findDocuments.mockResolvedValue( - Promise.resolve(getEmptyFindResult()) - ); // no current conversations - clients.elasticAssistant.getAIAssistantConversationsDataClient.createConversation.mockResolvedValue( - getConversationMock(getQueryConversationParams()) - ); // creation succeeds - - context.core.elasticsearch.client.asCurrentUser.search.mockResolvedValue( - elasticsearchClientMock.createSuccessTransportRequestPromise(getBasicEmptySearchResponse()) - ); - context.elasticAssistant.getCurrentUser.mockReturnValue(mockUser1); - const mockGetElser = jest.fn().mockResolvedValue('.elser_model_2'); - chatCompleteRoute(server.router, mockGetElser); - }); - - describe('status codes', () => { - test('returns 200 with a conversation created via AIAssistantConversationsDataClient', async () => { - const response = await server.inject( - getCreateConversationRequest(), - requestContextMock.convertContext(context) - ); - expect(response.status).toEqual(200); - }); - - test('returns 401 Unauthorized when request context getCurrentUser is not defined', async () => { - context.elasticAssistant.getCurrentUser.mockReturnValueOnce(null); - const response = await server.inject( - getCreateConversationRequest(), - requestContextMock.convertContext(context) - ); - expect(response.status).toEqual(401); - }); - }); - - describe('unhappy paths', () => { - test('catches error if creation throws', async () => { - clients.elasticAssistant.getAIAssistantConversationsDataClient.createConversation.mockImplementation( - async () => { - throw new Error('Test error'); - } - ); - const response = await server.inject( - getCreateConversationRequest(), - requestContextMock.convertContext(context) - ); - expect(response.status).toEqual(500); - expect(response.body).toEqual({ - message: 'Test error', - status_code: 500, - }); - }); - }); - - describe('request validation', () => { - test('disallows unknown title', async () => { - const request = requestMock.create({ - method: 'post', - path: ELASTIC_AI_ASSISTANT_CHAT_COMPLETE_URL, - body: { - ...getCreateConversationSchemaMock(), - title: true, - }, - }); - const result = server.validate(request); - - expect(result.badRequest).toHaveBeenCalled(); - }); - }); - describe('conversation containing messages', () => { - const getMessage = (role: string = 'user') => ({ - role, - content: 'test content', - timestamp: '2019-12-13T16:40:33.400Z', - }); - const defaultMessage = getMessage(); - - test('is successful', async () => { - const request = requestMock.create({ - method: 'post', - path: ELASTIC_AI_ASSISTANT_CHAT_COMPLETE_URL, - body: { - ...getCreateConversationSchemaMock(), - messages: [defaultMessage], +const mockRequest = { + params: { connectorId: 'mock-connector-id' }, + body: { + connectorId: 'my-gen-ai', + persist: true, + messages: [ + { + role: 'user', + content: + "Evaluate the event from the context and format your output neatly in markdown syntax for my Elastic Security case.\nAdd your description, recommended actions and bulleted triage steps. Use the MITRE ATT&CK data provided to add more context and recommendations from MITRE, and hyperlink to the relevant pages on MITRE's website. Be sure to include the user and host risk score data from the context. Your response should include steps that point to Elastic Security specific features, including endpoint response actions, the Elastic Agent OSQuery manager integration (with example osquery queries), timelines and entity analytics and link to all the relevant Elastic Security documentation.", + data: { + 'event.category': 'process', + 'process.pid': 69516, + 'host.os.version': 14.5, + 'host.os.name': 'macOS', + 'host.name': 'Yuliias-MBP', + 'process.name': 'biomesyncd', + 'user.name': 'yuliianaumenko', + 'process.working_directory': '/', + 'event.module': 'system', + 'process.executable': '/usr/libexec/biomesyncd', + 'process.args': '/usr/libexec/biomesyncd', + message: 'Process biomesyncd (PID: 69516) by user yuliianaumenko STOPPED', }, - }); + }, + ], + }, + events: { + aborted$: NEVER, + }, +}; - const response = await server.inject(request, requestContextMock.convertContext(context)); - expect(response.status).toEqual(200); - }); +const mockResponse = { + ok: jest.fn().mockImplementation((x) => x), + error: jest.fn().mockImplementation((x) => x), +}; - test('fails when provided with an unsupported message role', async () => { - const wrongMessage = getMessage('test_thing'); +describe('chatCompleteRoute', () => { + const mockGetElser = jest.fn().mockResolvedValue('.elser_model_2'); - const request = requestMock.create({ - method: 'post', - path: ELASTIC_AI_ASSISTANT_CHAT_COMPLETE_URL, - body: { - ...getCreateConversationSchemaMock(), - messages: [wrongMessage], + beforeEach(() => { + jest.clearAllMocks(); + actionsClient.getBulk.mockResolvedValue([ + { + id: '1', + isPreconfigured: false, + isSystemAction: false, + isDeprecated: false, + name: 'my name', + actionTypeId: '.gen-ai', + isMissingSecrets: false, + config: { + a: true, + b: true, + c: true, }, - }); - const result = await server.validate(request); - expect(result.badRequest).toHaveBeenCalledWith( - `messages.0.role: Invalid enum value. Expected 'system' | 'user' | 'assistant', received 'test_thing'` - ); - }); + }, + ]); + }); + + it('returns the expected response when isEnabledKnowledgeBase=false', async () => { + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + const result = await handler( + mockContext, + { + ...mockRequest, + body: { + ...mockRequest.body, + isEnabledKnowledgeBase: false, + }, + }, + mockResponse + ); + + expect(result).toEqual({ + body: { + connector_id: 'mock-connector-id', + data: mockActionResponse, + status: 'ok', + }, + }); + }), + }; + }), + }, + }; + + await chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('returns the expected response when isEnabledKnowledgeBase=true', async () => { + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + const result = await handler(mockContext, mockRequest, mockResponse); + + expect(result).toEqual({ + body: { + connector_id: 'mock-connector-id', + data: mockActionResponse, + status: 'ok', + }, + headers: { 'content-type': 'application/json' }, + }); + }), + }; + }), + }, + }; + + await chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('returns the expected error when executeCustomLlmChain fails', async () => { + const requestWithBadConnectorId = { + ...mockRequest, + params: { connectorId: 'bad-connector-id' }, + }; + + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + const result = await handler(mockContext, requestWithBadConnectorId, mockResponse); + + expect(result).toEqual({ + body: 'simulated error', + statusCode: 500, + }); + }), + }; + }), + }, + }; + + await chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('reports success events to telemetry - kb on, RAG alerts off', async () => { + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + await handler(mockContext, mockRequest, mockResponse); + + expect(reportEvent).toHaveBeenCalledWith(INVOKE_ASSISTANT_SUCCESS_EVENT.eventType, { + isEnabledKnowledgeBase: true, + isEnabledRAGAlerts: false, + actionTypeId: '.gen-ai', + model: 'gpt-4', + assistantStreamingEnabled: false, + }); + }), + }; + }), + }, + }; + + await chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('reports success events to telemetry - kb on, RAG alerts on', async () => { + const ragRequest = { + ...mockRequest, + body: { + ...mockRequest.body, + anonymizationFields: [ + { id: '@timestamp', field: '@timestamp', allowed: true, anonymized: false }, + { id: 'host.name', field: 'host.name', allowed: true, anonymized: true }, + ], + replacements: [], + isEnabledRAGAlerts: true, + }, + }; + + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + await handler(mockContext, ragRequest, mockResponse); + + expect(reportEvent).toHaveBeenCalledWith(INVOKE_ASSISTANT_SUCCESS_EVENT.eventType, { + isEnabledKnowledgeBase: true, + isEnabledRAGAlerts: true, + actionTypeId: '.gen-ai', + model: 'gpt-4', + assistantStreamingEnabled: false, + }); + }), + }; + }), + }, + }; + + await chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('reports success events to telemetry - kb off, RAG alerts on', async () => { + const req = { + ...mockRequest, + body: { + ...mockRequest.body, + isEnabledKnowledgeBase: false, + anonymizationFields: [ + { id: '@timestamp', field: '@timestamp', allowed: true, anonymized: false }, + { id: 'host.name', field: 'host.name', allowed: true, anonymized: true }, + ], + replacements: [], + isEnabledRAGAlerts: true, + }, + }; + + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + await handler(mockContext, req, mockResponse); + + expect(reportEvent).toHaveBeenCalledWith(INVOKE_ASSISTANT_SUCCESS_EVENT.eventType, { + actionTypeId: '.gen-ai', + model: 'gpt-4', + assistantStreamingEnabled: false, + isEnabledKnowledgeBase: false, + isEnabledRAGAlerts: true, + }); + }), + }; + }), + }, + }; + + await chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('reports success events to telemetry - kb off, RAG alerts off', async () => { + const req = { + ...mockRequest, + body: { + ...mockRequest.body, + isEnabledKnowledgeBase: false, + }, + }; + + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + await handler(mockContext, req, mockResponse); + + expect(reportEvent).toHaveBeenCalledWith(INVOKE_ASSISTANT_SUCCESS_EVENT.eventType, { + actionTypeId: '.gen-ai', + model: 'gpt-4', + assistantStreamingEnabled: false, + isEnabledKnowledgeBase: false, + isEnabledRAGAlerts: false, + }); + }), + }; + }), + }, + }; + + await chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('reports error events to telemetry - kb on, RAG alerts off', async () => { + const requestWithBadConnectorId = { + ...mockRequest, + params: { connectorId: 'bad-connector-id' }, + }; + + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + await handler(mockContext, requestWithBadConnectorId, mockResponse); + + expect(reportEvent).toHaveBeenCalledWith(INVOKE_ASSISTANT_ERROR_EVENT.eventType, { + errorMessage: 'simulated error', + isEnabledKnowledgeBase: true, + isEnabledRAGAlerts: false, + actionTypeId: '.gen-ai', + model: 'gpt-4', + assistantStreamingEnabled: false, + }); + }), + }; + }), + }, + }; + + await chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('reports error events to telemetry - kb on, RAG alerts on', async () => { + const badRequest = { + ...mockRequest, + params: { connectorId: 'bad-connector-id' }, + body: { + ...mockRequest.body, + isEnabledRAGAlerts: true, + }, + }; + + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + await handler(mockContext, badRequest, mockResponse); + + expect(reportEvent).toHaveBeenCalledWith(INVOKE_ASSISTANT_ERROR_EVENT.eventType, { + errorMessage: 'simulated error', + isEnabledKnowledgeBase: true, + isEnabledRAGAlerts: true, + actionTypeId: '.gen-ai', + model: 'gpt-4', + assistantStreamingEnabled: false, + }); + }), + }; + }), + }, + }; + + await chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('reports error events to telemetry - kb off, RAG alerts on', async () => { + const badRequest = { + ...mockRequest, + params: { connectorId: 'bad-connector-id' }, + body: { + ...mockRequest.body, + isEnabledKnowledgeBase: false, + anonymizationFields: [ + { id: '@timestamp', field: '@timestamp', allowed: true, anonymized: false }, + { id: 'host.name', field: 'host.name', allowed: true, anonymized: true }, + ], + replacements: [], + isEnabledRAGAlerts: true, + }, + }; + + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + await handler(mockContext, badRequest, mockResponse); + + expect(reportEvent).toHaveBeenCalledWith(INVOKE_ASSISTANT_ERROR_EVENT.eventType, { + errorMessage: 'simulated error', + isEnabledKnowledgeBase: false, + isEnabledRAGAlerts: true, + actionTypeId: '.gen-ai', + model: 'gpt-4', + assistantStreamingEnabled: false, + }); + }), + }; + }), + }, + }; + + await chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('Adds error to conversation history', async () => { + const badRequest = { + ...mockRequest, + params: { connectorId: 'bad-connector-id' }, + body: { + ...mockRequest.body, + conversationId: '99999', + }, + }; + + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + await handler(mockContext, badRequest, mockResponse); + expect(appendConversationMessages.mock.calls[1][0].messages[0]).toEqual( + expect.objectContaining({ + content: 'simulated error', + isError: true, + role: 'assistant', + }) + ); + }), + }; + }), + }, + }; + + await chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('reports error events to telemetry - kb off, RAG alerts off', async () => { + const badRequest = { + ...mockRequest, + params: { connectorId: 'bad-connector-id' }, + body: { + ...mockRequest.body, + isEnabledKnowledgeBase: false, + }, + }; + + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + await handler(mockContext, badRequest, mockResponse); + + expect(reportEvent).toHaveBeenCalledWith(INVOKE_ASSISTANT_ERROR_EVENT.eventType, { + errorMessage: 'simulated error', + isEnabledKnowledgeBase: false, + isEnabledRAGAlerts: false, + actionTypeId: '.gen-ai', + model: 'gpt-4', + assistantStreamingEnabled: false, + }); + }), + }; + }), + }, + }; + + await chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('returns the expected response when subAction=invokeStream and actionTypeId=.gen-ai', async () => { + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + const result = await handler( + mockContext, + { + ...mockRequest, + body: { + ...mockRequest.body, + subAction: 'invokeStream', + actionTypeId: '.gen-ai', + }, + }, + mockResponse + ); + + expect(result).toEqual({ + body: mockStream, + headers: { + 'Cache-Control': 'no-cache', + Connection: 'keep-alive', + 'Transfer-Encoding': 'chunked', + 'X-Accel-Buffering': 'no', + 'X-Content-Type-Options': 'nosniff', + }, + }); + }), + }; + }), + }, + }; + + await chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('returns the expected response when subAction=invokeStream and actionTypeId=.bedrock', async () => { + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + const result = await handler( + mockContext, + { + ...mockRequest, + body: { + ...mockRequest.body, + subAction: 'invokeStream', + actionTypeId: '.bedrock', + }, + }, + mockResponse + ); + + expect(result).toEqual({ + body: mockStream, + headers: { + 'Cache-Control': 'no-cache', + Connection: 'keep-alive', + 'Transfer-Encoding': 'chunked', + 'X-Accel-Buffering': 'no', + 'X-Content-Type-Options': 'nosniff', + }, + }); + }), + }; + }), + }, + }; + await chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('returns the expected response when subAction=invokeAI and actionTypeId=.gen-ai', async () => { + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + const result = await handler( + mockContext, + { + ...mockRequest, + body: { + ...mockRequest.body, + subAction: 'invokeAI', + actionTypeId: '.gen-ai', + }, + }, + mockResponse + ); + + expect(result).toEqual({ + body: { connector_id: 'mock-connector-id', data: mockActionResponse, status: 'ok' }, + headers: { + 'content-type': 'application/json', + }, + }); + }), + }; + }), + }, + }; + + await chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('returns the expected response when subAction=invokeAI and actionTypeId=.bedrock', async () => { + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + const result = await handler( + mockContext, + { + ...mockRequest, + body: { + ...mockRequest.body, + subAction: 'invokeAI', + actionTypeId: '.bedrock', + }, + }, + mockResponse + ); + + expect(result).toEqual({ + body: { connector_id: 'mock-connector-id', data: mockActionResponse, status: 'ok' }, + headers: { + 'content-type': 'application/json', + }, + }); + }), + }; + }), + }, + }; + await chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); }); }); diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index e89a9fdaca7f5..800239f030664 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -102,11 +102,12 @@ export const postActionsConnectorExecuteRoute = ( defaultPluginName: DEFAULT_PLUGIN_NAME, logger, }); - const enableKnowledgeBaseByDefault = - assistantContext.getRegisteredFeatures(pluginName).assistantKnowledgeBaseByDefault; + const isGraphAvailable = + assistantContext.getRegisteredFeatures(pluginName).assistantKnowledgeBaseByDefault && + request.body.isEnabledKnowledgeBase; // TODO: remove non-graph persistance when KB will be enabled by default - if (!enableKnowledgeBaseByDefault && conversationId && conversationsDataClient) { + if (!isGraphAvailable && conversationId && conversationsDataClient) { const updatedConversation = await updateConversationWithUserInput({ actionsClient, actionTypeId, @@ -147,7 +148,7 @@ export const postActionsConnectorExecuteRoute = ( } }; - if (!request.body.isEnabledKnowledgeBase && !request.body.isEnabledRAGAlerts) { + if (!isGraphAvailable && !request.body.isEnabledRAGAlerts) { // if not langchain, call execute action directly and return the response: return await nonLangChainExecute({ abortSignal, @@ -174,7 +175,7 @@ export const postActionsConnectorExecuteRoute = ( context, getElser, logger, - messages: (enableKnowledgeBaseByDefault && newMessage ? [newMessage] : messages) ?? [], + messages: (isGraphAvailable && newMessage ? [newMessage] : messages) ?? [], onLlmResponse, onNewReplacements, replacements: latestReplacements, From 0fdb096e9e7544c8659b2b771261f6962b547ca3 Mon Sep 17 00:00:00 2001 From: YulNaumenko Date: Thu, 27 Jun 2024 18:01:19 -0700 Subject: [PATCH 12/23] fixed lint errors --- .../server/lib/executor.test.ts | 90 +++++++------------ .../execute_custom_llm_chain/index.test.ts | 19 ++-- .../execute_custom_llm_chain/index.ts | 2 +- .../routes/attack_discovery/helpers.test.ts | 14 +-- .../server/routes/helpers.ts | 72 ++++++++++++++- .../server/routes/categorization_routes.ts | 3 +- .../server/routes/ecs_routes.ts | 3 +- .../server/routes/related_routes.ts | 3 +- 8 files changed, 123 insertions(+), 83 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/executor.test.ts b/x-pack/plugins/elastic_assistant/server/lib/executor.test.ts index 4e9cb7ffaec67..d2d0031b5b485 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/executor.test.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/executor.test.ts @@ -13,18 +13,10 @@ import { executeAction, Props } from './executor'; import { PassThrough } from 'stream'; -import { KibanaRequest } from '@kbn/core-http-server'; -import { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock'; -import { ExecuteConnectorRequestBody } from '@kbn/elastic-assistant-common'; import { loggerMock } from '@kbn/logging-mocks'; import * as ParseStream from './parse_stream'; -const request = { - body: { - subAction: 'invokeAI', - message: 'hello', - }, -} as KibanaRequest; + const onLlmResponse = jest.fn(async () => {}); // We need it to be a promise, or it'll crash because of missing `.catch` const connectorId = 'testConnectorId'; const mockLogger = loggerMock.create(); @@ -96,83 +88,69 @@ describe('executeAction', () => { }); }); - it('should throw an error if the actions plugin fails to retrieve the actions client', async () => { - const actions = { - getActionsClientWithRequest: jest - .fn() - .mockRejectedValue(new Error('Failed to retrieve actions client')), - } as unknown as Props['actions']; - - await expect(executeAction({ ...testProps, actions })).rejects.toThrowError( - 'Failed to retrieve actions client' - ); - }); - it('should throw an error if the actions client fails to execute the action', async () => { - const actions = { - getActionsClientWithRequest: jest.fn().mockResolvedValue({ - execute: jest.fn().mockRejectedValue(new Error('Failed to execute action')), - }), - } as unknown as Props['actions']; + const actionsClient = actionsClientMock.create(); + actionsClient.execute.mockRejectedValue(new Error('Failed to execute action')); + testProps.actionsClient = actionsClient; - await expect(executeAction({ ...testProps, actions })).rejects.toThrowError( + await expect(executeAction({ ...testProps, actionsClient })).rejects.toThrowError( 'Failed to execute action' ); }); it('should throw an error when the response from the actions framework is null or undefined', async () => { - const actions = { - getActionsClientWithRequest: jest.fn().mockResolvedValue({ - execute: jest.fn().mockResolvedValue({ - data: null, - }), - }), - } as unknown as Props['actions']; + const actionsClient = actionsClientMock.create(); + actionsClient.execute.mockImplementationOnce( + jest.fn().mockResolvedValue({ + data: null, + }) + ); + testProps.actionsClient = actionsClient; try { - await executeAction({ ...testProps, actions }); + await executeAction({ ...testProps, actionsClient }); } catch (e) { expect(e.message).toBe('Action result status is error: result is not streamable'); } }); it('should throw an error if action result status is "error"', async () => { - const actions = { - getActionsClientWithRequest: jest.fn().mockResolvedValue({ - execute: jest.fn().mockResolvedValue({ - status: 'error', - message: 'Error message', - serviceMessage: 'Service error message', - }), - }), - } as unknown as ActionsPluginStart; + const actionsClient = actionsClientMock.create(); + actionsClient.execute.mockImplementationOnce( + jest.fn().mockResolvedValue({ + status: 'error', + message: 'Error message', + serviceMessage: 'Service error message', + }) + ); + testProps.actionsClient = actionsClient; await expect( executeAction({ ...testProps, - actions, + actionsClient, connectorId: '12345', }) ).rejects.toThrowError('Action result status is error: Error message - Service error message'); }); it('should throw an error if content of response data is not a string or streamable', async () => { - const actions = { - getActionsClientWithRequest: jest.fn().mockResolvedValue({ - execute: jest.fn().mockResolvedValue({ - status: 'ok', - data: { - message: 12345, - }, - }), - }), - } as unknown as ActionsPluginStart; + const actionsClient = actionsClientMock.create(); + actionsClient.execute.mockImplementationOnce( + jest.fn().mockResolvedValue({ + status: 'ok', + data: { + message: 12345, + }, + }) + ); + testProps.actionsClient = actionsClient; await expect( executeAction({ ...testProps, - actions, + actionsClient, connectorId: '12345', }) ).rejects.toThrowError('Action result status is error: result is not streamable'); diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.test.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.test.ts index e1e8cdc50eee0..6f05edbed007b 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.test.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.test.ts @@ -5,10 +5,11 @@ * 2.0. */ -import { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; import { elasticsearchServiceMock } from '@kbn/core-elasticsearch-server-mocks'; import { coreMock } from '@kbn/core/server/mocks'; import { KibanaRequest } from '@kbn/core/server'; +import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock'; + import { loggerMock } from '@kbn/logging-mocks'; import { initializeAgentExecutorWithOptions } from 'langchain/agents'; @@ -84,7 +85,7 @@ const mockRequest: KibanaRequest = { body: {} } as K any // eslint-disable-line @typescript-eslint/no-explicit-any >; -const mockActions: ActionsPluginStart = {} as ActionsPluginStart; +const actionsClient = actionsClientMock.create(); const mockLogger = loggerMock.create(); const mockTelemetry = coreMock.createSetup().analytics; const esClientMock = elasticsearchServiceMock.createScopedClusterClient().asCurrentUser; @@ -95,7 +96,7 @@ const esStoreMock = new ElasticsearchStore( mockTelemetry ); const defaultProps: AgentExecutorParams = { - actions: mockActions, + actionsClient, isEnabledKnowledgeBase: true, connectorId: mockConnectorId, esClient: esClientMock, @@ -151,11 +152,10 @@ describe('callAgentExecutor', () => { await callAgentExecutor(defaultProps); expect(ActionsClientChatOpenAI).toHaveBeenCalledWith({ - actions: mockActions, + actionsClient, connectorId: mockConnectorId, logger: mockLogger, maxRetries: 0, - request: mockRequest, streaming: false, temperature: 0.2, llmType: 'openai', @@ -189,11 +189,10 @@ describe('callAgentExecutor', () => { await callAgentExecutor({ ...defaultProps, isStream: true }); expect(ActionsClientChatOpenAI).toHaveBeenCalledWith({ - actions: mockActions, + actionsClient, connectorId: mockConnectorId, logger: mockLogger, maxRetries: 0, - request: mockRequest, streaming: true, temperature: 0.2, llmType: 'openai', @@ -213,11 +212,10 @@ describe('callAgentExecutor', () => { await callAgentExecutor(bedrockProps); expect(ActionsClientSimpleChatModel).toHaveBeenCalledWith({ - actions: mockActions, + actionsClient, connectorId: mockConnectorId, logger: mockLogger, maxRetries: 0, - request: mockRequest, streaming: false, temperature: 0, llmType: 'bedrock', @@ -253,11 +251,10 @@ describe('callAgentExecutor', () => { await callAgentExecutor({ ...bedrockProps, isStream: true }); expect(ActionsClientSimpleChatModel).toHaveBeenCalledWith({ - actions: mockActions, + actionsClient, connectorId: mockConnectorId, logger: mockLogger, maxRetries: 0, - request: mockRequest, streaming: true, temperature: 0, llmType: 'bedrock', diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts index 0a01a70b99216..b6a624b368d82 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts @@ -17,9 +17,9 @@ import { ActionsClientChatOpenAI, ActionsClientSimpleChatModel, } from '@kbn/langchain/server'; +import { MessagesPlaceholder } from '@langchain/core/prompts'; import { EsAnonymizationFieldsSchema } from '../../../ai_assistant_data_clients/anonymization_fields/types'; import { transformESSearchToAnonymizationFields } from '../../../ai_assistant_data_clients/anonymization_fields/helpers'; -import { MessagesPlaceholder } from '@langchain/core/prompts'; import { AgentExecutor } from '../executors/types'; import { APMTracer } from '../tracers/apm_tracer'; import { AssistantToolParams } from '../../../types'; diff --git a/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/helpers.test.ts b/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/helpers.test.ts index 7f4baec88e60e..60c20419a7bdb 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/helpers.test.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/helpers.test.ts @@ -7,6 +7,8 @@ import { AuthenticatedUser } from '@kbn/core-security-common'; import moment from 'moment'; +import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock'; + import { REQUIRED_FOR_ATTACK_DISCOVERY, addGenerationInterval, @@ -20,7 +22,6 @@ import { import { ActionsClientLlm } from '@kbn/langchain/server'; import { AttackDiscoveryDataClient } from '../../ai_assistant_data_clients/attack_discovery'; import { OpenAiProviderType } from '@kbn/stack-connectors-plugin/common/openai/constants'; -import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; import { elasticsearchServiceMock } from '@kbn/core-elasticsearch-server-mocks'; import { loggerMock } from '@kbn/logging-mocks'; import { KibanaRequest } from '@kbn/core-http-server'; @@ -87,7 +88,6 @@ const mockApiConfig = { const mockCurrentAd = transformESSearchToAttackDiscovery(getAttackDiscoverySearchEsMock())[0]; -const mockActions: ActionsPluginStart = {} as ActionsPluginStart; // eslint-disable-next-line @typescript-eslint/no-explicit-any const mockRequest: KibanaRequest = {} as unknown as KibanaRequest< unknown, @@ -114,14 +114,14 @@ describe('helpers', () => { describe('getAssistantToolParams', () => { const alertsIndexPattern = '.alerts-security.alerts-default'; const esClient = elasticsearchClientMock.createElasticsearchClient(); + const actionsClient = actionsClientMock.create(); const langChainTimeout = 1000; const latestReplacements = {}; const llm = new ActionsClientLlm({ - actions: mockActions, + actionsClient, connectorId: 'test-connecter-id', llmType: 'bedrock', logger: mockLogger, - request: mockRequest, temperature: 0, timeout: 580000, }); @@ -129,7 +129,7 @@ describe('helpers', () => { const size = 20; const mockParams = { - actions: {} as unknown as ActionsPluginStart, + actionsClient, alertsIndexPattern: 'alerts-*', anonymizationFields: [{ id: '1', field: 'field1', allowed: true, anonymized: true }], apiConfig: mockApiConfig, @@ -170,7 +170,7 @@ describe('helpers', () => { ]; const result = getAssistantToolParams({ - actions: mockParams.actions, + actionsClient, alertsIndexPattern, apiConfig: mockApiConfig, anonymizationFields, @@ -205,7 +205,7 @@ describe('helpers', () => { const anonymizationFields = undefined; const result = getAssistantToolParams({ - actions: mockParams.actions, + actionsClient, alertsIndexPattern, apiConfig: mockApiConfig, anonymizationFields, diff --git a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts index 5c2aca29ed84a..d4bafa8685ec7 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts @@ -7,6 +7,7 @@ import { AnalyticsServiceSetup, + IKibanaResponse, KibanaRequest, KibanaResponseFactory, Logger, @@ -23,12 +24,13 @@ import { } from '@kbn/elastic-assistant-common'; import { ILicense } from '@kbn/licensing-plugin/server'; import { i18n } from '@kbn/i18n'; -import { PublicMethodsOf } from '@kbn/utility-types'; +import { AwaitedProperties, PublicMethodsOf } from '@kbn/utility-types'; import { ActionsClient } from '@kbn/actions-plugin/server'; +import { AssistantFeatureKey } from '@kbn/elastic-assistant-common/impl/capabilities'; import { MINIMUM_AI_ASSISTANT_LICENSE } from '../../common/constants'; import { ESQL_RESOURCE, KNOWLEDGE_BASE_INDEX_PATTERN } from './knowledge_base/constants'; import { callAgentExecutor } from '../lib/langchain/execute_custom_llm_chain'; -import { getLlmType } from './utils'; +import { buildResponse, getLlmType } from './utils'; import { AgentExecutorParams, AssistantDataClients, @@ -608,3 +610,69 @@ export const updateConversationWithUserInput = async ({ } return updatedConversation; }; + +interface PerformChecksParams { + authenticatedUser?: boolean; + capability?: AssistantFeatureKey; + context: AwaitedProperties< + Pick + >; + license?: boolean; + request: KibanaRequest; + response: KibanaResponseFactory; +} + +/** + * Helper to perform checks for authenticated user, capability, and license. Perform all or one + * of the checks by providing relevant optional params. Check order is license, authenticated user, + * then capability. + * + * @param authenticatedUser - Whether to check for an authenticated user + * @param capability - Specific capability to check if enabled, e.g. `assistantModelEvaluation` + * @param context - Route context + * @param license - Whether to check for a valid license + * @param request - Route KibanaRequest + * @param response - Route KibanaResponseFactory + */ +export const performChecks = ({ + authenticatedUser, + capability, + context, + license, + request, + response, +}: PerformChecksParams): IKibanaResponse | undefined => { + const assistantResponse = buildResponse(response); + + if (license) { + if (!hasAIAssistantLicense(context.licensing.license)) { + return response.forbidden({ + body: { + message: UPGRADE_LICENSE_MESSAGE, + }, + }); + } + } + + if (authenticatedUser) { + if (context.elasticAssistant.getCurrentUser() == null) { + return assistantResponse.error({ + body: `Authenticated user not found`, + statusCode: 401, + }); + } + } + + if (capability) { + const pluginName = getPluginNameFromRequest({ + request, + defaultPluginName: DEFAULT_PLUGIN_NAME, + }); + const registeredFeatures = context.elasticAssistant.getRegisteredFeatures(pluginName); + if (!registeredFeatures[capability]) { + return response.notFound(); + } + } + + return undefined; +}; diff --git a/x-pack/plugins/integration_assistant/server/routes/categorization_routes.ts b/x-pack/plugins/integration_assistant/server/routes/categorization_routes.ts index 6654898bd0232..ccdd050641abd 100644 --- a/x-pack/plugins/integration_assistant/server/routes/categorization_routes.ts +++ b/x-pack/plugins/integration_assistant/server/routes/categorization_routes.ts @@ -63,9 +63,8 @@ export function registerCategorizationRoutes( const llmClass = isOpenAI ? ActionsClientChatOpenAI : ActionsClientSimpleChatModel; const model = new llmClass({ - actions: actionsPlugin, + actionsClient, connectorId: connector.id, - request: req, logger, llmType: isOpenAI ? 'openai' : 'bedrock', model: connector.config?.defaultModel, diff --git a/x-pack/plugins/integration_assistant/server/routes/ecs_routes.ts b/x-pack/plugins/integration_assistant/server/routes/ecs_routes.ts index ee461b94feba4..a8a5b7a0f8036 100644 --- a/x-pack/plugins/integration_assistant/server/routes/ecs_routes.ts +++ b/x-pack/plugins/integration_assistant/server/routes/ecs_routes.ts @@ -54,9 +54,8 @@ export function registerEcsRoutes(router: IRouter Date: Thu, 27 Jun 2024 19:23:11 -0700 Subject: [PATCH 13/23] fixed tests --- .../server/language_models/llm.test.ts | 28 +++++++++++-------- .../server/lib/executor.test.ts | 23 +++++---------- .../post_attack_discovery.test.ts | 2 ++ 3 files changed, 26 insertions(+), 27 deletions(-) diff --git a/x-pack/packages/kbn-langchain/server/language_models/llm.test.ts b/x-pack/packages/kbn-langchain/server/language_models/llm.test.ts index a5d92f89a93e5..26c6b38f904c3 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/llm.test.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/llm.test.ts @@ -18,6 +18,12 @@ const mockExecute = jest.fn().mockImplementation(() => ({ data: mockActionResponse, status: 'ok', })); +actionsClient.execute.mockImplementation( + jest.fn().mockImplementation(() => ({ + data: mockActionResponse, + status: 'ok', + })) +); const mockLogger = loggerMock.create(); @@ -79,13 +85,11 @@ describe('ActionsClientLlm', () => { }); it('rejects with the expected error when the action result status is error', async () => { - const hasErrorStatus = jest.fn().mockImplementation(() => ({ - message: 'action-result-message', - serviceMessage: 'action-result-service-message', - status: 'error', // <-- error status - })); - - actionsClient.execute.mockRejectedValueOnce(hasErrorStatus); + actionsClient.execute.mockImplementation(() => { + throw new Error( + 'ActionsClientLlm: action result status is error: action-result-message - action-result-service-message' + ); + }); const actionsClientLlm = new ActionsClientLlm({ actionsClient, connectorId, @@ -100,10 +104,12 @@ describe('ActionsClientLlm', () => { it('rejects with the expected error the message has invalid content', async () => { const invalidContent = { message: 1234 }; - mockExecute.mockImplementation(() => ({ - data: invalidContent, - status: 'ok', - })); + actionsClient.execute.mockImplementation( + jest.fn().mockResolvedValue({ + data: invalidContent, + status: 'ok', + }) + ); const actionsClientLlm = new ActionsClientLlm({ actionsClient, diff --git a/x-pack/plugins/elastic_assistant/server/lib/executor.test.ts b/x-pack/plugins/elastic_assistant/server/lib/executor.test.ts index d2d0031b5b485..a01ac3d126e59 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/executor.test.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/executor.test.ts @@ -39,15 +39,6 @@ describe('executeAction', () => { jest.clearAllMocks(); }); it('should execute an action and return a StaticResponse when the response from the actions framework is a string', async () => { - /* const actions = { - getActionsClientWithRequest: jest.fn().mockResolvedValue({ - execute: jest.fn().mockResolvedValue({ - data: { - message: 'Test message', - }, - }), - }), - } as unknown as Props['actions'];*/ testProps.actionsClient.execute = jest.fn().mockResolvedValue({ data: { message: 'Test message', @@ -66,13 +57,13 @@ describe('executeAction', () => { it('should execute an action and return a Readable object when the response from the actions framework is a stream', async () => { const readableStream = new PassThrough(); - const actionsClient = { - getActionsClientWithRequest: jest.fn().mockResolvedValue({ - execute: jest.fn().mockResolvedValue({ - data: readableStream, - }), - }), - } as unknown as Props['actionsClient']; + const actionsClient = actionsClientMock.create(); + actionsClient.execute.mockImplementationOnce( + jest.fn().mockResolvedValue({ + status: 'ok', + data: readableStream, + }) + ); const result = await executeAction({ ...testProps, actionsClient }); diff --git a/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/post_attack_discovery.test.ts b/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/post_attack_discovery.test.ts index 9ecfb5c2af333..cbd3e6063fbd2 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/post_attack_discovery.test.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/attack_discovery/post_attack_discovery.test.ts @@ -10,6 +10,7 @@ import { postAttackDiscoveryRoute } from './post_attack_discovery'; import { serverMock } from '../../__mocks__/server'; import { requestContextMock } from '../../__mocks__/request_context'; import { elasticsearchServiceMock } from '@kbn/core-elasticsearch-server-mocks'; +import { actionsMock } from '@kbn/actions-plugin/server/mocks'; import { AttackDiscoveryDataClient } from '../../ai_assistant_data_clients/attack_discovery'; import { transformESSearchToAttackDiscovery } from '../../ai_assistant_data_clients/attack_discovery/transforms'; import { getAttackDiscoverySearchEsMock } from '../../__mocks__/attack_discovery_schema.mock'; @@ -68,6 +69,7 @@ describe('postAttackDiscoveryRoute', () => { jest.clearAllMocks(); context.elasticAssistant.getCurrentUser.mockReturnValue(mockUser); context.elasticAssistant.getAttackDiscoveryDataClient.mockResolvedValue(mockDataClient); + context.elasticAssistant.actions = actionsMock.createStart(); postAttackDiscoveryRoute(server.router); findAttackDiscoveryByConnectorId.mockResolvedValue(mockCurrentAd); (getAssistantTool as jest.Mock).mockReturnValue({ getTool: jest.fn() }); From 42afcc60f31838b96e90a38dd02208290ee4a9d5 Mon Sep 17 00:00:00 2001 From: YulNaumenko Date: Thu, 27 Jun 2024 21:50:11 -0700 Subject: [PATCH 14/23] fixed tests --- .../language_models/chat_openai.test.ts | 2 ++ .../server/language_models/llm.test.ts | 5 +--- .../language_models/simple_chat_model.test.ts | 28 ++++++++++++------- .../server/lib/get_chat_params.test.ts | 4 +-- 4 files changed, 23 insertions(+), 16 deletions(-) diff --git a/x-pack/packages/kbn-langchain/server/language_models/chat_openai.test.ts b/x-pack/packages/kbn-langchain/server/language_models/chat_openai.test.ts index 92fd210fd7c53..7d01468b4d6f7 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/chat_openai.test.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/chat_openai.test.ts @@ -59,6 +59,7 @@ describe('ActionsClientChatOpenAI', () => { data: mockChatCompletion, status: 'ok', })); + actionsClient.execute.mockImplementation(mockExecute); }); describe('_llmType', () => { @@ -98,6 +99,7 @@ describe('ActionsClientChatOpenAI', () => { functions: [jest.fn()], }; it('returns the expected data', async () => { + actionsClient.execute.mockImplementation(mockStreamExecute); const actionsClientChatOpenAI = new ActionsClientChatOpenAI({ ...defaultArgs, streaming: true, diff --git a/x-pack/packages/kbn-langchain/server/language_models/llm.test.ts b/x-pack/packages/kbn-langchain/server/language_models/llm.test.ts index 26c6b38f904c3..aa33bbf7a6d44 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/llm.test.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/llm.test.ts @@ -14,10 +14,7 @@ import { mockActionResponse } from './mocks'; const connectorId = 'mock-connector-id'; const actionsClient = actionsClientMock.create(); -const mockExecute = jest.fn().mockImplementation(() => ({ - data: mockActionResponse, - status: 'ok', -})); + actionsClient.execute.mockImplementation( jest.fn().mockImplementation(() => ({ data: mockActionResponse, diff --git a/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.test.ts b/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.test.ts index 40c7e19cfdd80..98da9a4e81b53 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.test.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/simple_chat_model.test.ts @@ -79,6 +79,12 @@ jest.mock('../utils/gemini'); describe('ActionsClientSimpleChatModel', () => { beforeEach(() => { jest.clearAllMocks(); + actionsClient.execute.mockImplementation( + jest.fn().mockImplementation(() => ({ + data: mockActionResponse, + status: 'ok', + })) + ); mockExecute.mockImplementation(() => ({ data: mockActionResponse, status: 'ok', @@ -125,18 +131,18 @@ describe('ActionsClientSimpleChatModel', () => { callOptions, callRunManager ); - const subAction = mockExecute.mock.calls[0][0].params.subAction; + const subAction = actionsClient.execute.mock.calls[0][0].params.subAction; expect(subAction).toEqual('invokeAI'); expect(result).toEqual(mockActionResponse.message); }); it('rejects with the expected error when the action result status is error', async () => { - const hasErrorStatus = jest.fn().mockImplementation(() => ({ - message: 'action-result-message', - serviceMessage: 'action-result-service-message', - status: 'error', // <-- error status - })); + const hasErrorStatus = jest.fn().mockImplementation(() => { + throw Error( + 'ActionsClientSimpleChatModel: action result status is error: action-result-message - action-result-service-message' + ); + }); actionsClient.execute.mockRejectedValueOnce(hasErrorStatus); @@ -155,10 +161,12 @@ describe('ActionsClientSimpleChatModel', () => { it('rejects with the expected error the message has invalid content', async () => { const invalidContent = { message: 1234 }; - mockExecute.mockImplementation(() => ({ - data: invalidContent, - status: 'ok', - })); + actionsClient.execute.mockImplementation( + jest.fn().mockResolvedValue({ + data: invalidContent, + status: 'ok', + }) + ); const actionsClientSimpleChatModel = new ActionsClientSimpleChatModel(defaultArgs); diff --git a/x-pack/plugins/search_playground/server/lib/get_chat_params.test.ts b/x-pack/plugins/search_playground/server/lib/get_chat_params.test.ts index 8eb46a2cdb7e4..b05ac4a75c0ad 100644 --- a/x-pack/plugins/search_playground/server/lib/get_chat_params.test.ts +++ b/x-pack/plugins/search_playground/server/lib/get_chat_params.test.ts @@ -39,6 +39,7 @@ describe('getChatParams', () => { const actions = { getActionsClientWithRequest: jest.fn(() => Promise.resolve(mockActionsClient)), } as unknown as ActionsPluginStartContract; + const logger = jest.fn() as unknown as Logger; const request = jest.fn() as unknown as KibanaRequest; @@ -89,11 +90,10 @@ describe('getChatParams', () => { temperature: 0, llmType: 'bedrock', traceId: 'test-uuid', - request: expect.anything(), logger: expect.anything(), model: 'custom-model', connectorId: '2', - actions: expect.anything(), + actionsClient: expect.anything(), }); expect(result.chatPrompt).toContain('How does it work?'); }); From f8bce6c9d116341ba1f7d339ee69320ea7b2ae9c Mon Sep 17 00:00:00 2001 From: YulNaumenko Date: Fri, 28 Jun 2024 12:02:08 -0700 Subject: [PATCH 15/23] fixing more tests --- .../chat/post_chat_complete_route.gen.ts | 1 + .../chat/post_chat_complete_route.schema.yaml | 2 + .../routes/chat/chat_complete_route.test.ts | 130 ++++++++---------- .../server/routes/chat/chat_complete_route.ts | 13 +- .../server/routes/helpers.ts | 8 +- .../post_actions_connector_execute.test.ts | 80 ++++++----- .../routes/post_actions_connector_execute.ts | 7 +- 7 files changed, 123 insertions(+), 118 deletions(-) diff --git a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.gen.ts b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.gen.ts index 25ba5a68030b3..0b6c3bbe6cbb3 100644 --- a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.gen.ts +++ b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.gen.ts @@ -54,6 +54,7 @@ export type ChatCompleteProps = z.infer; export const ChatCompleteProps = z.object({ conversationId: z.string().optional(), promptId: z.string().optional(), + isStream: z.boolean().optional(), responseLanguage: z.string().optional(), langSmithProject: z.string().optional(), langSmithApiKey: z.string().optional(), diff --git a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.schema.yaml b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.schema.yaml index c8f783b5b5ab8..21c348251b039 100644 --- a/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.schema.yaml +++ b/x-pack/packages/kbn-elastic-assistant-common/impl/schemas/chat/post_chat_complete_route.schema.yaml @@ -85,6 +85,8 @@ components: type: string promptId: type: string + isStream: + type: boolean responseLanguage: type: string langSmithProject: diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts index 0ebe03980ec8d..a83ab071222b6 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts @@ -6,7 +6,7 @@ */ import { ElasticsearchClient, IRouter, KibanaRequest, Logger } from '@kbn/core/server'; -import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; +import type { ActionsClient } from '@kbn/actions-plugin/server'; import { BaseMessage } from '@langchain/core/messages'; import { NEVER } from 'rxjs'; import { mockActionResponse } from '../../__mocks__/action_result_data'; @@ -24,6 +24,10 @@ import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/act import { getFindAnonymizationFieldsResultWithSingleHit } from '../../__mocks__/response'; import { defaultAssistantFeatures } from '@kbn/elastic-assistant-common'; import { chatCompleteRoute } from './chat_complete_route'; +import { PublicMethodsOf } from '@kbn/utility-types'; +import { licensingMock } from '@kbn/licensing-plugin/server/mocks'; + +const license = licensingMock.createLicenseMock(); const actionsClient = actionsClientMock.create(); jest.mock('../../lib/build_response', () => ({ @@ -49,7 +53,7 @@ jest.mock('../../lib/langchain/execute_custom_llm_chain', () => ({ connectorId, isStream, }: { - actions: ActionsPluginStart; + actionsClient: PublicMethodsOf; connectorId: string; esClient: ElasticsearchClient; langChainMessages: BaseMessage[]; @@ -88,49 +92,54 @@ const existingConversation = getConversationResponseMock(); const reportEvent = jest.fn(); const appendConversationMessages = jest.fn(); const mockContext = { - elasticAssistant: { - actions: { - getActionsClientWithRequest: jest.fn().mockResolvedValue(actionsClient), + resolve: jest.fn().mockResolvedValue({ + elasticAssistant: { + actions: { + getActionsClientWithRequest: jest.fn().mockResolvedValue(actionsClient), + }, + getRegisteredTools: jest.fn(() => []), + getRegisteredFeatures: jest.fn(() => defaultAssistantFeatures), + logger: loggingSystemMock.createLogger(), + telemetry: { ...coreMock.createSetup().analytics, reportEvent }, + getCurrentUser: () => ({ + username: 'user', + email: 'email', + fullName: 'full name', + roles: ['user-role'], + enabled: true, + authentication_realm: { name: 'native1', type: 'native' }, + lookup_realm: { name: 'native1', type: 'native' }, + authentication_provider: { type: 'basic', name: 'basic1' }, + authentication_type: 'realm', + elastic_cloud_user: false, + metadata: { _reserved: false }, + }), + getAIAssistantConversationsDataClient: jest.fn().mockResolvedValue({ + getConversation: jest.fn().mockResolvedValue(existingConversation), + updateConversation: jest.fn().mockResolvedValue(existingConversation), + appendConversationMessages: + appendConversationMessages.mockResolvedValue(existingConversation), + }), + getAIAssistantAnonymizationFieldsDataClient: jest.fn().mockResolvedValue({ + findDocuments: jest.fn().mockResolvedValue(getFindAnonymizationFieldsResultWithSingleHit()), + }), }, - getRegisteredTools: jest.fn(() => []), - getRegisteredFeatures: jest.fn(() => defaultAssistantFeatures), - logger: loggingSystemMock.createLogger(), - telemetry: { ...coreMock.createSetup().analytics, reportEvent }, - getCurrentUser: () => ({ - username: 'user', - email: 'email', - fullName: 'full name', - roles: ['user-role'], - enabled: true, - authentication_realm: { name: 'native1', type: 'native' }, - lookup_realm: { name: 'native1', type: 'native' }, - authentication_provider: { type: 'basic', name: 'basic1' }, - authentication_type: 'realm', - elastic_cloud_user: false, - metadata: { _reserved: false }, - }), - getAIAssistantConversationsDataClient: jest.fn().mockResolvedValue({ - getConversation: jest.fn().mockResolvedValue(existingConversation), - updateConversation: jest.fn().mockResolvedValue(existingConversation), - appendConversationMessages: - appendConversationMessages.mockResolvedValue(existingConversation), - }), - getAIAssistantAnonymizationFieldsDataClient: jest.fn().mockResolvedValue({ - findDocuments: jest.fn().mockResolvedValue(getFindAnonymizationFieldsResultWithSingleHit()), - }), - }, - core: { - elasticsearch: { - client: elasticsearchServiceMock.createScopedClusterClient(), + core: { + elasticsearch: { + client: elasticsearchServiceMock.createScopedClusterClient(), + }, + savedObjects: coreMock.createRequestHandlerContext().savedObjects, }, - savedObjects: coreMock.createRequestHandlerContext().savedObjects, - }, + licensing: { + ...licensingMock.createRequestHandlerContext({ license }), + license, + }, + }), }; const mockRequest = { - params: { connectorId: 'mock-connector-id' }, body: { - connectorId: 'my-gen-ai', + connectorId: 'mock-connector-id', persist: true, messages: [ { @@ -149,7 +158,6 @@ const mockRequest = { 'event.module': 'system', 'process.executable': '/usr/libexec/biomesyncd', 'process.args': '/usr/libexec/biomesyncd', - message: 'Process biomesyncd (PID: 69516) by user yuliianaumenko STOPPED', }, }, ], @@ -169,6 +177,13 @@ describe('chatCompleteRoute', () => { beforeEach(() => { jest.clearAllMocks(); + license.hasAtLeast.mockReturnValue(true); + actionsClient.execute.mockImplementation( + jest.fn().mockResolvedValue(() => ({ + data: 'mockChatCompletion', + status: 'ok', + })) + ); actionsClient.getBulk.mockResolvedValue([ { id: '1', @@ -187,7 +202,7 @@ describe('chatCompleteRoute', () => { ]); }); - it('returns the expected response when isEnabledKnowledgeBase=false', async () => { + it('returns the expected response when using the existingConversation', async () => { const mockRouter = { versioned: { post: jest.fn().mockImplementation(() => { @@ -199,7 +214,7 @@ describe('chatCompleteRoute', () => { ...mockRequest, body: { ...mockRequest.body, - isEnabledKnowledgeBase: false, + conversationId: existingConversation.id, }, }, mockResponse @@ -211,34 +226,9 @@ describe('chatCompleteRoute', () => { data: mockActionResponse, status: 'ok', }, - }); - }), - }; - }), - }, - }; - - await chatCompleteRoute( - mockRouter as unknown as IRouter, - mockGetElser - ); - }); - - it('returns the expected response when isEnabledKnowledgeBase=true', async () => { - const mockRouter = { - versioned: { - post: jest.fn().mockImplementation(() => { - return { - addVersion: jest.fn().mockImplementation(async (_, handler) => { - const result = await handler(mockContext, mockRequest, mockResponse); - - expect(result).toEqual({ - body: { - connector_id: 'mock-connector-id', - data: mockActionResponse, - status: 'ok', + headers: { + 'content-type': 'application/json', }, - headers: { 'content-type': 'application/json' }, }); }), }; @@ -246,7 +236,7 @@ describe('chatCompleteRoute', () => { }, }; - await chatCompleteRoute( + chatCompleteRoute( mockRouter as unknown as IRouter, mockGetElser ); diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts index 145842cc7619c..890746d81dd32 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts @@ -19,7 +19,7 @@ import { import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; import { getRequestAbortedSignal } from '@kbn/data-plugin/server'; import { ElasticAssistantPluginRouter, GetElser } from '../../types'; -import { buildResponse } from '../utils'; +import { buildResponse } from '../../lib/build_response'; import { DEFAULT_PLUGIN_NAME, UPGRADE_LICENSE_MESSAGE, @@ -98,7 +98,7 @@ export const chatCompleteRoute = ( }; // get the actions plugin start contract from the request context: - const actions = (await context.elasticAssistant).actions; + const actions = ctx.elasticAssistant.actions; const actionsClient = await actions.getActionsClientWithRequest(request); const connectors = await actionsClient.getBulk({ ids: [connectorId] }); const actionTypeId = connectors.length > 0 ? connectors[0].actionTypeId : '.gen-ai'; @@ -158,9 +158,8 @@ export const chatCompleteRoute = ( defaultPluginName: DEFAULT_PLUGIN_NAME, logger, }); - const enableKnowledgeBaseByDefault = ( - await context.elasticAssistant - ).getRegisteredFeatures(pluginName).assistantKnowledgeBaseByDefault; + const enableKnowledgeBaseByDefault = + ctx.elasticAssistant.getRegisteredFeatures(pluginName).assistantKnowledgeBaseByDefault; // TODO: remove non-graph persistance when KB will be enabled by default if (!enableKnowledgeBaseByDefault && request.body.persist && conversationsDataClient) { const updatedConversation = await createOrUpdateConversationWithUserInput({ @@ -207,12 +206,12 @@ export const chatCompleteRoute = ( return await langChainExecute({ abortSignal, isEnabledKnowledgeBase: true, - isStream: false, + isStream: request.body.isStream ?? false, actionsClient, actionTypeId, connectorId, conversationId, - context, + context: ctx, getElser, logger, messages: messages ?? [], diff --git a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts index d4bafa8685ec7..fb6629d58a0cf 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts @@ -330,7 +330,9 @@ export interface LangChainExecuteParams { actionTypeId: string; connectorId: string; conversationId?: string; - context: ElasticAssistantRequestHandlerContext; + context: AwaitedProperties< + Pick + >; actionsClient: PublicMethodsOf; // eslint-disable-next-line @typescript-eslint/no-explicit-any request: KibanaRequest; @@ -374,13 +376,13 @@ export const langChainExecute = async ({ defaultPluginName: DEFAULT_PLUGIN_NAME, logger, }); - const assistantContext = await context.elasticAssistant; + const assistantContext = context.elasticAssistant; const assistantTools = assistantContext .getRegisteredTools(pluginName) .filter((x) => x.id !== 'attack-discovery'); // We don't (yet) support asking the assistant for NEW attack discoveries from a conversation // get a scoped esClient for assistant memory - const esClient = (await context.core).elasticsearch.client.asCurrentUser; + const esClient = context.core.elasticsearch.client.asCurrentUser; // convert the assistant messages to LangChain messages: const langChainMessages = getLangChainMessages(messages); diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.test.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.test.ts index 5ee8d8e83c846..91c2cdf18aa9d 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.test.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.test.ts @@ -24,7 +24,9 @@ import { getConversationResponseMock } from '../ai_assistant_data_clients/conver import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock'; import { getFindAnonymizationFieldsResultWithSingleHit } from '../__mocks__/response'; import { defaultAssistantFeatures } from '@kbn/elastic-assistant-common'; +import { licensingMock } from '@kbn/licensing-plugin/server/mocks'; +const license = licensingMock.createLicenseMock(); const actionsClient = actionsClientMock.create(); jest.mock('../lib/build_response', () => ({ buildResponse: jest.fn().mockImplementation((x) => x), @@ -88,43 +90,49 @@ const existingConversation = getConversationResponseMock(); const reportEvent = jest.fn(); const appendConversationMessages = jest.fn(); const mockContext = { - elasticAssistant: { - actions: { - getActionsClientWithRequest: jest.fn().mockResolvedValue(actionsClient), + resolve: jest.fn().mockResolvedValue({ + elasticAssistant: { + actions: { + getActionsClientWithRequest: jest.fn().mockResolvedValue(actionsClient), + }, + getRegisteredTools: jest.fn(() => []), + getRegisteredFeatures: jest.fn(() => defaultAssistantFeatures), + logger: loggingSystemMock.createLogger(), + telemetry: { ...coreMock.createSetup().analytics, reportEvent }, + getCurrentUser: () => ({ + username: 'user', + email: 'email', + fullName: 'full name', + roles: ['user-role'], + enabled: true, + authentication_realm: { name: 'native1', type: 'native' }, + lookup_realm: { name: 'native1', type: 'native' }, + authentication_provider: { type: 'basic', name: 'basic1' }, + authentication_type: 'realm', + elastic_cloud_user: false, + metadata: { _reserved: false }, + }), + getAIAssistantConversationsDataClient: jest.fn().mockResolvedValue({ + getConversation: jest.fn().mockResolvedValue(existingConversation), + updateConversation: jest.fn().mockResolvedValue(existingConversation), + appendConversationMessages: + appendConversationMessages.mockResolvedValue(existingConversation), + }), + getAIAssistantAnonymizationFieldsDataClient: jest.fn().mockResolvedValue({ + findDocuments: jest.fn().mockResolvedValue(getFindAnonymizationFieldsResultWithSingleHit()), + }), }, - getRegisteredTools: jest.fn(() => []), - getRegisteredFeatures: jest.fn(() => defaultAssistantFeatures), - logger: loggingSystemMock.createLogger(), - telemetry: { ...coreMock.createSetup().analytics, reportEvent }, - getCurrentUser: () => ({ - username: 'user', - email: 'email', - fullName: 'full name', - roles: ['user-role'], - enabled: true, - authentication_realm: { name: 'native1', type: 'native' }, - lookup_realm: { name: 'native1', type: 'native' }, - authentication_provider: { type: 'basic', name: 'basic1' }, - authentication_type: 'realm', - elastic_cloud_user: false, - metadata: { _reserved: false }, - }), - getAIAssistantConversationsDataClient: jest.fn().mockResolvedValue({ - getConversation: jest.fn().mockResolvedValue(existingConversation), - updateConversation: jest.fn().mockResolvedValue(existingConversation), - appendConversationMessages: - appendConversationMessages.mockResolvedValue(existingConversation), - }), - getAIAssistantAnonymizationFieldsDataClient: jest.fn().mockResolvedValue({ - findDocuments: jest.fn().mockResolvedValue(getFindAnonymizationFieldsResultWithSingleHit()), - }), - }, - core: { - elasticsearch: { - client: elasticsearchServiceMock.createScopedClusterClient(), + core: { + elasticsearch: { + client: elasticsearchServiceMock.createScopedClusterClient(), + }, + savedObjects: coreMock.createRequestHandlerContext().savedObjects, }, - savedObjects: coreMock.createRequestHandlerContext().savedObjects, - }, + licensing: { + ...licensingMock.createRequestHandlerContext({ license }), + license, + }, + }), }; const mockRequest = { @@ -153,6 +161,7 @@ describe('postActionsConnectorExecuteRoute', () => { beforeEach(() => { jest.clearAllMocks(); + license.hasAtLeast.mockReturnValue(true); actionsClient.getBulk.mockResolvedValue([ { id: '1', @@ -195,6 +204,7 @@ describe('postActionsConnectorExecuteRoute', () => { data: mockActionResponse, status: 'ok', }, + headers: { 'content-type': 'application/json' }, }); }), }; diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index 800239f030664..2a2ce03f4043c 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -58,7 +58,8 @@ export const postActionsConnectorExecuteRoute = ( const abortSignal = getRequestAbortedSignal(request.events.aborted$); const resp = buildResponse(response); - const assistantContext = await context.elasticAssistant; + const ctx = await context.resolve(['core', 'elasticAssistant', 'licensing']); + const assistantContext = ctx.elasticAssistant; const logger: Logger = assistantContext.logger; const telemetry = assistantContext.telemetry; let onLlmResponse; @@ -90,7 +91,7 @@ export const postActionsConnectorExecuteRoute = ( } // get the actions plugin start contract from the request context: - const actions = (await context.elasticAssistant).actions; + const actions = ctx.elasticAssistant.actions; const actionsClient = await actions.getActionsClientWithRequest(request); const conversationsDataClient = @@ -172,7 +173,7 @@ export const postActionsConnectorExecuteRoute = ( actionTypeId, connectorId, conversationId, - context, + context: ctx, getElser, logger, messages: (isGraphAvailable && newMessage ? [newMessage] : messages) ?? [], From a45db819b82007af8a74e59999989637f863223b Mon Sep 17 00:00:00 2001 From: YulNaumenko Date: Fri, 28 Jun 2024 14:18:49 -0700 Subject: [PATCH 16/23] fixed data and streaming --- .../routes/chat/chat_complete_route.test.ts | 211 +----------------- .../server/routes/chat/chat_complete_route.ts | 6 +- .../routes/post_actions_connector_execute.ts | 2 +- 3 files changed, 12 insertions(+), 207 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts index a83ab071222b6..b37ee305af91e 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts @@ -117,6 +117,7 @@ const mockContext = { getAIAssistantConversationsDataClient: jest.fn().mockResolvedValue({ getConversation: jest.fn().mockResolvedValue(existingConversation), updateConversation: jest.fn().mockResolvedValue(existingConversation), + createConversation: jest.fn().mockResolvedValue(existingConversation), appendConversationMessages: appendConversationMessages.mockResolvedValue(existingConversation), }), @@ -245,7 +246,7 @@ describe('chatCompleteRoute', () => { it('returns the expected error when executeCustomLlmChain fails', async () => { const requestWithBadConnectorId = { ...mockRequest, - params: { connectorId: 'bad-connector-id' }, + connectorId: 'bad-connector-id', }; const mockRouter = { @@ -307,8 +308,6 @@ describe('chatCompleteRoute', () => { { id: '@timestamp', field: '@timestamp', allowed: true, anonymized: false }, { id: 'host.name', field: 'host.name', allowed: true, anonymized: true }, ], - replacements: [], - isEnabledRAGAlerts: true, }, }; @@ -348,8 +347,6 @@ describe('chatCompleteRoute', () => { { id: '@timestamp', field: '@timestamp', allowed: true, anonymized: false }, { id: 'host.name', field: 'host.name', allowed: true, anonymized: true }, ], - replacements: [], - isEnabledRAGAlerts: true, }, }; @@ -417,7 +414,7 @@ describe('chatCompleteRoute', () => { it('reports error events to telemetry - kb on, RAG alerts off', async () => { const requestWithBadConnectorId = { ...mockRequest, - params: { connectorId: 'bad-connector-id' }, + connectorId: 'bad-connector-id', }; const mockRouter = { @@ -447,93 +444,13 @@ describe('chatCompleteRoute', () => { ); }); - it('reports error events to telemetry - kb on, RAG alerts on', async () => { - const badRequest = { - ...mockRequest, - params: { connectorId: 'bad-connector-id' }, - body: { - ...mockRequest.body, - isEnabledRAGAlerts: true, - }, - }; - - const mockRouter = { - versioned: { - post: jest.fn().mockImplementation(() => { - return { - addVersion: jest.fn().mockImplementation(async (_, handler) => { - await handler(mockContext, badRequest, mockResponse); - - expect(reportEvent).toHaveBeenCalledWith(INVOKE_ASSISTANT_ERROR_EVENT.eventType, { - errorMessage: 'simulated error', - isEnabledKnowledgeBase: true, - isEnabledRAGAlerts: true, - actionTypeId: '.gen-ai', - model: 'gpt-4', - assistantStreamingEnabled: false, - }); - }), - }; - }), - }, - }; - - await chatCompleteRoute( - mockRouter as unknown as IRouter, - mockGetElser - ); - }); - - it('reports error events to telemetry - kb off, RAG alerts on', async () => { - const badRequest = { - ...mockRequest, - params: { connectorId: 'bad-connector-id' }, - body: { - ...mockRequest.body, - isEnabledKnowledgeBase: false, - anonymizationFields: [ - { id: '@timestamp', field: '@timestamp', allowed: true, anonymized: false }, - { id: 'host.name', field: 'host.name', allowed: true, anonymized: true }, - ], - replacements: [], - isEnabledRAGAlerts: true, - }, - }; - - const mockRouter = { - versioned: { - post: jest.fn().mockImplementation(() => { - return { - addVersion: jest.fn().mockImplementation(async (_, handler) => { - await handler(mockContext, badRequest, mockResponse); - - expect(reportEvent).toHaveBeenCalledWith(INVOKE_ASSISTANT_ERROR_EVENT.eventType, { - errorMessage: 'simulated error', - isEnabledKnowledgeBase: false, - isEnabledRAGAlerts: true, - actionTypeId: '.gen-ai', - model: 'gpt-4', - assistantStreamingEnabled: false, - }); - }), - }; - }), - }, - }; - - await chatCompleteRoute( - mockRouter as unknown as IRouter, - mockGetElser - ); - }); - it('Adds error to conversation history', async () => { const badRequest = { ...mockRequest, - params: { connectorId: 'bad-connector-id' }, body: { ...mockRequest.body, conversationId: '99999', + connectorId: 'bad-connector-id', }, }; @@ -562,44 +479,7 @@ describe('chatCompleteRoute', () => { ); }); - it('reports error events to telemetry - kb off, RAG alerts off', async () => { - const badRequest = { - ...mockRequest, - params: { connectorId: 'bad-connector-id' }, - body: { - ...mockRequest.body, - isEnabledKnowledgeBase: false, - }, - }; - - const mockRouter = { - versioned: { - post: jest.fn().mockImplementation(() => { - return { - addVersion: jest.fn().mockImplementation(async (_, handler) => { - await handler(mockContext, badRequest, mockResponse); - - expect(reportEvent).toHaveBeenCalledWith(INVOKE_ASSISTANT_ERROR_EVENT.eventType, { - errorMessage: 'simulated error', - isEnabledKnowledgeBase: false, - isEnabledRAGAlerts: false, - actionTypeId: '.gen-ai', - model: 'gpt-4', - assistantStreamingEnabled: false, - }); - }), - }; - }), - }, - }; - - await chatCompleteRoute( - mockRouter as unknown as IRouter, - mockGetElser - ); - }); - - it('returns the expected response when subAction=invokeStream and actionTypeId=.gen-ai', async () => { + it('returns the expected response when isStream=true and actionTypeId=.gen-ai', async () => { const mockRouter = { versioned: { post: jest.fn().mockImplementation(() => { @@ -611,8 +491,7 @@ describe('chatCompleteRoute', () => { ...mockRequest, body: { ...mockRequest.body, - subAction: 'invokeStream', - actionTypeId: '.gen-ai', + isStream: true, }, }, mockResponse @@ -640,7 +519,7 @@ describe('chatCompleteRoute', () => { ); }); - it('returns the expected response when subAction=invokeStream and actionTypeId=.bedrock', async () => { + it('returns the expected response when isStream=true and actionTypeId=.bedrock', async () => { const mockRouter = { versioned: { post: jest.fn().mockImplementation(() => { @@ -652,8 +531,7 @@ describe('chatCompleteRoute', () => { ...mockRequest, body: { ...mockRequest.body, - subAction: 'invokeStream', - actionTypeId: '.bedrock', + isStream: true, }, }, mockResponse @@ -679,77 +557,4 @@ describe('chatCompleteRoute', () => { mockGetElser ); }); - - it('returns the expected response when subAction=invokeAI and actionTypeId=.gen-ai', async () => { - const mockRouter = { - versioned: { - post: jest.fn().mockImplementation(() => { - return { - addVersion: jest.fn().mockImplementation(async (_, handler) => { - const result = await handler( - mockContext, - { - ...mockRequest, - body: { - ...mockRequest.body, - subAction: 'invokeAI', - actionTypeId: '.gen-ai', - }, - }, - mockResponse - ); - - expect(result).toEqual({ - body: { connector_id: 'mock-connector-id', data: mockActionResponse, status: 'ok' }, - headers: { - 'content-type': 'application/json', - }, - }); - }), - }; - }), - }, - }; - - await chatCompleteRoute( - mockRouter as unknown as IRouter, - mockGetElser - ); - }); - - it('returns the expected response when subAction=invokeAI and actionTypeId=.bedrock', async () => { - const mockRouter = { - versioned: { - post: jest.fn().mockImplementation(() => { - return { - addVersion: jest.fn().mockImplementation(async (_, handler) => { - const result = await handler( - mockContext, - { - ...mockRequest, - body: { - ...mockRequest.body, - subAction: 'invokeAI', - actionTypeId: '.bedrock', - }, - }, - mockResponse - ); - - expect(result).toEqual({ - body: { connector_id: 'mock-connector-id', data: mockActionResponse, status: 'ok' }, - headers: { - 'content-type': 'application/json', - }, - }); - }), - }; - }), - }, - }; - await chatCompleteRoute( - mockRouter as unknown as IRouter, - mockGetElser - ); - }); }); diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts index 890746d81dd32..dbeb9c93531e9 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts @@ -137,12 +137,11 @@ export const chatCompleteRoute = ( getAnonymizedValue, onNewReplacements, rawData: Object.keys(m.data).reduce( - (obj, key) => ({ ...obj, key: [m.data ? m.data[key] : ''] }), + (obj, key) => ({ ...obj, [key]: [m.data ? m.data[key] : ''] }), {} ), }); const wr = `${SYSTEM_PROMPT_CONTEXT_NON_I18N(anonymizedData)}\n`; - content = `${wr}\n${m.content}`; } const transformedMessage = { @@ -175,8 +174,9 @@ export const chatCompleteRoute = ( model: request.body.model, }); if (updatedConversation == null) { - return response.badRequest({ + return assistantResponse.error({ body: `conversation id: "${conversationId}" not updated`, + statusCode: 400, }); } // messages are anonymized by conversationsDataClient diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index 2a2ce03f4043c..af095dbb47343 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -149,7 +149,7 @@ export const postActionsConnectorExecuteRoute = ( } }; - if (!isGraphAvailable && !request.body.isEnabledRAGAlerts) { + if (!request.body.isEnabledKnowledgeBase && !request.body.isEnabledRAGAlerts) { // if not langchain, call execute action directly and return the response: return await nonLangChainExecute({ abortSignal, From ff4f1b84943458cde2c53f63a30f523bccb6d2fd Mon Sep 17 00:00:00 2001 From: YulNaumenko Date: Sun, 30 Jun 2024 21:49:41 -0700 Subject: [PATCH 17/23] changes from comments --- .../graphs/default_assistant_graph/graph.ts | 2 + .../nodes/persist_conversation_changes.ts | 9 +++- .../server/routes/chat/chat_complete_route.ts | 42 ++++++++++--------- .../server/routes/helpers.ts | 2 +- .../routes/user_conversations/create_route.ts | 31 +++++++------- .../routes/user_conversations/update_route.ts | 29 ++++++------- 6 files changed, 60 insertions(+), 55 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts index c9be3e0eb0e48..b16f7d9693e5f 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts @@ -60,6 +60,7 @@ export const getDefaultAssistantGraph = ({ logger, responseLanguage, tools, + replacements, }: GetDefaultAssistantGraphParams) => { try { // Default graph state @@ -139,6 +140,7 @@ export const getDefaultAssistantGraph = ({ state, conversationsDataClient: dataClients?.conversationsDataClient, conversationId, + replacements, }); const shouldContinueEdge = (state: AgentState) => shouldContinue({ ...nodeParams, state }); const shouldContinueGenerateTitleEdge = (state: AgentState) => diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/persist_conversation_changes.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/persist_conversation_changes.ts index 39331310e0f0e..a86897e67adbf 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/persist_conversation_changes.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/persist_conversation_changes.ts @@ -5,7 +5,10 @@ * 2.0. */ -import { replaceAnonymizedValuesWithOriginalValues } from '@kbn/elastic-assistant-common'; +import { + Replacements, + replaceAnonymizedValuesWithOriginalValues, +} from '@kbn/elastic-assistant-common'; import { AgentState, NodeParamsBase } from '../types'; import { AIAssistantConversationsDataClient } from '../../../../../ai_assistant_data_clients/conversations'; import { getLangChainMessages } from '../../../helpers'; @@ -14,6 +17,7 @@ export interface PersistConversationChangesParams extends NodeParamsBase { conversationsDataClient?: AIAssistantConversationsDataClient; conversationId?: string; state: AgentState; + replacements?: Replacements; } export const PERSIST_CONVERSATION_CHANGES_NODE = 'persistConversationChanges'; @@ -23,6 +27,7 @@ export const persistConversationChanges = async ({ conversationId, logger, state, + replacements = {}, }: PersistConversationChangesParams) => { logger.debug(`Node state:\n ${JSON.stringify(state, null, 2)}`); @@ -50,7 +55,7 @@ export const persistConversationChanges = async ({ { content: replaceAnonymizedValuesWithOriginalValues({ messageContent: state.input, - replacements: {}, + replacements, }), role: 'user', timestamp: new Date().toISOString(), diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts index dbeb9c93531e9..3a5fc2a9d1cf5 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts @@ -15,6 +15,7 @@ import { Replacements, transformRawData, getAnonymizedValue, + ConversationResponse, } from '@kbn/elastic-assistant-common'; import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; import { getRequestAbortedSignal } from '@kbn/data-plugin/server'; @@ -22,12 +23,11 @@ import { ElasticAssistantPluginRouter, GetElser } from '../../types'; import { buildResponse } from '../../lib/build_response'; import { DEFAULT_PLUGIN_NAME, - UPGRADE_LICENSE_MESSAGE, appendAssistantMessageToConversation, createOrUpdateConversationWithUserInput, getPluginNameFromRequest, - hasAIAssistantLicense, langChainExecute, + performChecks, } from '../helpers'; import { transformESSearchToAnonymizationFields } from '../../ai_assistant_data_clients/anonymization_fields/helpers'; import { EsAnonymizationFieldsSchema } from '../../ai_assistant_data_clients/anonymization_fields/types'; @@ -66,20 +66,17 @@ export const chatCompleteRoute = ( const logger: Logger = ctx.elasticAssistant.logger; const telemetry = ctx.elasticAssistant.telemetry; - const license = ctx.licensing.license; - if (!hasAIAssistantLicense(license)) { - return response.forbidden({ - body: { - message: UPGRADE_LICENSE_MESSAGE, - }, - }); - } - const authenticatedUser = ctx.elasticAssistant.getCurrentUser(); - if (authenticatedUser == null) { - return assistantResponse.error({ - body: `Authenticated user not found`, - statusCode: 401, - }); + // Perform license, authenticated user and FF checks + const checkResponse = performChecks({ + authenticatedUser: true, + capability: 'assistantKnowledgeBaseByDefault', + context: ctx, + license: true, + request, + response, + }); + if (checkResponse) { + return checkResponse; } const conversationsDataClient = @@ -151,6 +148,7 @@ export const chatCompleteRoute = ( return transformedMessage; }); + let updatedConversation: ConversationResponse | undefined | null; // Fetch any tools registered by the request's originating plugin const pluginName = getPluginNameFromRequest({ request, @@ -160,8 +158,12 @@ export const chatCompleteRoute = ( const enableKnowledgeBaseByDefault = ctx.elasticAssistant.getRegisteredFeatures(pluginName).assistantKnowledgeBaseByDefault; // TODO: remove non-graph persistance when KB will be enabled by default - if (!enableKnowledgeBaseByDefault && request.body.persist && conversationsDataClient) { - const updatedConversation = await createOrUpdateConversationWithUserInput({ + if ( + (!enableKnowledgeBaseByDefault || (enableKnowledgeBaseByDefault && !conversationId)) && + request.body.persist && + conversationsDataClient + ) { + updatedConversation = await createOrUpdateConversationWithUserInput({ actionsClient, actionTypeId, connectorId, @@ -191,9 +193,9 @@ export const chatCompleteRoute = ( traceData: Message['traceData'] = {}, isError = false ): Promise => { - if (conversationId && conversationsDataClient) { + if (updatedConversation?.id && conversationsDataClient) { await appendAssistantMessageToConversation({ - conversationId, + conversationId: updatedConversation?.id, conversationsDataClient, messageContent: content, replacements: latestReplacements, diff --git a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts index fb6629d58a0cf..f59d465b64080 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts @@ -454,7 +454,7 @@ export const langChainExecute = async ({ // New code path for LangGraph implementation, behind `assistantKnowledgeBaseByDefault` FF let result: StreamResponseWithHeaders | StaticReturnType; - if (enableKnowledgeBaseByDefault) { + if (enableKnowledgeBaseByDefault && request.body.isEnabledKnowledgeBase) { result = await callAssistantGraph(executorParams); } else { result = await callAgentExecutor(executorParams); diff --git a/x-pack/plugins/elastic_assistant/server/routes/user_conversations/create_route.ts b/x-pack/plugins/elastic_assistant/server/routes/user_conversations/create_route.ts index 2d51939eb528b..551d1b2ee5880 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/user_conversations/create_route.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/user_conversations/create_route.ts @@ -16,7 +16,7 @@ import { import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; import { ElasticAssistantPluginRouter } from '../../types'; import { buildResponse } from '../utils'; -import { UPGRADE_LICENSE_MESSAGE, hasAIAssistantLicense } from '../helpers'; +import { performChecks } from '../helpers'; export const createConversationRoute = (router: ElasticAssistantPluginRouter): void => { router.versioned @@ -41,27 +41,26 @@ export const createConversationRoute = (router: ElasticAssistantPluginRouter): v const assistantResponse = buildResponse(response); try { const ctx = await context.resolve(['core', 'elasticAssistant', 'licensing']); - const license = ctx.licensing.license; - if (!hasAIAssistantLicense(license)) { - return response.forbidden({ - body: { - message: UPGRADE_LICENSE_MESSAGE, - }, - }); - } - const authenticatedUser = ctx.elasticAssistant.getCurrentUser(); - if (authenticatedUser == null) { - return assistantResponse.error({ - body: `Authenticated user not found`, - statusCode: 401, - }); + // Perform license, authenticated user and FF checks + const checkResponse = performChecks({ + authenticatedUser: true, + capability: 'assistantKnowledgeBaseByDefault', + context: ctx, + license: true, + request, + response, + }); + if (checkResponse) { + return checkResponse; } const dataClient = await ctx.elasticAssistant.getAIAssistantConversationsDataClient(); const result = await dataClient?.findDocuments({ perPage: 100, page: 1, - filter: `users:{ id: "${authenticatedUser?.profile_uid}" } AND title:${request.body.title}`, + filter: `users:{ id: "${ + ctx.elasticAssistant.getCurrentUser()?.profile_uid + }" } AND title:${request.body.title}`, fields: ['title'], }); if (result?.data != null && result.total > 0) { diff --git a/x-pack/plugins/elastic_assistant/server/routes/user_conversations/update_route.ts b/x-pack/plugins/elastic_assistant/server/routes/user_conversations/update_route.ts index 046785166a109..c0798a45cb7f7 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/user_conversations/update_route.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/user_conversations/update_route.ts @@ -19,7 +19,7 @@ import { UpdateConversationRequestParams } from '@kbn/elastic-assistant-common/i import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; import { ElasticAssistantPluginRouter } from '../../types'; import { buildResponse } from '../utils'; -import { UPGRADE_LICENSE_MESSAGE, hasAIAssistantLicense } from '../helpers'; +import { performChecks } from '../helpers'; export const updateConversationRoute = (router: ElasticAssistantPluginRouter) => { router.versioned @@ -45,22 +45,20 @@ export const updateConversationRoute = (router: ElasticAssistantPluginRouter) => const { id } = request.params; try { const ctx = await context.resolve(['core', 'elasticAssistant', 'licensing']); - const license = ctx.licensing.license; - if (!hasAIAssistantLicense(license)) { - return response.forbidden({ - body: { - message: UPGRADE_LICENSE_MESSAGE, - }, - }); - } - const authenticatedUser = ctx.elasticAssistant.getCurrentUser(); - if (authenticatedUser == null) { - return assistantResponse.error({ - body: `Authenticated user not found`, - statusCode: 401, - }); + // Perform license, authenticated user and FF checks + const checkResponse = performChecks({ + authenticatedUser: true, + capability: 'assistantKnowledgeBaseByDefault', + context: ctx, + license: true, + request, + response, + }); + if (checkResponse) { + return checkResponse; } + const dataClient = await ctx.elasticAssistant.getAIAssistantConversationsDataClient(); const existingConversation = await dataClient?.getConversation({ id, authenticatedUser }); @@ -72,7 +70,6 @@ export const updateConversationRoute = (router: ElasticAssistantPluginRouter) => } const conversation = await dataClient?.updateConversation({ conversationUpdateProps: request.body, - authenticatedUser, }); if (conversation == null) { return assistantResponse.error({ From b289452e732f5780876cadb4811f71d8cbc2aa55 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Mon, 1 Jul 2024 11:27:26 -0600 Subject: [PATCH 18/23] fix lint --- .../lib/langchain/graphs/default_assistant_graph/helpers.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index 83c118b130546..c9565e0e6e4d6 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -134,7 +134,7 @@ export const streamGraph = async ({ }; // Start processing events, do not await! Return `responseWithHeaders` immediately - await processEvent(); + void processEvent(); return responseWithHeaders; }; From 15fc400118da813957d0e20910c5b029f865659a Mon Sep 17 00:00:00 2001 From: kibanamachine <42973632+kibanamachine@users.noreply.github.com> Date: Mon, 1 Jul 2024 18:22:21 +0000 Subject: [PATCH 19/23] [CI] Auto-commit changed files from 'node scripts/eslint --no-cache --fix' --- .../server/routes/categorization_routes.ts | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/x-pack/plugins/integration_assistant/server/routes/categorization_routes.ts b/x-pack/plugins/integration_assistant/server/routes/categorization_routes.ts index def8196b74220..2dbdc63210a59 100644 --- a/x-pack/plugins/integration_assistant/server/routes/categorization_routes.ts +++ b/x-pack/plugins/integration_assistant/server/routes/categorization_routes.ts @@ -64,17 +64,17 @@ export function registerCategorizationRoutes( const isOpenAI = connector.actionTypeId === '.gen-ai'; const llmClass = isOpenAI ? ActionsClientChatOpenAI : ActionsClientSimpleChatModel; - const model = new llmClass({ - actionsClient, - connectorId: connector.id, - logger, - llmType: isOpenAI ? 'openai' : 'bedrock', - model: connector.config?.defaultModel, - temperature: 0.05, - maxTokens: 4096, - signal: abortSignal, - streaming: false, - }); + const model = new llmClass({ + actionsClient, + connectorId: connector.id, + logger, + llmType: isOpenAI ? 'openai' : 'bedrock', + model: connector.config?.defaultModel, + temperature: 0.05, + maxTokens: 4096, + signal: abortSignal, + streaming: false, + }); const graph = await getCategorizationGraph(client, model); const results = await graph.invoke({ From 09e3c2b8b8defa48924869910924742040ac68a3 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Mon, 1 Jul 2024 13:56:42 -0600 Subject: [PATCH 20/23] fix streaming --- .../graphs/default_assistant_graph/helpers.ts | 46 ++++++++++--------- .../nodes/run_agent.ts | 6 ++- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index c9565e0e6e4d6..482c89c10e969 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -11,6 +11,7 @@ import { streamFactory, StreamResponseWithHeaders } from '@kbn/ml-response-strea import { transformError } from '@kbn/securitysolution-es-utils'; import type { KibanaRequest } from '@kbn/core-http-server'; import type { ExecuteConnectorRequestBody, TraceData } from '@kbn/elastic-assistant-common'; +import { AGENT_NODE_TAG } from './nodes/run_agent'; import { DEFAULT_ASSISTANT_GRAPH_ID, DefaultAssistantGraph } from './graph'; import type { OnLlmResponse, TraceOptions } from '../../executors/types'; import type { APMTracer } from '../../tracers/apm_tracer'; @@ -91,32 +92,35 @@ export const streamGraph = async ({ if (done) return; const event = value; - if (event.event === 'on_llm_stream') { - const chunk = event.data?.chunk; - // TODO: For Bedrock streaming support, override `handleLLMNewToken` in callbacks, - // TODO: or maybe we can update ActionsClientSimpleChatModel to handle this `on_llm_stream` event - if (event.name === 'ActionsClientChatOpenAI') { - const msg = chunk.message; - - if (msg.tool_call_chunks && msg.tool_call_chunks.length > 0) { - /* empty */ - } else if (!didEnd) { - if (msg.response_metadata.finish_reason === 'stop') { - handleStreamEnd(finalMessage); - } else { - push({ payload: msg.content, type: 'content' }); - finalMessage += msg.content; + // only process events that are part of the agent run + if ((event.tags || []).includes(AGENT_NODE_TAG)) { + if (event.event === 'on_llm_stream') { + const chunk = event.data?.chunk; + // TODO: For Bedrock streaming support, override `handleLLMNewToken` in callbacks, + // TODO: or maybe we can update ActionsClientSimpleChatModel to handle this `on_llm_stream` event + if (event.name === 'ActionsClientChatOpenAI') { + const msg = chunk.message; + + if (msg.tool_call_chunks && msg.tool_call_chunks.length > 0) { + /* empty */ + } else if (!didEnd) { + if (msg.response_metadata.finish_reason === 'stop') { + handleStreamEnd(finalMessage); + } else { + push({ payload: msg.content, type: 'content' }); + finalMessage += msg.content; + } } } - } - } else if (event.event === 'on_llm_end') { - const generations = event.data.output?.generations[0]; - if (generations && generations[0]?.generationInfo.finish_reason === 'stop') { - handleStreamEnd(finalMessage); + } else if (event.event === 'on_llm_end') { + const generations = event.data.output?.generations[0]; + if (generations && generations[0]?.generationInfo.finish_reason === 'stop') { + handleStreamEnd(finalMessage); + } } } - processEvent(); + void processEvent(); } catch (err) { // if I throw an error here, it crashes the server. Not sure how to get around that. // If I put await on this function the error works properly, but when there is not an error diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/run_agent.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/run_agent.ts index b0353bb5d8ec7..0a6ee3b79087c 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/run_agent.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/run_agent.ts @@ -19,6 +19,8 @@ export interface RunAgentParams extends NodeParamsBase { export const AGENT_NODE = 'agent'; +export const AGENT_NODE_TAG = 'agent_run'; + const NO_HISTORY = '[No existing knowledge history]'; /** * Node to run the agent @@ -44,11 +46,11 @@ export const runAgent = async ({ query: '', }); - const agentOutcome = await agentRunnable.invoke( + const agentOutcome = await agentRunnable.withConfig({ tags: [AGENT_NODE_TAG] }).invoke( { ...state, chat_history: state.messages, // TODO: Message de-dupe with ...state spread - knowledge_history: knowledgeHistory?.length ? knowledgeHistory : NO_HISTORY, + knowledge_history: JSON.stringify(knowledgeHistory?.length ? knowledgeHistory : NO_HISTORY), }, config ); From 5aa9c7200d06f7a14528170057b2e9875dfa319d Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Mon, 1 Jul 2024 14:49:19 -0600 Subject: [PATCH 21/23] rm FF from chat/convo routes --- .../server/routes/chat/chat_complete_route.ts | 3 +-- .../server/routes/user_conversations/create_route.ts | 3 +-- .../server/routes/user_conversations/update_route.ts | 3 +-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts index 3a5fc2a9d1cf5..80650e1f7dbab 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts @@ -66,10 +66,9 @@ export const chatCompleteRoute = ( const logger: Logger = ctx.elasticAssistant.logger; const telemetry = ctx.elasticAssistant.telemetry; - // Perform license, authenticated user and FF checks + // Perform license and authenticated user checks const checkResponse = performChecks({ authenticatedUser: true, - capability: 'assistantKnowledgeBaseByDefault', context: ctx, license: true, request, diff --git a/x-pack/plugins/elastic_assistant/server/routes/user_conversations/create_route.ts b/x-pack/plugins/elastic_assistant/server/routes/user_conversations/create_route.ts index 551d1b2ee5880..9c02586a60b88 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/user_conversations/create_route.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/user_conversations/create_route.ts @@ -41,10 +41,9 @@ export const createConversationRoute = (router: ElasticAssistantPluginRouter): v const assistantResponse = buildResponse(response); try { const ctx = await context.resolve(['core', 'elasticAssistant', 'licensing']); - // Perform license, authenticated user and FF checks + // Perform license and authenticated user checks const checkResponse = performChecks({ authenticatedUser: true, - capability: 'assistantKnowledgeBaseByDefault', context: ctx, license: true, request, diff --git a/x-pack/plugins/elastic_assistant/server/routes/user_conversations/update_route.ts b/x-pack/plugins/elastic_assistant/server/routes/user_conversations/update_route.ts index c0798a45cb7f7..069c189609f23 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/user_conversations/update_route.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/user_conversations/update_route.ts @@ -46,10 +46,9 @@ export const updateConversationRoute = (router: ElasticAssistantPluginRouter) => try { const ctx = await context.resolve(['core', 'elasticAssistant', 'licensing']); const authenticatedUser = ctx.elasticAssistant.getCurrentUser(); - // Perform license, authenticated user and FF checks + // Perform license and authenticated user checks const checkResponse = performChecks({ authenticatedUser: true, - capability: 'assistantKnowledgeBaseByDefault', context: ctx, license: true, request, From 263ddc5254760a1072274e85699f6d6d8bd26362 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Mon, 1 Jul 2024 17:27:26 -0600 Subject: [PATCH 22/23] fix test --- x-pack/plugins/elastic_assistant/server/routes/helpers.ts | 3 +++ 1 file changed, 3 insertions(+) diff --git a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts index f59d465b64080..aa060e24bc5df 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/helpers.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/helpers.ts @@ -316,6 +316,9 @@ export const nonLangChainExecute = async ({ }); return response.ok({ body: result, + ...(request.body.subAction === 'invokeAI' + ? { headers: { 'content-type': 'application/json' } } + : {}), }); }; From cf52d5ff7e346c0b71ff4f2daaadf6eaed4f6126 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Mon, 1 Jul 2024 19:51:29 -0600 Subject: [PATCH 23/23] fix chat completion route test --- .../routes/chat/chat_complete_route.test.ts | 162 ++++-------------- .../server/routes/chat/chat_complete_route.ts | 17 +- 2 files changed, 44 insertions(+), 135 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts index b37ee305af91e..a487e56019bd8 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts @@ -33,26 +33,19 @@ const actionsClient = actionsClientMock.create(); jest.mock('../../lib/build_response', () => ({ buildResponse: jest.fn().mockImplementation((x) => x), })); -jest.mock('../../lib/executor', () => ({ - executeAction: jest.fn().mockImplementation(async ({ connectorId }) => { - if (connectorId === 'mock-connector-id') { - return { - connector_id: 'mock-connector-id', - data: mockActionResponse, - status: 'ok', - }; - } else { - throw new Error('simulated error'); - } - }), -})); const mockStream = jest.fn().mockImplementation(() => new PassThrough()); jest.mock('../../lib/langchain/execute_custom_llm_chain', () => ({ callAgentExecutor: jest.fn().mockImplementation( async ({ connectorId, isStream, + onLlmResponse, }: { + onLlmResponse: ( + content: string, + replacements: Record, + isError: boolean + ) => Promise; actionsClient: PublicMethodsOf; connectorId: string; esClient: ElasticsearchClient; @@ -64,25 +57,14 @@ jest.mock('../../lib/langchain/execute_custom_llm_chain', () => ({ }) => { if (!isStream && connectorId === 'mock-connector-id') { return { - body: { - connector_id: 'mock-connector-id', - data: mockActionResponse, - status: 'ok', - }, - headers: { 'content-type': 'application/json' }, + connector_id: 'mock-connector-id', + data: mockActionResponse, + status: 'ok', }; } else if (isStream && connectorId === 'mock-connector-id') { - return { - body: mockStream, - headers: { - 'Cache-Control': 'no-cache', - Connection: 'keep-alive', - 'Transfer-Encoding': 'chunked', - 'X-Accel-Buffering': 'no', - 'X-Content-Type-Options': 'nosniff', - }, - }; + return mockStream; } else { + onLlmResponse('simulated error', {}, true).catch(() => {}); throw new Error('simulated error'); } } @@ -140,8 +122,12 @@ const mockContext = { const mockRequest = { body: { + conversationId: 'mock-conversation-id', connectorId: 'mock-connector-id', persist: true, + isEnabledKnowledgeBase: true, + isEnabledRAGAlerts: false, + model: 'gpt-4', messages: [ { role: 'user', @@ -222,14 +208,9 @@ describe('chatCompleteRoute', () => { ); expect(result).toEqual({ - body: { - connector_id: 'mock-connector-id', - data: mockActionResponse, - status: 'ok', - }, - headers: { - 'content-type': 'application/json', - }, + connector_id: 'mock-connector-id', + data: mockActionResponse, + status: 'ok', }); }), }; @@ -246,7 +227,10 @@ describe('chatCompleteRoute', () => { it('returns the expected error when executeCustomLlmChain fails', async () => { const requestWithBadConnectorId = { ...mockRequest, - connectorId: 'bad-connector-id', + body: { + ...mockRequest.body, + connectorId: 'bad-connector-id', + }, }; const mockRouter = { @@ -304,6 +288,7 @@ describe('chatCompleteRoute', () => { ...mockRequest, body: { ...mockRequest.body, + isEnabledRAGAlerts: true, anonymizationFields: [ { id: '@timestamp', field: '@timestamp', allowed: true, anonymized: false }, { id: 'host.name', field: 'host.name', allowed: true, anonymized: true }, @@ -337,86 +322,15 @@ describe('chatCompleteRoute', () => { ); }); - it('reports success events to telemetry - kb off, RAG alerts on', async () => { - const req = { - ...mockRequest, - body: { - ...mockRequest.body, - isEnabledKnowledgeBase: false, - anonymizationFields: [ - { id: '@timestamp', field: '@timestamp', allowed: true, anonymized: false }, - { id: 'host.name', field: 'host.name', allowed: true, anonymized: true }, - ], - }, - }; - - const mockRouter = { - versioned: { - post: jest.fn().mockImplementation(() => { - return { - addVersion: jest.fn().mockImplementation(async (_, handler) => { - await handler(mockContext, req, mockResponse); - - expect(reportEvent).toHaveBeenCalledWith(INVOKE_ASSISTANT_SUCCESS_EVENT.eventType, { - actionTypeId: '.gen-ai', - model: 'gpt-4', - assistantStreamingEnabled: false, - isEnabledKnowledgeBase: false, - isEnabledRAGAlerts: true, - }); - }), - }; - }), - }, - }; - - await chatCompleteRoute( - mockRouter as unknown as IRouter, - mockGetElser - ); - }); - - it('reports success events to telemetry - kb off, RAG alerts off', async () => { - const req = { + it('reports error events to telemetry - kb on, RAG alerts off', async () => { + const requestWithBadConnectorId = { ...mockRequest, body: { ...mockRequest.body, - isEnabledKnowledgeBase: false, - }, - }; - - const mockRouter = { - versioned: { - post: jest.fn().mockImplementation(() => { - return { - addVersion: jest.fn().mockImplementation(async (_, handler) => { - await handler(mockContext, req, mockResponse); - - expect(reportEvent).toHaveBeenCalledWith(INVOKE_ASSISTANT_SUCCESS_EVENT.eventType, { - actionTypeId: '.gen-ai', - model: 'gpt-4', - assistantStreamingEnabled: false, - isEnabledKnowledgeBase: false, - isEnabledRAGAlerts: false, - }); - }), - }; - }), + connectorId: 'bad-connector-id', }, }; - await chatCompleteRoute( - mockRouter as unknown as IRouter, - mockGetElser - ); - }); - - it('reports error events to telemetry - kb on, RAG alerts off', async () => { - const requestWithBadConnectorId = { - ...mockRequest, - connectorId: 'bad-connector-id', - }; - const mockRouter = { versioned: { post: jest.fn().mockImplementation(() => { @@ -427,7 +341,7 @@ describe('chatCompleteRoute', () => { expect(reportEvent).toHaveBeenCalledWith(INVOKE_ASSISTANT_ERROR_EVENT.eventType, { errorMessage: 'simulated error', isEnabledKnowledgeBase: true, - isEnabledRAGAlerts: false, + isEnabledRAGAlerts: true, actionTypeId: '.gen-ai', model: 'gpt-4', assistantStreamingEnabled: false, @@ -497,16 +411,7 @@ describe('chatCompleteRoute', () => { mockResponse ); - expect(result).toEqual({ - body: mockStream, - headers: { - 'Cache-Control': 'no-cache', - Connection: 'keep-alive', - 'Transfer-Encoding': 'chunked', - 'X-Accel-Buffering': 'no', - 'X-Content-Type-Options': 'nosniff', - }, - }); + expect(result).toEqual(mockStream); }), }; }), @@ -537,16 +442,7 @@ describe('chatCompleteRoute', () => { mockResponse ); - expect(result).toEqual({ - body: mockStream, - headers: { - 'Cache-Control': 'no-cache', - Connection: 'keep-alive', - 'Transfer-Encoding': 'chunked', - 'X-Accel-Buffering': 'no', - 'X-Content-Type-Options': 'nosniff', - }, - }); + expect(result).toEqual(mockStream); }), }; }), diff --git a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts index 80650e1f7dbab..10da330a36c79 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts @@ -19,6 +19,7 @@ import { } from '@kbn/elastic-assistant-common'; import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; import { getRequestAbortedSignal } from '@kbn/data-plugin/server'; +import { INVOKE_ASSISTANT_ERROR_EVENT } from '../../lib/telemetry/event_based_telemetry'; import { ElasticAssistantPluginRouter, GetElser } from '../../types'; import { buildResponse } from '../../lib/build_response'; import { @@ -61,10 +62,12 @@ export const chatCompleteRoute = ( async (context, request, response) => { const abortSignal = getRequestAbortedSignal(request.events.aborted$); const assistantResponse = buildResponse(response); + let telemetry; + let actionTypeId; try { const ctx = await context.resolve(['core', 'elasticAssistant', 'licensing']); const logger: Logger = ctx.elasticAssistant.logger; - const telemetry = ctx.elasticAssistant.telemetry; + telemetry = ctx.elasticAssistant.telemetry; // Perform license and authenticated user checks const checkResponse = performChecks({ @@ -97,7 +100,7 @@ export const chatCompleteRoute = ( const actions = ctx.elasticAssistant.actions; const actionsClient = await actions.getActionsClientWithRequest(request); const connectors = await actionsClient.getBulk({ ids: [connectorId] }); - const actionTypeId = connectors.length > 0 ? connectors[0].actionTypeId : '.gen-ai'; + actionTypeId = connectors.length > 0 ? connectors[0].actionTypeId : '.gen-ai'; // replacements const anonymizationFieldsRes = @@ -226,6 +229,16 @@ export const chatCompleteRoute = ( }); } catch (err) { const error = transformError(err as Error); + telemetry?.reportEvent(INVOKE_ASSISTANT_ERROR_EVENT.eventType, { + actionTypeId: actionTypeId ?? '', + isEnabledKnowledgeBase: true, + isEnabledRAGAlerts: true, + model: request.body.model, + errorMessage: error.message, + // TODO rm actionTypeId check when llmClass for bedrock streaming is implemented + // tracked here: https://github.com/elastic/security-team/issues/7363 + assistantStreamingEnabled: request.body.isStream ?? false, + }); return assistantResponse.error({ body: error.message, statusCode: error.statusCode,