diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant_context/index.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant_context/index.tsx index 9ac817e03973a..78b29f30ab8fa 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant_context/index.tsx +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant_context/index.tsx @@ -70,6 +70,7 @@ export interface AssistantProviderProps { children: React.ReactNode; getComments: GetAssistantMessages; http: HttpSetup; + inferenceEnabled?: boolean; baseConversations: Record; nameSpace?: string; navigateToApp: (appId: string, options?: NavigateToAppOptions | undefined) => Promise; @@ -102,6 +103,7 @@ export interface UseAssistantContext { currentUserAvatar?: UserAvatar; getComments: GetAssistantMessages; http: HttpSetup; + inferenceEnabled: boolean; knowledgeBase: KnowledgeBaseConfig; getLastConversationId: (conversationTitle?: string) => string; promptContexts: Record; @@ -144,6 +146,7 @@ export const AssistantProvider: React.FC = ({ children, getComments, http, + inferenceEnabled = false, baseConversations, navigateToApp, nameSpace = DEFAULT_ASSISTANT_NAMESPACE, @@ -276,6 +279,7 @@ export const AssistantProvider: React.FC = ({ docLinks, getComments, http, + inferenceEnabled, knowledgeBase: { ...DEFAULT_KNOWLEDGE_BASE_SETTINGS, ...localStorageKnowledgeBase, @@ -317,6 +321,7 @@ export const AssistantProvider: React.FC = ({ docLinks, getComments, http, + inferenceEnabled, localStorageKnowledgeBase, promptContexts, navigateToApp, diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/connector_selector/index.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/connector_selector/index.tsx index 3127ab7fe3911..29f0c7ef10b7a 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/connector_selector/index.tsx +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/connector_selector/index.tsx @@ -97,12 +97,10 @@ export const ConnectorSelector: React.FC = React.memo( const connectorOptions = useMemo( () => (aiConnectors ?? []).map((connector) => { - const connectorTypeTitle = - getGenAiConfig(connector)?.apiProvider ?? - getActionTypeTitle(actionTypeRegistry.get(connector.actionTypeId)); const connectorDetails = connector.isPreconfigured ? i18n.PRECONFIGURED_CONNECTOR - : connectorTypeTitle; + : getGenAiConfig(connector)?.apiProvider ?? + getActionTypeTitle(actionTypeRegistry.get(connector.actionTypeId)); const attackDiscoveryStats = stats !== null ? stats.statsPerConnector.find((s) => s.connectorId === connector.id) ?? null diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/connector_selector_inline/action_type_selector_modal.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/connector_selector_inline/action_type_selector_modal.tsx index 090b5c01d125c..818d729b69f03 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/connector_selector_inline/action_type_selector_modal.tsx +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/connector_selector_inline/action_type_selector_modal.tsx @@ -29,7 +29,7 @@ interface Props { actionTypeSelectorInline: boolean; } const itemClassName = css` - inline-size: 220px; + inline-size: 150px; .euiKeyPadMenuItem__label { white-space: nowrap; diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/helpers.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/helpers.tsx index 99550f1cafe75..63f6b3867ba7d 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/helpers.tsx +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/helpers.tsx @@ -68,10 +68,11 @@ export const getConnectorTypeTitle = ( if (!connector) { return null; } - const connectorTypeTitle = - getGenAiConfig(connector)?.apiProvider ?? - getActionTypeTitle(actionTypeRegistry.get(connector.actionTypeId)); - const actionType = connector.isPreconfigured ? PRECONFIGURED_CONNECTOR : connectorTypeTitle; + + const actionType = connector.isPreconfigured + ? PRECONFIGURED_CONNECTOR + : getGenAiConfig(connector)?.apiProvider ?? + getActionTypeTitle(actionTypeRegistry.get(connector.actionTypeId)); return actionType; }; diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/use_load_action_types/index.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/use_load_action_types/index.tsx index 81c8c2a4ea7e4..8b34db331666e 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/use_load_action_types/index.tsx +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/use_load_action_types/index.tsx @@ -41,18 +41,12 @@ export const useLoadActionTypes = ({ featureId: GenerativeAIForSecurityConnectorFeatureId, }); - const actionTypeKey = { - bedrock: '.bedrock', - openai: '.gen-ai', - gemini: '.gemini', - }; + // TODO add .inference once all the providers support unified completion + const actionTypes = ['.bedrock', '.gen-ai', '.gemini']; - const sortedData = queryResult - .filter((p) => - [actionTypeKey.bedrock, actionTypeKey.openai, actionTypeKey.gemini].includes(p.id) - ) + return queryResult + .filter((p) => actionTypes.includes(p.id)) .sort((a, b) => a.name.localeCompare(b.name)); - return sortedData; }, { retry: false, diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/use_load_connectors/index.test.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/use_load_connectors/index.test.tsx index 685d01c988e0d..ff6df23779646 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/use_load_connectors/index.test.tsx +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/use_load_connectors/index.test.tsx @@ -8,6 +8,8 @@ import { waitFor, renderHook } from '@testing-library/react'; import { useLoadConnectors, Props } from '.'; import { mockConnectors } from '../../mock/connectors'; +import { TestProviders } from '../../mock/test_providers/test_providers'; +import React, { ReactNode } from 'react'; const mockConnectorsAndExtras = [ ...mockConnectors, @@ -45,17 +47,6 @@ const loadConnectorsResult = mockConnectors.map((c) => ({ isSystemAction: false, })); -jest.mock('@tanstack/react-query', () => ({ - useQuery: jest.fn().mockImplementation(async (queryKey, fn, opts) => { - try { - const res = await fn(); - return Promise.resolve(res); - } catch (e) { - opts.onError(e); - } - }), -})); - const http = { get: jest.fn().mockResolvedValue(connectorsApiResponse), }; @@ -63,24 +54,56 @@ const toasts = { addError: jest.fn(), }; const defaultProps = { http, toasts } as unknown as Props; + +const createWrapper = (inferenceEnabled = false) => { + // eslint-disable-next-line react/display-name + return ({ children }: { children: ReactNode }) => ( + {children} + ); +}; + describe('useLoadConnectors', () => { beforeEach(() => { jest.clearAllMocks(); }); it('should call api to load action types', async () => { - renderHook(() => useLoadConnectors(defaultProps)); + renderHook(() => useLoadConnectors(defaultProps), { + wrapper: TestProviders, + }); await waitFor(() => { expect(defaultProps.http.get).toHaveBeenCalledWith('/api/actions/connectors'); expect(toasts.addError).not.toHaveBeenCalled(); }); }); - it('should return sorted action types, removing isMissingSecrets and wrong action type ids', async () => { - const { result } = renderHook(() => useLoadConnectors(defaultProps)); + it('should return sorted action types, removing isMissingSecrets and wrong action type ids, excluding .inference results', async () => { + const { result } = renderHook(() => useLoadConnectors(defaultProps), { + wrapper: TestProviders, + }); + await waitFor(() => { + expect(result.current.data).toStrictEqual( + loadConnectorsResult + .filter((c) => c.actionTypeId !== '.inference') + // @ts-ignore ts does not like config, but we define it in the mock data + .map((c) => ({ ...c, apiProvider: c.config.apiProvider })) + ); + }); + }); + + it('includes preconfigured .inference results when inferenceEnabled is true', async () => { + const { result } = renderHook(() => useLoadConnectors(defaultProps), { + wrapper: createWrapper(true), + }); await waitFor(() => { - expect(result.current).resolves.toStrictEqual( - // @ts-ignore ts does not like config, but we define it in the mock data - loadConnectorsResult.map((c) => ({ ...c, apiProvider: c.config.apiProvider })) + expect(result.current.data).toStrictEqual( + mockConnectors + .filter( + (c) => + c.actionTypeId !== '.inference' || + (c.actionTypeId === '.inference' && c.isPreconfigured) + ) + // @ts-ignore ts does not like config, but we define it in the mock data + .map((c) => ({ ...c, referencedByCount: 0, apiProvider: c?.config?.apiProvider })) ); }); }); @@ -88,7 +111,9 @@ describe('useLoadConnectors', () => { const mockHttp = { get: jest.fn().mockRejectedValue(new Error('this is an error')), } as unknown as Props['http']; - renderHook(() => useLoadConnectors({ ...defaultProps, http: mockHttp })); + renderHook(() => useLoadConnectors({ ...defaultProps, http: mockHttp }), { + wrapper: TestProviders, + }); await waitFor(() => expect(toasts.addError).toHaveBeenCalled()); }); }); diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/use_load_connectors/index.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/use_load_connectors/index.tsx index 293993e82fde6..b54537eb3439f 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/use_load_connectors/index.tsx +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/use_load_connectors/index.tsx @@ -13,6 +13,7 @@ import type { IHttpFetchError } from '@kbn/core-http-browser'; import { HttpSetup } from '@kbn/core-http-browser'; import { IToasts } from '@kbn/core-notifications-browser'; import { OpenAiProviderType } from '@kbn/stack-connectors-plugin/common/openai/constants'; +import { useAssistantContext } from '../../assistant_context'; import { AIConnector } from '../connector_selector'; import * as i18n from '../translations'; @@ -27,16 +28,17 @@ export interface Props { toasts?: IToasts; } -const actionTypeKey = { - bedrock: '.bedrock', - openai: '.gen-ai', - gemini: '.gemini', -}; +const actionTypes = ['.bedrock', '.gen-ai', '.gemini']; export const useLoadConnectors = ({ http, toasts, }: Props): UseQueryResult => { + const { inferenceEnabled } = useAssistantContext(); + if (inferenceEnabled) { + actionTypes.push('.inference'); + } + return useQuery( QUERY_KEY, async () => { @@ -45,9 +47,9 @@ export const useLoadConnectors = ({ (acc: AIConnector[], connector) => [ ...acc, ...(!connector.isMissingSecrets && - [actionTypeKey.bedrock, actionTypeKey.openai, actionTypeKey.gemini].includes( - connector.actionTypeId - ) + actionTypes.includes(connector.actionTypeId) && + // only include preconfigured .inference connectors + (connector.actionTypeId !== '.inference' || connector.isPreconfigured) ? [ { ...connector, diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/mock/connectors.ts b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/mock/connectors.ts index 6f72a89205251..1735da8a29b7e 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/mock/connectors.ts +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/mock/connectors.ts @@ -71,4 +71,26 @@ export const mockConnectors: AIConnector[] = [ apiProvider: 'OpenAI', }, }, + { + id: 'c29c28a0-20fe-11ee-9386-a1f4d42ec542', + name: 'Regular Inference Connector', + isMissingSecrets: false, + actionTypeId: '.inference', + secrets: {}, + isPreconfigured: false, + isDeprecated: false, + isSystemAction: false, + config: { + apiProvider: 'OpenAI', + }, + }, + { + id: 'c29c28a0-20fe-11ee-9396-a1f4d42ec542', + name: 'Preconfigured Inference Connector', + isMissingSecrets: false, + actionTypeId: '.inference', + isPreconfigured: true, + isDeprecated: false, + isSystemAction: false, + }, ]; diff --git a/x-pack/platform/packages/shared/kbn-langchain/server/language_models/chat_openai.test.ts b/x-pack/platform/packages/shared/kbn-langchain/server/language_models/chat_openai.test.ts index 6dda607f3d192..606394029ab66 100644 --- a/x-pack/platform/packages/shared/kbn-langchain/server/language_models/chat_openai.test.ts +++ b/x-pack/platform/packages/shared/kbn-langchain/server/language_models/chat_openai.test.ts @@ -79,105 +79,205 @@ describe('ActionsClientChatOpenAI', () => { }); }); - describe('completionWithRetry streaming: true', () => { - beforeEach(() => { - jest.clearAllMocks(); - mockStreamExecute.mockImplementation(() => ({ - data: { - consumerStream: asyncGenerator() as unknown as Stream, - tokenCountStream: asyncGenerator() as unknown as Stream, - }, - status: 'ok', - })); - }); - const defaultStreamingArgs: OpenAI.ChatCompletionCreateParamsStreaming = { - messages: [{ content: prompt, role: 'user' }], - stream: true, - model: 'gpt-4o', - n: 99, - stop: ['a stop sequence'], - functions: [jest.fn()], - }; - it('returns the expected data', async () => { - actionsClient.execute.mockImplementation(mockStreamExecute); - const actionsClientChatOpenAI = new ActionsClientChatOpenAI({ - ...defaultArgs, - streaming: true, - actionsClient, + describe('OpenAI', () => { + describe('completionWithRetry streaming: true', () => { + beforeEach(() => { + jest.clearAllMocks(); + mockStreamExecute.mockImplementation(() => ({ + data: { + consumerStream: asyncGenerator() as unknown as Stream, + tokenCountStream: asyncGenerator() as unknown as Stream, + }, + status: 'ok', + })); }); + const defaultStreamingArgs: OpenAI.ChatCompletionCreateParamsStreaming = { + messages: [{ content: prompt, role: 'user' }], + stream: true, + model: 'gpt-4o', + n: 99, + stop: ['a stop sequence'], + tools: [{ function: jest.fn(), type: 'function' }], + }; + it('returns the expected data', async () => { + actionsClient.execute.mockImplementation(mockStreamExecute); + const actionsClientChatOpenAI = new ActionsClientChatOpenAI({ + ...defaultArgs, + streaming: true, + actionsClient, + }); - const result: AsyncIterable = - await actionsClientChatOpenAI.completionWithRetry(defaultStreamingArgs); - expect(mockStreamExecute).toHaveBeenCalledWith({ - actionId: connectorId, - params: { - subActionParams: { - model: 'gpt-4o', - messages: [{ role: 'user', content: 'Do you know my name?' }], - signal, - timeout: 999999, - n: defaultStreamingArgs.n, - stop: defaultStreamingArgs.stop, - functions: defaultStreamingArgs.functions, - temperature: 0.2, + const result: AsyncIterable = + await actionsClientChatOpenAI.completionWithRetry(defaultStreamingArgs); + expect(mockStreamExecute).toHaveBeenCalledWith({ + actionId: connectorId, + params: { + subActionParams: { + model: 'gpt-4o', + messages: [{ role: 'user', content: 'Do you know my name?' }], + signal, + timeout: 999999, + n: defaultStreamingArgs.n, + stop: defaultStreamingArgs.stop, + tools: defaultStreamingArgs.tools, + temperature: 0.2, + }, + subAction: 'invokeAsyncIterator', }, - subAction: 'invokeAsyncIterator', - }, - signal, + signal, + }); + expect(result).toEqual(asyncGenerator()); }); - expect(result).toEqual(asyncGenerator()); }); - }); - describe('completionWithRetry streaming: false', () => { - const defaultNonStreamingArgs: OpenAI.ChatCompletionCreateParamsNonStreaming = { - messages: [{ content: prompt, role: 'user' }], - stream: false, - model: 'gpt-4o', - }; - it('returns the expected data', async () => { - const actionsClientChatOpenAI = new ActionsClientChatOpenAI(defaultArgs); + describe('completionWithRetry streaming: false', () => { + const defaultNonStreamingArgs: OpenAI.ChatCompletionCreateParamsNonStreaming = { + messages: [{ content: prompt, role: 'user' }], + stream: false, + model: 'gpt-4o', + }; + it('returns the expected data', async () => { + const actionsClientChatOpenAI = new ActionsClientChatOpenAI(defaultArgs); - const result: OpenAI.ChatCompletion = await actionsClientChatOpenAI.completionWithRetry( - defaultNonStreamingArgs - ); - expect(mockExecute).toHaveBeenCalledWith({ - actionId: connectorId, - params: { - subActionParams: { - body: '{"temperature":0.2,"model":"gpt-4o","messages":[{"role":"user","content":"Do you know my name?"}]}', - signal, - timeout: 999999, + const result: OpenAI.ChatCompletion = await actionsClientChatOpenAI.completionWithRetry( + defaultNonStreamingArgs + ); + expect(mockExecute).toHaveBeenCalledWith({ + actionId: connectorId, + params: { + subActionParams: { + body: '{"temperature":0.2,"model":"gpt-4o","messages":[{"role":"user","content":"Do you know my name?"}]}', + signal, + timeout: 999999, + }, + subAction: 'run', }, - subAction: 'run', - }, - signal, + signal, + }); + expect(result.choices[0].message.content).toEqual(mockActionResponse.message); + }); + + it('rejects with the expected error when the action result status is error', async () => { + const hasErrorStatus = jest.fn().mockImplementation(() => ({ + message: 'action-result-message', + serviceMessage: 'action-result-service-message', + status: 'error', // <-- error status + })); + actionsClient.execute.mockRejectedValueOnce(hasErrorStatus); + + const actionsClientChatOpenAI = new ActionsClientChatOpenAI({ + ...defaultArgs, + actionsClient, + }); + + expect(actionsClientChatOpenAI.completionWithRetry(defaultNonStreamingArgs)) + .rejects.toThrowError( + 'ActionsClientChatOpenAI: action result status is error: action-result-message - action-result-service-message' + ) + .catch(() => { + /* ...handle/report the error (or just suppress it, if that's appropriate + [which it sometimes, though rarely, is])... + */ + }); }); - expect(result.choices[0].message.content).toEqual(mockActionResponse.message); }); + }); + + describe('Inference', () => { + describe('completionWithRetry streaming: true', () => { + beforeEach(() => { + jest.clearAllMocks(); + mockStreamExecute.mockImplementation(() => ({ + data: { + consumerStream: asyncGenerator() as unknown as Stream, + tokenCountStream: asyncGenerator() as unknown as Stream, + }, + status: 'ok', + })); + }); + const defaultStreamingArgs: OpenAI.ChatCompletionCreateParamsStreaming = { + messages: [{ content: prompt, role: 'user' }], + stream: true, + model: 'gpt-4o', + n: 99, + stop: ['a stop sequence'], + tools: [{ function: jest.fn(), type: 'function' }], + }; + it('returns the expected data', async () => { + actionsClient.execute.mockImplementation(mockStreamExecute); + const actionsClientChatOpenAI = new ActionsClientChatOpenAI({ + ...defaultArgs, + llmType: 'inference', + streaming: true, + actionsClient, + }); - it('rejects with the expected error when the action result status is error', async () => { - const hasErrorStatus = jest.fn().mockImplementation(() => ({ - message: 'action-result-message', - serviceMessage: 'action-result-service-message', - status: 'error', // <-- error status - })); - actionsClient.execute.mockRejectedValueOnce(hasErrorStatus); + const result: AsyncIterable = + await actionsClientChatOpenAI.completionWithRetry(defaultStreamingArgs); + expect(mockStreamExecute).toHaveBeenCalledWith({ + actionId: connectorId, + params: { + subAction: 'unified_completion_async_iterator', + subActionParams: { + body: { + model: 'gpt-4o', + messages: [{ role: 'user', content: 'Do you know my name?' }], - const actionsClientChatOpenAI = new ActionsClientChatOpenAI({ - ...defaultArgs, - actionsClient, + n: defaultStreamingArgs.n, + stop: defaultStreamingArgs.stop, + tools: defaultStreamingArgs.tools, + temperature: 0.2, + }, + signal, + }, + }, + signal, + }); + expect(result).toEqual(asyncGenerator()); }); + }); - expect(actionsClientChatOpenAI.completionWithRetry(defaultNonStreamingArgs)) - .rejects.toThrowError( - 'ActionsClientChatOpenAI: action result status is error: action-result-message - action-result-service-message' - ) - .catch(() => { - /* ...handle/report the error (or just suppress it, if that's appropriate - [which it sometimes, though rarely, is])... - */ + describe('completionWithRetry streaming: false', () => { + const defaultNonStreamingArgs: OpenAI.ChatCompletionCreateParamsNonStreaming = { + messages: [{ content: prompt, role: 'user' }], + stream: false, + model: 'gpt-4o', + n: 99, + stop: ['a stop sequence'], + tools: [{ function: jest.fn(), type: 'function' }], + }; + it('returns the expected data', async () => { + const actionsClientChatOpenAI = new ActionsClientChatOpenAI({ + ...defaultArgs, + llmType: 'inference', }); + + const result: OpenAI.ChatCompletion = await actionsClientChatOpenAI.completionWithRetry( + defaultNonStreamingArgs + ); + + expect(JSON.stringify(mockExecute.mock.calls[0][0])).toEqual( + JSON.stringify({ + actionId: connectorId, + params: { + subAction: 'unified_completion', + subActionParams: { + body: { + temperature: 0.2, + model: 'gpt-4o', + n: 99, + stop: ['a stop sequence'], + tools: [{ function: jest.fn(), type: 'function' }], + messages: [{ role: 'user', content: 'Do you know my name?' }], + }, + signal, + }, + }, + signal, + }) + ); + expect(result.choices[0].message.content).toEqual(mockActionResponse.message); + }); }); }); }); diff --git a/x-pack/platform/packages/shared/kbn-langchain/server/language_models/chat_openai.ts b/x-pack/platform/packages/shared/kbn-langchain/server/language_models/chat_openai.ts index f679193c23f92..bbbc419143f31 100644 --- a/x-pack/platform/packages/shared/kbn-langchain/server/language_models/chat_openai.ts +++ b/x-pack/platform/packages/shared/kbn-langchain/server/language_models/chat_openai.ts @@ -15,7 +15,11 @@ import { Stream } from 'openai/streaming'; import type OpenAI from 'openai'; import { PublicMethodsOf } from '@kbn/utility-types'; import { DEFAULT_OPEN_AI_MODEL, DEFAULT_TIMEOUT } from './constants'; -import { InvokeAIActionParamsSchema, RunActionParamsSchema } from './types'; +import { + InferenceChatCompleteParamsSchema, + InvokeAIActionParamsSchema, + RunActionParamsSchema, +} from './types'; const LLM_TYPE = 'ActionsClientChatOpenAI'; @@ -136,7 +140,7 @@ export class ActionsClientChatOpenAI extends ChatOpenAI { | OpenAI.ChatCompletionCreateParamsNonStreaming ): Promise | OpenAI.ChatCompletion> { return this.caller.call(async () => { - const requestBody = this.formatRequestForActionsClient(completionRequest); + const requestBody = this.formatRequestForActionsClient(completionRequest, this.llmType); this.#logger.debug( () => `${LLM_TYPE}#completionWithRetry ${this.#traceId} assistantMessage:\n${JSON.stringify( @@ -179,11 +183,15 @@ export class ActionsClientChatOpenAI extends ChatOpenAI { formatRequestForActionsClient( completionRequest: | OpenAI.ChatCompletionCreateParamsNonStreaming - | OpenAI.ChatCompletionCreateParamsStreaming + | OpenAI.ChatCompletionCreateParamsStreaming, + llmType: string ): { actionId: string; params: { - subActionParams: InvokeAIActionParamsSchema | RunActionParamsSchema; + subActionParams: + | InvokeAIActionParamsSchema + | RunActionParamsSchema + | InferenceChatCompleteParamsSchema; subAction: string; }; signal?: AbortSignal; @@ -194,33 +202,48 @@ export class ActionsClientChatOpenAI extends ChatOpenAI { // security sends this from connectors, it is only missing from preconfigured connectors // this should be undefined otherwise so the connector handles the model (stack_connector has access to preconfigured connector model values) model: this.model, - // ensure we take the messages from the completion request, not the client request n: completionRequest.n, stop: completionRequest.stop, - functions: completionRequest.functions, + tools: completionRequest.tools, + ...(completionRequest.tool_choice ? { tool_choice: completionRequest.tool_choice } : {}), + // deprecated, use tools + ...(completionRequest.functions ? { functions: completionRequest?.functions } : {}), + // ensure we take the messages from the completion request, not the client request messages: completionRequest.messages.map((message) => ({ role: message.role, content: message.content ?? '', ...('name' in message ? { name: message?.name } : {}), - ...('function_call' in message ? { function_call: message?.function_call } : {}), ...('tool_calls' in message ? { tool_calls: message?.tool_calls } : {}), ...('tool_call_id' in message ? { tool_call_id: message?.tool_call_id } : {}), + // deprecated, use tool_calls + ...('function_call' in message ? { function_call: message?.function_call } : {}), })), }; + const subAction = + llmType === 'inference' + ? completionRequest.stream + ? 'unified_completion_async_iterator' + : 'unified_completion' + : // langchain expects stream to be of type AsyncIterator + // for non-stream, use `run` instead of `invokeAI` in order to get the entire OpenAI.ChatCompletion response, + // which may contain non-content messages like functions + completionRequest.stream + ? 'invokeAsyncIterator' + : 'run'; // create a new connector request body with the assistant message: + const subActionParams = { + ...(llmType === 'inference' + ? { body } + : completionRequest.stream + ? { ...body, timeout: this.#timeout ?? DEFAULT_TIMEOUT } + : { body: JSON.stringify(body), timeout: this.#timeout ?? DEFAULT_TIMEOUT }), + signal: this.#signal, + }; return { actionId: this.#connectorId, params: { - // langchain expects stream to be of type AsyncIterator - // for non-stream, use `run` instead of `invokeAI` in order to get the entire OpenAI.ChatCompletion response, - // which may contain non-content messages like functions - subAction: completionRequest.stream ? 'invokeAsyncIterator' : 'run', - subActionParams: { - ...(completionRequest.stream ? body : { body: JSON.stringify(body) }), - signal: this.#signal, - // This timeout is large because LangChain prompts can be complicated and take a long time - timeout: this.#timeout ?? DEFAULT_TIMEOUT, - }, + subAction, + subActionParams, }, signal: this.#signal, }; diff --git a/x-pack/platform/packages/shared/kbn-langchain/server/language_models/llm.test.ts b/x-pack/platform/packages/shared/kbn-langchain/server/language_models/llm.test.ts index aa33bbf7a6d44..0f37ee77d53b7 100644 --- a/x-pack/platform/packages/shared/kbn-langchain/server/language_models/llm.test.ts +++ b/x-pack/platform/packages/shared/kbn-langchain/server/language_models/llm.test.ts @@ -10,18 +10,13 @@ import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/act import { ActionsClientLlm } from './llm'; import { mockActionResponse } from './mocks'; +import { getDefaultArguments } from '..'; +import { DEFAULT_TIMEOUT } from './constants'; const connectorId = 'mock-connector-id'; const actionsClient = actionsClientMock.create(); -actionsClient.execute.mockImplementation( - jest.fn().mockImplementation(() => ({ - data: mockActionResponse, - status: 'ok', - })) -); - const mockLogger = loggerMock.create(); const prompt = 'Do you know my name?'; @@ -29,20 +24,12 @@ const prompt = 'Do you know my name?'; describe('ActionsClientLlm', () => { beforeEach(() => { jest.clearAllMocks(); - }); - - describe('getActionResultData', () => { - it('returns the expected data', async () => { - const actionsClientLlm = new ActionsClientLlm({ - actionsClient, - connectorId, - logger: mockLogger, - }); - - const result = await actionsClientLlm._call(prompt); // ignore the result - - expect(result).toEqual(mockActionResponse.message); - }); + actionsClient.execute.mockImplementation( + jest.fn().mockImplementation(() => ({ + data: mockActionResponse, + status: 'ok', + })) + ); }); describe('_llmType', () => { @@ -69,6 +56,68 @@ describe('ActionsClientLlm', () => { }); describe('_call', () => { + it('executes with the expected arguments when llmType is not inference', async () => { + const actionsClientLlm = new ActionsClientLlm({ + actionsClient, + connectorId, + logger: mockLogger, + }); + await actionsClientLlm._call(prompt); + expect(actionsClient.execute).toHaveBeenCalledWith({ + actionId: 'mock-connector-id', + params: { + subAction: 'invokeAI', + subActionParams: { + messages: [ + { + content: 'Do you know my name?', + role: 'user', + }, + ], + ...getDefaultArguments(), + timeout: DEFAULT_TIMEOUT, + }, + }, + }); + }); + it('executes with the expected arguments when llmType is inference', async () => { + actionsClient.execute.mockImplementation( + jest.fn().mockImplementation(() => ({ + data: { + choices: [ + { + message: { content: mockActionResponse.message }, + }, + ], + }, + status: 'ok', + })) + ); + const actionsClientLlm = new ActionsClientLlm({ + actionsClient, + connectorId, + logger: mockLogger, + llmType: 'inference', + }); + const result = await actionsClientLlm._call(prompt); + expect(actionsClient.execute).toHaveBeenCalledWith({ + actionId: 'mock-connector-id', + params: { + subAction: 'unified_completion', + subActionParams: { + body: { + messages: [ + { + content: 'Do you know my name?', + role: 'user', + }, + ], + }, + }, + }, + }); + expect(result).toEqual(mockActionResponse.message); + }); it('returns the expected content when _call is invoked', async () => { const actionsClientLlm = new ActionsClientLlm({ actionsClient, @@ -77,8 +126,7 @@ describe('ActionsClientLlm', () => { }); const result = await actionsClientLlm._call(prompt); - - expect(result).toEqual('Yes, your name is Andrew. How can I assist you further, Andrew?'); + expect(result).toEqual(mockActionResponse.message); }); it('rejects with the expected error when the action result status is error', async () => { diff --git a/x-pack/platform/packages/shared/kbn-langchain/server/language_models/llm.ts b/x-pack/platform/packages/shared/kbn-langchain/server/language_models/llm.ts index 2a634ccb490cf..787c4e85b1358 100644 --- a/x-pack/platform/packages/shared/kbn-langchain/server/language_models/llm.ts +++ b/x-pack/platform/packages/shared/kbn-langchain/server/language_models/llm.ts @@ -89,24 +89,35 @@ export class ActionsClientLlm extends LLM { assistantMessage )} ` ); + // create a new connector request body with the assistant message: const requestBody = { actionId: this.#connectorId, - params: { - // hard code to non-streaming subaction as this class only supports non-streaming - subAction: 'invokeAI', - subActionParams: { - model: this.model, - messages: [assistantMessage], // the assistant message - ...getDefaultArguments(this.llmType, this.temperature), - // This timeout is large because LangChain prompts can be complicated and take a long time - timeout: this.#timeout ?? DEFAULT_TIMEOUT, - }, - }, + params: + this.llmType === 'inference' + ? { + subAction: 'unified_completion', + subActionParams: { + body: { + model: this.model, + messages: [assistantMessage], // the assistant message + }, + }, + } + : { + // hard code to non-streaming subaction as this class only supports non-streaming + subAction: 'invokeAI', + subActionParams: { + model: this.model, + messages: [assistantMessage], // the assistant message + ...getDefaultArguments(this.llmType, this.temperature), + // This timeout is large because LangChain prompts can be complicated and take a long time + timeout: this.#timeout ?? DEFAULT_TIMEOUT, + }, + }, }; const actionResult = await this.#actionsClient.execute(requestBody); - if (actionResult.status === 'error') { const error = new Error( `${LLM_TYPE}: action result status is error: ${actionResult?.message} - ${actionResult?.serviceMessage}` @@ -117,6 +128,18 @@ export class ActionsClientLlm extends LLM { throw error; } + if (this.llmType === 'inference') { + const content = get('data.choices[0].message.content', actionResult); + + if (typeof content !== 'string') { + throw new Error( + `${LLM_TYPE}: inference content should be a string, but it had an unexpected type: ${typeof content}` + ); + } + + return content; // per the contact of _call, return a string + } + const content = get('data.message', actionResult); if (typeof content !== 'string') { diff --git a/x-pack/platform/packages/shared/kbn-langchain/server/language_models/types.ts b/x-pack/platform/packages/shared/kbn-langchain/server/language_models/types.ts index 35415e8eaf118..69d18d4f1b2a0 100644 --- a/x-pack/platform/packages/shared/kbn-langchain/server/language_models/types.ts +++ b/x-pack/platform/packages/shared/kbn-langchain/server/language_models/types.ts @@ -45,6 +45,9 @@ export interface RunActionParamsSchema { signal?: AbortSignal; timeout?: number; } +export interface InferenceChatCompleteParamsSchema { + body: InvokeAIActionParamsSchema; +} export interface TraceOptions { evaluationId?: string; diff --git a/x-pack/platform/plugins/shared/stack_connectors/common/openai/schema.ts b/x-pack/platform/plugins/shared/stack_connectors/common/openai/schema.ts index 8a08da157b163..7c3d4afcb8d1e 100644 --- a/x-pack/platform/plugins/shared/stack_connectors/common/openai/schema.ts +++ b/x-pack/platform/plugins/shared/stack_connectors/common/openai/schema.ts @@ -68,6 +68,41 @@ const AIMessage = schema.object({ export const InvokeAIActionParamsSchema = schema.object({ messages: schema.arrayOf(AIMessage), model: schema.maybe(schema.string()), + tools: schema.maybe( + schema.arrayOf( + schema.object( + { + type: schema.literal('function'), + function: schema.object( + { + description: schema.maybe(schema.string()), + name: schema.string(), + parameters: schema.object({}, { unknowns: 'allow' }), + strict: schema.maybe(schema.boolean()), + }, + { unknowns: 'allow' } + ), + }, + // Not sure if this will include other properties, we should pass them if it does + { unknowns: 'allow' } + ) + ) + ), + tool_choice: schema.maybe( + schema.oneOf([ + schema.literal('none'), + schema.literal('auto'), + schema.literal('required'), + schema.object( + { + type: schema.literal('function'), + function: schema.object({ name: schema.string() }, { unknowns: 'allow' }), + }, + { unknowns: 'ignore' } + ), + ]) + ), + // Deprecated in favor of tools functions: schema.maybe( schema.arrayOf( schema.object( @@ -89,6 +124,7 @@ export const InvokeAIActionParamsSchema = schema.object({ ) ) ), + // Deprecated in favor of tool_choice function_call: schema.maybe( schema.oneOf([ schema.literal('none'), diff --git a/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/inference/inference.ts b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/inference/inference.ts index 63d8904a6af8a..5bb52a3160a45 100644 --- a/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/inference/inference.ts +++ b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/inference/inference.ts @@ -49,7 +49,7 @@ import { chunksIntoMessage, eventSourceStreamIntoObservable } from './helpers'; export class InferenceConnector extends SubActionConnector { // Not using Axios protected getResponseErrorMessage(error: AxiosError): string { - throw new Error('Method not implemented.'); + throw new Error(error.message || 'Method not implemented.'); } private inferenceId; @@ -128,11 +128,13 @@ export class InferenceConnector extends SubActionConnector { const obs$ = from(eventSourceStreamIntoObservable(res as unknown as Readable)).pipe( filter((line) => !!line && line !== '[DONE]'), map((line) => { - return JSON.parse(line) as OpenAI.ChatCompletionChunk | { error: { message: string } }; + return JSON.parse(line) as + | OpenAI.ChatCompletionChunk + | { error: { message?: string; reason?: string } }; }), tap((line) => { if ('error' in line) { - throw new Error(line.error.message); + throw new Error(line.error.message || line.error.reason || 'Unknown error'); } if ( 'choices' in line && diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.test.ts new file mode 100644 index 0000000000000..1953dd4d45bf5 --- /dev/null +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.test.ts @@ -0,0 +1,211 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock'; +import { callAssistantGraph } from '.'; +import { getDefaultAssistantGraph } from './graph'; +import { invokeGraph, streamGraph } from './helpers'; +import { loggerMock } from '@kbn/logging-mocks'; +import { AgentExecutorParams, AssistantDataClients } from '../../executors/types'; +import { elasticsearchClientMock } from '@kbn/core-elasticsearch-client-server-mocks'; +import { getFindAnonymizationFieldsResultWithSingleHit } from '../../../../__mocks__/response'; +import { + createOpenAIToolsAgent, + createStructuredChatAgent, + createToolCallingAgent, +} from 'langchain/agents'; +jest.mock('./graph'); +jest.mock('./helpers'); +jest.mock('langchain/agents'); +jest.mock('@kbn/langchain/server/tracers/apm'); +jest.mock('@kbn/langchain/server/tracers/telemetry'); +const getDefaultAssistantGraphMock = getDefaultAssistantGraph as jest.Mock; +describe('callAssistantGraph', () => { + const mockDataClients = { + anonymizationFieldsDataClient: { + findDocuments: jest.fn(), + }, + kbDataClient: { + isInferenceEndpointExists: jest.fn(), + getAssistantTools: jest.fn(), + }, + } as unknown as AssistantDataClients; + + const mockRequest = { + body: { + model: 'test-model', + }, + }; + + const defaultParams = { + actionsClient: actionsClientMock.create(), + alertsIndexPattern: 'test-pattern', + assistantTools: [], + connectorId: 'test-connector', + conversationId: 'test-conversation', + dataClients: mockDataClients, + esClient: elasticsearchClientMock.createScopedClusterClient().asCurrentUser, + inference: {}, + langChainMessages: [{ content: 'test message' }], + llmTasks: { retrieveDocumentationAvailable: jest.fn(), retrieveDocumentation: jest.fn() }, + llmType: 'openai', + isOssModel: false, + logger: loggerMock.create(), + isStream: false, + onLlmResponse: jest.fn(), + onNewReplacements: jest.fn(), + replacements: [], + request: mockRequest, + size: 1, + systemPrompt: 'test-prompt', + telemetry: {}, + telemetryParams: {}, + traceOptions: {}, + responseLanguage: 'English', + } as unknown as AgentExecutorParams; + + beforeEach(() => { + jest.clearAllMocks(); + (mockDataClients?.kbDataClient?.isInferenceEndpointExists as jest.Mock).mockResolvedValue(true); + getDefaultAssistantGraphMock.mockReturnValue({}); + (invokeGraph as jest.Mock).mockResolvedValue({ + output: 'test-output', + traceData: {}, + conversationId: 'new-conversation-id', + }); + (streamGraph as jest.Mock).mockResolvedValue({}); + (mockDataClients?.anonymizationFieldsDataClient?.findDocuments as jest.Mock).mockResolvedValue( + getFindAnonymizationFieldsResultWithSingleHit() + ); + }); + + it('calls invokeGraph with correct parameters for non-streaming', async () => { + const result = await callAssistantGraph(defaultParams); + + expect(invokeGraph).toHaveBeenCalledWith( + expect.objectContaining({ + inputs: expect.objectContaining({ + input: 'test message', + }), + }) + ); + expect(result.body).toEqual({ + connector_id: 'test-connector', + data: 'test-output', + trace_data: {}, + replacements: [], + status: 'ok', + conversationId: 'new-conversation-id', + }); + }); + + it('calls streamGraph with correct parameters for streaming', async () => { + const params = { ...defaultParams, isStream: true }; + await callAssistantGraph(params); + + expect(streamGraph).toHaveBeenCalledWith( + expect.objectContaining({ + inputs: expect.objectContaining({ + input: 'test message', + }), + }) + ); + }); + + it('calls getDefaultAssistantGraph without signal for openai', async () => { + await callAssistantGraph(defaultParams); + expect(getDefaultAssistantGraphMock.mock.calls[0][0]).not.toHaveProperty('signal'); + }); + + it('calls getDefaultAssistantGraph with signal for bedrock', async () => { + await callAssistantGraph({ ...defaultParams, llmType: 'bedrock' }); + expect(getDefaultAssistantGraphMock.mock.calls[0][0]).toHaveProperty('signal'); + }); + + it('handles error when anonymizationFieldsDataClient.findDocuments fails', async () => { + (mockDataClients?.anonymizationFieldsDataClient?.findDocuments as jest.Mock).mockRejectedValue( + new Error('test error') + ); + + await expect(callAssistantGraph(defaultParams)).rejects.toThrow('test error'); + }); + + it('handles error when kbDataClient.isInferenceEndpointExists fails', async () => { + (mockDataClients?.kbDataClient?.isInferenceEndpointExists as jest.Mock).mockRejectedValue( + new Error('test error') + ); + + await expect(callAssistantGraph(defaultParams)).rejects.toThrow('test error'); + }); + + it('returns correct response when no conversationId is returned', async () => { + (invokeGraph as jest.Mock).mockResolvedValue({ output: 'test-output', traceData: {} }); + + const result = await callAssistantGraph(defaultParams); + + expect(result.body).toEqual({ + connector_id: 'test-connector', + data: 'test-output', + trace_data: {}, + replacements: [], + status: 'ok', + }); + }); + + describe('agentRunnable', () => { + it('creates OpenAIToolsAgent for openai llmType', async () => { + const params = { ...defaultParams, llmType: 'openai' }; + await callAssistantGraph(params); + + expect(createOpenAIToolsAgent).toHaveBeenCalled(); + expect(createStructuredChatAgent).not.toHaveBeenCalled(); + expect(createToolCallingAgent).not.toHaveBeenCalled(); + }); + + it('creates OpenAIToolsAgent for inference llmType', async () => { + const params = { ...defaultParams, llmType: 'inference' }; + await callAssistantGraph(params); + + expect(createOpenAIToolsAgent).toHaveBeenCalled(); + expect(createStructuredChatAgent).not.toHaveBeenCalled(); + expect(createToolCallingAgent).not.toHaveBeenCalled(); + }); + + it('creates ToolCallingAgent for bedrock llmType', async () => { + const params = { ...defaultParams, llmType: 'bedrock' }; + await callAssistantGraph(params); + + expect(createToolCallingAgent).toHaveBeenCalled(); + expect(createOpenAIToolsAgent).not.toHaveBeenCalled(); + expect(createStructuredChatAgent).not.toHaveBeenCalled(); + }); + + it('creates ToolCallingAgent for gemini llmType', async () => { + const params = { + ...defaultParams, + request: { + body: { model: 'gemini-1.5-flash' }, + } as unknown as AgentExecutorParams['request'], + llmType: 'gemini', + }; + await callAssistantGraph(params); + + expect(createToolCallingAgent).toHaveBeenCalled(); + expect(createOpenAIToolsAgent).not.toHaveBeenCalled(); + expect(createStructuredChatAgent).not.toHaveBeenCalled(); + }); + + it('creates StructuredChatAgent for oss model', async () => { + const params = { ...defaultParams, llmType: 'openai', isOssModel: true }; + await callAssistantGraph(params); + + expect(createStructuredChatAgent).toHaveBeenCalled(); + expect(createOpenAIToolsAgent).not.toHaveBeenCalled(); + expect(createToolCallingAgent).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index cfcd0f49071b3..2e94e4bcd4ea0 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -8,7 +8,7 @@ import { StructuredTool } from '@langchain/core/tools'; import { getDefaultArguments } from '@kbn/langchain/server'; import { - createOpenAIFunctionsAgent, + createOpenAIToolsAgent, createStructuredChatAgent, createToolCallingAgent, } from 'langchain/agents'; @@ -130,30 +130,31 @@ export const callAssistantGraph: AgentExecutor = async ({ } } - const agentRunnable = isOpenAI - ? await createOpenAIFunctionsAgent({ - llm: createLlmInstance(), - tools, - prompt: formatPrompt(systemPrompts.openai, systemPrompt), - streamRunnable: isStream, - }) - : llmType && ['bedrock', 'gemini'].includes(llmType) - ? await createToolCallingAgent({ - llm: createLlmInstance(), - tools, - prompt: - llmType === 'bedrock' - ? formatPrompt(systemPrompts.bedrock, systemPrompt) - : formatPrompt(systemPrompts.gemini, systemPrompt), - streamRunnable: isStream, - }) - : // used with OSS models - await createStructuredChatAgent({ - llm: createLlmInstance(), - tools, - prompt: formatPromptStructured(systemPrompts.structuredChat, systemPrompt), - streamRunnable: isStream, - }); + const agentRunnable = + isOpenAI || llmType === 'inference' + ? await createOpenAIToolsAgent({ + llm: createLlmInstance(), + tools, + prompt: formatPrompt(systemPrompts.openai, systemPrompt), + streamRunnable: isStream, + }) + : llmType && ['bedrock', 'gemini'].includes(llmType) + ? await createToolCallingAgent({ + llm: createLlmInstance(), + tools, + prompt: + llmType === 'bedrock' + ? formatPrompt(systemPrompts.bedrock, systemPrompt) + : formatPrompt(systemPrompts.gemini, systemPrompt), + streamRunnable: isStream, + }) + : // used with OSS models + await createStructuredChatAgent({ + llm: createLlmInstance(), + tools, + prompt: formatPromptStructured(systemPrompts.structuredChat, systemPrompt), + streamRunnable: isStream, + }); const apmTracer = new APMTracer({ projectName: traceOptions?.projectName ?? 'default' }, logger); const telemetryTracer = telemetryParams diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/utils.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/utils.ts index 4cc213f0e0db8..cb38ea78e27bc 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/utils.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/utils.ts @@ -177,6 +177,7 @@ export const getLlmType = (actionTypeId: string): string | undefined => { [`.gen-ai`]: `openai`, [`.bedrock`]: `bedrock`, [`.gemini`]: `gemini`, + [`.inference`]: `inference`, }; return llmTypeDictionary[actionTypeId]; }; diff --git a/x-pack/solutions/security/plugins/security_solution/public/assistant/provider.tsx b/x-pack/solutions/security/plugins/security_solution/public/assistant/provider.tsx index f4161fccbc1c2..9b7ab890de1d3 100644 --- a/x-pack/solutions/security/plugins/security_solution/public/assistant/provider.tsx +++ b/x-pack/solutions/security/plugins/security_solution/public/assistant/provider.tsx @@ -144,6 +144,16 @@ export const AssistantProvider: FC> = ({ children }) docLinks: { ELASTIC_WEBSITE_URL, DOC_LINK_VERSION }, userProfile, } = useKibana().services; + + let inferenceEnabled = false; + try { + actionTypeRegistry.get('.inference'); + inferenceEnabled = true; + } catch (e) { + // swallow error + // inferenceEnabled will be false + } + const basePath = useBasePath(); const baseConversations = useBaseConversations(); @@ -222,6 +232,7 @@ export const AssistantProvider: FC> = ({ children }) baseConversations={baseConversations} getComments={getComments} http={http} + inferenceEnabled={inferenceEnabled} navigateToApp={navigateToApp} title={ASSISTANT_TITLE} toasts={toasts}