Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33954,7 +33954,6 @@
"xpack.searchPlayground.header.view.chat": "Chat",
"xpack.searchPlayground.header.view.preview": "Aperçu",
"xpack.searchPlayground.header.view.query": "Requête",
"xpack.searchPlayground.inferenceModel": "{name}",
"xpack.searchPlayground.loadConnectorsError": "Erreur lors du chargement des connecteurs. Veuillez vérifier votre configuration et réessayer.",
"xpack.searchPlayground.openAIAzureConnectorTitle": "OpenAI Azure",
"xpack.searchPlayground.openAIAzureModel": "{name} (Azure OpenAI)",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33991,7 +33991,6 @@
"xpack.searchPlayground.header.view.chat": "チャット",
"xpack.searchPlayground.header.view.preview": "プレビュー",
"xpack.searchPlayground.header.view.query": "クエリー",
"xpack.searchPlayground.inferenceModel": "{name}",
"xpack.searchPlayground.loadConnectorsError": "コネクターの読み込みエラーです。構成を確認して、再試行してください。",
"xpack.searchPlayground.openAIAzureConnectorTitle": "OpenAI Azure",
"xpack.searchPlayground.openAIAzureModel": "{name} (Azure OpenAI)",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33978,7 +33978,6 @@
"xpack.searchPlayground.header.view.chat": "聊天",
"xpack.searchPlayground.header.view.preview": "预览",
"xpack.searchPlayground.header.view.query": "查询",
"xpack.searchPlayground.inferenceModel": "{name}",
"xpack.searchPlayground.loadConnectorsError": "加载连接器进出错。请检查您的配置,然后重试。",
"xpack.searchPlayground.openAIAzureConnectorTitle": "OpenAI Azure",
"xpack.searchPlayground.openAIAzureModel": "{name} (Azure OpenAI)",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* 2.0.
*/

import { elasticModelIds } from '@kbn/inference-common';
import { ModelProvider, LLMs } from './types';

export const MODELS: ModelProvider[] = [
Expand Down Expand Up @@ -56,4 +57,10 @@ export const MODELS: ModelProvider[] = [
promptTokenLimit: 2097152,
provider: LLMs.gemini,
},
{
name: 'Elastic Managed LLM',
model: elasticModelIds.RainbowSprinkles,
promptTokenLimit: 200000,
provider: LLMs.inference,
},
];
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ const mapLlmToModels: Record<
icon: string | ((connector: PlaygroundConnector) => string);
getModels: (
connectorName: string,
includeName: boolean
includeName: boolean,
modelId?: string
) => Array<{ label: string; value?: string; promptTokenLimit?: number }>;
}
> = {
Expand Down Expand Up @@ -88,12 +89,11 @@ const mapLlmToModels: Record<
? SERVICE_PROVIDERS[connector.config.provider].icon
: '';
},
getModels: (connectorName) => [
getModels: (connectorName, _, modelId) => [
{
label: i18n.translate('xpack.searchPlayground.inferenceModel', {
defaultMessage: '{name}',
values: { name: connectorName },
}),
label: connectorName,
value: modelId,
promptTokenLimit: MODELS.find((m) => m.model === modelId)?.promptTokenLimit,
},
],
},
Expand Down Expand Up @@ -128,7 +128,13 @@ export const LLMsQuery =
const showConnectorName = Number(mapConnectorTypeToCount?.[connectorType]) > 1;

