diff --git a/x-pack/platform/packages/shared/ai-infra/product-doc-common/src/artifact.ts b/x-pack/platform/packages/shared/ai-infra/product-doc-common/src/artifact.ts index d67138a8d39f0..42f2fb89136ca 100644 --- a/x-pack/platform/packages/shared/ai-infra/product-doc-common/src/artifact.ts +++ b/x-pack/platform/packages/shared/ai-infra/product-doc-common/src/artifact.ts @@ -5,6 +5,7 @@ * 2.0. */ +import { isImpliedDefaultElserInferenceId } from './is_default_inference_endpoint'; import { type ProductName, DocumentationProduct } from './product'; const allowedProductNames: ProductName[] = Object.values(DocumentationProduct); @@ -24,7 +25,7 @@ export const getArtifactName = ({ }): string => { const ext = excludeExtension ? '' : '.zip'; return `kb-product-doc-${productName}-${productVersion}${ - inferenceId && inferenceId !== DEFAULT_ELSER ? `--${inferenceId}` : '' + inferenceId && !isImpliedDefaultElserInferenceId(inferenceId) ? `--${inferenceId}` : '' }${ext}`.toLowerCase(); }; diff --git a/x-pack/platform/packages/shared/ai-infra/product-doc-common/src/is_default_inference_endpoint.ts b/x-pack/platform/packages/shared/ai-infra/product-doc-common/src/is_default_inference_endpoint.ts index a2c0bfc01d7a3..c7f6de079e2af 100644 --- a/x-pack/platform/packages/shared/ai-infra/product-doc-common/src/is_default_inference_endpoint.ts +++ b/x-pack/platform/packages/shared/ai-infra/product-doc-common/src/is_default_inference_endpoint.ts @@ -17,6 +17,7 @@ export const isImpliedDefaultElserInferenceId = (inferenceId: string | null | un inferenceId === null || inferenceId === undefined || inferenceId === defaultInferenceEndpoints.ELSER || - inferenceId === defaultInferenceEndpoints.ELSER_IN_EIS_INFERENCE_ID + inferenceId === defaultInferenceEndpoints.ELSER_IN_EIS_INFERENCE_ID || + (typeof inferenceId === 'string' && inferenceId.toLowerCase().includes('elser')) ); }; diff --git a/x-pack/platform/plugins/shared/ai_infra/llm_tasks/README.md b/x-pack/platform/plugins/shared/ai_infra/llm_tasks/README.md index e019d456cd65a..5c5bacf1029c6 100644 --- a/x-pack/platform/plugins/shared/ai_infra/llm_tasks/README.md +++ b/x-pack/platform/plugins/shared/ai_infra/llm_tasks/README.md @@ -17,14 +17,13 @@ context. That API receive the inbound request as parameter. -Example: +Example, by default it will check with the default ELSER model: ```ts -if (await llmTasksStart.retrieveDocumentationAvailable({ request })) { +if (await llmTasksStart.retrieveDocumentationAvailable({ inferenceId })) { // task is available } else { // task is not available } -``` ### Executing the task @@ -37,6 +36,7 @@ const result = await llmTasksStart.retrieveDocumentation({ searchTerm: "How to create a space in Kibana?", request, connectorId: 'my-connector-id', + inferenceId: 'my-inference-id', }); const { success, documents } = result; diff --git a/x-pack/platform/plugins/shared/ai_infra/llm_tasks/server/plugin.ts b/x-pack/platform/plugins/shared/ai_infra/llm_tasks/server/plugin.ts index 50d9e2e340fab..14090d90e4387 100644 --- a/x-pack/platform/plugins/shared/ai_infra/llm_tasks/server/plugin.ts +++ b/x-pack/platform/plugins/shared/ai_infra/llm_tasks/server/plugin.ts @@ -7,7 +7,6 @@ import type { Logger } from '@kbn/logging'; import type { CoreSetup, CoreStart, Plugin, PluginInitializerContext } from '@kbn/core/server'; -import { defaultInferenceEndpoints } from '@kbn/inference-common'; import type { LlmTasksConfig } from './config'; import type { LlmTasksPluginSetup, @@ -41,9 +40,9 @@ export class LlmTasksPlugin start(core: CoreStart, startDependencies: PluginStartDependencies): LlmTasksPluginStart { const { inference, productDocBase } = startDependencies; return { - retrieveDocumentationAvailable: async () => { + retrieveDocumentationAvailable: async (options: { inferenceId: string }) => { const docBaseStatus = await startDependencies.productDocBase.management.getStatus({ - inferenceId: defaultInferenceEndpoints.ELSER, + inferenceId: options.inferenceId, }); return docBaseStatus.status === 'installed'; }, diff --git a/x-pack/platform/plugins/shared/ai_infra/llm_tasks/server/types.ts b/x-pack/platform/plugins/shared/ai_infra/llm_tasks/server/types.ts index d550e4398b509..5905070d4027d 100644 --- a/x-pack/platform/plugins/shared/ai_infra/llm_tasks/server/types.ts +++ b/x-pack/platform/plugins/shared/ai_infra/llm_tasks/server/types.ts @@ -32,7 +32,7 @@ export interface LlmTasksPluginStart { * are respected. Can be used to check if the task can be registered * as LLM tool for example. */ - retrieveDocumentationAvailable: () => Promise; + retrieveDocumentationAvailable: (options: { inferenceId: string }) => Promise; /** * Perform the `retrieveDocumentation` task. * diff --git a/x-pack/platform/plugins/shared/ai_infra/product_doc_base/server/services/package_installer/package_installer.ts b/x-pack/platform/plugins/shared/ai_infra/product_doc_base/server/services/package_installer/package_installer.ts index e9f23f30626c1..0332ef5606530 100644 --- a/x-pack/platform/plugins/shared/ai_infra/product_doc_base/server/services/package_installer/package_installer.ts +++ b/x-pack/platform/plugins/shared/ai_infra/product_doc_base/server/services/package_installer/package_installer.ts @@ -17,6 +17,7 @@ import { defaultInferenceEndpoints } from '@kbn/inference-common'; import { cloneDeep } from 'lodash'; import type { InferenceInferenceEndpointInfo } from '@elastic/elasticsearch/lib/api/types'; import { i18n } from '@kbn/i18n'; +import { isImpliedDefaultElserInferenceId } from '@kbn/product-doc-common/src/is_default_inference_endpoint'; import type { ProductDocInstallClient } from '../doc_install_status'; import { downloadToDisk, @@ -179,7 +180,7 @@ export class PackageInstaller { inferenceId, }); - if (customInference && customInference?.inference_id !== this.elserInferenceId) { + if (customInference && !isImpliedDefaultElserInferenceId(customInference?.inference_id)) { if (customInference?.task_type !== 'text_embedding') { throw new Error( `Inference [${inferenceId}]'s task type ${customInference?.task_type} is not supported. Please use a model with task type 'text_embedding'.` @@ -191,7 +192,7 @@ export class PackageInstaller { }); } - if (!customInference || customInference?.inference_id === this.elserInferenceId) { + if (!customInference || isImpliedDefaultElserInferenceId(customInference?.inference_id)) { await ensureDefaultElserDeployed({ client: this.esClient, }); diff --git a/x-pack/solutions/observability/plugins/observability_ai_assistant_app/server/functions/documentation.ts b/x-pack/solutions/observability/plugins/observability_ai_assistant_app/server/functions/documentation.ts index a8a897c28fcc8..d5adafb7f7902 100644 --- a/x-pack/solutions/observability/plugins/observability_ai_assistant_app/server/functions/documentation.ts +++ b/x-pack/solutions/observability/plugins/observability_ai_assistant_app/server/functions/documentation.ts @@ -18,7 +18,11 @@ export async function registerDocumentationFunction({ resources, pluginsStart: { llmTasks }, }: FunctionRegistrationParameters) { - const isProductDocAvailable = (await llmTasks.retrieveDocumentationAvailable()) ?? false; + const esClient = (await resources.context.core).elasticsearch.client; + const inferenceId = + (await getInferenceIdFromWriteIndex(esClient)) ?? defaultInferenceEndpoints.ELSER; + const isProductDocAvailable = + (await llmTasks.retrieveDocumentationAvailable({ inferenceId })) ?? false; if (isProductDocAvailable) { functions.registerInstruction(({ availableFunctionNames }) => { @@ -66,11 +70,6 @@ export async function registerDocumentationFunction({ } as const, }, async ({ arguments: { query, product }, connectorId, simulateFunctionCalling }) => { - const esClient = (await resources.context.core).elasticsearch.client; - - const inferenceId = - (await getInferenceIdFromWriteIndex(esClient)) ?? defaultInferenceEndpoints.ELSER; - const response = await llmTasks!.retrieveDocumentation({ searchTerm: query, products: product ? [product] : undefined, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts index ea289abb14d1e..09dd6bccebb3d 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts @@ -22,6 +22,7 @@ import { } from '@kbn/elastic-assistant-common'; import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; import { getRequestAbortedSignal } from '@kbn/data-plugin/server'; +import { defaultInferenceEndpoints } from '@kbn/inference-common'; import { INVOKE_ASSISTANT_ERROR_EVENT } from '../../lib/telemetry/event_based_telemetry'; import { ElasticAssistantPluginRouter } from '../../types'; import { buildResponse } from '../../lib/build_response'; @@ -86,7 +87,9 @@ export const chatCompleteRoute = ( telemetry = ctx.elasticAssistant.telemetry; const inference = ctx.elasticAssistant.inference; const productDocsAvailable = - (await ctx.elasticAssistant.llmTasks.retrieveDocumentationAvailable()) ?? false; + (await ctx.elasticAssistant.llmTasks.retrieveDocumentationAvailable({ + inferenceId: defaultInferenceEndpoints.ELSER, + })) ?? false; // Perform license and authenticated user checks const checkResponse = await performChecks({ diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts index 4087ed23ec85b..ea2e5a9497b30 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts @@ -29,6 +29,7 @@ import { getDefaultArguments } from '@kbn/langchain/server'; import { StructuredTool } from '@langchain/core/tools'; import { AgentFinish } from 'langchain/agents'; import { omit } from 'lodash/fp'; +import { defaultInferenceEndpoints } from '@kbn/inference-common'; import { localToolPrompts, promptGroupId as toolsGroupId } from '../../lib/prompt/tool_prompts'; import { promptGroupId } from '../../lib/prompt/local_prompt_object'; import { getFormattedTime, getModelOrOss } from '../../lib/prompt/helpers'; @@ -173,7 +174,9 @@ export const postEvaluateRoute = ( const inference = ctx.elasticAssistant.inference; const productDocsAvailable = - (await ctx.elasticAssistant.llmTasks.retrieveDocumentationAvailable()) ?? false; + (await ctx.elasticAssistant.llmTasks.retrieveDocumentationAvailable({ + inferenceId: defaultInferenceEndpoints.ELSER, + })) ?? false; const { featureFlags } = await context.core; const inferenceChatModelDisabled = await featureFlags.getBooleanValue( diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index 70354c7050f0b..39a66c6fc4170 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -22,6 +22,7 @@ import { INFERENCE_CHAT_MODEL_DISABLED_FEATURE_FLAG, } from '@kbn/elastic-assistant-common'; import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; +import { defaultInferenceEndpoints } from '@kbn/inference-common'; import { getPrompt } from '../lib/prompt'; import { INVOKE_ASSISTANT_ERROR_EVENT } from '../lib/telemetry/event_based_telemetry'; import { buildResponse } from '../lib/build_response'; @@ -124,7 +125,9 @@ export const postActionsConnectorExecuteRoute = ( const inference = ctx.elasticAssistant.inference; const savedObjectsClient = ctx.elasticAssistant.savedObjectsClient; const productDocsAvailable = - (await ctx.elasticAssistant.llmTasks.retrieveDocumentationAvailable()) ?? false; + (await ctx.elasticAssistant.llmTasks.retrieveDocumentationAvailable({ + inferenceId: defaultInferenceEndpoints.ELSER, + })) ?? false; const actionsClient = await actions.getActionsClientWithRequest(request); const connectors = await actionsClient.getBulk({ ids: [connectorId] }); const connector = connectors.length > 0 ? connectors[0] : undefined; diff --git a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/retrieve_elastic_doc.spec.ts b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/retrieve_elastic_doc.spec.ts index 83ff52fa717f5..b68ba07ce9f61 100644 --- a/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/retrieve_elastic_doc.spec.ts +++ b/x-pack/test/api_integration/deployment_agnostic/apis/observability/ai_assistant/complete/functions/retrieve_elastic_doc.spec.ts @@ -10,12 +10,15 @@ import { ChatCompletionStreamParams } from 'openai/lib/ChatCompletionStream'; import { ChatCompletionMessageParam } from 'openai/resources'; import { last } from 'lodash'; import { MessageAddEvent, MessageRole } from '@kbn/observability-ai-assistant-plugin/common'; +import { TINY_ELSER_INFERENCE_ID } from '../../utils/model_and_inference'; import { LlmProxy, createLlmProxy } from '../../utils/create_llm_proxy'; import { chatComplete } from '../../utils/conversation'; import type { DeploymentAgnosticFtrProviderContext } from '../../../../../ftr_provider_context'; import { installProductDoc, uninstallProductDoc } from '../../utils/product_doc_base'; - -const DEFAULT_INFERENCE_ID = '.elser-2-elasticsearch'; +import { + deployTinyElserAndSetupKb, + teardownTinyElserModelAndInferenceEndpoint, +} from '../../utils/model_and_inference'; export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderContext) { const log = getService('log'); @@ -86,14 +89,19 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon let llmProxy: LlmProxy; let connectorId: string; let messageAddedEvents: MessageAddEvent[]; - let firstRequestBody: ChatCompletionStreamParams; - let secondRequestBody: ChatCompletionStreamParams; + let toolCallRequestBody: ChatCompletionStreamParams; + let userPromptRequestBody: ChatCompletionStreamParams; + before(async () => { llmProxy = await createLlmProxy(log); connectorId = await observabilityAIAssistantAPIClient.createProxyActionConnector({ port: llmProxy.getPort(), }); - await installProductDoc(supertest, DEFAULT_INFERENCE_ID); + await deployTinyElserAndSetupKb(getService); + + await installProductDoc(supertest, TINY_ELSER_INFERENCE_ID); + + void llmProxy.interceptQueryRewrite('This is a rewritten user prompt.'); void llmProxy.interceptWithFunctionRequest({ name: 'retrieve_elastic_doc', @@ -113,20 +121,21 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon })); await llmProxy.waitForAllInterceptorsToHaveBeenCalled(); - firstRequestBody = llmProxy.interceptedRequests[0].requestBody; - secondRequestBody = llmProxy.interceptedRequests[1].requestBody; + toolCallRequestBody = llmProxy.interceptedRequests[1].requestBody; + userPromptRequestBody = llmProxy.interceptedRequests[2].requestBody; }); after(async () => { - await uninstallProductDoc(supertest, DEFAULT_INFERENCE_ID); + await uninstallProductDoc(supertest, TINY_ELSER_INFERENCE_ID); llmProxy.close(); await observabilityAIAssistantAPIClient.deleteActionConnector({ actionId: connectorId, }); + await teardownTinyElserModelAndInferenceEndpoint(getService); }); - it('makes 2 requests to the LLM', () => { - expect(llmProxy.interceptedRequests.length).to.be(2); + it('makes 3 requests to the LLM', () => { + expect(llmProxy.interceptedRequests.length).to.be(3); }); it('emits 5 messageAdded events', () => { @@ -134,29 +143,27 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon }); describe('The first request', () => { - it('contains the retrieve_elastic_doc function', () => { - expect(firstRequestBody.tools?.map((t) => t.function.name)).to.contain( + it('enables the LLM to call `retrieve_elastic_doc`', () => { + expect(toolCallRequestBody.tool_choice).to.be('auto'); + expect(toolCallRequestBody.tools?.map((t) => t.function.name)).to.contain( 'retrieve_elastic_doc' ); }); - - it('leaves the LLM to choose the correct tool by leave tool_choice as auto and passes tools', () => { - expect(firstRequestBody.tool_choice).to.be('auto'); - expect(firstRequestBody.tools?.length).to.not.be(0); - }); }); describe('The second request - Sending the user prompt', () => { let lastMessage: ChatCompletionMessageParam; let parsedContent: { documents: Array<{ title: string; content: string; url: string }> }; + before(() => { - lastMessage = last(secondRequestBody.messages) as ChatCompletionMessageParam; + lastMessage = last(userPromptRequestBody.messages) as ChatCompletionMessageParam; parsedContent = JSON.parse(lastMessage.content as string); }); + it('includes the retrieve_elastic_doc function call', () => { - expect(secondRequestBody.messages[4].role).to.be(MessageRole.Assistant); + expect(userPromptRequestBody.messages[4].role).to.be(MessageRole.Assistant); // @ts-expect-error - expect(secondRequestBody.messages[4].tool_calls[0].function.name).to.be( + expect(userPromptRequestBody.messages[4].tool_calls[0].function.name).to.be( 'retrieve_elastic_doc' ); }); @@ -166,9 +173,10 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon // @ts-expect-error expect(lastMessage?.tool_call_id).to.equal( // @ts-expect-error - secondRequestBody.messages[4].tool_calls[0].id + userPromptRequestBody.messages[4].tool_calls[0].id ); }); + it('sends the retrieved documents from Elastic docs to the LLM', () => { expect(lastMessage.content).to.be.a('string'); });