diff --git a/x-pack/platform/plugins/private/translations/translations/fr-FR.json b/x-pack/platform/plugins/private/translations/translations/fr-FR.json index 05762fc5ae7e5..37a2b641f356f 100644 --- a/x-pack/platform/plugins/private/translations/translations/fr-FR.json +++ b/x-pack/platform/plugins/private/translations/translations/fr-FR.json @@ -33954,7 +33954,6 @@ "xpack.searchPlayground.header.view.chat": "Chat", "xpack.searchPlayground.header.view.preview": "Aperçu", "xpack.searchPlayground.header.view.query": "Requête", - "xpack.searchPlayground.inferenceModel": "{name}", "xpack.searchPlayground.loadConnectorsError": "Erreur lors du chargement des connecteurs. Veuillez vérifier votre configuration et réessayer.", "xpack.searchPlayground.openAIAzureConnectorTitle": "OpenAI Azure", "xpack.searchPlayground.openAIAzureModel": "{name} (Azure OpenAI)", diff --git a/x-pack/platform/plugins/private/translations/translations/ja-JP.json b/x-pack/platform/plugins/private/translations/translations/ja-JP.json index 3e78b808d4a46..bbe0b4aae46ba 100644 --- a/x-pack/platform/plugins/private/translations/translations/ja-JP.json +++ b/x-pack/platform/plugins/private/translations/translations/ja-JP.json @@ -33991,7 +33991,6 @@ "xpack.searchPlayground.header.view.chat": "チャット", "xpack.searchPlayground.header.view.preview": "プレビュー", "xpack.searchPlayground.header.view.query": "クエリー", - "xpack.searchPlayground.inferenceModel": "{name}", "xpack.searchPlayground.loadConnectorsError": "コネクターの読み込みエラーです。構成を確認して、再試行してください。", "xpack.searchPlayground.openAIAzureConnectorTitle": "OpenAI Azure", "xpack.searchPlayground.openAIAzureModel": "{name} (Azure OpenAI)", diff --git a/x-pack/platform/plugins/private/translations/translations/zh-CN.json b/x-pack/platform/plugins/private/translations/translations/zh-CN.json index 36e387ad8fd6b..068da896b1ddd 100644 --- a/x-pack/platform/plugins/private/translations/translations/zh-CN.json +++ b/x-pack/platform/plugins/private/translations/translations/zh-CN.json @@ -33978,7 +33978,6 @@ "xpack.searchPlayground.header.view.chat": "聊天", "xpack.searchPlayground.header.view.preview": "预览", "xpack.searchPlayground.header.view.query": "查询", - "xpack.searchPlayground.inferenceModel": "{name}", "xpack.searchPlayground.loadConnectorsError": "加载连接器进出错。请检查您的配置,然后重试。", "xpack.searchPlayground.openAIAzureConnectorTitle": "OpenAI Azure", "xpack.searchPlayground.openAIAzureModel": "{name} (Azure OpenAI)", diff --git a/x-pack/solutions/search/plugins/search_playground/common/models.ts b/x-pack/solutions/search/plugins/search_playground/common/models.ts index a1fe7dbe928d8..38923768f87a3 100644 --- a/x-pack/solutions/search/plugins/search_playground/common/models.ts +++ b/x-pack/solutions/search/plugins/search_playground/common/models.ts @@ -5,6 +5,7 @@ * 2.0. */ +import { elasticModelIds } from '@kbn/inference-common'; import { ModelProvider, LLMs } from './types'; export const MODELS: ModelProvider[] = [ @@ -56,4 +57,10 @@ export const MODELS: ModelProvider[] = [ promptTokenLimit: 2097152, provider: LLMs.gemini, }, + { + name: 'Elastic Managed LLM', + model: elasticModelIds.RainbowSprinkles, + promptTokenLimit: 200000, + provider: LLMs.inference, + }, ]; diff --git a/x-pack/solutions/search/plugins/search_playground/public/hooks/use_llms_models.ts b/x-pack/solutions/search/plugins/search_playground/public/hooks/use_llms_models.ts index df3e77d2e3307..b4e04b94eaef3 100644 --- a/x-pack/solutions/search/plugins/search_playground/public/hooks/use_llms_models.ts +++ b/x-pack/solutions/search/plugins/search_playground/public/hooks/use_llms_models.ts @@ -29,7 +29,8 @@ const mapLlmToModels: Record< icon: string | ((connector: PlaygroundConnector) => string); getModels: ( connectorName: string, - includeName: boolean + includeName: boolean, + modelId?: string ) => Array<{ label: string; value?: string; promptTokenLimit?: number }>; } > = { @@ -88,12 +89,11 @@ const mapLlmToModels: Record< ? SERVICE_PROVIDERS[connector.config.provider].icon : ''; }, - getModels: (connectorName) => [ + getModels: (connectorName, _, modelId) => [ { - label: i18n.translate('xpack.searchPlayground.inferenceModel', { - defaultMessage: '{name}', - values: { name: connectorName }, - }), + label: connectorName, + value: modelId, + promptTokenLimit: MODELS.find((m) => m.model === modelId)?.promptTokenLimit, }, ], }, @@ -128,7 +128,13 @@ export const LLMsQuery = const showConnectorName = Number(mapConnectorTypeToCount?.[connectorType]) > 1; llmParams - .getModels(connector.name, false) + .getModels( + connector.name, + false, + isInferenceActionConnector(connector) + ? connector.config?.providerConfig?.model_id + : undefined + ) .map(({ label, value, promptTokenLimit }) => ({ id: connector?.id + label, name: label, diff --git a/x-pack/solutions/search/plugins/search_playground/public/providers/unsaved_form_provider.tsx b/x-pack/solutions/search/plugins/search_playground/public/providers/unsaved_form_provider.tsx index fc5eb0aaafe07..fce08733a5ad5 100644 --- a/x-pack/solutions/search/plugins/search_playground/public/providers/unsaved_form_provider.tsx +++ b/x-pack/solutions/search/plugins/search_playground/public/providers/unsaved_form_provider.tsx @@ -94,6 +94,7 @@ export const UnsavedFormProvider: React.FC { + if (models.length === 0) return; // don't continue if there are no models const defaultModel = models.find((model) => !model.disabled); const currentModel = form.getValues(PlaygroundFormFields.summarizationModel); diff --git a/x-pack/solutions/search/plugins/search_playground/public/types.ts b/x-pack/solutions/search/plugins/search_playground/public/types.ts index 020a28f4e623a..77a9fef5a6038 100644 --- a/x-pack/solutions/search/plugins/search_playground/public/types.ts +++ b/x-pack/solutions/search/plugins/search_playground/public/types.ts @@ -248,7 +248,13 @@ export interface LLMModel { export type { ActionConnector, UserConfiguredActionConnector }; export type InferenceActionConnector = ActionConnector & { - config: { provider: ServiceProviderKeys; inferenceId: string }; + config: { + providerConfig?: { + model_id?: string; + }; + provider: ServiceProviderKeys; + inferenceId: string; + }; }; export type PlaygroundConnector = ActionConnector & { title: string; type: LLMs }; diff --git a/x-pack/solutions/search/plugins/search_playground/server/lib/get_chat_params.test.ts b/x-pack/solutions/search/plugins/search_playground/server/lib/get_chat_params.test.ts index 49d8b5e579a95..c8f3a0c289f94 100644 --- a/x-pack/solutions/search/plugins/search_playground/server/lib/get_chat_params.test.ts +++ b/x-pack/solutions/search/plugins/search_playground/server/lib/get_chat_params.test.ts @@ -17,6 +17,7 @@ import { loggerMock, MockedLogger } from '@kbn/logging-mocks'; import { httpServerMock } from '@kbn/core/server/mocks'; import { PluginStartContract as ActionsPluginStartContract } from '@kbn/actions-plugin/server'; import { inferenceMock } from '@kbn/inference-plugin/server/mocks'; +import { elasticModelIds } from '@kbn/inference-common'; jest.mock('@kbn/langchain/server', () => { const original = jest.requireActual('@kbn/langchain/server'); @@ -236,4 +237,84 @@ describe('getChatParams', () => { }); expect(result.chatPrompt).toContain('How does it work?'); }); + + it('returns the correct params for the EIS connector', async () => { + const mockConnector = { + id: 'elastic-llm', + actionTypeId: INFERENCE_CONNECTOR_ID, + config: { + providerConfig: { + model_id: elasticModelIds.RainbowSprinkles, + }, + }, + }; + mockActionsClient.get.mockResolvedValue(mockConnector); + + const result = await getChatParams( + { + connectorId: 'elastic-llm', + prompt: 'How does it work?', + citations: false, + }, + { actions, request, logger, inference } + ); + + expect(result).toMatchObject({ + connector: mockConnector, + summarizationModel: elasticModelIds.RainbowSprinkles, + }); + + expect(Prompt).toHaveBeenCalledWith('How does it work?', { + citations: false, + context: true, + type: 'anthropic', + }); + expect(QuestionRewritePrompt).toHaveBeenCalledWith({ + type: 'anthropic', + }); + expect(inference.getChatModel).toHaveBeenCalledWith({ + request, + connectorId: 'elastic-llm', + chatModelOptions: expect.objectContaining({ + model: elasticModelIds.RainbowSprinkles, + maxRetries: 0, + }), + }); + }); + + it('it returns provided model with EIS connector', async () => { + const mockConnector = { + id: 'elastic-llm', + actionTypeId: INFERENCE_CONNECTOR_ID, + config: { + providerConfig: { + model_id: elasticModelIds.RainbowSprinkles, + }, + }, + }; + mockActionsClient.get.mockResolvedValue(mockConnector); + + const result = await getChatParams( + { + connectorId: 'elastic-llm', + model: 'foo-bar', + prompt: 'How does it work?', + citations: false, + }, + { actions, request, logger, inference } + ); + + expect(result).toMatchObject({ + summarizationModel: 'foo-bar', + }); + + expect(inference.getChatModel).toHaveBeenCalledWith({ + request, + connectorId: 'elastic-llm', + chatModelOptions: expect.objectContaining({ + model: 'foo-bar', + maxRetries: 0, + }), + }); + }); }); diff --git a/x-pack/solutions/search/plugins/search_playground/server/lib/get_chat_params.ts b/x-pack/solutions/search/plugins/search_playground/server/lib/get_chat_params.ts index 07f8ee544382f..a4529d2535df6 100644 --- a/x-pack/solutions/search/plugins/search_playground/server/lib/get_chat_params.ts +++ b/x-pack/solutions/search/plugins/search_playground/server/lib/get_chat_params.ts @@ -15,7 +15,9 @@ import { BaseLanguageModel } from '@langchain/core/language_models/base'; import type { Connector } from '@kbn/actions-plugin/server/application/connector/types'; import { getDefaultArguments } from '@kbn/langchain/server'; import type { InferenceServerStart } from '@kbn/inference-plugin/server'; + import { Prompt, QuestionRewritePrompt } from '../../common/prompt'; +import { isEISConnector } from '../utils/eis'; export const getChatParams = async ( { @@ -39,32 +41,42 @@ export const getChatParams = async ( chatPrompt: string; questionRewritePrompt: string; connector: Connector; + summarizationModel?: string; }> => { + let summarizationModel = model; const actionsClient = await actions.getActionsClientWithRequest(request); const connector = await actionsClient.get({ id: connectorId }); let llmType: string; let modelType: 'openai' | 'anthropic' | 'gemini'; - switch (connector.actionTypeId) { - case INFERENCE_CONNECTOR_ID: - llmType = 'inference'; - modelType = 'openai'; - break; - case OPENAI_CONNECTOR_ID: - llmType = 'openai'; - modelType = 'openai'; - break; - case BEDROCK_CONNECTOR_ID: - llmType = 'bedrock'; - modelType = 'anthropic'; - break; - case GEMINI_CONNECTOR_ID: - llmType = 'gemini'; - modelType = 'gemini'; - break; - default: - throw new Error(`Invalid connector type: ${connector.actionTypeId}`); + if (isEISConnector(connector)) { + llmType = 'bedrock'; + modelType = 'anthropic'; + if (!summarizationModel && connector.config?.providerConfig?.model_id) { + summarizationModel = connector.config?.providerConfig?.model_id; + } + } else { + switch (connector.actionTypeId) { + case INFERENCE_CONNECTOR_ID: + llmType = 'inference'; + modelType = 'openai'; + break; + case OPENAI_CONNECTOR_ID: + llmType = 'openai'; + modelType = 'openai'; + break; + case BEDROCK_CONNECTOR_ID: + llmType = 'bedrock'; + modelType = 'anthropic'; + break; + case GEMINI_CONNECTOR_ID: + llmType = 'gemini'; + modelType = 'gemini'; + break; + default: + throw new Error(`Invalid connector type: ${connector.actionTypeId}`); + } } const chatPrompt = Prompt(prompt, { @@ -81,7 +93,7 @@ export const getChatParams = async ( request, connectorId, chatModelOptions: { - model: model || connector?.config?.defaultModel, + model: summarizationModel || connector?.config?.defaultModel, temperature: getDefaultArguments(llmType).temperature, // prevents the agent from retrying on failure // failure could be due to bad connector, we should deliver that result to the client asap @@ -90,5 +102,11 @@ export const getChatParams = async ( }, }); - return { chatModel, chatPrompt, questionRewritePrompt, connector }; + return { + chatModel, + chatPrompt, + questionRewritePrompt, + connector, + summarizationModel: summarizationModel || connector?.config?.defaultModel, + }; }; diff --git a/x-pack/solutions/search/plugins/search_playground/server/routes.ts b/x-pack/solutions/search/plugins/search_playground/server/routes.ts index 0e7adeee9ae9c..fb56094cb1494 100644 --- a/x-pack/solutions/search/plugins/search_playground/server/routes.ts +++ b/x-pack/solutions/search/plugins/search_playground/server/routes.ts @@ -125,15 +125,16 @@ export function defineRoutes(routeOptions: DefineRoutesOptions) { es_client: client.asCurrentUser, }); const { messages, data } = request.body; - const { chatModel, chatPrompt, questionRewritePrompt, connector } = await getChatParams( - { - connectorId: data.connector_id, - model: data.summarization_model, - citations: data.citations, - prompt: data.prompt, - }, - { actions, inference, logger, request } - ); + const { chatModel, chatPrompt, questionRewritePrompt, connector, summarizationModel } = + await getChatParams( + { + connectorId: data.connector_id, + model: data.summarization_model, + citations: data.citations, + prompt: data.prompt, + }, + { actions, inference, logger, request } + ); let sourceFields: ElasticsearchRetrieverContentField; @@ -144,7 +145,7 @@ export function defineRoutes(routeOptions: DefineRoutesOptions) { throw Error(e); } - const model = MODELS.find((m) => m.model === data.summarization_model); + const model = MODELS.find((m) => m.model === summarizationModel); const modelPromptLimit = model?.promptTokenLimit; const chain = ConversationalChain({ @@ -167,7 +168,7 @@ export function defineRoutes(routeOptions: DefineRoutesOptions) { connectorType: connector.actionTypeId + (connector.config?.apiProvider ? `-${connector.config.apiProvider}` : ''), - model: data.summarization_model ?? '', + model: summarizationModel ?? '', isCitationsEnabled: data.citations, }); diff --git a/x-pack/solutions/search/plugins/search_playground/server/utils/eis.ts b/x-pack/solutions/search/plugins/search_playground/server/utils/eis.ts new file mode 100644 index 0000000000000..e3579a07374b7 --- /dev/null +++ b/x-pack/solutions/search/plugins/search_playground/server/utils/eis.ts @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { INFERENCE_CONNECTOR_ID } from '@kbn/stack-connectors-plugin/common/inference/constants'; +import type { Connector } from '@kbn/actions-plugin/server/application/connector/types'; +import { elasticModelIds } from '@kbn/inference-common'; + +export const isEISConnector = (connector: Connector) => { + if (connector.actionTypeId !== INFERENCE_CONNECTOR_ID) return false; + const modelId = connector.config?.providerConfig?.model_id ?? undefined; + if (modelId === elasticModelIds.RainbowSprinkles) { + return true; + } + return false; +};