llmParams
.getModels(connector.name, false)
.getModels(
connector.name,
false,
isInferenceActionConnector(connector)
? connector.config?.providerConfig?.model_id
: undefined
)
.map(({ label, value, promptTokenLimit }) => ({
id: connector?.id + label,
name: label,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ export const UnsavedFormProvider: React.FC<React.PropsWithChildren<UnsavedFormPr
}, [form, storage, setLocalSessionDebounce]);

useEffect(() => {
if (models.length === 0) return; // don't continue if there are no models
const defaultModel = models.find((model) => !model.disabled);
const currentModel = form.getValues(PlaygroundFormFields.summarizationModel);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,13 @@ export interface LLMModel {

export type { ActionConnector, UserConfiguredActionConnector };
export type InferenceActionConnector = ActionConnector & {
config: { provider: ServiceProviderKeys; inferenceId: string };
config: {
providerConfig?: {
model_id?: string;
};
provider: ServiceProviderKeys;
inferenceId: string;
};
};
export type PlaygroundConnector = ActionConnector & { title: string; type: LLMs };

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import { loggerMock, MockedLogger } from '@kbn/logging-mocks';
import { httpServerMock } from '@kbn/core/server/mocks';
import { PluginStartContract as ActionsPluginStartContract } from '@kbn/actions-plugin/server';
import { inferenceMock } from '@kbn/inference-plugin/server/mocks';
import { elasticModelIds } from '@kbn/inference-common';

jest.mock('@kbn/langchain/server', () => {
const original = jest.requireActual('@kbn/langchain/server');
Expand Down Expand Up @@ -236,4 +237,84 @@ describe('getChatParams', () => {
});
expect(result.chatPrompt).toContain('How does it work?');
});

it('returns the correct params for the EIS connector', async () => {
const mockConnector = {
id: 'elastic-llm',
actionTypeId: INFERENCE_CONNECTOR_ID,
config: {
providerConfig: {
model_id: elasticModelIds.RainbowSprinkles,
},
},
};
mockActionsClient.get.mockResolvedValue(mockConnector);

const result = await getChatParams(
{
connectorId: 'elastic-llm',
prompt: 'How does it work?',
citations: false,
},
{ actions, request, logger, inference }
);

expect(result).toMatchObject({
connector: mockConnector,
summarizationModel: elasticModelIds.RainbowSprinkles,
});

expect(Prompt).toHaveBeenCalledWith('How does it work?', {
citations: false,
context: true,
type: 'anthropic',
});
expect(QuestionRewritePrompt).toHaveBeenCalledWith({
type: 'anthropic',
});
expect(inference.getChatModel).toHaveBeenCalledWith({
request,
connectorId: 'elastic-llm',
chatModelOptions: expect.objectContaining({
model: elasticModelIds.RainbowSprinkles,
maxRetries: 0,
}),
});
});

it('it returns provided model with EIS connector', async () => {
const mockConnector = {
id: 'elastic-llm',
actionTypeId: INFERENCE_CONNECTOR_ID,
config: {
providerConfig: {
model_id: elasticModelIds.RainbowSprinkles,
},
},
};
mockActionsClient.get.mockResolvedValue(mockConnector);

const result = await getChatParams(
{
connectorId: 'elastic-llm',
model: 'foo-bar',
prompt: 'How does it work?',
citations: false,
},
{ actions, request, logger, inference }
);

expect(result).toMatchObject({
summarizationModel: 'foo-bar',
});

expect(inference.getChatModel).toHaveBeenCalledWith({
request,
connectorId: 'elastic-llm',
chatModelOptions: expect.objectContaining({
model: 'foo-bar',
maxRetries: 0,
}),
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ import { BaseLanguageModel } from '@langchain/core/language_models/base';
import type { Connector } from '@kbn/actions-plugin/server/application/connector/types';
import { getDefaultArguments } from '@kbn/langchain/server';
import type { InferenceServerStart } from '@kbn/inference-plugin/server';

import { Prompt, QuestionRewritePrompt } from '../../common/prompt';
import { isEISConnector } from '../utils/eis';

export const getChatParams = async (
{
Expand All @@ -39,32 +41,42 @@ export const getChatParams = async (
chatPrompt: string;
questionRewritePrompt: string;
connector: Connector;
summarizationModel?: string;
}> => {
let summarizationModel = model;
const actionsClient = await actions.getActionsClientWithRequest(request);
const connector = await actionsClient.get({ id: connectorId });

let llmType: string;
let modelType: 'openai' | 'anthropic' | 'gemini';

switch (connector.actionTypeId) {
case INFERENCE_CONNECTOR_ID:
llmType = 'inference';
modelType = 'openai';
break;
case OPENAI_CONNECTOR_ID:
llmType = 'openai';
modelType = 'openai';
break;
case BEDROCK_CONNECTOR_ID:
llmType = 'bedrock';
modelType = 'anthropic';
break;
case GEMINI_CONNECTOR_ID:
llmType = 'gemini';
modelType = 'gemini';
break;
default:
throw new Error(`Invalid connector type: ${connector.actionTypeId}`);
if (isEISConnector(connector)) {
llmType = 'bedrock';
modelType = 'anthropic';
if (!summarizationModel && connector.config?.providerConfig?.model_id) {
summarizationModel = connector.config?.providerConfig?.model_id;
}
} else {
switch (connector.actionTypeId) {
case INFERENCE_CONNECTOR_ID:
llmType = 'inference';
modelType = 'openai';
break;
case OPENAI_CONNECTOR_ID:
llmType = 'openai';
modelType = 'openai';
break;
case BEDROCK_CONNECTOR_ID:
llmType = 'bedrock';
modelType = 'anthropic';
break;
case GEMINI_CONNECTOR_ID:
llmType = 'gemini';
modelType = 'gemini';
break;
default:
throw new Error(`Invalid connector type: ${connector.actionTypeId}`);
}
}

const chatPrompt = Prompt(prompt, {
Expand All @@ -81,7 +93,7 @@ export const getChatParams = async (
request,
connectorId,
chatModelOptions: {
model: model || connector?.config?.defaultModel,
model: summarizationModel || connector?.config?.defaultModel,
temperature: getDefaultArguments(llmType).temperature,
// prevents the agent from retrying on failure
// failure could be due to bad connector, we should deliver that result to the client asap
Expand All @@ -90,5 +102,11 @@ export const getChatParams = async (
},
});

return { chatModel, chatPrompt, questionRewritePrompt, connector };
return {
chatModel,
chatPrompt,
questionRewritePrompt,
connector,
summarizationModel: summarizationModel || connector?.config?.defaultModel,
};
};
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,16 @@ export function defineRoutes(routeOptions: DefineRoutesOptions) {
es_client: client.asCurrentUser,
});
const { messages, data } = request.body;
const { chatModel, chatPrompt, questionRewritePrompt, connector } = await getChatParams(
{
connectorId: data.connector_id,
model: data.summarization_model,
citations: data.citations,
prompt: data.prompt,
},
{ actions, inference, logger, request }
);
const { chatModel, chatPrompt, questionRewritePrompt, connector, summarizationModel } =
await getChatParams(
{
connectorId: data.connector_id,
model: data.summarization_model,
citations: data.citations,
prompt: data.prompt,
},
{ actions, inference, logger, request }
);

let sourceFields: ElasticsearchRetrieverContentField;

Expand All @@ -144,7 +145,7 @@ export function defineRoutes(routeOptions: DefineRoutesOptions) {
throw Error(e);
}

const model = MODELS.find((m) => m.model === data.summarization_model);
const model = MODELS.find((m) => m.model === summarizationModel);
const modelPromptLimit = model?.promptTokenLimit;

const chain = ConversationalChain({
Expand All @@ -167,7 +168,7 @@ export function defineRoutes(routeOptions: DefineRoutesOptions) {
connectorType:
connector.actionTypeId +
(connector.config?.apiProvider ? `-${connector.config.apiProvider}` : ''),
model: data.summarization_model ?? '',
model: summarizationModel ?? '',
isCitationsEnabled: data.citations,
});

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { INFERENCE_CONNECTOR_ID } from '@kbn/stack-connectors-plugin/common/inference/constants';
import type { Connector } from '@kbn/actions-plugin/server/application/connector/types';
import { elasticModelIds } from '@kbn/inference-common';

export const isEISConnector = (connector: Connector) => {
if (connector.actionTypeId !== INFERENCE_CONNECTOR_ID) return false;
const modelId = connector.config?.providerConfig?.model_id ?? undefined;
if (modelId === elasticModelIds.RainbowSprinkles) {
return true;
}
return false;
};