diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index c6efcf64f8b24..842b5f8325c8a 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1001,6 +1001,7 @@ x-pack/platform/packages/shared/kbn-failure-store-modal @elastic/kibana-manageme x-pack/platform/packages/shared/kbn-fs @elastic/kibana-security x-pack/platform/packages/shared/kbn-grok-heuristics @elastic/obs-onboarding-team x-pack/platform/packages/shared/kbn-inference-cli @elastic/appex-ai-infra @elastic/search-inference-team +x-pack/platform/packages/shared/kbn-inference-connectors @elastic/search-kibana x-pack/platform/packages/shared/kbn-inference-endpoint-ui-common @elastic/search-kibana x-pack/platform/packages/shared/kbn-inference-prompt-utils @elastic/appex-ai-infra x-pack/platform/packages/shared/kbn-inference-tracing @elastic/appex-ai-infra diff --git a/package.json b/package.json index 174a671e73a8b..c089430f1d424 100644 --- a/package.json +++ b/package.json @@ -716,6 +716,7 @@ "@kbn/index-patterns-test-plugin": "link:src/platform/test/plugin_functional/plugins/index_patterns", "@kbn/indices-metadata-plugin": "link:x-pack/platform/plugins/private/indices_metadata", "@kbn/inference-common": "link:x-pack/platform/packages/shared/ai-infra/inference-common", + "@kbn/inference-connectors": "link:x-pack/platform/packages/shared/kbn-inference-connectors", "@kbn/inference-endpoint-plugin": "link:x-pack/platform/plugins/shared/inference_endpoint", "@kbn/inference-endpoint-ui-common": "link:x-pack/platform/packages/shared/kbn-inference-endpoint-ui-common", "@kbn/inference-langchain": "link:x-pack/platform/packages/shared/ai-infra/inference-langchain", diff --git a/src/platform/plugins/shared/workflows_extensions/server/steps/ai/utils/resolve_connector_id.test.ts b/src/platform/plugins/shared/workflows_extensions/server/steps/ai/utils/resolve_connector_id.test.ts index e8ed095a88075..f7018cd28e8fb 100644 --- a/src/platform/plugins/shared/workflows_extensions/server/steps/ai/utils/resolve_connector_id.test.ts +++ b/src/platform/plugins/shared/workflows_extensions/server/steps/ai/utils/resolve_connector_id.test.ts @@ -26,6 +26,7 @@ describe('resolveConnectorId', () => { config: {}, capabilities: {}, isInferenceEndpoint: false, + isPreconfigured: false, ...partial, }); diff --git a/tsconfig.base.json b/tsconfig.base.json index f188947687ddc..607d2aef297b3 100644 --- a/tsconfig.base.json +++ b/tsconfig.base.json @@ -1366,6 +1366,8 @@ "@kbn/inference-cli/*": ["x-pack/platform/packages/shared/kbn-inference-cli/*"], "@kbn/inference-common": ["x-pack/platform/packages/shared/ai-infra/inference-common"], "@kbn/inference-common/*": ["x-pack/platform/packages/shared/ai-infra/inference-common/*"], + "@kbn/inference-connectors": ["x-pack/platform/packages/shared/kbn-inference-connectors"], + "@kbn/inference-connectors/*": ["x-pack/platform/packages/shared/kbn-inference-connectors/*"], "@kbn/inference-endpoint-plugin": ["x-pack/platform/plugins/shared/inference_endpoint"], "@kbn/inference-endpoint-plugin/*": ["x-pack/platform/plugins/shared/inference_endpoint/*"], "@kbn/inference-endpoint-ui-common": ["x-pack/platform/packages/shared/kbn-inference-endpoint-ui-common"], diff --git a/x-pack/platform/packages/shared/ai-infra/inference-common/index.ts b/x-pack/platform/packages/shared/ai-infra/inference-common/index.ts index 6461d8e2aa20b..15de78a0124d8 100644 --- a/x-pack/platform/packages/shared/ai-infra/inference-common/index.ts +++ b/x-pack/platform/packages/shared/ai-infra/inference-common/index.ts @@ -150,6 +150,8 @@ export { contextWindowFromModelName, type InferenceConnector, type InferenceConnectorCapabilities, + type RawConnector, + type RawInferenceConnector, } from './src/connectors'; export { defaultInferenceEndpoints, diff --git a/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors/connector_capabilities.test.ts b/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors/connector_capabilities.test.ts index 306abcc5e053a..fb94556788f8a 100644 --- a/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors/connector_capabilities.test.ts +++ b/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors/connector_capabilities.test.ts @@ -19,6 +19,7 @@ const createConnector = (parts: Partial): InferenceConnector config: {}, capabilities: {}, isInferenceEndpoint: false, + isPreconfigured: false, ...parts, }; }; diff --git a/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors/connector_config.test.ts b/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors/connector_config.test.ts index a82f2cc57a7db..536a5cc5c3bc5 100644 --- a/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors/connector_config.test.ts +++ b/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors/connector_config.test.ts @@ -16,6 +16,7 @@ const createConnector = (parts: Partial): InferenceConnector connectorId: 'connectorId', config: {}, isInferenceEndpoint: false, + isPreconfigured: false, capabilities: {}, ...parts, }; diff --git a/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors/connector_to_inference.ts b/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors/connector_to_inference.ts index e71407f4554ff..3a7cfc260827c 100644 --- a/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors/connector_to_inference.ts +++ b/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors/connector_to_inference.ts @@ -30,6 +30,7 @@ export const connectorToInference = (connector: RawConnector): InferenceConnecto config: connector.config ?? {}, capabilities: {}, isInferenceEndpoint: false, + isPreconfigured: connector.isPreconfigured ?? false, }; inferenceConnector.capabilities.contextWindowSize = getContextWindowSize(inferenceConnector); diff --git a/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors/connectors.ts b/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors/connectors.ts index b0b97c749de33..b5c4280fec4f3 100644 --- a/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors/connectors.ts +++ b/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors/connectors.ts @@ -41,6 +41,12 @@ export interface InferenceConnector { * rather than a Kibana stack connector. `connectorId` holds the inference endpoint ID. */ isInferenceEndpoint: boolean; + /** + * When true, this connector is preconfigured (i.e. managed by Elastic). + * For native inference endpoints this is determined by the presence of + * `metadata.display.name` on the underlying ES inference endpoint. + */ + isPreconfigured: boolean; } export interface InferenceConnectorCapabilities { @@ -60,6 +66,7 @@ export interface RawConnector { actionTypeId: string; name: string; config?: Record; + isPreconfigured?: boolean; } export interface RawInferenceConnector { @@ -67,4 +74,5 @@ export interface RawInferenceConnector { actionTypeId: InferenceConnectorType; name: string; config?: Record; + isPreconfigured?: boolean; } diff --git a/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors/index.ts b/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors/index.ts index db7252f644444..f4dd143976680 100644 --- a/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors/index.ts +++ b/x-pack/platform/packages/shared/ai-infra/inference-common/src/connectors/index.ts @@ -18,6 +18,8 @@ export { InferenceConnectorType, type InferenceConnector, type InferenceConnectorCapabilities, + type RawConnector, + type RawInferenceConnector, } from './connectors'; export { getModelDefinition } from './known_models'; export { getContextWindowSize, contextWindowFromModelName } from './connector_capabilities'; diff --git a/x-pack/platform/packages/shared/ai-infra/inference-langchain/src/chat_model/inference_chat_model.test.ts b/x-pack/platform/packages/shared/ai-infra/inference-langchain/src/chat_model/inference_chat_model.test.ts index 3c2c826614b57..7a1298bf06830 100644 --- a/x-pack/platform/packages/shared/ai-infra/inference-langchain/src/chat_model/inference_chat_model.test.ts +++ b/x-pack/platform/packages/shared/ai-infra/inference-langchain/src/chat_model/inference_chat_model.test.ts @@ -40,6 +40,7 @@ const createConnector = (parts: Partial = {}): InferenceConn config: {}, capabilities: {}, isInferenceEndpoint: false, + isPreconfigured: false, ...parts, }; }; diff --git a/x-pack/platform/packages/shared/kbn-ai-assistant/src/chat/chat_actions_menu.tsx b/x-pack/platform/packages/shared/kbn-ai-assistant/src/chat/chat_actions_menu.tsx index f1298ac472b1b..e519f065c30b1 100644 --- a/x-pack/platform/packages/shared/kbn-ai-assistant/src/chat/chat_actions_menu.tsx +++ b/x-pack/platform/packages/shared/kbn-ai-assistant/src/chat/chat_actions_menu.tsx @@ -77,8 +77,8 @@ export function ChatActionsMenu({ if (!connectors.connectors) return [[], []]; return connectors.connectors.reduce( - ([pre, custom], { id, name, isPreconfigured }) => { - const item = { value: id, label: name }; + ([pre, custom], { connectorId, name, isPreconfigured }) => { + const item = { value: connectorId, label: name }; return isPreconfigured ? [[...pre, item], custom] : [pre, [...custom, item]]; }, [[], []] @@ -128,7 +128,11 @@ export function ChatActionsMenu({ defaultMessage: 'Connector', })}{' '} - {connectors.connectors?.find(({ id }) => id === connectors.selectedConnector)?.name} + { + connectors.connectors?.find( + ({ connectorId }) => connectorId === connectors.selectedConnector + )?.name + } ), diff --git a/x-pack/platform/packages/shared/kbn-ai-assistant/src/chat/chat_body.tsx b/x-pack/platform/packages/shared/kbn-ai-assistant/src/chat/chat_body.tsx index 10b4c8a353d41..e8561a4b84128 100644 --- a/x-pack/platform/packages/shared/kbn-ai-assistant/src/chat/chat_body.tsx +++ b/x-pack/platform/packages/shared/kbn-ai-assistant/src/chat/chat_body.tsx @@ -34,7 +34,6 @@ import { type ChatActionClickPayload, type Feedback, aiAssistantSimulatedFunctionCalling, - getElasticManagedLlmConnector, InferenceModelState, } from '@kbn/observability-ai-assistant-plugin/public'; import type { AuthenticatedUser } from '@kbn/security-plugin/common'; @@ -413,13 +412,13 @@ export function ChatBody({ conversation.refresh(); }; - const elasticManagedLlm = getElasticManagedLlmConnector(connectors.connectors); const { conversationCalloutDismissed } = useElasticLlmCalloutsStatus(false); const showElasticLlmCalloutInChat = - !!elasticManagedLlm && - connectors.selectedConnector === elasticManagedLlm.id && - !conversationCalloutDismissed; + (connectors.connectors || []).some( + (connector) => + connector.connectorId === connectors.selectedConnector && connector.isPreconfigured + ) && !conversationCalloutDismissed; const showKnowledgeBaseReIndexingCallout = knowledgeBase.status.value?.enabled === true && diff --git a/x-pack/platform/packages/shared/kbn-ai-assistant/src/chat/elastic_llm_conversation_callout.tsx b/x-pack/platform/packages/shared/kbn-ai-assistant/src/chat/elastic_llm_conversation_callout.tsx index 3812af2afdad6..9a596672d8821 100644 --- a/x-pack/platform/packages/shared/kbn-ai-assistant/src/chat/elastic_llm_conversation_callout.tsx +++ b/x-pack/platform/packages/shared/kbn-ai-assistant/src/chat/elastic_llm_conversation_callout.tsx @@ -46,7 +46,7 @@ export const ElasticLlmConversationCallout = () => { onDismiss={onDismiss} iconType="info" title={i18n.translate('xpack.aiAssistant.elasticLlmCallout.title', { - defaultMessage: `You're using the Elastic Managed LLM connector`, + defaultMessage: `You're using an Elastic Managed LLM connector`, })} size="s" className={elasticLlmCalloutClassName} @@ -54,7 +54,7 @@ export const ElasticLlmConversationCallout = () => {

( ({ useAssistantContext: jest.fn(), +})); + +jest.mock('@kbn/inference-connectors', () => ({ useLoadConnectors: jest.fn(), })); diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/api/chat_complete/use_chat_complete.ts b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/api/chat_complete/use_chat_complete.ts index 2bba137a37cb3..b66c4ccb0d732 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/api/chat_complete/use_chat_complete.ts +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/api/chat_complete/use_chat_complete.ts @@ -8,9 +8,10 @@ import { useCallback, useMemo, useRef, useState } from 'react'; import type { PromptIds, Replacements } from '@kbn/elastic-assistant-common'; import type { HttpFetchQuery } from '@kbn/core-http-browser'; +import { useLoadConnectors } from '@kbn/inference-connectors'; import type { ChatCompleteResponse } from './post_chat_complete'; import { postChatComplete } from './post_chat_complete'; -import { useAssistantContext, useLoadConnectors } from '../../../..'; +import { useAssistantContext } from '../../../..'; interface SendMessageProps { message: string; @@ -30,7 +31,11 @@ export const useChatComplete = ({ connectorId }: { connectorId: string }): UseCh const { alertsIndexPattern, http, traceOptions, settings } = useAssistantContext(); const [isLoading, setIsLoading] = useState(false); const abortController = useRef(new AbortController()); - const { data: connectors } = useLoadConnectors({ http, inferenceEnabled: true, settings }); + const { data: connectors } = useLoadConnectors({ + http, + featureId: 'elastic_assistant', + settings, + }); const actionTypeId = useMemo( () => connectors?.find(({ id }) => id === connectorId)?.actionTypeId ?? '.gen-ai', [connectors, connectorId] diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/assistant_header/index.test.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/assistant_header/index.test.tsx index a58c50a3e3876..666236d5a97c7 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/assistant_header/index.test.tsx +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/assistant_header/index.test.tsx @@ -11,7 +11,7 @@ import { act, fireEvent, render, screen, within } from '@testing-library/react'; import { AssistantHeader } from '.'; import { TestProviders } from '../../mock/test_providers/test_providers'; import { alertConvo, welcomeConvo } from '../../mock/conversation'; -import { useLoadConnectors } from '../../connectorland/use_load_connectors'; +import { useLoadConnectors } from '@kbn/inference-connectors'; import { mockConnectors } from '../../mock/connectors'; import { CLOSE } from './translations'; import { ConversationSharedState } from '@kbn/elastic-assistant-common'; @@ -55,7 +55,7 @@ const testProps = { setCurrentConversation: jest.fn(), }; -jest.mock('../../connectorland/use_load_connectors', () => ({ +jest.mock('@kbn/inference-connectors', () => ({ useLoadConnectors: jest.fn(() => { return { data: [], diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/conversations/conversation_settings/conversation_settings_editor.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/conversations/conversation_settings/conversation_settings_editor.tsx index b02455d792dfa..90167d4d159a2 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/conversations/conversation_settings/conversation_settings_editor.tsx +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/conversations/conversation_settings/conversation_settings_editor.tsx @@ -16,6 +16,7 @@ import { getCurrentConversationOwner, ConversationSharedState, } from '@kbn/elastic-assistant-common'; +import { useLoadConnectors } from '@kbn/inference-connectors'; import { ShareSelect } from '../../share_conversation/share_select'; import { useAssistantContext, type Conversation } from '../../../..'; import * as i18n from './translations'; @@ -25,7 +26,6 @@ import type { AIConnector } from '../../../connectorland/connector_selector'; import { ConnectorSelector } from '../../../connectorland/connector_selector'; import { SelectSystemPrompt } from '../../prompt_editor/system_prompt/select_system_prompt'; import { ModelSelector } from '../../../connectorland/models/model_selector/model_selector'; -import { useLoadConnectors } from '../../../connectorland/use_load_connectors'; import { getGenAiConfig } from '../../../connectorland/helpers'; import type { ConversationsBulkActions } from '../../api'; import { getDefaultSystemPrompt } from '../../use_conversation/helpers'; @@ -60,6 +60,7 @@ export const ConversationSettingsEditor: React.FC = ({ // Connector details const { data: connectors, isFetchedAfterMount: isFetchedConnectors } = useLoadConnectors({ http, + featureId: 'elastic_assistant', settings, }); const defaultConnector = useMemo( diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/assistant_settings.test.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/assistant_settings.test.tsx index b423c117185ff..90f037dbe74b1 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/assistant_settings.test.tsx +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/assistant_settings.test.tsx @@ -76,7 +76,7 @@ const testProps = { setPaginationObserver: jest.fn(), }; jest.mock('../../assistant_context'); -jest.mock('../../..', () => ({ +jest.mock('@kbn/inference-connectors', () => ({ useLoadConnectors: jest.fn(() => { return { data: [], diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/assistant_settings.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/assistant_settings.tsx index 26b914289b108..a3eca78b8903d 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/assistant_settings.tsx +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/assistant_settings.tsx @@ -18,10 +18,10 @@ import { import styled from 'styled-components'; import { css } from '@emotion/react'; +import { useLoadConnectors } from '@kbn/inference-connectors'; import { useConversationsUpdater } from './use_settings_updater/use_conversations_updater'; import type { AIConnector } from '../../connectorland/connector_selector'; import type { Conversation } from '../../..'; -import { useLoadConnectors } from '../../..'; import * as i18n from './translations'; import { useAssistantContext } from '../../assistant_context'; import { TEST_IDS } from '../constants'; @@ -86,6 +86,7 @@ export const AssistantSettings: React.FC = React.memo( const { data: connectors } = useLoadConnectors({ http, + featureId: 'elastic_assistant', settings, }); const { diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/assistant_settings_management.test.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/assistant_settings_management.test.tsx index 03fbbdf00bc2f..6366927502acf 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/assistant_settings_management.test.tsx +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/assistant_settings_management.test.tsx @@ -99,7 +99,7 @@ jest.mock('.', () => { }; }); -jest.mock('../../connectorland/use_load_connectors', () => ({ +jest.mock('@kbn/inference-connectors', () => ({ useLoadConnectors: jest.fn().mockReturnValue({ data: [] }), })); diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/assistant_settings_management.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/assistant_settings_management.tsx index 72564d5cb83bd..1560af93ba10c 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/assistant_settings_management.tsx +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/assistant_settings_management.tsx @@ -16,9 +16,9 @@ import { } from '@elastic/eui'; import { css } from '@emotion/react'; import type { DataViewsContract } from '@kbn/data-views-plugin/public'; +import { useLoadConnectors } from '@kbn/inference-connectors'; import * as i18n from './translations'; import { useAssistantContext } from '../../assistant_context'; -import { useLoadConnectors } from '../../connectorland/use_load_connectors'; import { getDefaultConnector } from '../helpers'; import { ConversationSettingsManagement } from '../conversations/conversation_settings_management'; import { QuickPromptSettingsManagement } from '../quick_prompts/quick_prompt_settings_management'; @@ -69,6 +69,7 @@ export const AssistantSettingsManagement: React.FC = React.memo( const { data: connectors } = useLoadConnectors({ http, + featureId: 'elastic_assistant', settings, }); const defaultConnector = useMemo( diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/evaluation_settings/evaluation_settings.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/evaluation_settings/evaluation_settings.tsx index abfd31135e5d1..c3b49012f6562 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/evaluation_settings/evaluation_settings.tsx +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/evaluation_settings/evaluation_settings.tsx @@ -34,10 +34,10 @@ import type { import { isEmpty } from 'lodash/fp'; import moment from 'moment'; +import { useLoadConnectors } from '@kbn/inference-connectors'; import * as i18n from './translations'; import { useAssistantContext } from '../../../assistant_context'; import { DEFAULT_ATTACK_DISCOVERY_MAX_ALERTS } from '../../../assistant_context/constants'; -import { useLoadConnectors } from '../../../connectorland/use_load_connectors'; import { getActionTypeTitle, getGenAiConfig } from '../../../connectorland/helpers'; import { PRECONFIGURED_CONNECTOR } from '../../../connectorland/translations'; import { usePerformEvaluation } from '../../api/evaluate/use_perform_evaluation'; @@ -51,7 +51,11 @@ const AS_PLAIN_TEXT: EuiComboBoxSingleSelectionShape = { asPlainText: true }; export const EvaluationSettings: React.FC = React.memo(() => { const { actionTypeRegistry, http, setTraceOptions, toasts, traceOptions, settings } = useAssistantContext(); - const { data: connectors } = useLoadConnectors({ http, inferenceEnabled: true, settings }); + const { data: connectors } = useLoadConnectors({ + http, + featureId: 'elastic_assistant', + settings, + }); const { mutate: performEvaluation, isLoading: isPerformingEvaluation } = usePerformEvaluation({ http, toasts, diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/search_ai_lake_configurations_settings_management.test.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/search_ai_lake_configurations_settings_management.test.tsx index 38b7fb291a823..f1f83614e0605 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/search_ai_lake_configurations_settings_management.test.tsx +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/search_ai_lake_configurations_settings_management.test.tsx @@ -96,7 +96,7 @@ jest.mock('.', () => { }; }); -jest.mock('../../connectorland/use_load_connectors', () => ({ +jest.mock('@kbn/inference-connectors', () => ({ useLoadConnectors: jest.fn().mockReturnValue({ data: [] }), })); diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/search_ai_lake_configurations_settings_management.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/search_ai_lake_configurations_settings_management.tsx index e8fc6611f5c13..3ef7c24ef68a8 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/search_ai_lake_configurations_settings_management.tsx +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/settings/search_ai_lake_configurations_settings_management.tsx @@ -15,10 +15,10 @@ import { } from '@elastic/eui'; import { css } from '@emotion/react'; import type { DataViewsContract } from '@kbn/data-views-plugin/public'; +import { useLoadConnectors } from '@kbn/inference-connectors'; import { AIForSOCConnectorSettingsManagement } from '../../connectorland/ai_for_soc_connector_settings_management'; import * as i18n from './translations'; import { useAssistantContext } from '../../assistant_context'; -import { useLoadConnectors } from '../../connectorland/use_load_connectors'; import { getDefaultConnector } from '../helpers'; import { ConversationSettingsManagement } from '../conversations/conversation_settings_management'; import { QuickPromptSettingsManagement } from '../quick_prompts/quick_prompt_settings_management'; @@ -67,6 +67,7 @@ export const SearchAILakeConfigurationsSettingsManagement: React.FC = Rea const { data: connectors } = useLoadConnectors({ http, + featureId: 'elastic_assistant', settings, }); const defaultConnector = useMemo( diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/use_assistant_overlay/index.test.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/use_assistant_overlay/index.test.tsx index 142381a19a278..b3bfc314b5d1a 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/use_assistant_overlay/index.test.tsx +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/use_assistant_overlay/index.test.tsx @@ -36,7 +36,7 @@ jest.mock('../use_conversation', () => { }); jest.mock('../../connectorland/helpers'); -jest.mock('../../connectorland/use_load_connectors', () => { +jest.mock('@kbn/inference-connectors', () => { return { useLoadConnectors: jest.fn(() => ({ data: mockConnectors, diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/connector_selector/index.test.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/connector_selector/index.test.tsx index caf846c7e9411..37d57cc7b76c9 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/connector_selector/index.test.tsx +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/connector_selector/index.test.tsx @@ -11,7 +11,7 @@ import { act, fireEvent, render, screen } from '@testing-library/react'; import { mockAssistantAvailability, TestProviders } from '../../mock/test_providers/test_providers'; import { mockActionTypes, mockConnectors } from '../../mock/connectors'; import * as i18n from '../translations'; -import { useLoadConnectors } from '../use_load_connectors'; +import { useLoadConnectors } from '@kbn/inference-connectors'; import { createMockUseLoadConnectorsResult } from '../../mock/test_helpers'; const onConnectorSelectionChange = jest.fn(); @@ -27,7 +27,7 @@ const connectorTwo = mockConnectors[1]; const mockRefetchConnectors = jest.fn(); -jest.mock('../use_load_connectors', () => ({ +jest.mock('@kbn/inference-connectors', () => ({ useLoadConnectors: jest.fn(), })); diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/connector_selector/index.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/connector_selector/index.tsx index 050db71fe5d0f..88bed8daf8c8a 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/connector_selector/index.tsx +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/connector_selector/index.tsx @@ -28,8 +28,8 @@ import type { OpenAiProviderType } from '@kbn/connector-schemas/openai'; import { some } from 'lodash'; import type { AttackDiscoveryStats } from '@kbn/elastic-assistant-common'; import { GEN_AI_SETTINGS_DEFAULT_AI_CONNECTOR } from '@kbn/management-settings-ids'; +import { useLoadConnectors } from '@kbn/inference-connectors'; import { AttackDiscoveryStatusIndicator } from './attack_discovery_status_indicator'; -import { useLoadConnectors } from '../use_load_connectors'; import * as i18n from '../translations'; import { useLoadActionTypes } from '../use_load_action_types'; import { useAssistantContext } from '../../assistant_context'; @@ -87,14 +87,8 @@ export const ConnectorSelector: React.FC = React.memo( stats = null, explicitConnectorSelection, }) => { - const { - actionTypeRegistry, - http, - assistantAvailability, - inferenceEnabled, - settings, - navigateToApp, - } = useAssistantContext(); + const { actionTypeRegistry, http, assistantAvailability, settings, navigateToApp } = + useAssistantContext(); const { euiTheme } = useEuiTheme(); const [isConnectorModalVisible, setIsConnectorModalVisible] = useState(false); @@ -105,7 +99,7 @@ export const ConnectorSelector: React.FC = React.memo( const { data: aiConnectors, refetch: refetchConnectors } = useLoadConnectors({ http, - inferenceEnabled, + featureId: 'elastic_assistant', settings, }); diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/connector_selector_inline/connector_selector_inline.test.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/connector_selector_inline/connector_selector_inline.test.tsx index 5de875f87858d..87cda8ef0d89b 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/connector_selector_inline/connector_selector_inline.test.tsx +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/connector_selector_inline/connector_selector_inline.test.tsx @@ -12,7 +12,7 @@ import { TestProviders } from '../../mock/test_providers/test_providers'; import { mockConnectors } from '../../mock/connectors'; import { ConnectorSelectorInline } from './connector_selector_inline'; import type { Conversation } from '../../..'; -import { useLoadConnectors } from '../use_load_connectors'; +import { useLoadConnectors } from '@kbn/inference-connectors'; import { MOCK_CURRENT_USER } from '../../assistant/use_conversation/sample_conversations'; const setApiConfig = jest.fn(); @@ -38,7 +38,7 @@ jest.mock('@kbn/triggers-actions-ui-plugin/public/common/constants', () => ({ }), })); -jest.mock('../use_load_connectors', () => ({ +jest.mock('@kbn/inference-connectors', () => ({ useLoadConnectors: jest.fn(() => { return { data: mockConnectors, diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/connector_setup/index.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/connector_setup/index.tsx index 6b0e4b1d0125c..69a0bd5be55ba 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/connector_setup/index.tsx +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/connector_setup/index.tsx @@ -9,13 +9,13 @@ import React, { useCallback, useMemo, useState } from 'react'; import type { ActionConnector } from '@kbn/triggers-actions-ui-plugin/public/common/constants'; import type { ActionType } from '@kbn/triggers-actions-ui-plugin/public'; +import { useLoadConnectors } from '@kbn/inference-connectors'; import { AddConnectorModal } from '../add_connector_modal'; import { WELCOME_CONVERSATION } from '../../assistant/use_conversation/sample_conversations'; import type { Conversation } from '../../..'; import { useLoadActionTypes } from '../use_load_action_types'; import { useConversation } from '../../assistant/use_conversation'; import { useAssistantContext } from '../../assistant_context'; -import { useLoadConnectors } from '../use_load_connectors'; import { getGenAiConfig } from '../helpers'; export interface ConnectorSetupProps { @@ -35,12 +35,15 @@ export const ConnectorSetup = ({ ); const { setApiConfig } = useConversation(); // Access all conversations so we can add connector to all on initial setup - const { actionTypeRegistry, assistantAvailability, http, inferenceEnabled, settings } = - useAssistantContext(); + const { actionTypeRegistry, assistantAvailability, http, settings } = useAssistantContext(); const isMissingConnectorPrivileges = !assistantAvailability.hasConnectorsAllPrivilege; - const { refetch: refetchConnectors } = useLoadConnectors({ http, inferenceEnabled, settings }); + const { refetch: refetchConnectors } = useLoadConnectors({ + http, + featureId: 'elastic_assistant', + settings, + }); const { data: actionTypes } = useLoadActionTypes({ http }); diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/use_load_connectors/index.test.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/use_load_connectors/index.test.tsx deleted file mode 100644 index 5f85c68e9cc2e..0000000000000 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/use_load_connectors/index.test.tsx +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import { waitFor, renderHook } from '@testing-library/react'; -import type { Props } from '.'; -import { useLoadConnectors } from '.'; -import { mockConnectors } from '../../mock/connectors'; -import { TestProviders } from '../../mock/test_providers/test_providers'; -import { - GEN_AI_SETTINGS_DEFAULT_AI_CONNECTOR, - GEN_AI_SETTINGS_DEFAULT_AI_CONNECTOR_DEFAULT_ONLY, -} from '@kbn/management-settings-ids'; - -const mockConnectorsAndExtras = [ - ...mockConnectors, - // These connectors are not supported for inference - { - ...mockConnectors[0], - id: 'connector-missing-secrets', - name: 'Connector Missing Secrets', - isMissingSecrets: true, - }, - { - ...mockConnectors[0], - - id: 'connector-wrong-action-type', - name: 'Connector Wrong Action Type', - isMissingSecrets: true, - actionTypeId: '.d3', - }, - { - ...mockConnectors[0], - id: 'connector-text-embedding', - name: 'Text Embedding Connector', - isMissingSecrets: false, - actionTypeId: '.inference', - config: { - taskType: 'text_embedding', - }, - }, -]; - -const connectorsApiResponse = mockConnectorsAndExtras.map((c) => ({ - ...c, - connector_type_id: c.actionTypeId, - is_preconfigured: false, - is_deprecated: false, - referenced_by_count: 0, - is_missing_secrets: c.isMissingSecrets, - is_system_action: false, -})); - -const loadConnectorsResult = mockConnectors.map((c) => ({ - ...c, - isPreconfigured: false, - isDeprecated: false, - referencedByCount: 0, - isSystemAction: false, -})); - -const http = { - get: jest.fn().mockResolvedValue(connectorsApiResponse), -}; -const toasts = { - addError: jest.fn(), -}; -const settings = { - client: { - get: jest.fn().mockImplementation((settingKey) => { - if (settingKey === GEN_AI_SETTINGS_DEFAULT_AI_CONNECTOR) { - return undefined; - } - if (settingKey === GEN_AI_SETTINGS_DEFAULT_AI_CONNECTOR_DEFAULT_ONLY) { - return false; - } - }), - }, -}; -const defaultProps = { http, toasts, settings } as unknown as Props; - -describe('useLoadConnectors', () => { - beforeEach(() => { - jest.clearAllMocks(); - }); - it('should call api to load action types', async () => { - renderHook(() => useLoadConnectors(defaultProps), { - wrapper: TestProviders, - }); - await waitFor(() => { - expect(defaultProps.http.get).toHaveBeenCalledWith('/api/actions/connectors'); - expect(toasts.addError).not.toHaveBeenCalled(); - }); - }); - - it('should return sorted action types, removing isMissingSecrets and wrong action type ids, excluding .inference results', async () => { - const { result } = renderHook(() => useLoadConnectors(defaultProps), { - wrapper: TestProviders, - }); - await waitFor(() => { - expect(result.current.data).toStrictEqual( - loadConnectorsResult - .filter((c) => c.actionTypeId !== '.inference') - // @ts-ignore ts does not like config, but we define it in the mock data - .map((c) => ({ ...c, apiProvider: c.config.apiProvider })) - ); - }); - }); - - it('includes preconfigured .inference results when inferenceEnabled is true', async () => { - const { result } = renderHook( - () => useLoadConnectors({ ...defaultProps, inferenceEnabled: true }), - { - wrapper: TestProviders, - } - ); - await waitFor(() => { - expect(result.current.data).toStrictEqual( - mockConnectors - // @ts-ignore ts does not like config, but we define it in the mock data - .map((c) => ({ ...c, referencedByCount: 0, apiProvider: c?.config?.apiProvider })) - ); - }); - }); - it('should display error toast when api throws error', async () => { - const mockHttp = { - get: jest.fn().mockRejectedValue(new Error('this is an error')), - } as unknown as Props['http']; - renderHook(() => useLoadConnectors({ ...defaultProps, http: mockHttp }), { - wrapper: TestProviders, - }); - await waitFor(() => expect(toasts.addError).toHaveBeenCalled()); - }); - - it('should filter out .inference connectors without chat_completion taskType', async () => { - const { result } = renderHook( - () => useLoadConnectors({ ...defaultProps, inferenceEnabled: true }), - { - wrapper: TestProviders, - } - ); - await waitFor(() => { - const connectorIds = result.current.data?.map((c) => c.id) || []; - - expect(connectorIds).not.toContain('connector-text-embedding'); - expect(connectorIds).not.toContain('text-embedding-connector-id'); - expect(connectorIds).not.toContain('sparse-embedding-connector-id'); - expect(connectorIds).toContain('c29c28a0-20fe-11ee-9386-a1f4d42ec542'); // Regular Inference Connector - expect(connectorIds).toContain('connectorId'); // OpenAI connector - }); - }); -}); diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/use_load_connectors/index.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/use_load_connectors/index.tsx deleted file mode 100644 index 1c24dbe52bd9d..0000000000000 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/connectorland/use_load_connectors/index.tsx +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -import { useEffect } from 'react'; -import type { UseQueryResult } from '@kbn/react-query'; -import { useQuery } from '@kbn/react-query'; -import type { ServerError } from '@kbn/cases-plugin/public/types'; -import { loadAllActions as loadConnectors } from '@kbn/triggers-actions-ui-plugin/public/common/constants'; -import type { IHttpFetchError, HttpSetup } from '@kbn/core-http-browser'; -import type { IToasts } from '@kbn/core-notifications-browser'; -import type { OpenAiProviderType } from '@kbn/connector-schemas/openai'; -import type { SettingsStart } from '@kbn/core-ui-settings-browser'; -import { isSupportedConnector } from '@kbn/inference-common'; -import { getAvailableAiConnectors } from '@kbn/elastic-assistant-common/impl/connectors/get_available_connectors'; -import type { AIConnector } from '../connector_selector'; -import * as i18n from '../translations'; -/** - * Cache expiration in ms -- 1 minute, useful if connector is deleted/access removed - */ -// const STALE_TIME = 1000 * 60; -const QUERY_KEY = ['elastic-assistant, load-connectors']; - -export interface Props { - http: HttpSetup; - toasts?: IToasts; - inferenceEnabled?: boolean; - settings: SettingsStart; -} - -const actionTypes = ['.bedrock', '.gen-ai', '.gemini']; - -export const useLoadConnectors = ({ - http, - toasts, - inferenceEnabled = false, - settings, -}: Props): UseQueryResult => { - useEffect(() => { - if (inferenceEnabled && !actionTypes.includes('.inference')) { - actionTypes.push('.inference'); - } - }, [inferenceEnabled]); - - return useQuery( - QUERY_KEY, - async () => { - const connectors = await loadConnectors({ http }); - - const allAiConnectors = connectors.flatMap((connector) => { - if ( - !connector.isMissingSecrets && - actionTypes.includes(connector.actionTypeId) && - isSupportedConnector(connector) - ) { - const aiConnector: AIConnector = { - ...connector, - apiProvider: - !connector.isPreconfigured && - !connector.isSystemAction && - connector?.config?.apiProvider - ? (connector?.config?.apiProvider as OpenAiProviderType) - : undefined, - }; - return [aiConnector]; - } - return []; - }); - - return getAvailableAiConnectors({ - allAiConnectors, - settings, - }); - }, - { - retry: false, - keepPreviousData: true, - onError: (error: ServerError) => { - if (error.name !== 'AbortError') { - toasts?.addError( - error.body && error.body.message ? new Error(error.body.message) : error, - { - title: i18n.LOAD_CONNECTORS_ERROR_MESSAGE, - } - ); - } - }, - } - ); -}; diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/index.ts b/x-pack/platform/packages/shared/kbn-elastic-assistant/index.ts index 8a53ee807f8d8..63065b64647d6 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/index.ts +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/index.ts @@ -114,7 +114,6 @@ export { } from './impl/assistant_context/constants'; export type { AIConnector } from './impl/connectorland/connector_selector'; -export { useLoadConnectors } from './impl/connectorland/use_load_connectors'; export type { /** for rendering results in a code block */ diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/moon.yml b/x-pack/platform/packages/shared/kbn-elastic-assistant/moon.yml index cc74ebcd63f09..fdf0933ba4285 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/moon.yml +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/moon.yml @@ -57,6 +57,7 @@ dependsOn: - '@kbn/core-analytics-browser' - '@kbn/shared-ux-utility' - '@kbn/shared-ux-ai-components' + - '@kbn/inference-connectors' tags: - shared-browser - package diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/tsconfig.json b/x-pack/platform/packages/shared/kbn-elastic-assistant/tsconfig.json index 4d9c7b7a5bc25..5bf11266da9fb 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/tsconfig.json +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/tsconfig.json @@ -56,6 +56,7 @@ "@kbn/kibana-react-plugin", "@kbn/core-analytics-browser", "@kbn/shared-ux-utility", - "@kbn/shared-ux-ai-components" + "@kbn/shared-ux-ai-components", + "@kbn/inference-connectors" ] } diff --git a/x-pack/platform/packages/shared/kbn-evals-phoenix-executor/src/with_phoenix_executor.ts b/x-pack/platform/packages/shared/kbn-evals-phoenix-executor/src/with_phoenix_executor.ts index 71fa32495547d..e3bc5e216fd0f 100644 --- a/x-pack/platform/packages/shared/kbn-evals-phoenix-executor/src/with_phoenix_executor.ts +++ b/x-pack/platform/packages/shared/kbn-evals-phoenix-executor/src/with_phoenix_executor.ts @@ -26,6 +26,7 @@ function buildModelFromConnector(connectorWithId: AvailableConnectorWithId): Mod connectorId: connectorWithId.id, name: connectorWithId.name, capabilities: { contextWindowSize: 32000 }, + isPreconfigured: false, isInferenceEndpoint: false, }; diff --git a/x-pack/platform/packages/shared/kbn-evals/src/evaluate.ts b/x-pack/platform/packages/shared/kbn-evals/src/evaluate.ts index 67741b0182d3e..012094ef4b745 100644 --- a/x-pack/platform/packages/shared/kbn-evals/src/evaluate.ts +++ b/x-pack/platform/packages/shared/kbn-evals/src/evaluate.ts @@ -247,6 +247,7 @@ export const evaluate = base.extend<{}, EvaluationSpecificWorkerFixtures>({ config: connectorWithId.config, connectorId: connectorWithId.id, name: connectorWithId.name, + isPreconfigured: false, isInferenceEndpoint: false, capabilities: { contextWindowSize: 32000, diff --git a/x-pack/platform/packages/shared/kbn-inference-connectors/index.ts b/x-pack/platform/packages/shared/kbn-inference-connectors/index.ts new file mode 100644 index 0000000000000..624dea64bb6bd --- /dev/null +++ b/x-pack/platform/packages/shared/kbn-inference-connectors/index.ts @@ -0,0 +1,10 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export { useLoadConnectors } from './src/use_load_connectors'; +export type { UseLoadConnectorsProps } from './src/use_load_connectors'; +export type { AIConnector } from './src/types'; diff --git a/x-pack/platform/packages/shared/kbn-inference-connectors/jest.config.js b/x-pack/platform/packages/shared/kbn-inference-connectors/jest.config.js new file mode 100644 index 0000000000000..95f7aaf1366ab --- /dev/null +++ b/x-pack/platform/packages/shared/kbn-inference-connectors/jest.config.js @@ -0,0 +1,12 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +module.exports = { + preset: '@kbn/test', + rootDir: '../../../../..', + roots: ['/x-pack/platform/packages/shared/kbn-inference-connectors'], +}; diff --git a/x-pack/platform/packages/shared/kbn-inference-connectors/kibana.jsonc b/x-pack/platform/packages/shared/kbn-inference-connectors/kibana.jsonc new file mode 100644 index 0000000000000..3f9e21b986235 --- /dev/null +++ b/x-pack/platform/packages/shared/kbn-inference-connectors/kibana.jsonc @@ -0,0 +1,9 @@ +{ + "type": "shared-browser", + "id": "@kbn/inference-connectors", + "owner": [ + "@elastic/search-kibana" + ], + "group": "platform", + "visibility": "shared" +} diff --git a/x-pack/platform/packages/shared/kbn-inference-connectors/moon.yml b/x-pack/platform/packages/shared/kbn-inference-connectors/moon.yml new file mode 100644 index 0000000000000..f61e2d914ae98 --- /dev/null +++ b/x-pack/platform/packages/shared/kbn-inference-connectors/moon.yml @@ -0,0 +1,63 @@ +# This file is generated by the @kbn/moon package. Any manual edits will be erased! +# To extend this, write your extensions/overrides to 'moon.extend.yml' +# then regenerate this file with: 'node scripts/regenerate_moon_projects.js --update --filter @kbn/inference-connectors' + +$schema: https://moonrepo.dev/schemas/project.json +id: '@kbn/inference-connectors' +layer: unknown +owners: + defaultOwner: '@elastic/search-kibana' +toolchains: + default: node +language: typescript +project: + title: '@kbn/inference-connectors' + description: Moon project for @kbn/inference-connectors + channel: '' + owner: '@elastic/search-kibana' + sourceRoot: x-pack/platform/packages/shared/kbn-inference-connectors +dependsOn: + - '@kbn/i18n' + - '@kbn/react-query' + - '@kbn/core-http-browser' + - '@kbn/core-notifications-browser' + - '@kbn/core-ui-settings-browser' + - '@kbn/connector-schemas' + - '@kbn/inference-common' + - '@kbn/management-settings-ids' + - '@kbn/alerts-ui-shared' +tags: + - shared-browser + - package + - prod + - group-platform + - shared + - jest-unit-tests +fileGroups: + src: + - '**/*.ts' + - '**/*.tsx' + - '!target/**/*' +tasks: + jest: + command: node + args: + - '--no-experimental-require-module' + - $workspaceRoot/scripts/jest + - '--config' + - $projectRoot/jest.config.js + options: + runFromWorkspaceRoot: true + inputs: + - '@group(src)' + jestCI: + command: node + args: + - '--no-experimental-require-module' + - $workspaceRoot/scripts/jest + - '--config' + - $projectRoot/jest.config.js + options: + runFromWorkspaceRoot: true + inputs: + - '@group(src)' diff --git a/x-pack/platform/packages/shared/kbn-inference-connectors/package.json b/x-pack/platform/packages/shared/kbn-inference-connectors/package.json new file mode 100644 index 0000000000000..8f3d296972e6e --- /dev/null +++ b/x-pack/platform/packages/shared/kbn-inference-connectors/package.json @@ -0,0 +1,7 @@ +{ + "name": "@kbn/inference-connectors", + "private": true, + "version": "1.0.0", + "license": "Elastic License 2.0", + "sideEffects": false +} diff --git a/x-pack/platform/packages/shared/kbn-inference-connectors/src/types.ts b/x-pack/platform/packages/shared/kbn-inference-connectors/src/types.ts new file mode 100644 index 0000000000000..f9bfd2c447bc4 --- /dev/null +++ b/x-pack/platform/packages/shared/kbn-inference-connectors/src/types.ts @@ -0,0 +1,14 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { ActionConnector } from '@kbn/alerts-ui-shared/src/common/types'; +import type { OpenAiProviderType } from '@kbn/connector-schemas/openai'; + +export type AIConnector = ActionConnector & { + // related to OpenAI connectors, ex: Azure OpenAI, OpenAI + apiProvider?: OpenAiProviderType; +}; diff --git a/x-pack/platform/packages/shared/kbn-inference-connectors/src/use_load_connectors.ts b/x-pack/platform/packages/shared/kbn-inference-connectors/src/use_load_connectors.ts new file mode 100644 index 0000000000000..05614a364bfce --- /dev/null +++ b/x-pack/platform/packages/shared/kbn-inference-connectors/src/use_load_connectors.ts @@ -0,0 +1,109 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { UseQueryResult } from '@kbn/react-query'; +import { useQuery } from '@kbn/react-query'; +import type { IHttpFetchError, HttpSetup } from '@kbn/core-http-browser'; +import type { IToasts } from '@kbn/core-notifications-browser'; +import type { OpenAiProviderType } from '@kbn/connector-schemas/openai'; +import type { SettingsStart } from '@kbn/core-ui-settings-browser'; +import { type InferenceConnector } from '@kbn/inference-common'; +import { + GEN_AI_SETTINGS_DEFAULT_AI_CONNECTOR, + GEN_AI_SETTINGS_DEFAULT_AI_CONNECTOR_DEFAULT_ONLY, +} from '@kbn/management-settings-ids'; +import { i18n } from '@kbn/i18n'; +import type { AIConnector } from './types'; + +const INFERENCE_CONNECTORS_PATH = '/internal/inference/connectors'; +const QUERY_KEY = ['kbn-inference-connectors', 'load-connectors']; + +export interface UseLoadConnectorsProps { + http: HttpSetup; + toasts?: IToasts; + /** + * Feature identifier used to scope which inference endpoints are relevant. + * Reserved for future filtering logic. + */ + featureId: string; + settings: SettingsStart; +} + +const toAIConnector = (connector: InferenceConnector): AIConnector => + ({ + id: connector.connectorId, + name: connector.name, + actionTypeId: connector.type, + config: connector.config, + secrets: {}, + isPreconfigured: connector.isPreconfigured, + isSystemAction: false, + isDeprecated: false, + isConnectorTypeDeprecated: false, + isMissingSecrets: false, + apiProvider: + !connector.isPreconfigured && connector.config?.apiProvider + ? (connector.config.apiProvider as OpenAiProviderType) + : undefined, + } as AIConnector); + +const applyConnectorSettings = ( + allConnectors: T[], + settings: SettingsStart +): T[] => { + const defaultConnectorId = settings.client.get(GEN_AI_SETTINGS_DEFAULT_AI_CONNECTOR); + const defaultConnectorOnly = settings.client.get( + GEN_AI_SETTINGS_DEFAULT_AI_CONNECTOR_DEFAULT_ONLY, + false + ); + + if (defaultConnectorOnly && defaultConnectorId) { + const connector = allConnectors.find((c) => c.id === defaultConnectorId); + return connector ? [connector] : allConnectors; + } + return allConnectors; +}; + +const fetchAllConnectors = async (http: HttpSetup): Promise => { + const { connectors } = await http.get<{ connectors: InferenceConnector[] }>( + INFERENCE_CONNECTORS_PATH + ); + return connectors; +}; + +export const useLoadConnectors = ({ + http, + toasts, + featureId, + settings, +}: UseLoadConnectorsProps): UseQueryResult => { + return useQuery( + [...QUERY_KEY, featureId], + async () => { + const connectors = await fetchAllConnectors(http); + return applyConnectorSettings(connectors.map(toAIConnector), settings); + }, + { + retry: false, + keepPreviousData: true, + onError: (error: IHttpFetchError) => { + if (error.name !== 'AbortError') { + toasts?.addError( + error.body && (error.body as { message?: string }).message + ? new Error((error.body as { message: string }).message) + : error, + { + title: i18n.translate('inferenceConnectors.useLoadConnectors.errorMessage', { + defaultMessage: 'Error loading connectors', + }), + } + ); + } + }, + } + ); +}; diff --git a/x-pack/platform/packages/shared/kbn-inference-connectors/tsconfig.json b/x-pack/platform/packages/shared/kbn-inference-connectors/tsconfig.json new file mode 100644 index 0000000000000..9bf5f00da017f --- /dev/null +++ b/x-pack/platform/packages/shared/kbn-inference-connectors/tsconfig.json @@ -0,0 +1,30 @@ +{ + "extends": "@kbn/tsconfig-base/tsconfig.json", + "compilerOptions": { + "outDir": "target/types", + "types": [ + "jest", + "node", + "react", + "@kbn/ambient-ui-types" + ] + }, + "include": [ + "**/*.ts", + "**/*.tsx" + ], + "exclude": [ + "target/**/*" + ], + "kbn_references": [ + "@kbn/i18n", + "@kbn/react-query", + "@kbn/core-http-browser", + "@kbn/core-notifications-browser", + "@kbn/core-ui-settings-browser", + "@kbn/connector-schemas", + "@kbn/inference-common", + "@kbn/management-settings-ids", + "@kbn/alerts-ui-shared" + ] +} diff --git a/x-pack/platform/packages/shared/kbn-langchain/moon.yml b/x-pack/platform/packages/shared/kbn-langchain/moon.yml index 105e4ffa0d8e4..fa429f1f54db0 100644 --- a/x-pack/platform/packages/shared/kbn-langchain/moon.yml +++ b/x-pack/platform/packages/shared/kbn-langchain/moon.yml @@ -20,6 +20,7 @@ dependsOn: - '@kbn/core' - '@kbn/logging' - '@kbn/actions-plugin' + - '@kbn/inference-common' - '@kbn/logging-mocks' - '@kbn/utility-types' - '@kbn/tooling-log' diff --git a/x-pack/platform/packages/shared/kbn-langchain/server/language_models/llm.test.ts b/x-pack/platform/packages/shared/kbn-langchain/server/language_models/llm.test.ts index 0f37ee77d53b7..aefe051ce4c86 100644 --- a/x-pack/platform/packages/shared/kbn-langchain/server/language_models/llm.test.ts +++ b/x-pack/platform/packages/shared/kbn-langchain/server/language_models/llm.test.ts @@ -7,6 +7,7 @@ import { loggerMock } from '@kbn/logging-mocks'; import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock'; +import type { InferenceClient } from '@kbn/inference-common'; import { ActionsClientLlm } from './llm'; import { mockActionResponse } from './mocks'; @@ -166,5 +167,86 @@ describe('ActionsClientLlm', () => { 'ActionsClientLlm: content should be a string, but it had an unexpected type: number' ); }); + + describe('isInferenceEndpoint', () => { + const mockInferenceClient = { + chatComplete: jest.fn(), + } as unknown as InferenceClient; + + beforeEach(() => { + (mockInferenceClient.chatComplete as jest.Mock).mockResolvedValue({ + content: 'Hello, world', + }); + }); + + it('calls inferenceClient.chatComplete with the correct arguments', async () => { + const actionsClientLlm = new ActionsClientLlm({ + actionsClient, + connectorId, + inferenceClient: mockInferenceClient, + isInferenceEndpoint: true, + logger: mockLogger, + }); + + const result = await actionsClientLlm._call(prompt); + + expect(mockInferenceClient.chatComplete).toHaveBeenCalledWith({ + connectorId, + messages: [{ role: 'user', content: prompt }], + temperature: undefined, + modelName: undefined, + timeout: undefined, + }); + expect(result).toEqual('Hello, world'); + expect(actionsClient.execute).not.toHaveBeenCalled(); + }); + + it('passes model and temperature when provided', async () => { + const actionsClientLlm = new ActionsClientLlm({ + actionsClient, + connectorId, + inferenceClient: mockInferenceClient, + isInferenceEndpoint: true, + model: 'my-model', + temperature: 0.5, + logger: mockLogger, + }); + + await actionsClientLlm._call(prompt); + + expect(mockInferenceClient.chatComplete).toHaveBeenCalledWith( + expect.objectContaining({ modelName: 'my-model', temperature: 0.5 }) + ); + }); + + it('rejects when inferenceClient is not provided', async () => { + const actionsClientLlm = new ActionsClientLlm({ + actionsClient, + connectorId, + isInferenceEndpoint: true, + logger: mockLogger, + }); + + await expect(actionsClientLlm._call(prompt)).rejects.toThrowError( + 'ActionsClientLlm: inferenceClient is required when isInferenceEndpoint is true' + ); + }); + + it('propagates errors from inferenceClient.chatComplete', async () => { + (mockInferenceClient.chatComplete as jest.Mock).mockRejectedValue( + new Error('quota exceeded') + ); + + const actionsClientLlm = new ActionsClientLlm({ + actionsClient, + connectorId, + inferenceClient: mockInferenceClient, + isInferenceEndpoint: true, + logger: mockLogger, + }); + + await expect(actionsClientLlm._call(prompt)).rejects.toThrowError('quota exceeded'); + }); + }); }); }); diff --git a/x-pack/platform/packages/shared/kbn-langchain/server/language_models/llm.ts b/x-pack/platform/packages/shared/kbn-langchain/server/language_models/llm.ts index bea3c094c6803..b99c438a6444d 100644 --- a/x-pack/platform/packages/shared/kbn-langchain/server/language_models/llm.ts +++ b/x-pack/platform/packages/shared/kbn-langchain/server/language_models/llm.ts @@ -12,6 +12,8 @@ import { get } from 'lodash/fp'; import { v4 as uuidv4 } from 'uuid'; import type { PublicMethodsOf } from '@kbn/utility-types'; import type { TelemetryMetadata } from '@kbn/actions-plugin/server/lib'; +import type { InferenceClient } from '@kbn/inference-common'; +import { MessageRole } from '@kbn/inference-common'; import { DEFAULT_TIMEOUT, getDefaultArguments } from './constants'; import { getMessageContentAndRole } from './helpers'; @@ -22,6 +24,8 @@ const LLM_TYPE = 'ActionsClientLlm'; interface ActionsClientLlmParams { actionsClient: PublicMethodsOf; connectorId: string; + inferenceClient?: InferenceClient; + isInferenceEndpoint?: boolean; llmType?: string; logger: Logger; model?: string; @@ -35,6 +39,8 @@ interface ActionsClientLlmParams { export class ActionsClientLlm extends LLM { #actionsClient: PublicMethodsOf; #connectorId: string; + #inferenceClient?: InferenceClient; + #isInferenceEndpoint: boolean; #logger: Logger; #traceId: string; #timeout?: number; @@ -50,6 +56,8 @@ export class ActionsClientLlm extends LLM { constructor({ actionsClient, connectorId, + inferenceClient, + isInferenceEndpoint = false, traceId = uuidv4(), llmType, logger, @@ -65,6 +73,8 @@ export class ActionsClientLlm extends LLM { this.#actionsClient = actionsClient; this.#connectorId = connectorId; + this.#inferenceClient = inferenceClient; + this.#isInferenceEndpoint = isInferenceEndpoint; this.#traceId = traceId; this.llmType = llmType ?? LLM_TYPE; this.#logger = logger; @@ -95,6 +105,24 @@ export class ActionsClientLlm extends LLM { )} ` ); + if (this.#isInferenceEndpoint) { + if (!this.#inferenceClient) { + throw new Error( + `${LLM_TYPE}: inferenceClient is required when isInferenceEndpoint is true` + ); + } + + const result = await this.#inferenceClient.chatComplete({ + connectorId: this.#connectorId, + messages: [{ role: MessageRole.User, content: prompt }], + temperature: this.temperature, + modelName: this.model, + timeout: this.#timeout, + }); + + return result.content; + } + // create a new connector request body with the assistant message: const requestBody = { actionId: this.#connectorId, diff --git a/x-pack/platform/packages/shared/kbn-langchain/tsconfig.json b/x-pack/platform/packages/shared/kbn-langchain/tsconfig.json index 3c39d6714d28b..5b99d13079666 100644 --- a/x-pack/platform/packages/shared/kbn-langchain/tsconfig.json +++ b/x-pack/platform/packages/shared/kbn-langchain/tsconfig.json @@ -17,6 +17,7 @@ "@kbn/core", "@kbn/logging", "@kbn/actions-plugin", + "@kbn/inference-common", "@kbn/logging-mocks", "@kbn/utility-types", "@kbn/tooling-log" diff --git a/x-pack/platform/plugins/private/gen_ai_settings/moon.yml b/x-pack/platform/plugins/private/gen_ai_settings/moon.yml index df8e0cea38d77..426c0a3fcd3a1 100644 --- a/x-pack/platform/plugins/private/gen_ai_settings/moon.yml +++ b/x-pack/platform/plugins/private/gen_ai_settings/moon.yml @@ -53,6 +53,7 @@ dependsOn: - '@kbn/workflows' - '@kbn/agent-builder-browser' - '@kbn/agent-builder-plugin' + - '@kbn/zod' tags: - plugin - prod diff --git a/x-pack/platform/plugins/private/gen_ai_settings/public/components/default_ai_connector/default_ai_connector.test.tsx b/x-pack/platform/plugins/private/gen_ai_settings/public/components/default_ai_connector/default_ai_connector.test.tsx index 48b33b5931501..420d4dc732456 100644 --- a/x-pack/platform/plugins/private/gen_ai_settings/public/components/default_ai_connector/default_ai_connector.test.tsx +++ b/x-pack/platform/plugins/private/gen_ai_settings/public/components/default_ai_connector/default_ai_connector.test.tsx @@ -13,7 +13,7 @@ import { QueryClient, QueryClientProvider } from '@kbn/react-query'; import { I18nProvider } from '@kbn/i18n-react'; import userEvent from '@testing-library/user-event'; import { KibanaContextProvider } from '@kbn/kibana-react-plugin/public'; -import { createMockConnectorFindResult } from '@kbn/actions-plugin/server/application/connector/mocks'; +import { InferenceConnectorType } from '@kbn/inference-common'; function SettingsProbe({ onValue }: { onValue: (v: any) => void }) { const value = useSettingsContext(); @@ -27,19 +27,24 @@ const mockConnectors = { loading: false, reload: jest.fn(), connectors: [ - createMockConnectorFindResult({ - actionTypeId: 'pre-configured.1', - id: 'pre-configured1', - isPreconfigured: true, + { + connectorId: 'pre-configured1', name: 'Pre configured Connector', - referencedByCount: 0, - }), - createMockConnectorFindResult({ - actionTypeId: 'custom.1', - id: 'custom1', + type: InferenceConnectorType.OpenAI, + config: {}, + capabilities: {}, + isPreconfigured: true, + isInferenceEndpoint: false, + }, + { + connectorId: 'custom1', name: 'Custom Connector 1', - referencedByCount: 0, - }), + type: InferenceConnectorType.OpenAI, + config: {}, + capabilities: {}, + isPreconfigured: false, + isInferenceEndpoint: false, + }, ], }; @@ -48,6 +53,8 @@ interface TestWrapperProps { canSaveAdvancedSettings?: boolean; } +const mockGenAiSettingsApi = jest.fn().mockResolvedValue({}); + function TestWrapper({ children, canSaveAdvancedSettings = true }: TestWrapperProps) { const queryClient = new QueryClient(); @@ -66,6 +73,7 @@ function TestWrapper({ children, canSaveAdvancedSettings = true }: TestWrapperPr addDanger: jest.fn(), }, }, + genAiSettingsApi: mockGenAiSettingsApi, }} > @@ -270,15 +278,15 @@ describe('DefaultAIConnector', () => { loading: false, reload: jest.fn(), connectors: [ - createMockConnectorFindResult({ - actionTypeId: 'custom.1', - id: 'custom1', - isDeprecated: false, - isPreconfigured: false, - isSystemAction: false, + { + connectorId: 'custom1', name: 'Custom Connector 1', - referencedByCount: 0, - }), + type: InferenceConnectorType.OpenAI, + config: {}, + capabilities: {}, + isPreconfigured: false, + isInferenceEndpoint: false, + }, ], }; diff --git a/x-pack/platform/plugins/private/gen_ai_settings/public/components/default_ai_connector/default_ai_connector.tsx b/x-pack/platform/plugins/private/gen_ai_settings/public/components/default_ai_connector/default_ai_connector.tsx index c4501220fee81..3880f39889998 100644 --- a/x-pack/platform/plugins/private/gen_ai_settings/public/components/default_ai_connector/default_ai_connector.tsx +++ b/x-pack/platform/plugins/private/gen_ai_settings/public/components/default_ai_connector/default_ai_connector.tsx @@ -23,6 +23,7 @@ import { import type { FieldDefinition, UnsavedFieldChange } from '@kbn/management-settings-types'; import type { UiSettingsType } from '@kbn/core/public'; import { i18n } from '@kbn/i18n'; +import { useConnectorExists } from '../../hooks/use_connector_exists'; import type { UseGenAiConnectorsResult } from '../../hooks/use_genai_connectors'; import { useFieldSettingsContext, type ValidationError } from '../../contexts/settings_context'; import { NO_DEFAULT_CONNECTOR } from '../../../common/constants'; @@ -44,19 +45,16 @@ const NoDefaultOption: EuiComboBoxOptionOption = { const validateDefaultAIConnector = ( unsavedChanges: Record>, fields: Record>, - connectors: UseGenAiConnectorsResult + connectorExists: boolean, + connectorExistsLoading: boolean ): ValidationError[] => { const defaultLlmValue = getDefaultLlmValue(unsavedChanges, fields); const defaultLlmOnlyValue = getDefaultLlmOnlyValue(unsavedChanges, fields); const errors: ValidationError[] = []; - // Check if selected connector exists - const selectedConnectorExists = - connectors.connectors?.some((connector) => connector.id === defaultLlmValue) || - defaultLlmValue === NO_DEFAULT_CONNECTOR; - - if (!selectedConnectorExists && !connectors.loading) { + // Check if selected connector exists (via direct getConnectorById, not the deduped list) + if (!connectorExists && !connectorExistsLoading && defaultLlmValue !== NO_DEFAULT_CONNECTOR) { errors.push({ message: i18n.translate( 'xpack.gen_ai_settings.settings.defaultLLm.select.error.selectedDefaultLlmDoesNotExist.message', @@ -92,7 +90,7 @@ const getOptions = (connectors: UseGenAiConnectorsResult): EuiComboBoxOptionOpti ?.filter((connector) => connector.isPreconfigured) .map((connector) => ({ label: connector.name, - value: connector.id, + value: connector.connectorId, })) ?? []; const custom = @@ -100,7 +98,7 @@ const getOptions = (connectors: UseGenAiConnectorsResult): EuiComboBoxOptionOpti ?.filter((connector) => !connector.isPreconfigured) .map((connector) => ({ label: connector.name, - value: connector.id, + value: connector.connectorId, })) ?? []; return [ @@ -162,16 +160,28 @@ export const DefaultAIConnector: React.FC = ({ connectors }) => { const canEditAdvancedSettings = application.capabilities.advancedSettings?.save; + const defaultLlmValue = getDefaultLlmValue(unsavedChanges, fields); + + // Check existence via direct getConnectorById to avoid false negatives from the deduped list + const { exists: connectorExists, loading: connectorExistsLoading } = + useConnectorExists(defaultLlmValue); + // Calculate and set validation errors automatically React.useEffect(() => { - const errors = validateDefaultAIConnector(unsavedChanges, fields, connectors); + const errors = validateDefaultAIConnector( + unsavedChanges, + fields, + connectorExists, + connectorExistsLoading + ); setValidationErrors(errors); - }, [unsavedChanges, fields, connectors, setValidationErrors]); + }, [unsavedChanges, fields, connectorExists, connectorExistsLoading, setValidationErrors]); // Get current validation errors for inline display const validationErrors = useMemo( - () => validateDefaultAIConnector(unsavedChanges, fields, connectors), - [unsavedChanges, fields, connectors] + () => + validateDefaultAIConnector(unsavedChanges, fields, connectorExists, connectorExistsLoading), + [unsavedChanges, fields, connectorExists, connectorExistsLoading] ); const onChangeDefaultLlm = (selectedOptions: EuiComboBoxOptionOption[]) => { @@ -218,8 +228,6 @@ export const DefaultAIConnector: React.FC = ({ connectors }) => { }); }; - const defaultLlmValue = getDefaultLlmValue(unsavedChanges, fields); - const selectedOptions = useMemo( () => getOptionsByValues(defaultLlmValue, options), [defaultLlmValue, options] diff --git a/x-pack/platform/plugins/private/gen_ai_settings/public/components/gen_ai_settings_app.tsx b/x-pack/platform/plugins/private/gen_ai_settings/public/components/gen_ai_settings_app.tsx index 15917b4ded701..06c226ad396d6 100644 --- a/x-pack/platform/plugins/private/gen_ai_settings/public/components/gen_ai_settings_app.tsx +++ b/x-pack/platform/plugins/private/gen_ai_settings/public/components/gen_ai_settings_app.tsx @@ -32,7 +32,6 @@ import { useEnabledFeatures } from '../contexts/enabled_features_context'; import { useKibana } from '../hooks/use_kibana'; import { GoToSpacesButton } from './go_to_spaces_button'; import { useGenAiConnectors } from '../hooks/use_genai_connectors'; -import { getElasticManagedLlmConnector } from '../utils/get_elastic_managed_llm_connector'; import { useSettingsContext } from '../contexts/settings_context'; import { DefaultAIConnector } from './default_ai_connector/default_ai_connector'; import { BottomBarActions } from './bottom_bar_actions/bottom_bar_actions'; @@ -83,7 +82,9 @@ export const GenAiSettingsApp: React.FC = ({ setBreadcrum application.capabilities.actions?.save === true; const canManageSpaces = application.capabilities.management.kibana.spaces; const connectors = useGenAiConnectors(); - const hasElasticManagedLlm = getElasticManagedLlmConnector(connectors.connectors); + const hasElasticManagedLlm = (connectors.connectors || []).some( + (connector) => connector.isPreconfigured + ); useEffect(() => { const breadcrumbs = [ diff --git a/x-pack/platform/plugins/private/gen_ai_settings/public/hooks/use_connector_exists.ts b/x-pack/platform/plugins/private/gen_ai_settings/public/hooks/use_connector_exists.ts new file mode 100644 index 0000000000000..d1edde2acff42 --- /dev/null +++ b/x-pack/platform/plugins/private/gen_ai_settings/public/hooks/use_connector_exists.ts @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { useEffect, useState } from 'react'; +import { NO_DEFAULT_CONNECTOR } from '../../common/constants'; +import { useKibana } from './use_kibana'; + +/** + * Checks whether a given connector ID actually exists by fetching it directly + * from the inference plugin. This avoids false negatives from the deduped + * connector list returned by getConnectorList. + */ +export function useConnectorExists(connectorId: string): { + exists: boolean; + loading: boolean; +} { + const { + services: { genAiSettingsApi }, + } = useKibana(); + + const [exists, setExists] = useState(true); + const [loading, setLoading] = useState(false); + + useEffect(() => { + if (!connectorId || connectorId === NO_DEFAULT_CONNECTOR) { + setExists(true); + setLoading(false); + return; + } + + const controller = new AbortController(); + setLoading(true); + + genAiSettingsApi('GET /internal/gen_ai_settings/connectors/{connectorId}', { + signal: controller.signal, + params: { path: { connectorId } }, + }) + .then(() => { + setExists(true); + }) + .catch((e) => { + if (e?.name !== 'AbortError') { + setExists(false); + } + }) + .finally(() => { + setLoading(false); + }); + + return () => controller.abort(); + }, [connectorId, genAiSettingsApi]); + + return { exists, loading }; +} diff --git a/x-pack/platform/plugins/private/gen_ai_settings/public/hooks/use_genai_connectors.ts b/x-pack/platform/plugins/private/gen_ai_settings/public/hooks/use_genai_connectors.ts index 2e6d2ca2a9f6e..c1561d2019b62 100644 --- a/x-pack/platform/plugins/private/gen_ai_settings/public/hooks/use_genai_connectors.ts +++ b/x-pack/platform/plugins/private/gen_ai_settings/public/hooks/use_genai_connectors.ts @@ -6,11 +6,11 @@ */ import { useCallback, useEffect, useState } from 'react'; -import type { FindActionResult } from '@kbn/actions-plugin/server'; +import type { InferenceConnector } from '@kbn/inference-common'; import { useKibana } from './use_kibana'; export interface UseGenAiConnectorsResult { - connectors?: FindActionResult[]; + connectors?: InferenceConnector[]; loading: boolean; error?: Error; reload: () => void; @@ -21,7 +21,7 @@ export function useGenAiConnectors(): UseGenAiConnectorsResult { services: { genAiSettingsApi }, } = useKibana(); - const [connectors, setConnectors] = useState(undefined); + const [connectors, setConnectors] = useState(undefined); const [loading, setLoading] = useState(false); const [error, setError] = useState(undefined); diff --git a/x-pack/platform/plugins/private/gen_ai_settings/public/utils/get_elastic_managed_llm_connector.ts b/x-pack/platform/plugins/private/gen_ai_settings/public/utils/get_elastic_managed_llm_connector.ts deleted file mode 100644 index 9613160e58cac..0000000000000 --- a/x-pack/platform/plugins/private/gen_ai_settings/public/utils/get_elastic_managed_llm_connector.ts +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import type { FindActionResult } from '@kbn/actions-plugin/server'; - -export const INFERENCE_CONNECTOR_ACTION_TYPE_ID = '.inference'; - -export const getElasticManagedLlmConnector = (connectors: FindActionResult[] | undefined) => { - if (!Array.isArray(connectors) || connectors.length === 0) { - return undefined; - } - - return connectors.find( - (connector) => - connector.actionTypeId === INFERENCE_CONNECTOR_ACTION_TYPE_ID && - connector.isPreconfigured && - connector.config?.provider === 'elastic' - ); -}; diff --git a/x-pack/platform/plugins/private/gen_ai_settings/server/routes/connectors/route.ts b/x-pack/platform/plugins/private/gen_ai_settings/server/routes/connectors/route.ts index 48040b7ac1e9b..87ebe563aa5d6 100644 --- a/x-pack/platform/plugins/private/gen_ai_settings/server/routes/connectors/route.ts +++ b/x-pack/platform/plugins/private/gen_ai_settings/server/routes/connectors/route.ts @@ -4,8 +4,8 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { FindActionResult } from '@kbn/actions-plugin/server'; -import { isSupportedConnector } from '@kbn/inference-common'; +import { z } from '@kbn/zod/v4'; +import type { InferenceConnector } from '@kbn/inference-common'; import { createGenAiSettingsServerRoute } from '../create_gen_ai_settings_server_route'; const listConnectorsRoute = createGenAiSettingsServerRoute({ @@ -16,33 +16,36 @@ const listConnectorsRoute = createGenAiSettingsServerRoute({ reason: 'The route is protected by the actions plugin', }, }, - handler: async (resources): Promise => { + handler: async (resources): Promise => { const { request, plugins } = resources; - const actionsClient = await ( - await plugins.actions.start() - ).getActionsClientWithRequest(request); + const inferenceStart = await plugins.inference.start(); - const [availableTypes, connectors] = await Promise.all([ - actionsClient - .listTypes({ - includeSystemActionTypes: false, - }) - .then((types) => - types - .filter((type) => type.enabled && type.enabledInLicense && type.enabledInConfig) - .map((type) => type.id) - ), - actionsClient.getAll(), - ]); + return inferenceStart.getConnectorList(request); + }, +}); + +const getConnectorByIdRoute = createGenAiSettingsServerRoute({ + endpoint: 'GET /internal/gen_ai_settings/connectors/{connectorId}', + security: { + authz: { + enabled: false, + reason: 'The route is protected by the actions plugin', + }, + }, + params: z.object({ + path: z.object({ connectorId: z.string() }), + }), + handler: async (resources): Promise => { + const { request, params, plugins } = resources; + + const inferenceStart = await plugins.inference.start(); - return connectors.filter( - (connector) => - availableTypes.includes(connector.actionTypeId) && isSupportedConnector(connector) - ); + return inferenceStart.getConnectorById(params.path.connectorId, request); }, }); export const connectorRoutes = { ...listConnectorsRoute, + ...getConnectorByIdRoute, }; diff --git a/x-pack/platform/plugins/private/gen_ai_settings/tsconfig.json b/x-pack/platform/plugins/private/gen_ai_settings/tsconfig.json index 119dbb7941c78..91d0293fa0a6b 100644 --- a/x-pack/platform/plugins/private/gen_ai_settings/tsconfig.json +++ b/x-pack/platform/plugins/private/gen_ai_settings/tsconfig.json @@ -41,6 +41,7 @@ "@kbn/workflows", "@kbn/agent-builder-browser", "@kbn/agent-builder-plugin", + "@kbn/zod", ], "exclude": ["target/**/*"] } diff --git a/x-pack/platform/plugins/shared/agent_builder/moon.yml b/x-pack/platform/plugins/shared/agent_builder/moon.yml index 76f35b2a470fa..e945fc57f53bb 100644 --- a/x-pack/platform/plugins/shared/agent_builder/moon.yml +++ b/x-pack/platform/plugins/shared/agent_builder/moon.yml @@ -86,6 +86,7 @@ dependsOn: - '@kbn/licensing-plugin' - '@kbn/lens-embeddable-utils' - '@kbn/elastic-assistant' + - '@kbn/inference-connectors' - '@kbn/react-query' - '@kbn/shared-ux-utility' - '@kbn/usage-collection-plugin' diff --git a/x-pack/platform/plugins/shared/agent_builder/public/application/components/conversations/conversation_input/input_actions/connector_selector/connector_selector.tsx b/x-pack/platform/plugins/shared/agent_builder/public/application/components/conversations/conversation_input/input_actions/connector_selector/connector_selector.tsx index e3cabeeb9fb90..0d2dbc439618b 100644 --- a/x-pack/platform/plugins/shared/agent_builder/public/application/components/conversations/conversation_input/input_actions/connector_selector/connector_selector.tsx +++ b/x-pack/platform/plugins/shared/agent_builder/public/application/components/conversations/conversation_input/input_actions/connector_selector/connector_selector.tsx @@ -15,7 +15,7 @@ import { EuiPopoverFooter, EuiSelectable, } from '@elastic/eui'; -import { useLoadConnectors } from '@kbn/elastic-assistant'; +import { useLoadConnectors } from '@kbn/inference-connectors'; import { i18n } from '@kbn/i18n'; import { FormattedMessage } from '@kbn/i18n-react'; import React, { useEffect, useMemo, useState } from 'react'; @@ -160,8 +160,8 @@ export const ConnectorSelector: React.FC<{}> = () => { const { data: aiConnectors, isLoading } = useLoadConnectors({ http, + featureId: 'agent_builder', settings, - inferenceEnabled: true, }); const connectors = useMemo(() => aiConnectors ?? [], [aiConnectors]); diff --git a/x-pack/platform/plugins/shared/agent_builder/tsconfig.json b/x-pack/platform/plugins/shared/agent_builder/tsconfig.json index efddb8c120d62..9eac179789d0f 100644 --- a/x-pack/platform/plugins/shared/agent_builder/tsconfig.json +++ b/x-pack/platform/plugins/shared/agent_builder/tsconfig.json @@ -83,6 +83,7 @@ "@kbn/licensing-plugin", "@kbn/lens-embeddable-utils", "@kbn/elastic-assistant", + "@kbn/inference-connectors", "@kbn/react-query", "@kbn/shared-ux-utility", "@kbn/usage-collection-plugin", diff --git a/x-pack/platform/plugins/shared/automatic_import/moon.yml b/x-pack/platform/plugins/shared/automatic_import/moon.yml index 6f7715d60bf7a..ec79e177144d0 100644 --- a/x-pack/platform/plugins/shared/automatic_import/moon.yml +++ b/x-pack/platform/plugins/shared/automatic_import/moon.yml @@ -28,6 +28,7 @@ dependsOn: - '@kbn/i18n' - '@kbn/core-http-browser' - '@kbn/elastic-assistant' + - '@kbn/inference-connectors' - '@kbn/kibana-react-plugin' - '@kbn/code-editor' - '@kbn/monaco' diff --git a/x-pack/platform/plugins/shared/automatic_import/public/components/create_integration/create_automatic_import/steps/connector_step/connector_step.test.tsx b/x-pack/platform/plugins/shared/automatic_import/public/components/create_integration/create_automatic_import/steps/connector_step/connector_step.test.tsx index 7d0a1394b648b..5bb6f2cfe71f2 100644 --- a/x-pack/platform/plugins/shared/automatic_import/public/components/create_integration/create_automatic_import/steps/connector_step/connector_step.test.tsx +++ b/x-pack/platform/plugins/shared/automatic_import/public/components/create_integration/create_automatic_import/steps/connector_step/connector_step.test.tsx @@ -26,7 +26,7 @@ const defaultUseMockConnectors: { data: AIConnector[]; isLoading: boolean; refet refetch: jest.fn(), }; const mockUseLoadConnectors = jest.fn(() => defaultUseMockConnectors); -jest.mock('@kbn/elastic-assistant', () => ({ +jest.mock('@kbn/inference-connectors', () => ({ useLoadConnectors: () => mockUseLoadConnectors(), })); diff --git a/x-pack/platform/plugins/shared/automatic_import/public/components/create_integration/create_automatic_import/steps/connector_step/connector_step.tsx b/x-pack/platform/plugins/shared/automatic_import/public/components/create_integration/create_automatic_import/steps/connector_step/connector_step.tsx index 72be16ba090d5..247b9f7e21131 100644 --- a/x-pack/platform/plugins/shared/automatic_import/public/components/create_integration/create_automatic_import/steps/connector_step/connector_step.tsx +++ b/x-pack/platform/plugins/shared/automatic_import/public/components/create_integration/create_automatic_import/steps/connector_step/connector_step.tsx @@ -6,7 +6,7 @@ */ import React, { useCallback, useEffect, useState } from 'react'; -import { useLoadConnectors } from '@kbn/elastic-assistant'; +import { useLoadConnectors } from '@kbn/inference-connectors'; import { EuiForm, EuiFlexGroup, @@ -110,7 +110,12 @@ export const ConnectorStep = React.memo(({ connector }) => { isLoading, data: aiConnectors, refetch: refetchConnectors, - } = useLoadConnectors({ http, toasts: notifications.toasts, inferenceEnabled, settings }); + } = useLoadConnectors({ + http, + toasts: notifications.toasts, + featureId: 'automatic_import', + settings, + }); useEffect(() => { if (aiConnectors != null) { diff --git a/x-pack/platform/plugins/shared/automatic_import/tsconfig.json b/x-pack/platform/plugins/shared/automatic_import/tsconfig.json index 0dd330160a129..f4760906471cc 100644 --- a/x-pack/platform/plugins/shared/automatic_import/tsconfig.json +++ b/x-pack/platform/plugins/shared/automatic_import/tsconfig.json @@ -27,6 +27,7 @@ "@kbn/i18n", "@kbn/core-http-browser", "@kbn/elastic-assistant", + "@kbn/inference-connectors", "@kbn/kibana-react-plugin", "@kbn/code-editor", "@kbn/monaco", diff --git a/x-pack/platform/plugins/shared/automatic_import_v2/moon.yml b/x-pack/platform/plugins/shared/automatic_import_v2/moon.yml index 57e9c79f0f7c0..1a967593a1e84 100644 --- a/x-pack/platform/plugins/shared/automatic_import_v2/moon.yml +++ b/x-pack/platform/plugins/shared/automatic_import_v2/moon.yml @@ -46,6 +46,7 @@ dependsOn: - '@kbn/i18n-react' - '@kbn/es-ui-shared-plugin' - '@kbn/ai-assistant-connector-selector-action' + - '@kbn/inference-connectors' - '@kbn/management-settings-ids' - '@kbn/langchain' - '@kbn/core-lifecycle-browser-mocks' diff --git a/x-pack/platform/plugins/shared/automatic_import_v2/public/common/components/connector_selector.test.tsx b/x-pack/platform/plugins/shared/automatic_import_v2/public/common/components/connector_selector.test.tsx index 72600f38a30ee..f3d6bbcc19413 100644 --- a/x-pack/platform/plugins/shared/automatic_import_v2/public/common/components/connector_selector.test.tsx +++ b/x-pack/platform/plugins/shared/automatic_import_v2/public/common/components/connector_selector.test.tsx @@ -14,7 +14,7 @@ import { triggersActionsUiMock } from '@kbn/triggers-actions-ui-plugin/public/mo import { Form, useForm } from '@kbn/es-ui-shared-plugin/static/forms/hook_form_lib'; import { ConnectorSelector } from './connector_selector'; import { ConnectorSetup } from './connector_setup'; -import { useLoadConnectors } from '..'; +import { useLoadConnectors } from '@kbn/inference-connectors'; const mockConnectors = [ { @@ -49,9 +49,9 @@ const mockActionTypes = [ // Mock the useLoadConnectors hook const mockRefetch = jest.fn(); -jest.mock('../hooks/use_load_connectors', () => ({ +jest.mock('@kbn/inference-connectors', () => ({ useLoadConnectors: jest.fn(() => ({ - connectors: [], + data: [], isLoading: false, refetch: jest.fn(), })), @@ -136,7 +136,7 @@ describe('ConnectorSelector', () => { beforeEach(() => { jest.clearAllMocks(); mockUseLoadConnectors.mockReturnValue({ - connectors: mockConnectors, + data: mockConnectors, isLoading: false, refetch: mockRefetch, }); @@ -150,7 +150,7 @@ describe('ConnectorSelector', () => { it('should show loading spinner when connectors are loading', async () => { mockUseLoadConnectors.mockReturnValue({ - connectors: mockConnectors, + data: mockConnectors, isLoading: true, refetch: mockRefetch, }); @@ -161,7 +161,7 @@ describe('ConnectorSelector', () => { it('should show "Add connector" button when no connectors exist', async () => { mockUseLoadConnectors.mockReturnValue({ - connectors: [], + data: [], isLoading: false, refetch: mockRefetch, }); @@ -200,7 +200,7 @@ describe('ConnectorSelector', () => { it('should select first available connector when no default is set and no Elastic LLM', async () => { mockUseLoadConnectors.mockReturnValue({ - connectors: [mockConnectors[0], mockConnectors[1]], // No Elastic Managed LLM + data: [mockConnectors[0], mockConnectors[1]], // No Elastic Managed LLM isLoading: false, refetch: mockRefetch, }); @@ -230,7 +230,7 @@ describe('ConnectorSelector', () => { describe('connector creation', () => { it('should open connector setup when "Add connector" is clicked with no connectors', async () => { mockUseLoadConnectors.mockReturnValue({ - connectors: [], + data: [], isLoading: false, refetch: mockRefetch, }); diff --git a/x-pack/platform/plugins/shared/automatic_import_v2/public/common/components/connector_selector.tsx b/x-pack/platform/plugins/shared/automatic_import_v2/public/common/components/connector_selector.tsx index 01acc7fb76558..4ea82dd3d3027 100644 --- a/x-pack/platform/plugins/shared/automatic_import_v2/public/common/components/connector_selector.tsx +++ b/x-pack/platform/plugins/shared/automatic_import_v2/public/common/components/connector_selector.tsx @@ -27,7 +27,8 @@ import { GEN_AI_SETTINGS_DEFAULT_AI_CONNECTOR } from '@kbn/management-settings-i import { UseField } from '@kbn/es-ui-shared-plugin/static/forms/hook_form_lib'; import type { FieldHook } from '@kbn/es-ui-shared-plugin/static/forms/hook_form_lib'; import type { ActionConnector } from '@kbn/triggers-actions-ui-plugin/public'; -import { useLoadConnectors, useKibana } from '..'; +import { useLoadConnectors } from '@kbn/inference-connectors'; +import { useKibana } from '..'; import { ConnectorSetup } from './connector_setup'; import * as i18n from './translations'; @@ -282,12 +283,21 @@ export const ConnectorSelector: React.FC = ({ isDisabled = false, displayFancy, }) => { - const { settings, application } = useKibana().services; + const { http, notifications, settings, application } = useKibana().services; const [isPopoverOpen, setIsPopoverOpen] = useState(false); const [isConnectorModalVisible, setIsConnectorModalVisible] = useState(false); - const { connectors, isLoading, refetch } = useLoadConnectors(); + const { + data: connectors, + isLoading, + refetch, + } = useLoadConnectors({ + http, + toasts: notifications.toasts, + featureId: 'automatic_import_v2', + settings, + }); const settingsDefaultConnectorId = settings?.client.get( GEN_AI_SETTINGS_DEFAULT_AI_CONNECTOR diff --git a/x-pack/platform/plugins/shared/automatic_import_v2/public/common/hooks/translations.ts b/x-pack/platform/plugins/shared/automatic_import_v2/public/common/hooks/translations.ts index e1b3ff49d2745..cdaf096311c3f 100644 --- a/x-pack/platform/plugins/shared/automatic_import_v2/public/common/hooks/translations.ts +++ b/x-pack/platform/plugins/shared/automatic_import_v2/public/common/hooks/translations.ts @@ -53,19 +53,6 @@ export const SAVE_INTEGRATION_ERROR = i18n.translate( } ); -export const LOAD_CONNECTORS_ERROR_TITLE = i18n.translate( - 'xpack.automaticImportV2.hooks.loadConnectors.errorTitle', - { - defaultMessage: 'Unable to load connectors', - } -); -export const LOAD_CONNECTORS_ERROR_MESSAGE = i18n.translate( - 'xpack.automaticImportV2.hooks.loadConnectors.errorMessage', - { - defaultMessage: 'Failed to load connectors', - } -); - export const UPLOAD_SAMPLES_SUCCESS = i18n.translate( 'xpack.automaticImportV2.hooks.uploadSamples.success', { diff --git a/x-pack/platform/plugins/shared/automatic_import_v2/public/common/hooks/use_load_connectors.ts b/x-pack/platform/plugins/shared/automatic_import_v2/public/common/hooks/use_load_connectors.ts deleted file mode 100644 index 02211117e0ca3..0000000000000 --- a/x-pack/platform/plugins/shared/automatic_import_v2/public/common/hooks/use_load_connectors.ts +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import { useState, useEffect, useCallback } from 'react'; -import type { ActionConnector } from '@kbn/triggers-actions-ui-plugin/public'; -import { useKibana } from './use_kibana'; -import * as i18n from './translations'; - -const ALLOWED_ACTION_TYPE_IDS = ['.bedrock', '.gen-ai', '.gemini', '.inference']; - -interface UseLoadConnectorsResult { - connectors: ActionConnector[]; - isLoading: boolean; - error: string | undefined; - refetch: () => void; -} - -export const useLoadConnectors = (): UseLoadConnectorsResult => { - const { http, notifications } = useKibana().services; - const [connectors, setConnectors] = useState([]); - const [isLoading, setIsLoading] = useState(true); - const [error, setError] = useState(); - - const fetchConnectors = useCallback(async () => { - try { - setIsLoading(true); - setError(undefined); - - const response = await http.get('/api/actions/connectors'); - - // Filter to only AI-related connectors - const aiConnectors = response.filter((connector) => { - const typeId = connector.actionTypeId || (connector as any).connector_type_id; - return ALLOWED_ACTION_TYPE_IDS.includes(typeId); - }); - - setConnectors(aiConnectors); - } catch (e) { - const errorMessage = e instanceof Error ? e.message : i18n.LOAD_CONNECTORS_ERROR_MESSAGE; - setError(errorMessage); - notifications.toasts.addDanger({ - title: i18n.LOAD_CONNECTORS_ERROR_TITLE, - text: errorMessage, - }); - } finally { - setIsLoading(false); - } - }, [http, notifications.toasts]); - - useEffect(() => { - fetchConnectors(); - }, [fetchConnectors]); - - return { - connectors, - isLoading, - error, - refetch: fetchConnectors, - }; -}; diff --git a/x-pack/platform/plugins/shared/automatic_import_v2/public/common/index.ts b/x-pack/platform/plugins/shared/automatic_import_v2/public/common/index.ts index 78e9cfd303259..e776b7f33410a 100644 --- a/x-pack/platform/plugins/shared/automatic_import_v2/public/common/index.ts +++ b/x-pack/platform/plugins/shared/automatic_import_v2/public/common/index.ts @@ -10,7 +10,6 @@ export { useFetchIndices } from './hooks/use_fetch_indices'; export { useGetIntegrationById } from './hooks/use_get_integration_by_id'; export { useGetAllIntegrations } from './hooks/use_get_all_integrations'; export { useCreateUpdateIntegration } from './hooks/use_create_update_integration'; -export { useLoadConnectors } from './hooks/use_load_connectors'; export { useValidateIndex } from './hooks/use_validate_index'; export { useUploadSamples } from './hooks/use_upload_samples'; export { useDeleteDataStream } from './hooks/use_delete_data_stream'; diff --git a/x-pack/platform/plugins/shared/automatic_import_v2/tsconfig.json b/x-pack/platform/plugins/shared/automatic_import_v2/tsconfig.json index dfced964a7721..273b07a27db93 100644 --- a/x-pack/platform/plugins/shared/automatic_import_v2/tsconfig.json +++ b/x-pack/platform/plugins/shared/automatic_import_v2/tsconfig.json @@ -43,6 +43,7 @@ "@kbn/i18n-react", "@kbn/es-ui-shared-plugin", "@kbn/ai-assistant-connector-selector-action", + "@kbn/inference-connectors", "@kbn/management-settings-ids", "@kbn/langchain", "@kbn/core-lifecycle-browser-mocks", diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.test.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.test.ts index a68c4eb2308ea..ed34b65f290eb 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.test.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.test.ts @@ -43,6 +43,7 @@ describe('bedrockClaudeAdapter', () => { config: {}, capabilities: {}, isInferenceEndpoint: false, + isPreconfigured: false, }); }); diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/gemini/gemini_adapter.test.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/gemini/gemini_adapter.test.ts index aee81feb67812..181c2049c3eb1 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/gemini/gemini_adapter.test.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/gemini/gemini_adapter.test.ts @@ -34,6 +34,7 @@ describe('geminiAdapter', () => { config: {}, capabilities: {}, isInferenceEndpoint: false, + isPreconfigured: false, }); processVertexStreamMock.mockReset().mockImplementation(() => tap(noop)); processVertexResponseMock.mockReset().mockImplementation(() => tap(noop)); diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/inference/inference_adapter.test.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/inference/inference_adapter.test.ts index 24122a61be2e0..05a6f5831ab78 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/inference/inference_adapter.test.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/inference/inference_adapter.test.ts @@ -70,6 +70,7 @@ describe('inferenceAdapter', () => { config: {}, capabilities: {}, isInferenceEndpoint: false, + isPreconfigured: false, }; }); }); diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/function_calling_support.test.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/function_calling_support.test.ts index e5bd87bd057ae..4a1f648bab7f4 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/function_calling_support.test.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/function_calling_support.test.ts @@ -19,6 +19,7 @@ const createConnector = ( config: {}, capabilities: {}, isInferenceEndpoint: false, + isPreconfigured: false, ...parts, }; }; diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/inference_executor.test.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/inference_executor.test.ts index 60bbdf4e4f4d8..3a5a4e984b81a 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/inference_executor.test.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/utils/inference_executor.test.ts @@ -23,6 +23,7 @@ describe('createInferenceExecutor', () => { config: {}, capabilities: {}, isInferenceEndpoint: false, + isPreconfigured: false, }; beforeEach(() => { diff --git a/x-pack/platform/plugins/shared/inference/server/test_utils/inference_connector.ts b/x-pack/platform/plugins/shared/inference/server/test_utils/inference_connector.ts index f54d38e0ff07d..1d55eecb79734 100644 --- a/x-pack/platform/plugins/shared/inference/server/test_utils/inference_connector.ts +++ b/x-pack/platform/plugins/shared/inference/server/test_utils/inference_connector.ts @@ -18,6 +18,7 @@ export const createInferenceConnectorMock = ( config: {}, capabilities: {}, isInferenceEndpoint: false, + isPreconfigured: false, ...parts, }; }; diff --git a/x-pack/platform/plugins/shared/inference/server/util/get_connector_by_id.test.ts b/x-pack/platform/plugins/shared/inference/server/util/get_connector_by_id.test.ts index 68a38a6c1782e..b3bb9cdade933 100644 --- a/x-pack/platform/plugins/shared/inference/server/util/get_connector_by_id.test.ts +++ b/x-pack/platform/plugins/shared/inference/server/util/get_connector_by_id.test.ts @@ -30,6 +30,7 @@ describe('getConnectorById', () => { type: InferenceConnectorType.OpenAI, config: {}, isInferenceEndpoint: false, + isPreconfigured: false, capabilities: {}, ...parts, }); @@ -85,10 +86,38 @@ describe('getConnectorById', () => { expect(result).toEqual(expected); }); + it('resolves a stack connector ID to its superseding inference endpoint', async () => { + const inferenceEndpoint = createMockInferenceConnector({ + connectorId: 'my-inference-id', + name: 'My EIS Endpoint', + type: InferenceConnectorType.Inference, + isInferenceEndpoint: true, + }); + // The filtered list only contains the endpoint, not the stack connector + getConnectorListMock.mockResolvedValue([inferenceEndpoint]); + + const actionsClient = await actions.getActionsClientWithRequest(request); + (actionsClient.getAll as jest.Mock).mockResolvedValue([ + { + id: connectorId, + actionTypeId: InferenceConnectorType.Inference, + name: 'My Stack Connector', + config: { inferenceId: 'my-inference-id' }, + isPreconfigured: true, + }, + ]); + + const result = await getConnectorById({ actions, request, connectorId, esClient, logger }); + + expect(result).toEqual(inferenceEndpoint); + }); + it('throws if no connector matches the id', async () => { getConnectorListMock.mockResolvedValue([ createMockInferenceConnector({ connectorId: 'other' }), ]); + const actionsClient = await actions.getActionsClientWithRequest(request); + (actionsClient.getAll as jest.Mock).mockResolvedValue([]); await expect( getConnectorById({ actions, request, connectorId, esClient, logger }) @@ -97,6 +126,8 @@ describe('getConnectorById', () => { it('throws if the connector list is empty', async () => { getConnectorListMock.mockResolvedValue([]); + const actionsClient = await actions.getActionsClientWithRequest(request); + (actionsClient.getAll as jest.Mock).mockResolvedValue([]); await expect( getConnectorById({ actions, request, connectorId, esClient, logger }) diff --git a/x-pack/platform/plugins/shared/inference/server/util/get_connector_by_id.ts b/x-pack/platform/plugins/shared/inference/server/util/get_connector_by_id.ts index 4b039b41e753d..db9ebf99afe82 100644 --- a/x-pack/platform/plugins/shared/inference/server/util/get_connector_by_id.ts +++ b/x-pack/platform/plugins/shared/inference/server/util/get_connector_by_id.ts @@ -7,11 +7,19 @@ import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; import type { KibanaRequest, ElasticsearchClient, Logger } from '@kbn/core/server'; -import { createInferenceRequestError, type InferenceConnector } from '@kbn/inference-common'; +import { + createInferenceRequestError, + InferenceConnectorType, + type InferenceConnector, +} from '@kbn/inference-common'; import { getConnectorList } from './get_connector_list'; /** * Retrieves a connector or inference endpoint given the provided `connectorId`. + * + * If the `connectorId` matches a preconfigured `.inference` stack connector that has been + * superseded by its underlying inference endpoint (i.e. `getConnectorList` prefers the + * endpoint representation), the corresponding inference endpoint is returned instead. */ export const getConnectorById = async ({ connectorId, @@ -29,12 +37,29 @@ export const getConnectorById = async ({ const connectors = await getConnectorList({ actions, request, esClient, logger }); const match = connectors.find((c) => c.connectorId === connectorId); - if (!match) { - throw createInferenceRequestError( - `No connector or inference endpoint found for ID '${connectorId}'`, - 404 - ); + if (match) { + return match; } - return match; + // The requested ID may belong to a stack `.inference` connector whose underlying inference + // endpoint was already returned in the list under `inferenceId`. Look up the raw stack + // connector to resolve the alias. + const actionClient = await actions.getActionsClientWithRequest(request); + const allStackConnectors = await actionClient.getAll({ includeSystemActions: false }); + const stackConnector = allStackConnectors.find((c) => c.id === connectorId); + + if (stackConnector?.actionTypeId === InferenceConnectorType.Inference) { + const inferenceId = stackConnector.config?.inferenceId as string | undefined; + if (inferenceId) { + const endpointMatch = connectors.find((c) => c.connectorId === inferenceId); + if (endpointMatch) { + return endpointMatch; + } + } + } + + throw createInferenceRequestError( + `No connector or inference endpoint found for ID '${connectorId}'`, + 404 + ); }; diff --git a/x-pack/platform/plugins/shared/inference/server/util/get_connector_list.test.ts b/x-pack/platform/plugins/shared/inference/server/util/get_connector_list.test.ts new file mode 100644 index 0000000000000..bcd0e60ad4541 --- /dev/null +++ b/x-pack/platform/plugins/shared/inference/server/util/get_connector_list.test.ts @@ -0,0 +1,212 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { actionsClientMock, actionsMock } from '@kbn/actions-plugin/server/mocks'; +import { httpServerMock } from '@kbn/core/server/mocks'; +import { loggerMock } from '@kbn/logging-mocks'; +import { InferenceConnectorType } from '@kbn/inference-common'; +import { getConnectorList } from './get_connector_list'; +import { getInferenceEndpoints } from './get_inference_endpoints'; + +jest.mock('./get_inference_endpoints'); + +const getInferenceEndpointsMock = getInferenceEndpoints as jest.MockedFn< + typeof getInferenceEndpoints +>; + +describe('getConnectorList', () => { + let actions: ReturnType; + let actionsClient: ReturnType; + let request: ReturnType; + const esClient = {} as any; + const logger = loggerMock.create(); + + beforeEach(() => { + actions = actionsMock.createStart(); + actionsClient = actionsClientMock.create(); + request = httpServerMock.createKibanaRequest(); + actions.getActionsClientWithRequest.mockResolvedValue(actionsClient); + actionsClient.getAll.mockResolvedValue([]); + getInferenceEndpointsMock.mockResolvedValue([]); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + it('returns stack connectors from the actions plugin', async () => { + actionsClient.getAll.mockResolvedValue([ + { + id: 'connector-1', + actionTypeId: InferenceConnectorType.OpenAI, + name: 'My OpenAI Connector', + config: { apiProvider: 'OpenAI' }, + isPreconfigured: false, + isSystemAction: false, + isDeprecated: false, + referencedByCount: 0, + isMissingSecrets: false, + }, + ] as any); + + const result = await getConnectorList({ actions, request, esClient, logger }); + + expect(result).toHaveLength(1); + expect(result[0]).toMatchObject({ + connectorId: 'connector-1', + name: 'My OpenAI Connector', + type: InferenceConnectorType.OpenAI, + isInferenceEndpoint: false, + }); + }); + + it('returns native inference endpoints with display.name when available', async () => { + getInferenceEndpointsMock.mockResolvedValue([ + { + inferenceId: 'my-endpoint', + taskType: 'chat_completion', + service: 'openai', + serviceSettings: { model_id: 'gpt-4' }, + metadata: { display: { name: 'My Preconfigured Endpoint' } }, + }, + ]); + + const result = await getConnectorList({ actions, request, esClient, logger }); + + expect(result).toHaveLength(1); + expect(result[0]).toMatchObject({ + connectorId: 'my-endpoint', + name: 'My Preconfigured Endpoint', + type: InferenceConnectorType.Inference, + isInferenceEndpoint: true, + isPreconfigured: true, + }); + }); + + it('uses the matching stack connector name when the endpoint has no display.name', async () => { + actionsClient.getAll.mockResolvedValue([ + { + id: 'stack-connector-id', + actionTypeId: InferenceConnectorType.Inference, + name: 'My Named Connector', + config: { inferenceId: 'my-endpoint', taskType: 'chat_completion' }, + isPreconfigured: false, + isSystemAction: false, + isDeprecated: false, + referencedByCount: 0, + isMissingSecrets: false, + }, + ] as any); + + getInferenceEndpointsMock.mockResolvedValue([ + { + inferenceId: 'my-endpoint', + taskType: 'chat_completion', + service: 'openai', + serviceSettings: {}, + metadata: {}, + }, + ]); + + const result = await getConnectorList({ actions, request, esClient, logger }); + + const endpoint = result.find((c) => c.isInferenceEndpoint); + expect(endpoint).toMatchObject({ + connectorId: 'my-endpoint', + name: 'My Named Connector', + isInferenceEndpoint: true, + isPreconfigured: false, + }); + }); + + it('falls back to inferenceId as name when no display.name and no matching stack connector', async () => { + getInferenceEndpointsMock.mockResolvedValue([ + { + inferenceId: 'my-endpoint', + taskType: 'chat_completion', + service: 'openai', + serviceSettings: {}, + metadata: {}, + }, + ]); + + const result = await getConnectorList({ actions, request, esClient, logger }); + + expect(result).toHaveLength(1); + expect(result[0]).toMatchObject({ + connectorId: 'my-endpoint', + name: 'my-endpoint', + isInferenceEndpoint: true, + }); + }); + + it('deduplicates: excludes the .inference stack connector when a matching ES endpoint exists', async () => { + actionsClient.getAll.mockResolvedValue([ + { + id: 'stack-connector-id', + actionTypeId: InferenceConnectorType.Inference, + name: 'My Custom Connector', + config: { inferenceId: 'my-endpoint', taskType: 'chat_completion' }, + isPreconfigured: false, + isSystemAction: false, + isDeprecated: false, + referencedByCount: 0, + isMissingSecrets: false, + }, + ] as any); + + getInferenceEndpointsMock.mockResolvedValue([ + { + inferenceId: 'my-endpoint', + taskType: 'chat_completion', + service: 'openai', + serviceSettings: {}, + metadata: {}, + }, + ]); + + const result = await getConnectorList({ actions, request, esClient, logger }); + + expect(result).toHaveLength(1); + expect(result[0]).toMatchObject({ + connectorId: 'my-endpoint', + isInferenceEndpoint: true, + }); + expect(result.find((c) => !c.isInferenceEndpoint)).toBeUndefined(); + }); + + it('prefers display.name over stack connector name', async () => { + actionsClient.getAll.mockResolvedValue([ + { + id: 'stack-connector-id', + actionTypeId: InferenceConnectorType.Inference, + name: 'Stack Connector Name', + config: { inferenceId: 'my-endpoint', taskType: 'chat_completion' }, + isPreconfigured: false, + isSystemAction: false, + isDeprecated: false, + referencedByCount: 0, + isMissingSecrets: false, + }, + ] as any); + + getInferenceEndpointsMock.mockResolvedValue([ + { + inferenceId: 'my-endpoint', + taskType: 'chat_completion', + service: 'openai', + serviceSettings: {}, + metadata: { display: { name: 'Display Name Takes Priority' } }, + }, + ]); + + const result = await getConnectorList({ actions, request, esClient, logger }); + + const endpoint = result.find((c) => c.isInferenceEndpoint); + expect(endpoint?.name).toBe('Display Name Takes Priority'); + }); +}); diff --git a/x-pack/platform/plugins/shared/inference/server/util/get_connector_list.ts b/x-pack/platform/plugins/shared/inference/server/util/get_connector_list.ts index 3ee357fede1f9..3259115f3818b 100644 --- a/x-pack/platform/plugins/shared/inference/server/util/get_connector_list.ts +++ b/x-pack/platform/plugins/shared/inference/server/util/get_connector_list.ts @@ -50,9 +50,18 @@ export const getConnectorList = async ({ const connectors = connectorsResult.status === 'fulfilled' ? connectorsResult.value : []; const endpoints = endpointsResult.status === 'fulfilled' ? endpointsResult.value : []; + const stackConnectorByInferenceId = new Map( + connectors + .filter((c) => c.type === InferenceConnectorType.Inference) + .map((c) => [c.config?.inferenceId as string, c]) + ); + const inferenceEndpointConnectors: InferenceConnector[] = endpoints.map((ep) => ({ type: InferenceConnectorType.Inference, - name: ep.inferenceId, + name: + ep.metadata.display?.name ?? + stackConnectorByInferenceId.get(ep.inferenceId)?.name ?? + ep.inferenceId, connectorId: ep.inferenceId, config: { inferenceId: ep.inferenceId, @@ -65,9 +74,19 @@ export const getConnectorList = async ({ }, capabilities: {}, isInferenceEndpoint: true, + isPreconfigured: !!ep.metadata.display?.name, })); - return [...connectors, ...inferenceEndpointConnectors]; + // Exclude .inference stack connectors that have a corresponding ES inference endpoint, + // since the endpoint representation is preferred (includes native endpoints too). + const endpointInferenceIds = new Set(endpoints.map((ep) => ep.inferenceId)); + const filteredConnectors = connectors.filter( + (c) => + c.type !== InferenceConnectorType.Inference || + !endpointInferenceIds.has(c.config?.inferenceId as string) + ); + + return [...filteredConnectors, ...inferenceEndpointConnectors]; }; const getStackConnectors = async ({ diff --git a/x-pack/platform/plugins/shared/inference/server/util/get_inference_endpoint_by_id.test.ts b/x-pack/platform/plugins/shared/inference/server/util/get_inference_endpoint_by_id.test.ts index 126eb5463a353..93cca80af358c 100644 --- a/x-pack/platform/plugins/shared/inference/server/util/get_inference_endpoint_by_id.test.ts +++ b/x-pack/platform/plugins/shared/inference/server/util/get_inference_endpoint_by_id.test.ts @@ -38,6 +38,7 @@ describe('getInferenceEndpointById', () => { taskType: 'chat_completion', service: 'openai', serviceSettings: { model_id: 'gpt-4o' }, + metadata: {}, }); expect(mockInferenceGet).toHaveBeenCalledWith( diff --git a/x-pack/platform/plugins/shared/inference/server/util/get_inference_endpoint_by_id.ts b/x-pack/platform/plugins/shared/inference/server/util/get_inference_endpoint_by_id.ts index e4c1ce5bd477d..f7c59fb203c19 100644 --- a/x-pack/platform/plugins/shared/inference/server/util/get_inference_endpoint_by_id.ts +++ b/x-pack/platform/plugins/shared/inference/server/util/get_inference_endpoint_by_id.ts @@ -33,5 +33,6 @@ export const getInferenceEndpointById = async ({ taskType: endpoint.task_type, service: endpoint.service, serviceSettings: endpoint.service_settings as Record | undefined, + metadata: 'metadata' in endpoint ? (endpoint.metadata as Record) : {}, }; }; diff --git a/x-pack/platform/plugins/shared/inference/server/util/get_inference_endpoints.test.ts b/x-pack/platform/plugins/shared/inference/server/util/get_inference_endpoints.test.ts index 3d631f503be9a..2c8a60cc774d3 100644 --- a/x-pack/platform/plugins/shared/inference/server/util/get_inference_endpoints.test.ts +++ b/x-pack/platform/plugins/shared/inference/server/util/get_inference_endpoints.test.ts @@ -41,12 +41,14 @@ describe('getInferenceEndpoints', () => { taskType: 'chat_completion', service: 'openai', serviceSettings: { model_id: 'gpt-4o' }, + metadata: {}, }, { inferenceId: 'ep-2', taskType: 'text_embedding', service: 'elasticsearch', serviceSettings: undefined, + metadata: {}, }, ]); }); diff --git a/x-pack/platform/plugins/shared/inference/server/util/get_inference_endpoints.ts b/x-pack/platform/plugins/shared/inference/server/util/get_inference_endpoints.ts index 2e38d4ac0950c..a95352e49b23e 100644 --- a/x-pack/platform/plugins/shared/inference/server/util/get_inference_endpoints.ts +++ b/x-pack/platform/plugins/shared/inference/server/util/get_inference_endpoints.ts @@ -13,6 +13,13 @@ export interface InferenceEndpoint { taskType: InferenceTaskType; service: string; serviceSettings?: Record; + // Fix this typing when ES response is updated to include metadata + metadata: { + display?: { + name?: string; + creator?: string; + }; + }; } /** @@ -33,6 +40,7 @@ export const getInferenceEndpoints = async ({ taskType: ep.task_type, service: ep.service, serviceSettings: ep.service_settings as Record | undefined, + metadata: 'metadata' in ep ? (ep.metadata as Record) : {}, })); if (taskType) { diff --git a/x-pack/platform/plugins/shared/inference/server/util/inference_endpoint_id_cache.ts b/x-pack/platform/plugins/shared/inference/server/util/inference_endpoint_id_cache.ts index b6a40c5918394..8530ccade01fc 100644 --- a/x-pack/platform/plugins/shared/inference/server/util/inference_endpoint_id_cache.ts +++ b/x-pack/platform/plugins/shared/inference/server/util/inference_endpoint_id_cache.ts @@ -39,7 +39,10 @@ export class InferenceEndpointIdCache { void this.updateCacheIfExpired(); // deleted endpoints are very unlikely so safe to refresh lazily without awaiting return true; } - await this.updateCacheIfExpired(); // id not in cache, make sure we have latest data before returning + // Force a refresh on cache miss regardless of TTL, because new endpoints + // may have been created since the last refresh (e.g. after enabling CCM). + this.invalidate(); + await this.updateCacheIfExpired(); return this.knownIds.has(id); } diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/common/utils/get_inference_connector.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/common/utils/get_inference_connector.ts index 97d3f3349675a..a8ce1b0117786 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/common/utils/get_inference_connector.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/common/utils/get_inference_connector.ts @@ -5,12 +5,13 @@ * 2.0. */ -import type { Connector } from '@kbn/actions-plugin/server'; import { getConnectorProvider, getConnectorFamily, getConnectorModel, connectorToInference, + type InferenceConnector as CommonInferenceConnector, + type RawConnector, type InferenceConnectorType, type ModelFamily, type ModelProvider, @@ -26,12 +27,13 @@ export interface InferenceConnector { } export const getInferenceConnectorInfo = ( - connector?: Connector + connector?: CommonInferenceConnector | RawConnector ): InferenceConnector | undefined => { if (!connector) { return; } - const inferenceConnector = connectorToInference(connector); + const inferenceConnector = + 'connectorId' in connector ? connector : connectorToInference(connector); const modelFamily = getConnectorFamily(inferenceConnector); const modelProvider = getConnectorProvider(inferenceConnector); const modelId = getConnectorModel(inferenceConnector); diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/moon.yml b/x-pack/platform/plugins/shared/observability_ai_assistant/moon.yml index 907221580f931..126a0ca874886 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/moon.yml +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/moon.yml @@ -64,7 +64,6 @@ dependsOn: - '@kbn/licensing-types' - '@kbn/llm-tasks-plugin' - '@kbn/product-doc-base-plugin' - - '@kbn/inference-endpoint-plugin' - '@kbn/spaces-utils' - '@kbn/usage-collection-plugin' - '@kbn/ai-agent-confirmation-modal' diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/public/components/connector_selector/connector_selector_base.stories.tsx b/x-pack/platform/plugins/shared/observability_ai_assistant/public/components/connector_selector/connector_selector_base.stories.tsx index b5d8787811957..a7273d6c7f808 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/public/components/connector_selector/connector_selector_base.stories.tsx +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/public/components/connector_selector/connector_selector_base.stories.tsx @@ -6,7 +6,8 @@ */ import React from 'react'; import type { Meta, StoryObj } from '@storybook/react'; -import type { FindActionResult } from '@kbn/actions-plugin/server'; +import type { InferenceConnector } from '@kbn/inference-common'; +import { InferenceConnectorType } from '@kbn/inference-common'; import type { ComponentProps } from 'react'; import { EuiPanel } from '@elastic/eui'; import { ConnectorSelectorBase as Component } from './connector_selector_base'; @@ -31,9 +32,25 @@ export const Loaded: StoryObj = { loading: false, selectedConnector: 'gpt-4', connectors: [ - { id: 'gpt-4', name: 'OpenAI GPT-4' }, - { id: 'gpt-3.5-turbo', name: 'OpenAI GPT-3.5 Turbo' }, - ] as FindActionResult[], + { + connectorId: 'gpt-4', + name: 'OpenAI GPT-4', + type: InferenceConnectorType.OpenAI, + config: {}, + capabilities: {}, + isInferenceEndpoint: false, + isPreconfigured: false, + }, + { + connectorId: 'gpt-3.5-turbo', + name: 'OpenAI GPT-3.5 Turbo', + type: InferenceConnectorType.OpenAI, + config: {}, + capabilities: {}, + isInferenceEndpoint: false, + isPreconfigured: false, + }, + ] as InferenceConnector[], }, render, }; diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/public/components/connector_selector/connector_selector_base.tsx b/x-pack/platform/plugins/shared/observability_ai_assistant/public/components/connector_selector/connector_selector_base.tsx index ecf64b713f3aa..352d2df5cd42a 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/public/components/connector_selector/connector_selector_base.tsx +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/public/components/connector_selector/connector_selector_base.tsx @@ -96,7 +96,7 @@ export function ConnectorSelectorBase(props: ConnectorSelectorBaseProps) { compressed valueOfSelected={props.selectedConnector} options={props.connectors.map((connector) => ({ - value: connector.id, + value: connector.connectorId, inputDisplay: ( diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/public/components/insight/actions_menu.tsx b/x-pack/platform/plugins/shared/observability_ai_assistant/public/components/insight/actions_menu.tsx index 09abe7d923bf2..aa5f8f2ee9e01 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/public/components/insight/actions_menu.tsx +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/public/components/insight/actions_menu.tsx @@ -42,7 +42,11 @@ export function ActionsMenu({ defaultMessage: 'Connector', })}{' '} - {connectors.connectors?.find(({ id }) => id === connectors.selectedConnector)?.name} + { + connectors.connectors?.find( + ({ connectorId }) => connectorId === connectors.selectedConnector + )?.name + } ), diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/public/components/insight/insight_base.stories.tsx b/x-pack/platform/plugins/shared/observability_ai_assistant/public/components/insight/insight_base.stories.tsx index 891b061f34617..584ad60c5fcc4 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/public/components/insight/insight_base.stories.tsx +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/public/components/insight/insight_base.stories.tsx @@ -7,7 +7,8 @@ import React from 'react'; -import type { FindActionResult } from '@kbn/actions-plugin/server'; +import type { InferenceConnector } from '@kbn/inference-common'; +import { InferenceConnectorType } from '@kbn/inference-common'; import { EuiFlexGroup, EuiFlexItem } from '@elastic/eui'; import type { InsightBaseProps } from './insight_base'; import { InsightBase as Component } from './insight_base'; @@ -36,9 +37,25 @@ const defaultProps: InsightBaseProps = { {}, diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/public/hooks/use_genai_connectors.test.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/public/hooks/use_genai_connectors.test.ts index fc09f93e57bbd..5951a7a5ec67d 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/public/hooks/use_genai_connectors.test.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/public/hooks/use_genai_connectors.test.ts @@ -7,14 +7,14 @@ import { renderHook, waitFor } from '@testing-library/react'; import { useGenAIConnectorsWithoutContext } from './use_genai_connectors'; -import type { FindActionResult } from '@kbn/actions-plugin/server'; +import type { InferenceConnector } from '@kbn/inference-common'; +import { InferenceConnectorType } from '@kbn/inference-common'; import useLocalStorage from 'react-use/lib/useLocalStorage'; import type { ObservabilityAIAssistantService } from '../types'; import { GEN_AI_SETTINGS_DEFAULT_AI_CONNECTOR, GEN_AI_SETTINGS_DEFAULT_AI_CONNECTOR_DEFAULT_ONLY, } from '@kbn/management-settings-ids'; -import { createMockConnectorFindResult } from '@kbn/actions-plugin/server/application/connector/mocks'; // Mock dependencies and data jest.mock('react-use/lib/useLocalStorage', () => jest.fn()); @@ -31,37 +31,34 @@ jest.mock('./use_kibana', () => ({ jest.mock('../../common/utils/get_inference_connector', () => ({ getInferenceConnectorInfo: jest.fn((connector) => connector), })); -const mockConnectors: FindActionResult[] = [ - createMockConnectorFindResult({ - id: 'connector-1', +const mockConnectors: InferenceConnector[] = [ + { + connectorId: 'connector-1', name: 'Connector 1', - actionTypeId: '.gen-ai', + type: InferenceConnectorType.OpenAI, config: {}, - referencedByCount: 0, + capabilities: {}, + isInferenceEndpoint: false, isPreconfigured: false, - isDeprecated: false, - isSystemAction: false, - }), - createMockConnectorFindResult({ - id: 'connector-2', + }, + { + connectorId: 'connector-2', name: 'Connector 2', - actionTypeId: '.gen-ai', + type: InferenceConnectorType.OpenAI, config: {}, - referencedByCount: 0, + capabilities: {}, + isInferenceEndpoint: false, isPreconfigured: false, - isDeprecated: false, - isSystemAction: false, - }), - createMockConnectorFindResult({ - id: 'elastic-llm', + }, + { + connectorId: 'elastic-llm', name: 'Elastic LLM', - actionTypeId: '.inference', + type: InferenceConnectorType.Inference, config: { inferenceId: 'inf-1' }, - referencedByCount: 0, + capabilities: {}, + isInferenceEndpoint: true, isPreconfigured: true, - isDeprecated: false, - isSystemAction: false, - }), + }, ]; const mockAssistant: Partial = { diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/public/hooks/use_genai_connectors.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/public/hooks/use_genai_connectors.ts index fb92e03ac1321..0866d1c221ab4 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/public/hooks/use_genai_connectors.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/public/hooks/use_genai_connectors.ts @@ -6,7 +6,7 @@ */ import { useCallback, useEffect, useMemo, useState } from 'react'; -import type { FindActionResult } from '@kbn/actions-plugin/server'; +import type { InferenceConnector as CommonInferenceConnector } from '@kbn/inference-common'; import useLocalStorage from 'react-use/lib/useLocalStorage'; import { GEN_AI_SETTINGS_DEFAULT_AI_CONNECTOR, @@ -24,7 +24,7 @@ import { const NO_DEFAULT_CONNECTOR = 'NO_DEFAULT_CONNECTOR'; export interface UseGenAIConnectorsResult { - connectors?: FindActionResult[]; + connectors?: CommonInferenceConnector[]; selectedConnector?: string; loading: boolean; error?: Error; @@ -44,7 +44,7 @@ export function useGenAIConnectors(): UseGenAIConnectorsResult { export function useGenAIConnectorsWithoutContext( assistant: ObservabilityAIAssistantService ): UseGenAIConnectorsResult { - const [connectors, setConnectors] = useState(undefined); + const [connectors, setConnectors] = useState(undefined); const { services: { uiSettings }, } = useKibana(); @@ -86,16 +86,27 @@ export function useGenAIConnectorsWithoutContext( const fetchConnectors = useCallback(async () => { setLoading(true); try { - let results = await assistant.callApi('GET /internal/observability_ai_assistant/connectors', { - signal: controller.signal, - }); + let results: CommonInferenceConnector[]; if (isConnectorSelectionRestricted) { - const defaultC = results.find((con) => con.id === defaultConnector); - results = defaultC ? [defaultC] : []; + const connector = await assistant.callApi( + 'GET /internal/observability_ai_assistant/connectors/{connectorId}', + { + signal: controller.signal, + params: { path: { connectorId: defaultConnector } }, + } + ); + results = [connector]; + } else { + results = await assistant.callApi('GET /internal/observability_ai_assistant/connectors', { + signal: controller.signal, + }); } setConnectors(results); setLastUsedConnector((connectorId) => { - if (connectorId && results.findIndex((result) => result.id === connectorId) === -1) { + if ( + connectorId && + results.findIndex((result) => result.connectorId === connectorId) === -1 + ) { return ''; } return connectorId; @@ -119,7 +130,7 @@ export function useGenAIConnectorsWithoutContext( }, [controller, fetchConnectors]); const getConnector = (id: string) => { - const connector = connectors?.find((_connector) => _connector.id === id); + const connector = connectors?.find((_connector) => _connector.connectorId === id); return getInferenceConnectorInfo(connector); }; @@ -127,7 +138,7 @@ export function useGenAIConnectorsWithoutContext( connectors, loading, error, - selectedConnector: selectedConnector || connectors?.[0]?.id, + selectedConnector: selectedConnector || connectors?.[0]?.connectorId, selectConnector: (id: string) => { setLastUsedConnector(id); }, diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/public/index.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/public/index.ts index 46f63de941d36..56f5ed3db1abf 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/public/index.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/public/index.ts @@ -113,10 +113,7 @@ export { aiAssistantSearchConnectorIndexPattern, } from '../common/ui_settings/settings_keys'; -export { - getElasticManagedLlmConnector, - INFERENCE_CONNECTOR_ACTION_TYPE_ID, -} from './utils/get_elastic_managed_llm_connector'; +export { INFERENCE_CONNECTOR_ACTION_TYPE_ID } from './utils/get_elastic_managed_llm_connector'; export const elasticAiAssistantImage = elasticAiAssistantImg; diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/public/utils/get_elastic_managed_llm_connector.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/public/utils/get_elastic_managed_llm_connector.ts index 15de5f1d07d3b..6e20aeea411e7 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/public/utils/get_elastic_managed_llm_connector.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/public/utils/get_elastic_managed_llm_connector.ts @@ -18,8 +18,8 @@ export const getElasticManagedLlmConnector = ( return connectors.find( (connector) => - connector.actionTypeId === INFERENCE_CONNECTOR_ACTION_TYPE_ID && + connector.type === INFERENCE_CONNECTOR_ACTION_TYPE_ID && connector.isPreconfigured && - connector.config?.provider === 'elastic' + connector.config?.service === 'elastic' ); }; diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/routes/chat/route.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/routes/chat/route.ts index 017af160e060f..4b0eae83a99c3 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/routes/chat/route.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/server/routes/chat/route.ts @@ -4,7 +4,8 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import { notImplemented } from '@hapi/boom'; +import { boomify, notImplemented } from '@hapi/boom'; +import { isInferenceError } from '@kbn/inference-common'; import { toBooleanRt } from '@kbn/io-ts-utils'; import * as t from 'io-ts'; import type { Observable } from 'rxjs'; @@ -84,7 +85,7 @@ const chatCompletePublicRt = t.intersection([ async function initializeChatRequest({ context, request, - plugins: { cloud, actions }, + plugins: { cloud, inference }, params: { body: { connectorId, scopes }, }, @@ -92,16 +93,17 @@ async function initializeChatRequest({ }: ObservabilityAIAssistantRouteHandlerResources & { params: { body: { connectorId: string; scopes: AssistantScope[] } }; }) { - await withAssistantSpan('guard_against_invalid_connector', async () => { - const actionsClient = await (await actions.start()).getActionsClientWithRequest(request); - - const connector = await actionsClient.get({ - id: connectorId, - throwIfSystemAction: true, + try { + await withAssistantSpan('guard_against_invalid_connector', async () => { + const inferenceStart = await inference.start(); + return inferenceStart.getConnectorById(connectorId, request); }); - - return connector; - }); + } catch (error) { + if (isInferenceError(error) && error.status) { + throw boomify(error, { statusCode: error.status }); + } + throw error; + } const [client, cloudStart, simulateFunctionCalling] = await Promise.all([ service.getClient({ request, scopes }), diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/server/routes/connectors/route.ts b/x-pack/platform/plugins/shared/observability_ai_assistant/server/routes/connectors/route.ts index 0ec7d5a2ab0d6..d362abd2110ba 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/server/routes/connectors/route.ts +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/server/routes/connectors/route.ts @@ -4,9 +4,8 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { FindActionResult } from '@kbn/actions-plugin/server'; -import { InferenceConnectorType, isSupportedConnector } from '@kbn/inference-common'; -import { inferenceEndpointExists } from '@kbn/inference-endpoint-plugin/server/lib/inference_endpoint_exists'; +import * as t from 'io-ts'; +import type { InferenceConnector } from '@kbn/inference-common'; import { createObservabilityAIAssistantServerRoute } from '../create_observability_ai_assistant_server_route'; const listConnectorsRoute = createObservabilityAIAssistantServerRoute({ @@ -16,48 +15,33 @@ const listConnectorsRoute = createObservabilityAIAssistantServerRoute({ requiredPrivileges: ['ai_assistant'], }, }, - handler: async (resources): Promise => { - const { request, plugins, context } = resources; - const esClient = (await context.core).elasticsearch.client.asInternalUser; - - const actionsClient = await ( - await plugins.actions.start() - ).getActionsClientWithRequest(request); - - const [availableTypes, connectors] = await Promise.all([ - actionsClient - .listTypes({ - includeSystemActionTypes: false, - }) - .then((types) => - types - .filter((type) => type.enabled && type.enabledInLicense && type.enabledInConfig) - .map((type) => type.id) - ), - actionsClient.getAll(), - ]); - const filteredConnectors: typeof connectors = []; - - for (const connector of connectors) { - const hasAllowedType = availableTypes.includes(connector.actionTypeId); - const isSupported = isSupportedConnector(connector); - if (!hasAllowedType || !isSupported) continue; - - if (connector.actionTypeId === InferenceConnectorType.Inference) { - const endpointExists = await inferenceEndpointExists( - esClient, - connector.config?.inferenceId - ); - if (!endpointExists) continue; - } - - filteredConnectors.push(connector); - } + handler: async (resources): Promise => { + const { request, plugins } = resources; + const inferenceStart = await plugins.inference.start(); + return inferenceStart.getConnectorList(request); + }, +}); - return filteredConnectors; +const getConnectorByIdRoute = createObservabilityAIAssistantServerRoute({ + endpoint: 'GET /internal/observability_ai_assistant/connectors/{connectorId}', + params: t.type({ + path: t.type({ + connectorId: t.string, + }), + }), + security: { + authz: { + requiredPrivileges: ['ai_assistant'], + }, + }, + handler: async (resources): Promise => { + const { request, plugins, params } = resources; + const inferenceStart = await plugins.inference.start(); + return inferenceStart.getConnectorById(params.path.connectorId, request); }, }); export const connectorRoutes = { ...listConnectorsRoute, + ...getConnectorByIdRoute, }; diff --git a/x-pack/platform/plugins/shared/observability_ai_assistant/tsconfig.json b/x-pack/platform/plugins/shared/observability_ai_assistant/tsconfig.json index 8cedf44f309b9..a007fda0b001f 100644 --- a/x-pack/platform/plugins/shared/observability_ai_assistant/tsconfig.json +++ b/x-pack/platform/plugins/shared/observability_ai_assistant/tsconfig.json @@ -60,7 +60,6 @@ "@kbn/licensing-types", "@kbn/llm-tasks-plugin", "@kbn/product-doc-base-plugin", - "@kbn/inference-endpoint-plugin", "@kbn/spaces-utils", "@kbn/usage-collection-plugin", "@kbn/ai-agent-confirmation-modal", diff --git a/x-pack/platform/plugins/shared/search_inference_endpoints/server/inference_endpoints.test.ts b/x-pack/platform/plugins/shared/search_inference_endpoints/server/inference_endpoints.test.ts index ef936446a266d..ae89fe8bec04d 100644 --- a/x-pack/platform/plugins/shared/search_inference_endpoints/server/inference_endpoints.test.ts +++ b/x-pack/platform/plugins/shared/search_inference_endpoints/server/inference_endpoints.test.ts @@ -47,6 +47,7 @@ const createExpectedConnector = (id: string): InferenceConnector => ({ serviceSettings: undefined, }, capabilities: {}, + isPreconfigured: false, isInferenceEndpoint: true, }); diff --git a/x-pack/platform/plugins/shared/search_inference_endpoints/server/inference_endpoints.ts b/x-pack/platform/plugins/shared/search_inference_endpoints/server/inference_endpoints.ts index 98988d870b473..3b868ab99b27a 100644 --- a/x-pack/platform/plugins/shared/search_inference_endpoints/server/inference_endpoints.ts +++ b/x-pack/platform/plugins/shared/search_inference_endpoints/server/inference_endpoints.ts @@ -176,6 +176,7 @@ const fetchEndpoints = async ( serviceSettings, }, capabilities: {}, + isPreconfigured: false, isInferenceEndpoint: true, }; endpoints.push(connector); diff --git a/x-pack/platform/plugins/shared/streams/server/plugin.ts b/x-pack/platform/plugins/shared/streams/server/plugin.ts index e2aea4d88349a..0373220a4d22d 100644 --- a/x-pack/platform/plugins/shared/streams/server/plugin.ts +++ b/x-pack/platform/plugins/shared/streams/server/plugin.ts @@ -363,6 +363,7 @@ export class StreamsPlugin this.server.security = plugins.security; this.server.actions = plugins.actions; this.server.encryptedSavedObjects = plugins.encryptedSavedObjects; + this.server.inference = plugins.inference; this.server.taskManager = plugins.taskManager; } diff --git a/x-pack/platform/plugins/shared/streams/server/routes/internal/connectors/route.test.ts b/x-pack/platform/plugins/shared/streams/server/routes/internal/connectors/route.test.ts deleted file mode 100644 index 96f5bbe931f63..0000000000000 --- a/x-pack/platform/plugins/shared/streams/server/routes/internal/connectors/route.test.ts +++ /dev/null @@ -1,213 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import { - filterSupportedConnectors, - INFERENCE_CONNECTOR_TYPE, - type ConnectorWithConfig, -} from './route'; - -describe('filterSupportedConnectors', () => { - const createConnector = ( - overrides: Partial & Pick - ): ConnectorWithConfig => ({ - name: `Connector ${overrides.id}`, - ...overrides, - }); - - describe('connector type filtering', () => { - it('includes supported connector types (OpenAI, Bedrock, Gemini)', async () => { - const connectors = [ - createConnector({ id: 'openai-1', actionTypeId: '.gen-ai' }), - createConnector({ id: 'bedrock-1', actionTypeId: '.bedrock' }), - createConnector({ id: 'gemini-1', actionTypeId: '.gemini' }), - ]; - - const mockCheckEndpoint = jest.fn().mockResolvedValue(true); - const result = await filterSupportedConnectors(connectors, mockCheckEndpoint); - - expect(result).toHaveLength(3); - expect(result.map((c) => c.id)).toEqual(['openai-1', 'bedrock-1', 'gemini-1']); - // Non-inference connectors shouldn't trigger endpoint check - expect(mockCheckEndpoint).not.toHaveBeenCalled(); - }); - - it('filters out unsupported connector types', async () => { - const connectors = [ - createConnector({ id: 'openai-1', actionTypeId: '.gen-ai' }), - createConnector({ id: 'slack-1', actionTypeId: '.slack' }), - createConnector({ id: 'email-1', actionTypeId: '.email' }), - createConnector({ id: 'webhook-1', actionTypeId: '.webhook' }), - ]; - - const mockCheckEndpoint = jest.fn().mockResolvedValue(true); - const result = await filterSupportedConnectors(connectors, mockCheckEndpoint); - - expect(result).toHaveLength(1); - expect(result[0].id).toBe('openai-1'); - }); - }); - - describe('inference connector filtering', () => { - it('includes .inference connectors with taskType chat_completion', async () => { - const connectors = [ - createConnector({ - id: 'inference-1', - actionTypeId: INFERENCE_CONNECTOR_TYPE, - config: { taskType: 'chat_completion', inferenceId: 'my-endpoint' }, - }), - ]; - - const mockCheckEndpoint = jest.fn().mockResolvedValue(true); - const result = await filterSupportedConnectors(connectors, mockCheckEndpoint); - - expect(result).toHaveLength(1); - expect(result[0].id).toBe('inference-1'); - }); - - it('filters out .inference connectors with wrong taskType', async () => { - const connectors = [ - createConnector({ - id: 'inference-embed', - actionTypeId: INFERENCE_CONNECTOR_TYPE, - config: { taskType: 'text_embedding', inferenceId: 'embed-endpoint' }, - }), - createConnector({ - id: 'inference-sparse', - actionTypeId: INFERENCE_CONNECTOR_TYPE, - config: { taskType: 'sparse_embedding', inferenceId: 'sparse-endpoint' }, - }), - createConnector({ - id: 'inference-chat', - actionTypeId: INFERENCE_CONNECTOR_TYPE, - config: { taskType: 'chat_completion', inferenceId: 'chat-endpoint' }, - }), - ]; - - const mockCheckEndpoint = jest.fn().mockResolvedValue(true); - const result = await filterSupportedConnectors(connectors, mockCheckEndpoint); - - // Only chat_completion taskType should pass - expect(result).toHaveLength(1); - expect(result[0].id).toBe('inference-chat'); - }); - - it('filters out .inference connectors without taskType', async () => { - const connectors = [ - createConnector({ - id: 'inference-no-task', - actionTypeId: INFERENCE_CONNECTOR_TYPE, - config: { inferenceId: 'some-endpoint' }, - }), - ]; - - const mockCheckEndpoint = jest.fn().mockResolvedValue(true); - const result = await filterSupportedConnectors(connectors, mockCheckEndpoint); - - expect(result).toHaveLength(0); - }); - }); - - describe('inference endpoint validation', () => { - it('filters out .inference connectors when endpoint does not exist', async () => { - const connectors = [ - createConnector({ - id: 'inference-valid', - actionTypeId: INFERENCE_CONNECTOR_TYPE, - config: { taskType: 'chat_completion', inferenceId: 'valid-endpoint' }, - }), - createConnector({ - id: 'inference-invalid', - actionTypeId: INFERENCE_CONNECTOR_TYPE, - config: { taskType: 'chat_completion', inferenceId: 'invalid-endpoint' }, - }), - ]; - - const mockCheckEndpoint = jest.fn().mockImplementation((inferenceId: string) => { - return Promise.resolve(inferenceId === 'valid-endpoint'); - }); - - const result = await filterSupportedConnectors(connectors, mockCheckEndpoint); - - expect(result).toHaveLength(1); - expect(result[0].id).toBe('inference-valid'); - expect(mockCheckEndpoint).toHaveBeenCalledWith('valid-endpoint'); - expect(mockCheckEndpoint).toHaveBeenCalledWith('invalid-endpoint'); - }); - - it('includes .inference connectors without inferenceId in config', async () => { - const connectors = [ - createConnector({ - id: 'inference-no-id', - actionTypeId: INFERENCE_CONNECTOR_TYPE, - config: { taskType: 'chat_completion' }, - }), - ]; - - const mockCheckEndpoint = jest.fn().mockResolvedValue(false); - const result = await filterSupportedConnectors(connectors, mockCheckEndpoint); - - // Connector passes through because there's no inferenceId to validate - expect(result).toHaveLength(1); - expect(result[0].id).toBe('inference-no-id'); - expect(mockCheckEndpoint).not.toHaveBeenCalled(); - }); - }); - - describe('mixed connector scenarios', () => { - it('correctly filters a mix of supported and unsupported connectors', async () => { - const connectors = [ - createConnector({ id: 'openai-1', actionTypeId: '.gen-ai' }), - createConnector({ id: 'slack-1', actionTypeId: '.slack' }), - createConnector({ id: 'bedrock-1', actionTypeId: '.bedrock' }), - createConnector({ - id: 'inference-chat', - actionTypeId: INFERENCE_CONNECTOR_TYPE, - config: { taskType: 'chat_completion', inferenceId: 'chat-endpoint' }, - }), - createConnector({ - id: 'inference-embed', - actionTypeId: INFERENCE_CONNECTOR_TYPE, - config: { taskType: 'text_embedding', inferenceId: 'embed-endpoint' }, - }), - createConnector({ id: 'webhook-1', actionTypeId: '.webhook' }), - createConnector({ id: 'gemini-1', actionTypeId: '.gemini' }), - ]; - - const mockCheckEndpoint = jest.fn().mockResolvedValue(true); - const result = await filterSupportedConnectors(connectors, mockCheckEndpoint); - - expect(result).toHaveLength(4); - expect(result.map((c) => c.id)).toEqual([ - 'openai-1', - 'bedrock-1', - 'inference-chat', - 'gemini-1', - ]); - }); - - it('returns empty array when no connectors are supported', async () => { - const connectors = [ - createConnector({ id: 'slack-1', actionTypeId: '.slack' }), - createConnector({ id: 'email-1', actionTypeId: '.email' }), - ]; - - const mockCheckEndpoint = jest.fn().mockResolvedValue(true); - const result = await filterSupportedConnectors(connectors, mockCheckEndpoint); - - expect(result).toHaveLength(0); - }); - - it('returns empty array when given empty input', async () => { - const mockCheckEndpoint = jest.fn().mockResolvedValue(true); - const result = await filterSupportedConnectors([], mockCheckEndpoint); - - expect(result).toHaveLength(0); - expect(mockCheckEndpoint).not.toHaveBeenCalled(); - }); - }); -}); diff --git a/x-pack/platform/plugins/shared/streams/server/routes/internal/connectors/route.ts b/x-pack/platform/plugins/shared/streams/server/routes/internal/connectors/route.ts index 97795710d54aa..dd255fd3d196b 100644 --- a/x-pack/platform/plugins/shared/streams/server/routes/internal/connectors/route.ts +++ b/x-pack/platform/plugins/shared/streams/server/routes/internal/connectors/route.ts @@ -5,72 +5,10 @@ * 2.0. */ -import type { ElasticsearchClient } from '@kbn/core/server'; -import { isSupportedConnector, type InferenceConnector } from '@kbn/inference-common'; +import { z } from '@kbn/zod/v4'; import { STREAMS_API_PRIVILEGES } from '../../../../common/constants'; import { createServerRoute } from '../../create_server_route'; -export const INFERENCE_CONNECTOR_TYPE = '.inference'; - -/** - * Minimal connector interface for filtering. Compatible with the connector - * type returned by the actions plugin. - */ -export interface ConnectorWithConfig { - id: string; - actionTypeId: string; - name: string; - config?: Record; -} - -export async function inferenceEndpointExists( - esClient: ElasticsearchClient, - inferenceId: string -): Promise { - try { - const endpoints = await esClient.inference.get({ inference_id: inferenceId }); - return endpoints.endpoints.some((endpoint) => endpoint.inference_id === inferenceId); - } catch (error) { - return false; - } -} - -/** - * Filters connectors to only include supported GenAI connectors. - * For .inference connectors, also validates that the inference endpoint exists. - * - * @param connectors - List of all connectors to filter - * @param checkInferenceEndpointExists - Function to check if an inference endpoint exists - * @returns List of supported connectors with validated inference endpoints - */ -export async function filterSupportedConnectors( - connectors: T[], - checkInferenceEndpointExists: (inferenceId: string) => Promise -): Promise { - // Filter to only supported GenAI connector types - // Uses isSupportedConnector which also validates .inference connectors have taskType: 'chat_completion' - const supportedConnectors = connectors.filter((connector) => isSupportedConnector(connector)); - - // Validate inference connectors have endpoints - const validatedConnectors = await Promise.all( - supportedConnectors.map(async (connector) => { - if (connector.actionTypeId === INFERENCE_CONNECTOR_TYPE) { - const inferenceId = (connector.config as InferenceConnector['config'])?.inferenceId; - if (inferenceId) { - const exists = await checkInferenceEndpointExists(inferenceId); - if (!exists) { - return null; - } - } - } - return connector; - }) - ); - - // Type assertion is safe here - we're only filtering out nulls, the remaining values are T - return validatedConnectors.filter((connector) => connector !== null) as T[]; -} - export const getConnectorsRoute = createServerRoute({ endpoint: 'GET /internal/streams/connectors', options: { @@ -83,28 +21,34 @@ export const getConnectorsRoute = createServerRoute({ requiredPrivileges: [STREAMS_API_PRIVILEGES.read], }, }, - handler: async ({ request, getScopedClients, server }) => { - const { scopedClusterClient } = await getScopedClients({ request }); - - // Get actions client with request - const actionsClient = await server.actions.getActionsClientWithRequest(request); - - if (!actionsClient) { - throw new Error('Actions client not available'); - } - - const connectors = await actionsClient.getAll(); + handler: async ({ request, server }) => { + const connectors = await server.inference.getConnectorList(request); - const filteredConnectors = await filterSupportedConnectors(connectors, (inferenceId) => - inferenceEndpointExists(scopedClusterClient.asCurrentUser, inferenceId) - ); + return { connectors }; + }, +}); - return { - connectors: filteredConnectors, - }; +export const getConnectorByIdRoute = createServerRoute({ + endpoint: 'GET /internal/streams/connectors/{connectorId}', + options: { + access: 'internal', + summary: 'Get a GenAI connector by ID', + description: 'Fetches a single GenAI connector by its ID', + }, + security: { + authz: { + requiredPrivileges: [STREAMS_API_PRIVILEGES.read], + }, + }, + params: z.object({ + path: z.object({ connectorId: z.string() }), + }), + handler: async ({ request, params, server }) => { + return server.inference.getConnectorById(params.path.connectorId, request); }, }); export const connectorRoutes = { ...getConnectorsRoute, + ...getConnectorByIdRoute, }; diff --git a/x-pack/platform/plugins/shared/streams/server/types.ts b/x-pack/platform/plugins/shared/streams/server/types.ts index 7fd55f54be98a..a94cb6131feb3 100644 --- a/x-pack/platform/plugins/shared/streams/server/types.ts +++ b/x-pack/platform/plugins/shared/streams/server/types.ts @@ -40,6 +40,7 @@ export interface StreamsServer { security: SecurityPluginStart; actions: ActionsPluginStart; encryptedSavedObjects: EncryptedSavedObjectsPluginStart; + inference: InferenceServerStart; isServerless: boolean; taskManager: TaskManagerStartContract; } diff --git a/x-pack/platform/plugins/shared/streams_app/public/components/connector_list_button/connector_list_button.tsx b/x-pack/platform/plugins/shared/streams_app/public/components/connector_list_button/connector_list_button.tsx index 622877777eb3d..81ce5eb3dea44 100644 --- a/x-pack/platform/plugins/shared/streams_app/public/components/connector_list_button/connector_list_button.tsx +++ b/x-pack/platform/plugins/shared/streams_app/public/components/connector_list_button/connector_list_button.tsx @@ -92,10 +92,12 @@ export function ConnectorListButtonBase({ size="s" items={connectorsResult.connectors.map((connector) => ( { - connectorsResult.selectConnector(connector.id); + connectorsResult.selectConnector(connector.connectorId); closePopover(); }} > diff --git a/x-pack/platform/plugins/shared/streams_app/public/components/data_management/stream_detail_routing/review_suggestions_form/generate_suggestions_button.test.tsx b/x-pack/platform/plugins/shared/streams_app/public/components/data_management/stream_detail_routing/review_suggestions_form/generate_suggestions_button.test.tsx index 7ca85a94e0466..e19f830461408 100644 --- a/x-pack/platform/plugins/shared/streams_app/public/components/data_management/stream_detail_routing/review_suggestions_form/generate_suggestions_button.test.tsx +++ b/x-pack/platform/plugins/shared/streams_app/public/components/data_management/stream_detail_routing/review_suggestions_form/generate_suggestions_button.test.tsx @@ -14,7 +14,8 @@ import { } from './generate_suggestions_button'; import { AdditionalChargesCallout } from '../../shared/additional_charges_callout'; import type { AIFeatures } from '../../../../hooks/use_ai_features'; -import type { UseGenAIConnectorsResult, Connector } from '../../../../hooks/use_genai_connectors'; +import type { UseGenAIConnectorsResult } from '../../../../hooks/use_genai_connectors'; +import { InferenceConnectorType } from '@kbn/inference-common'; jest.mock('../../../../hooks/use_kibana', () => ({ useKibana: () => ({ @@ -36,10 +37,14 @@ jest.mock('../../../../hooks/use_kibana', () => ({ }), })); -const createMockConnector = (id: string, name: string): Connector => ({ - id, +const createMockConnector = (connectorId: string, name: string) => ({ + connectorId, name, - actionTypeId: '.gen-ai', + type: InferenceConnectorType.OpenAI, + config: {}, + capabilities: {}, + isPreconfigured: false, + isInferenceEndpoint: false, }); const createMockGenAiConnectors = ( diff --git a/x-pack/platform/plugins/shared/streams_app/public/components/significant_events_discovery/components/settings/tab.tsx b/x-pack/platform/plugins/shared/streams_app/public/components/significant_events_discovery/components/settings/tab.tsx index 1807e9d860d3e..0fb15103835de 100644 --- a/x-pack/platform/plugins/shared/streams_app/public/components/significant_events_discovery/components/settings/tab.tsx +++ b/x-pack/platform/plugins/shared/streams_app/public/components/significant_events_discovery/components/settings/tab.tsx @@ -54,6 +54,17 @@ export function SettingsTab() { uiSettings: core.uiSettings, }); + const defaultConnectorFetch = useStreamsAppFetch( + async ({ signal }) => { + if (!genAiConnectors.defaultConnector) return undefined; + return streams.streamsRepositoryClient.fetch( + 'GET /internal/streams/connectors/{connectorId}', + { signal, params: { path: { connectorId: genAiConnectors.defaultConnector } } } + ); + }, + [streams.streamsRepositoryClient, genAiConnectors.defaultConnector] + ); + const settingsFetch = useStreamsAppFetch( async ({ signal }) => streams.streamsRepositoryClient.fetch('GET /internal/streams/_significant_events/settings', { @@ -120,7 +131,7 @@ export function SettingsTab() { defaultMessage: 'Use default (genAiSettings:defaultAIConnector)', }), }, - ...(genAiConnectors.connectors ?? []).map((c) => ({ value: c.id, text: c.name })), + ...(genAiConnectors.connectors ?? []).map((c) => ({ value: c.connectorId, text: c.name })), ]; if (settingsFetch.loading && !settingsFetch.value) { @@ -134,9 +145,7 @@ export function SettingsTab() { discovery === NOT_SET_VALUE; const showNoDefaultCallout = !genAiConnectors.loading && !hasDefaultConnector && anyUsesDefault; const defaultConnectorName = - hasDefaultConnector && anyUsesDefault - ? genAiConnectors.connectors?.find((c) => c.id === genAiConnectors.defaultConnector)?.name - : undefined; + hasDefaultConnector && anyUsesDefault ? defaultConnectorFetch.value?.name : undefined; return ( diff --git a/x-pack/platform/plugins/shared/streams_app/public/hooks/use_ai_features.tsx b/x-pack/platform/plugins/shared/streams_app/public/hooks/use_ai_features.tsx index 7eeb407a96031..a802bd508d355 100644 --- a/x-pack/platform/plugins/shared/streams_app/public/hooks/use_ai_features.tsx +++ b/x-pack/platform/plugins/shared/streams_app/public/hooks/use_ai_features.tsx @@ -14,7 +14,6 @@ import { import { STREAMS_TIERED_AI_FEATURE } from '@kbn/streams-plugin/common'; import { useKibana } from './use_kibana'; import { useGenAIConnectors, type UseGenAIConnectorsResult } from './use_genai_connectors'; -import { getElasticManagedLlmConnector } from '../utils/get_elastic_managed_llm_connector'; export interface AIFeatures { loading: boolean; @@ -61,8 +60,6 @@ export function useAIFeatures(): AIFeatures | null { }; } - const elasticManagedLlmConnector = getElasticManagedLlmConnector(genAiConnectors.connectors); - // Check for actions.show permission (read access is sufficient for listing connectors) const hasActionsPermission = core.application.capabilities.actions?.show || false; @@ -74,9 +71,10 @@ export function useAIFeatures(): AIFeatures | null { const couldBeEnabled = Boolean( license?.hasAtLeast('enterprise') && core.application.capabilities.actions?.show ); - const isManagedAIConnector = elasticManagedLlmConnector - ? elasticManagedLlmConnector.id === genAiConnectors.selectedConnector - : false; + const selectedConnector = (genAiConnectors.connectors || []).find( + (connector) => connector.connectorId === genAiConnectors.selectedConnector + ); + const isManagedAIConnector = selectedConnector?.isPreconfigured || false; return { loading: false, diff --git a/x-pack/platform/plugins/shared/streams_app/public/hooks/use_genai_connectors.test.ts b/x-pack/platform/plugins/shared/streams_app/public/hooks/use_genai_connectors.test.ts index 39378fa018af8..c8ce2d33cef4e 100644 --- a/x-pack/platform/plugins/shared/streams_app/public/hooks/use_genai_connectors.test.ts +++ b/x-pack/platform/plugins/shared/streams_app/public/hooks/use_genai_connectors.test.ts @@ -6,9 +6,10 @@ */ import { renderHook, act, waitFor } from '@testing-library/react'; -import { useGenAIConnectors, type Connector } from './use_genai_connectors'; +import { useGenAIConnectors } from './use_genai_connectors'; import type { StreamsRepositoryClient } from '@kbn/streams-plugin/public/api'; import type { IUiSettingsClient } from '@kbn/core/public'; +import { InferenceConnectorType } from '@kbn/inference-common'; import { GEN_AI_SETTINGS_DEFAULT_AI_CONNECTOR, GEN_AI_SETTINGS_DEFAULT_AI_CONNECTOR_DEFAULT_ONLY, @@ -17,10 +18,14 @@ import { const STREAMS_CONNECTOR_STORAGE_KEY = 'xpack.streamsApp.lastUsedConnector'; const OLD_STORAGE_KEY = 'xpack.observabilityAiAssistant.lastUsedConnector'; -const createMockConnector = (id: string, name: string): Connector => ({ - id, +const createMockConnector = (connectorId: string, name: string) => ({ + connectorId, name, - actionTypeId: '.gen-ai', + type: InferenceConnectorType.OpenAI, + config: {}, + capabilities: {}, + isPreconfigured: false, + isInferenceEndpoint: false, }); describe('useGenAIConnectors', () => { diff --git a/x-pack/platform/plugins/shared/streams_app/public/hooks/use_genai_connectors.ts b/x-pack/platform/plugins/shared/streams_app/public/hooks/use_genai_connectors.ts index dea472870aded..88868fa9f60b3 100644 --- a/x-pack/platform/plugins/shared/streams_app/public/hooks/use_genai_connectors.ts +++ b/x-pack/platform/plugins/shared/streams_app/public/hooks/use_genai_connectors.ts @@ -9,6 +9,7 @@ import { useState, useEffect, useCallback, useMemo } from 'react'; import useLocalStorage from 'react-use/lib/useLocalStorage'; import type { IUiSettingsClient } from '@kbn/core/public'; import type { StreamsRepositoryClient } from '@kbn/streams-plugin/public/api'; +import type { InferenceConnector } from '@kbn/inference-common'; import { GEN_AI_SETTINGS_DEFAULT_AI_CONNECTOR, GEN_AI_SETTINGS_DEFAULT_AI_CONNECTOR_DEFAULT_ONLY, @@ -19,20 +20,8 @@ const OLD_STORAGE_KEY = 'xpack.observabilityAiAssistant.lastUsedConnector'; // TODO: Import from gen-ai-settings-plugin (package) once available const NO_DEFAULT_CONNECTOR = 'NO_DEFAULT_CONNECTOR'; -export interface Connector { - id: string; - name: string; - actionTypeId: string; - config?: Record; - isPreconfigured?: boolean; - isDeprecated?: boolean; - isSystemAction?: boolean; - isMissingSecrets?: boolean; - referencedByCount?: number; -} - export interface UseGenAIConnectorsResult { - connectors: Connector[] | undefined; + connectors: InferenceConnector[] | undefined; selectedConnector: string | undefined; loading: boolean; error: Error | undefined; @@ -49,7 +38,7 @@ export function useGenAIConnectors({ streamsRepositoryClient: StreamsRepositoryClient; uiSettings: IUiSettingsClient; }): UseGenAIConnectorsResult { - const [connectors, setConnectors] = useState(); + const [connectors, setConnectors] = useState(); const [loading, setLoading] = useState(true); const [error, setError] = useState(); @@ -85,7 +74,7 @@ export function useGenAIConnectors({ // If connector selection is restricted, only return the default connector if (isConnectorSelectionRestricted) { - const defaultC = results.find((con) => con.id === defaultConnector); + const defaultC = results.find((con) => con.connectorId === defaultConnector); results = defaultC ? [defaultC] : []; } @@ -93,7 +82,10 @@ export function useGenAIConnectors({ // Clear lastUsedConnector if it's no longer in the list setLastUsedConnector((connectorId) => { - if (connectorId && results.findIndex((result) => result.id === connectorId) === -1) { + if ( + connectorId && + results.findIndex((result) => result.connectorId === connectorId) === -1 + ) { return undefined; } return connectorId; @@ -148,7 +140,7 @@ export function useGenAIConnectors({ // If the selected connector is no longer available, select the first available connector useEffect(() => { - const availableConnectors = connectors?.map((connector) => connector.id); + const availableConnectors = connectors?.map((connector) => connector.connectorId); if ( selectedConnector && @@ -161,7 +153,7 @@ export function useGenAIConnectors({ return { connectors, - selectedConnector: selectedConnector || connectors?.[0]?.id, + selectedConnector: selectedConnector || connectors?.[0]?.connectorId, loading, error, selectConnector, diff --git a/x-pack/platform/plugins/shared/streams_app/public/utils/get_elastic_managed_llm_connector.ts b/x-pack/platform/plugins/shared/streams_app/public/utils/get_elastic_managed_llm_connector.ts deleted file mode 100644 index 9756bfd44a890..0000000000000 --- a/x-pack/platform/plugins/shared/streams_app/public/utils/get_elastic_managed_llm_connector.ts +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import type { Connector } from '../hooks/use_genai_connectors'; - -export const INFERENCE_CONNECTOR_ACTION_TYPE_ID = '.inference'; - -export const getElasticManagedLlmConnector = (connectors: Connector[] | undefined) => { - if (!Array.isArray(connectors) || connectors.length === 0) { - return undefined; - } - - return connectors.find( - (connector) => - connector.actionTypeId === INFERENCE_CONNECTOR_ACTION_TYPE_ID && - connector.isPreconfigured && - (connector.config as { provider?: string })?.provider === 'elastic' - ); -}; diff --git a/x-pack/solutions/observability/plugins/observability_agent_builder/public/hooks/use_genai_connectors.test.ts b/x-pack/solutions/observability/plugins/observability_agent_builder/public/hooks/use_genai_connectors.test.ts index 408ed384f7192..2bd027d769946 100644 --- a/x-pack/solutions/observability/plugins/observability_agent_builder/public/hooks/use_genai_connectors.test.ts +++ b/x-pack/solutions/observability/plugins/observability_agent_builder/public/hooks/use_genai_connectors.test.ts @@ -23,6 +23,7 @@ const mockConnectors: InferenceConnector[] = [ type: InferenceConnectorType.OpenAI, config: {}, capabilities: {}, + isPreconfigured: false, isInferenceEndpoint: false, }, { @@ -31,6 +32,7 @@ const mockConnectors: InferenceConnector[] = [ type: InferenceConnectorType.Bedrock, config: {}, capabilities: {}, + isPreconfigured: false, isInferenceEndpoint: false, }, ]; diff --git a/x-pack/solutions/observability/plugins/observability_ai_assistant_app/public/rule_connector/ai_assistant_params.tsx b/x-pack/solutions/observability/plugins/observability_ai_assistant_app/public/rule_connector/ai_assistant_params.tsx index 678265105158f..28fa582427f6b 100644 --- a/x-pack/solutions/observability/plugins/observability_ai_assistant_app/public/rule_connector/ai_assistant_params.tsx +++ b/x-pack/solutions/observability/plugins/observability_ai_assistant_app/public/rule_connector/ai_assistant_params.tsx @@ -108,7 +108,7 @@ const ObsAIAssistantParamsFields: React.FunctionComponent< // @ts-expect-error upgrade typescript v5.1.6 isInvalid={errors.connector?.length > 0} options={connectors?.map((connector) => { - return { value: connector.id, text: connector.name }; + return { value: connector.connectorId, text: connector.name }; })} onChange={(event) => { selectConnector(event.target.value); diff --git a/x-pack/solutions/observability/test/api_integration_deployment_agnostic/apis/ai_assistant/connectors/connectors.spec.ts b/x-pack/solutions/observability/test/api_integration_deployment_agnostic/apis/ai_assistant/connectors/connectors.spec.ts index e100fbc5c4adb..a8367c1d28b6a 100644 --- a/x-pack/solutions/observability/test/api_integration_deployment_agnostic/apis/ai_assistant/connectors/connectors.spec.ts +++ b/x-pack/solutions/observability/test/api_integration_deployment_agnostic/apis/ai_assistant/connectors/connectors.spec.ts @@ -28,14 +28,12 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon expect(status).to.be(200); }); - it('returns an empty list of connectors', async () => { + it('returns only preconfigured connectors', async () => { const res = await observabilityAIAssistantAPIClient.editor({ endpoint: 'GET /internal/observability_ai_assistant/connectors', }); - const connectorsExcludingPreconfiguredInference = res.body.filter( - (c) => c.actionTypeId !== '.inference' - ); + const connectorsExcludingPreconfiguredInference = res.body.filter((c) => !c.isPreconfigured); expect(connectorsExcludingPreconfiguredInference.length).to.be(0); }); @@ -48,9 +46,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon endpoint: 'GET /internal/observability_ai_assistant/connectors', }); - const connectorsExcludingPreconfiguredInference = res.body.filter( - (c) => c.actionTypeId !== '.inference' - ); + const connectorsExcludingPreconfiguredInference = res.body.filter((c) => !c.isPreconfigured); expect(connectorsExcludingPreconfiguredInference.length).to.be(1); await observabilityAIAssistantAPIClient.deleteActionConnector({ actionId: connectorId }); diff --git a/x-pack/solutions/observability/test/observability_ai_assistant_functional/tests/feature_controls/assistant_security.spec.ts b/x-pack/solutions/observability/test/observability_ai_assistant_functional/tests/feature_controls/assistant_security.spec.ts index d0cf9a857523c..30762a71b44b3 100644 --- a/x-pack/solutions/observability/test/observability_ai_assistant_functional/tests/feature_controls/assistant_security.spec.ts +++ b/x-pack/solutions/observability/test/observability_ai_assistant_functional/tests/feature_controls/assistant_security.spec.ts @@ -99,13 +99,13 @@ export default function ({ getPageObjects, getService }: FtrProviderContext) { observabilityAIAssistant: ['all'], }); }); - it('loads conversations UI with connector error message', async () => { + it('loads conversations UI with setup connector message', async () => { await PageObjects.common.navigateToUrl('obsAIAssistant', '', { ensureCurrentUrl: false, shouldLoginIfPrompted: false, shouldUseHashForSubUrl: false, }); - await testSubjects.existOrFail(ui.pages.conversations.connectorsErrorMsg); + await testSubjects.existOrFail(ui.pages.conversations.setupGenAiConnectorsButtonSelector); }); after(async () => { await deleteAndLogoutUser(getService, getPageObjects); diff --git a/x-pack/solutions/security/packages/security-ai-prompts/moon.yml b/x-pack/solutions/security/packages/security-ai-prompts/moon.yml index ea50c2fb4907e..276f618a56288 100644 --- a/x-pack/solutions/security/packages/security-ai-prompts/moon.yml +++ b/x-pack/solutions/security/packages/security-ai-prompts/moon.yml @@ -17,11 +17,9 @@ project: owner: '@elastic/security-generative-ai' sourceRoot: x-pack/solutions/security/packages/security-ai-prompts dependsOn: - - '@kbn/core' - - '@kbn/actions-plugin' - '@kbn/core-saved-objects-api-server' - - '@kbn/utility-types' - '@kbn/inference-common' + - '@kbn/core' tags: - shared-server - package diff --git a/x-pack/solutions/security/packages/security-ai-prompts/src/get_prompt.test.ts b/x-pack/solutions/security/packages/security-ai-prompts/src/get_prompt.test.ts index 46ab1caba833c..b504bc351a328 100644 --- a/x-pack/solutions/security/packages/security-ai-prompts/src/get_prompt.test.ts +++ b/x-pack/solutions/security/packages/security-ai-prompts/src/get_prompt.test.ts @@ -7,20 +7,98 @@ import { getPrompt, getPromptsByGroupId } from './get_prompt'; import type { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server'; -import type { ActionsClient } from '@kbn/actions-plugin/server'; import { localPrompts, promptDictionary, promptGroupId } from './mock_prompts'; -import { createMockConnector } from '@kbn/actions-plugin/server/application/connector/mocks'; jest.mock('@kbn/core-saved-objects-api-server'); -jest.mock('@kbn/actions-plugin/server'); -const defaultConnector = createMockConnector({ - id: 'mock', - name: 'Mock', - actionTypeId: '.inference', -}); + +const bedrockConnector = { + type: '.bedrock' as const, + name: 'Bedrock', + connectorId: 'connector-123', + config: { defaultModel: 'us.anthropic.claude-3-5-sonnet-20240620-v1:0' }, + capabilities: {}, + isInferenceEndpoint: false, + isPreconfigured: false, +}; + +const openaiConnector = { + type: '.gen-ai' as const, + name: 'OpenAI', + connectorId: 'connector-123', + config: { defaultModel: 'gpt-4o' }, + capabilities: {}, + isInferenceEndpoint: false, + isPreconfigured: false, +}; + +const inferenceBedrockConnector = { + type: '.inference' as const, + name: 'Inference Bedrock', + connectorId: 'connector-123', + config: { + provider: 'bedrock', + providerConfig: { model_id: 'us.anthropic.claude-3-5-sonnet-20240620-v1:0' }, + }, + capabilities: {}, + isInferenceEndpoint: false, + isPreconfigured: false, +}; + +const inferenceElasticConnector = { + type: '.inference' as const, + name: 'Inference Elastic', + connectorId: 'connector-123', + config: { provider: 'elastic', providerConfig: { model_id: 'rainbow-sprinkles' } }, + capabilities: {}, + isInferenceEndpoint: false, + isPreconfigured: false, +}; + +const inferenceElasticUnknownConnector = { + type: '.inference' as const, + name: 'Inference Elastic Unknown', + connectorId: 'connector-123', + config: { provider: 'elastic', providerConfig: { model_id: 'unknown-model' } }, + capabilities: {}, + isInferenceEndpoint: false, + isPreconfigured: false, +}; + +const inferenceEndpointAmazonBedrock = { + type: '.inference' as const, + name: 'my-endpoint', + connectorId: 'my-bedrock-endpoint', + config: { + service: 'amazonbedrock', + providerConfig: { model_id: 'us.anthropic.claude-3-5-sonnet-20240620-v1:0' }, + }, + capabilities: {}, + isInferenceEndpoint: true, + isPreconfigured: false, +}; + +const geminiConnector = { + type: '.gemini' as const, + name: 'Gemini', + connectorId: 'connector-123', + config: { defaultModel: 'gemini-1.5-pro-002' }, + capabilities: {}, + isInferenceEndpoint: false, + isPreconfigured: false, +}; + +const inferenceElasticConnectorRainbow = { + type: '.inference' as const, + name: 'Inference Elastic Rainbow', + connectorId: 'connector-123', + config: { provider: 'elastic', providerConfig: { model_id: 'rainbow-sprinkles' } }, + capabilities: {}, + isInferenceEndpoint: false, + isPreconfigured: false, +}; + describe('get_prompt', () => { let savedObjectsClient: jest.Mocked; - let actionsClient: jest.Mocked; beforeEach(() => { jest.clearAllMocks(); @@ -167,18 +245,11 @@ describe('get_prompt', () => { ], }), } as unknown as jest.Mocked; - - actionsClient = { - get: jest.fn().mockResolvedValue({ - config: { - provider: 'openai', - providerConfig: { model_id: 'gpt-4o' }, - }, - }), - } as unknown as jest.Mocked; }); + describe('getPrompt', () => { - it('returns the prompt matching provider and model', async () => { + it('returns the prompt matching provider and model (no connector lookup needed)', async () => { + const getInferenceConnectorById = jest.fn(); const result = await getPrompt({ savedObjectsClient, localPrompts, @@ -186,11 +257,10 @@ describe('get_prompt', () => { promptGroupId: promptGroupId.aiAssistant, provider: 'openai', model: 'gpt-4o', - actionsClient, + getInferenceConnectorById, connectorId: 'connector-123', }); - expect(actionsClient.get).not.toHaveBeenCalled(); - + expect(getInferenceConnectorById).not.toHaveBeenCalled(); expect(result).toBe('Hello world this is a system prompt'); }); @@ -202,46 +272,38 @@ describe('get_prompt', () => { promptGroupId: promptGroupId.aiAssistant, provider: 'openai', model: 'gpt-4o-mini', - actionsClient, - connectorId: 'connector-123', }); - expect(actionsClient.get).not.toHaveBeenCalled(); - expect(result).toBe('Hello world this is a system prompt no model'); }); - it('returns the prompt matching provider when model is not provided', async () => { + it('calls getInferenceConnectorById when only provider is given', async () => { + const getInferenceConnectorById = jest.fn().mockResolvedValue(openaiConnector); const result = await getPrompt({ savedObjectsClient, localPrompts, promptId: promptDictionary.systemPrompt, promptGroupId: promptGroupId.aiAssistant, provider: 'openai', - actionsClient, + getInferenceConnectorById, connectorId: 'connector-123', }); - expect(actionsClient.get).toHaveBeenCalled(); - - expect(result).toBe('Hello world this is a system prompt no model'); + expect(getInferenceConnectorById).toHaveBeenCalledWith('connector-123'); + expect(result).toBe('Hello world this is a system prompt'); }); - it('returns the default prompt when there is no match on provider', async () => { + it('returns the default prompt when provider has no match', async () => { const result = await getPrompt({ savedObjectsClient, localPrompts, promptId: promptDictionary.systemPrompt, promptGroupId: promptGroupId.aiAssistant, provider: 'badone', - actionsClient, - connectorId: 'connector-123', }); - expect(result).toBe('Hello world this is a system prompt no model, no provider'); }); - it('defaults provider to bedrock if provider is "inference"', async () => { - actionsClient.get.mockResolvedValue(defaultConnector); - + it('resolves the real provider when provider is "inference" via getInferenceConnectorById', async () => { + const getInferenceConnectorById = jest.fn().mockResolvedValue(inferenceBedrockConnector); const result = await getPrompt({ savedObjectsClient, localPrompts, @@ -249,76 +311,57 @@ describe('get_prompt', () => { promptGroupId: promptGroupId.aiAssistant, provider: 'inference', model: 'gpt-4o', - actionsClient, + getInferenceConnectorById, connectorId: 'connector-123', }); - - expect(result).toBe('Hello world this is a system prompt for bedrock'); + expect(getInferenceConnectorById).toHaveBeenCalledWith('connector-123'); + expect(result).toBe('Hello world this is a system prompt for bedrock claude-3-5-sonnet'); }); - it('returns the expected prompt from when provider is "elastic" and model matches in elasticModelDictionary', async () => { - actionsClient.get.mockResolvedValue({ - ...defaultConnector, - config: { - provider: 'elastic', - providerConfig: { model_id: 'rainbow-sprinkles' }, - }, - }); - + it('returns the expected prompt when provider is "elastic" and model matches in elasticModelDictionary', async () => { + const getInferenceConnectorById = jest.fn().mockResolvedValue(inferenceElasticConnector); const result = await getPrompt({ savedObjectsClient, localPrompts, promptId: promptDictionary.systemPrompt, promptGroupId: promptGroupId.aiAssistant, provider: 'inference', - actionsClient, + getInferenceConnectorById, connectorId: 'connector-123', }); - expect(result).toBe('Hello world this is a system prompt for bedrock claude-3-7-sonnet'); }); it('returns the bedrock prompt when provider is "elastic" but model does not match elasticModelDictionary', async () => { - actionsClient.get.mockResolvedValue({ - ...defaultConnector, - config: { - provider: 'elastic', - providerConfig: { model_id: 'unknown-model' }, - }, - }); - + const getInferenceConnectorById = jest + .fn() + .mockResolvedValue(inferenceElasticUnknownConnector); const result = await getPrompt({ savedObjectsClient, localPrompts, promptId: promptDictionary.systemPrompt, promptGroupId: promptGroupId.aiAssistant, provider: 'inference', - actionsClient, + getInferenceConnectorById, connectorId: 'connector-123', }); - - expect(result).toBe('Hello world this is a system prompt for bedrock'); + expect(result).toBe('Hello world this is a system prompt no model, no provider'); }); - it('returns the model prompt when no prompts are found and model is provided', async () => { - savedObjectsClient.find.mockResolvedValue({ - page: 1, - per_page: 20, - total: 0, - saved_objects: [], - }); - + it('returns the provider-specific prompt when connector has no model', async () => { + const getInferenceConnectorById = jest + .fn() + .mockResolvedValue({ ...bedrockConnector, config: {} }); const result = await getPrompt({ savedObjectsClient, localPrompts, promptId: promptDictionary.systemPrompt, promptGroupId: promptGroupId.aiAssistant, - actionsClient, provider: 'bedrock', + getInferenceConnectorById, connectorId: 'connector-123', }); - - expect(result).toBe('provider:bedrock default system prompt'); + expect(result).toBe('Hello world this is a system prompt for bedrock'); }); it('returns the default prompt when no prompts are found', async () => { @@ -328,16 +371,13 @@ describe('get_prompt', () => { total: 0, saved_objects: [], }); - const result = await getPrompt({ savedObjectsClient, localPrompts, promptId: promptDictionary.systemPrompt, promptGroupId: promptGroupId.aiAssistant, - actionsClient, connectorId: 'connector-123', }); - expect(result).toBe('default system prompt'); }); @@ -348,14 +388,12 @@ describe('get_prompt', () => { total: 0, saved_objects: [], }); - await expect( getPrompt({ savedObjectsClient, localPrompts, promptId: 'nonexistent-prompt', promptGroupId: 'nonexistent-group', - actionsClient, connectorId: 'connector-123', }) ).rejects.toThrow( @@ -363,76 +401,82 @@ describe('get_prompt', () => { ); }); - it('handles invalid connector configuration gracefully when provider is "inference"', async () => { - actionsClient.get.mockResolvedValue({ - ...defaultConnector, - config: {}, - }); + it('handles empty connector config gracefully when provider is "inference"', async () => { + const getInferenceConnectorById = jest + .fn() + .mockResolvedValue({ ...inferenceBedrockConnector, config: {} }); const result = await getPrompt({ savedObjectsClient, localPrompts, promptId: promptDictionary.systemPrompt, promptGroupId: promptGroupId.aiAssistant, provider: 'inference', - actionsClient, + getInferenceConnectorById, connectorId: 'connector-123', }); - - expect(result).toBe('Hello world this is a system prompt for bedrock'); + expect(result).toBe('Hello world this is a system prompt no model, no provider'); }); - it('retrieves the connector when no model or provider is provided', async () => { - actionsClient.get.mockResolvedValue({ - ...defaultConnector, - actionTypeId: '.bedrock', - config: { - defaultModel: 'us.anthropic.claude-3-5-sonnet-20240620-v1:0', - }, - }); + it('resolves provider and model from getInferenceConnectorById when none are provided', async () => { + const getInferenceConnectorById = jest.fn().mockResolvedValue(bedrockConnector); const result = await getPrompt({ savedObjectsClient, localPrompts, promptId: promptDictionary.systemPrompt, promptGroupId: promptGroupId.aiAssistant, - actionsClient, + getInferenceConnectorById, connectorId: 'connector-123', }); - expect(actionsClient.get).toHaveBeenCalled(); - + expect(getInferenceConnectorById).toHaveBeenCalled(); expect(result).toBe('Hello world this is a system prompt for bedrock claude-3-5-sonnet'); }); - it('retrieves the connector when no model is provided', async () => { - actionsClient.get.mockResolvedValue({ - ...defaultConnector, - actionTypeId: '.bedrock', - config: { - defaultModel: 'us.anthropic.claude-3-5-sonnet-20240620-v1:0', - }, - }); + it('finds the default prompt if no provider/model are indicated and no connector details are provided', async () => { const result = await getPrompt({ savedObjectsClient, localPrompts, promptId: promptDictionary.systemPrompt, promptGroupId: promptGroupId.aiAssistant, - provider: 'bedrock', - actionsClient, - connectorId: 'connector-123', }); - expect(actionsClient.get).toHaveBeenCalled(); + expect(result).toEqual('Hello world this is a system prompt no model, no provider'); + }); + it('uses getInferenceConnectorById for native ES inference endpoints', async () => { + const getInferenceConnectorById = jest.fn().mockResolvedValue(inferenceEndpointAmazonBedrock); + const result = await getPrompt({ + savedObjectsClient, + localPrompts, + promptId: promptDictionary.systemPrompt, + promptGroupId: promptGroupId.aiAssistant, + getInferenceConnectorById, + connectorId: 'my-bedrock-endpoint', + }); + expect(getInferenceConnectorById).toHaveBeenCalledWith('my-bedrock-endpoint'); expect(result).toBe('Hello world this is a system prompt for bedrock claude-3-5-sonnet'); }); - it('finds the default prompt if no provider/model are indicated and no connector details are provided', async () => { + it('falls back to default prompts when getInferenceConnectorById fails', async () => { + const getInferenceConnectorById = jest.fn().mockRejectedValue(new Error('Not found')); const result = await getPrompt({ savedObjectsClient, localPrompts, promptId: promptDictionary.systemPrompt, promptGroupId: promptGroupId.aiAssistant, + getInferenceConnectorById, + connectorId: 'unknown-endpoint', }); + expect(result).toBe('Hello world this is a system prompt no model, no provider'); + }); - expect(result).toEqual('Hello world this is a system prompt no model, no provider'); + it('falls back to default prompts when no getInferenceConnectorById is provided', async () => { + const result = await getPrompt({ + savedObjectsClient, + localPrompts, + promptId: promptDictionary.systemPrompt, + promptGroupId: promptGroupId.aiAssistant, + connectorId: 'unknown-endpoint', + }); + expect(result).toBe('Hello world this is a system prompt no model, no provider'); }); }); @@ -445,7 +489,6 @@ describe('get_prompt', () => { promptGroupId: promptGroupId.aiAssistant, provider: 'openai', model: 'gpt-4o', - actionsClient, connectorId: 'connector-123', }); expect(savedObjectsClient.find).toHaveBeenCalledWith({ @@ -453,7 +496,6 @@ describe('get_prompt', () => { searchFields: ['promptGroupId'], search: promptGroupId.aiAssistant, }); - expect(result).toEqual([ { promptId: promptDictionary.systemPrompt, @@ -469,10 +511,8 @@ describe('get_prompt', () => { promptIds: [promptDictionary.systemPrompt, promptDictionary.userPrompt], promptGroupId: promptGroupId.aiAssistant, provider: 'gemini', - actionsClient, connectorId: 'connector-123', }); - expect(result).toEqual([ { promptId: promptDictionary.systemPrompt, @@ -485,24 +525,16 @@ describe('get_prompt', () => { ]); }); - it('returns prompts matching the provided promptIds when connector is given', async () => { + it('returns prompts using getInferenceConnectorById for gemini connector', async () => { + const getInferenceConnectorById = jest.fn().mockResolvedValue(geminiConnector); const result = await getPromptsByGroupId({ savedObjectsClient, localPrompts, promptIds: [promptDictionary.systemPrompt, promptDictionary.userPrompt], promptGroupId: promptGroupId.aiAssistant, - connector: createMockConnector({ - actionTypeId: '.gemini', - config: { - defaultModel: 'gemini-1.5-pro-002', - }, - id: 'connector-123', - name: 'Gemini', - }), - actionsClient, + getInferenceConnectorById, connectorId: 'connector-123', }); - expect(result).toEqual([ { promptId: promptDictionary.systemPrompt, @@ -514,32 +546,27 @@ describe('get_prompt', () => { }, ]); }); - it('returns prompts matching the provided promptIds when inference connector is given', async () => { + + it('returns prompts using getInferenceConnectorById for inference connector with elastic provider', async () => { + const getInferenceConnectorById = jest + .fn() + .mockResolvedValue(inferenceElasticConnectorRainbow); const result = await getPromptsByGroupId({ savedObjectsClient, localPrompts, promptIds: [promptDictionary.systemPrompt], promptGroupId: promptGroupId.aiAssistant, - connector: createMockConnector({ - actionTypeId: '.inference', - config: { - provider: 'elastic', - providerConfig: { model_id: 'rainbow-sprinkles' }, - }, - id: 'connector-123', - name: 'Inference', - }), - actionsClient, + getInferenceConnectorById, connectorId: 'connector-123', }); - expect(result).toEqual([ { promptId: promptDictionary.systemPrompt, - prompt: 'Hello world this is a system prompt for bedrock', + prompt: 'Hello world this is a system prompt for bedrock claude-3-7-sonnet', }, ]); }); + it('throws an error when a prompt is missing', async () => { savedObjectsClient.find.mockResolvedValue({ page: 1, @@ -547,14 +574,12 @@ describe('get_prompt', () => { total: 0, saved_objects: [], }); - await expect( getPromptsByGroupId({ savedObjectsClient, localPrompts, promptIds: [promptDictionary.systemPrompt, 'fake-id'], promptGroupId: promptGroupId.aiAssistant, - actionsClient, connectorId: 'connector-123', }) ).rejects.toThrow('Prompt not found for promptId: fake-id and promptGroupId: aiAssistant'); @@ -567,7 +592,6 @@ describe('get_prompt', () => { promptIds: [promptDictionary.systemPrompt], promptGroupId: promptGroupId.aiAssistant, }); - expect(result).toEqual([ { promptId: promptDictionary.systemPrompt, diff --git a/x-pack/solutions/security/packages/security-ai-prompts/src/get_prompt.ts b/x-pack/solutions/security/packages/security-ai-prompts/src/get_prompt.ts index 7af1b44828108..7f9e426966a8f 100644 --- a/x-pack/solutions/security/packages/security-ai-prompts/src/get_prompt.ts +++ b/x-pack/solutions/security/packages/security-ai-prompts/src/get_prompt.ts @@ -5,10 +5,12 @@ * 2.0. */ -import type { PublicMethodsOf } from '@kbn/utility-types'; -import type { ActionsClient } from '@kbn/actions-plugin/server'; -import type { Connector } from '@kbn/actions-plugin/server/application/connector/types'; -import { elasticModelDictionary } from '@kbn/inference-common'; +import { + elasticModelDictionary, + InferenceConnectorType, + InferenceEndpointProvider, +} from '@kbn/inference-common'; +import type { InferenceConnector } from '@kbn/inference-common'; import type { PromptArray, Prompt, GetPromptArgs, GetPromptsByGroupIdArgs } from './types'; import { getProviderFromActionTypeId } from './utils'; import { promptSavedObjectType } from './saved_object_mappings'; @@ -27,9 +29,8 @@ import { promptSavedObjectType } from './saved_object_mappings'; * @param savedObjectsClient - saved objects client */ export const getPromptsByGroupId = async ({ - actionsClient, - connector, connectorId, + getInferenceConnectorById, localPrompts, model: providedModel, promptGroupId, @@ -41,8 +42,7 @@ export const getPromptsByGroupId = async ({ providedProvider, providedModel, connectorId, - actionsClient, - providedConnector: connector, + getInferenceConnectorById, }); const prompts = await savedObjectsClient.find({ @@ -88,9 +88,8 @@ export const getPromptsByGroupId = async ({ * @param savedObjectsClient - saved objects client */ export const getPrompt = async ({ - actionsClient, - connector, connectorId, + getInferenceConnectorById, localPrompts, model: providedModel, promptGroupId, @@ -102,8 +101,7 @@ export const getPrompt = async ({ providedProvider, providedModel, connectorId, - actionsClient, - providedConnector: connector, + getInferenceConnectorById, }); const prompts = await savedObjectsClient.find({ @@ -133,40 +131,58 @@ export const resolveProviderAndModel = async ({ providedProvider, providedModel, connectorId, - actionsClient, - providedConnector, + getInferenceConnectorById, }: { providedProvider?: string; providedModel?: string; connectorId?: string; - actionsClient?: PublicMethodsOf; - providedConnector?: Connector; + getInferenceConnectorById?: (id: string) => Promise; }): Promise<{ provider?: string; model?: string }> => { - let model = providedModel; - let provider = providedProvider; - if (!provider || !model || provider === 'inference') { - let connector = providedConnector; - if (!connector && connectorId != null && actionsClient) { - connector = await actionsClient.get({ id: connectorId }); - } - if (!connector) { - return {}; - } - if (provider === 'inference' && connector.config) { - provider = connector.config.provider || provider; - model = connector.config.providerConfig?.model_id || model; - - if (provider === 'elastic' && model) { - provider = elasticModelDictionary[model]?.provider || 'inference'; - model = elasticModelDictionary[model]?.model; - } - } else if (connector.config) { - provider = provider || getProviderFromActionTypeId(connector.actionTypeId); - model = model || connector.config.defaultModel; + if (providedProvider && providedModel && providedProvider !== 'inference') { + return { provider: providedProvider, model: providedModel }; + } + + if (connectorId != null && getInferenceConnectorById) { + try { + return resolveFromInferenceConnector(await getInferenceConnectorById(connectorId)); + } catch { + return { provider: providedProvider, model: providedModel }; } } - return { provider: provider === 'inference' ? 'bedrock' : provider, model }; + return { provider: providedProvider, model: providedModel }; +}; + +// Maps ES inference endpoint service names to the provider names used in prompt lookup +const inferenceServiceToProvider: Partial> = { + [InferenceEndpointProvider.AmazonBedrock]: 'bedrock', + [InferenceEndpointProvider.GoogleVertexAI]: 'gemini', + [InferenceEndpointProvider.OpenAI]: 'openai', + [InferenceEndpointProvider.AzureOpenAI]: 'openai', + [InferenceEndpointProvider.Elastic]: 'elastic', +}; + +const resolveFromInferenceConnector = ({ + type, + config, +}: InferenceConnector): { provider?: string; model?: string } => { + if (type === InferenceConnectorType.Inference) { + // .inference connectors: Kibana stack connectors use `provider`, native endpoints use `service` + const rawProvider: string | undefined = + config.provider || (config.service ? inferenceServiceToProvider[config.service] : undefined); + const rawModel: string | undefined = config.providerConfig?.model_id; + if (rawProvider === 'elastic' && rawModel) { + return { + provider: elasticModelDictionary[rawModel]?.provider || 'inference', + model: elasticModelDictionary[rawModel]?.model, + }; + } + return { provider: rawProvider, model: rawModel }; + } + return { + provider: getProviderFromActionTypeId(type), + model: config.defaultModel, + }; }; const findPrompt = ({ diff --git a/x-pack/solutions/security/packages/security-ai-prompts/src/types.ts b/x-pack/solutions/security/packages/security-ai-prompts/src/types.ts index ef42acb69f3df..8ffb9daba1803 100644 --- a/x-pack/solutions/security/packages/security-ai-prompts/src/types.ts +++ b/x-pack/solutions/security/packages/security-ai-prompts/src/types.ts @@ -5,10 +5,10 @@ * 2.0. */ -import type { PublicMethodsOf } from '@kbn/utility-types'; -import type { ActionsClient } from '@kbn/actions-plugin/server'; -import type { Connector } from '@kbn/actions-plugin/server/application/connector/types'; import type { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server'; +import type { InferenceConnector } from '@kbn/inference-common'; + +export type { InferenceConnector }; export interface Prompt { promptId: string; @@ -24,9 +24,8 @@ export interface Prompt { export type PromptArray = Array<{ promptId: string; prompt: string }>; export interface GetPromptArgs { - actionsClient?: PublicMethodsOf; - connector?: Connector; connectorId?: string; + getInferenceConnectorById?: (id: string) => Promise; localPrompts: Prompt[]; model?: string; promptId: string; diff --git a/x-pack/solutions/security/packages/security-ai-prompts/tsconfig.json b/x-pack/solutions/security/packages/security-ai-prompts/tsconfig.json index 8921b9ce45a58..c582e2e9708fd 100644 --- a/x-pack/solutions/security/packages/security-ai-prompts/tsconfig.json +++ b/x-pack/solutions/security/packages/security-ai-prompts/tsconfig.json @@ -9,11 +9,9 @@ }, "include": ["**/*.ts"], "kbn_references": [ - "@kbn/core", - "@kbn/actions-plugin", "@kbn/core-saved-objects-api-server", - "@kbn/utility-types", - "@kbn/inference-common" + "@kbn/inference-common", + "@kbn/core" ], "exclude": [ "target/**/*" diff --git a/x-pack/solutions/security/plugins/elastic_assistant/scripts/draw_graph_script.ts b/x-pack/solutions/security/plugins/elastic_assistant/scripts/draw_graph_script.ts index 99a49fd3e28be..ee7ea00dcc6fe 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/scripts/draw_graph_script.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/scripts/draw_graph_script.ts @@ -17,8 +17,6 @@ import type { Logger } from '@kbn/logging'; import { FakeChatModel, FakeLLM } from '@langchain/core/utils/testing'; import type { ContentReferencesStore } from '@kbn/elastic-assistant-common'; import { DefendInsightType } from '@kbn/elastic-assistant-common'; -import type { PublicMethodsOf } from '@kbn/utility-types'; -import type { ActionsClient } from '@kbn/actions-plugin/server'; import { MemorySaver } from '@langchain/langgraph-checkpoint'; import { ATTACK_DISCOVERY_GENERATION_DETAILS_MARKDOWN, @@ -62,7 +60,9 @@ const createLlmInstance = () => { async function getAssistantGraph(logger: Logger): Promise { const graph = await getDefaultAssistantGraph({ - actionsClient: {} as unknown as PublicMethodsOf, + getInferenceConnectorById: async () => { + throw new Error('not implemented'); + }, logger, createLlmInstance, tools: [], diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/__mocks__/mock_experiment_connector.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/__mocks__/mock_experiment_connector.ts index a61d12357135b..9372dbc95993a 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/__mocks__/mock_experiment_connector.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/__mocks__/mock_experiment_connector.ts @@ -5,18 +5,20 @@ * 2.0. */ -import { createMockConnector } from '@kbn/actions-plugin/server/application/connector/mocks'; -import type { Connector } from '@kbn/actions-plugin/server/application/connector/types'; +import type { InferenceConnector } from '@kbn/inference-common'; +import { InferenceConnectorType } from '@kbn/inference-common'; -export const mockExperimentConnector: Connector = createMockConnector({ +export const mockExperimentConnector: InferenceConnector = { + type: InferenceConnectorType.Gemini, name: 'Gemini 1.5 Pro 002', - actionTypeId: '.gemini', + connectorId: 'gemini-1-5-pro-002', config: { apiUrl: 'https://example.com', defaultModel: 'gemini-1.5-pro-002', gcpRegion: 'test-region', gcpProjectID: 'test-project-id', }, - id: 'gemini-1-5-pro-002', + capabilities: {}, + isInferenceEndpoint: false, isPreconfigured: true, -}); +}; diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/helpers/get_evaluator_llm/index.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/helpers/get_evaluator_llm/index.test.ts index a119610497c87..8b0a2b9c7db91 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/helpers/get_evaluator_llm/index.test.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/helpers/get_evaluator_llm/index.test.ts @@ -6,12 +6,12 @@ */ import type { ActionsClient } from '@kbn/actions-plugin/server'; -import type { Connector } from '@kbn/actions-plugin/server/application/connector/types'; import { ActionsClientLlm } from '@kbn/langchain/server'; import { loggerMock } from '@kbn/logging-mocks'; +import type { InferenceConnector } from '@kbn/inference-common'; +import { InferenceConnectorType } from '@kbn/inference-common'; import { getEvaluatorLlm } from '.'; -import { createMockConnector } from '@kbn/actions-plugin/server/application/connector/mocks'; jest.mock('@kbn/langchain/server', () => ({ ...jest.requireActual('@kbn/langchain/server'), @@ -22,27 +22,30 @@ jest.mock('@kbn/langchain/server', () => ({ const connectorTimeout = 1000; const evaluatorConnectorId = 'evaluator-connector-id'; -const evaluatorConnector = { - id: 'evaluatorConnectorId', - actionTypeId: '.gen-ai', +const evaluatorConnector: InferenceConnector = { + connectorId: 'evaluatorConnectorId', + type: InferenceConnectorType.OpenAI, name: 'GPT-4o', + config: {}, + capabilities: {}, + isInferenceEndpoint: false, isPreconfigured: true, - isSystemAction: false, - isDeprecated: false, -} as Connector; +}; -const experimentConnector: Connector = createMockConnector({ +const experimentConnector: InferenceConnector = { + connectorId: 'gemini-1-5-pro-002', + type: InferenceConnectorType.Gemini, name: 'Gemini 1.5 Pro 002', - actionTypeId: '.gemini', config: { apiUrl: 'https://example.com', defaultModel: 'gemini-1.5-pro-002', gcpRegion: 'test-region', gcpProjectID: 'test-project-id', }, - id: 'gemini-1-5-pro-002', + capabilities: {}, + isInferenceEndpoint: false, isPreconfigured: true, -}); +}; const logger = loggerMock.create(); @@ -50,78 +53,72 @@ describe('getEvaluatorLlm', () => { beforeEach(() => jest.clearAllMocks()); describe('getting the evaluation connector', () => { - it("calls actionsClient.get with the evaluator connector ID when it's provided", async () => { - const actionsClient = { - get: jest.fn(), - } as unknown as ActionsClient; + it("calls getInferenceConnectorById with the evaluator connector ID when it's provided", async () => { + const actionsClient = {} as unknown as ActionsClient; + const getInferenceConnectorById = jest.fn().mockResolvedValue(evaluatorConnector); await getEvaluatorLlm({ actionsClient, connectorTimeout, evaluatorConnectorId, experimentConnector, + getInferenceConnectorById, langSmithApiKey: undefined, logger, }); - expect(actionsClient.get).toHaveBeenCalledWith({ - id: evaluatorConnectorId, - throwIfSystemAction: false, - }); + expect(getInferenceConnectorById).toHaveBeenCalledWith(evaluatorConnectorId); }); - it("calls actionsClient.get with the experiment connector ID when the evaluator connector ID isn't provided", async () => { - const actionsClient = { - get: jest.fn().mockResolvedValue(null), - } as unknown as ActionsClient; + it("calls getInferenceConnectorById with the experiment connector ID when the evaluator connector ID isn't provided", async () => { + const actionsClient = {} as unknown as ActionsClient; + const getInferenceConnectorById = jest.fn().mockResolvedValue(experimentConnector); await getEvaluatorLlm({ actionsClient, connectorTimeout, evaluatorConnectorId: undefined, experimentConnector, + getInferenceConnectorById, langSmithApiKey: undefined, logger, }); - expect(actionsClient.get).toHaveBeenCalledWith({ - id: experimentConnector.id, - throwIfSystemAction: false, - }); + expect(getInferenceConnectorById).toHaveBeenCalledWith(experimentConnector.connectorId); }); - it('falls back to the experiment connector when the evaluator connector is not found', async () => { - const actionsClient = { - get: jest.fn().mockResolvedValue(null), - } as unknown as ActionsClient; + it('falls back to the experiment connector when getInferenceConnectorById throws', async () => { + const actionsClient = {} as unknown as ActionsClient; + const getInferenceConnectorById = jest.fn().mockRejectedValue(new Error('Not found')); await getEvaluatorLlm({ actionsClient, connectorTimeout, evaluatorConnectorId, experimentConnector, + getInferenceConnectorById, langSmithApiKey: undefined, logger, }); expect(ActionsClientLlm).toHaveBeenCalledWith( expect.objectContaining({ - connectorId: experimentConnector.id, + connectorId: experimentConnector.connectorId, }) ); }); }); it('logs the expected connector names and types', async () => { - const actionsClient = { - get: jest.fn().mockResolvedValue(evaluatorConnector), - } as unknown as ActionsClient; + const actionsClient = {} as unknown as ActionsClient; + const getInferenceConnectorById = jest.fn().mockResolvedValue(evaluatorConnector); await getEvaluatorLlm({ actionsClient, connectorTimeout, evaluatorConnectorId, experimentConnector, + getInferenceConnectorById, langSmithApiKey: undefined, logger, }); @@ -132,15 +129,15 @@ describe('getEvaluatorLlm', () => { }); it('creates a new ActionsClientLlm instance with the expected traceOptions', async () => { - const actionsClient = { - get: jest.fn().mockResolvedValue(evaluatorConnector), - } as unknown as ActionsClient; + const actionsClient = {} as unknown as ActionsClient; + const getInferenceConnectorById = jest.fn().mockResolvedValue(evaluatorConnector); await getEvaluatorLlm({ actionsClient, connectorTimeout, evaluatorConnectorId, experimentConnector, + getInferenceConnectorById, langSmithApiKey: 'test-api-key', logger, }); diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/helpers/get_evaluator_llm/index.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/helpers/get_evaluator_llm/index.ts index 8e8293a8cf82d..2ef6fe91af537 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/helpers/get_evaluator_llm/index.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/helpers/get_evaluator_llm/index.ts @@ -6,11 +6,12 @@ */ import type { ActionsClient } from '@kbn/actions-plugin/server'; -import type { Connector } from '@kbn/actions-plugin/server/application/connector/types'; import type { Logger } from '@kbn/core/server'; import { getLangSmithTracer } from '@kbn/langchain/server/tracers/langsmith'; import { ActionsClientLlm } from '@kbn/langchain/server'; import type { PublicMethodsOf } from '@kbn/utility-types'; +import type { InferenceConnector } from '@kbn/inference-common'; +import { getConnectorDefaultModel } from '@kbn/inference-common'; import { getLlmType } from '../../../../../routes/utils'; @@ -19,24 +20,29 @@ export const getEvaluatorLlm = async ({ connectorTimeout, evaluatorConnectorId, experimentConnector, + getInferenceConnectorById, langSmithApiKey, logger, }: { actionsClient: PublicMethodsOf; connectorTimeout: number; evaluatorConnectorId: string | undefined; - experimentConnector: Connector; + experimentConnector: InferenceConnector; + getInferenceConnectorById: (id: string) => Promise; langSmithApiKey: string | undefined; logger: Logger; }): Promise => { - const evaluatorConnector = - (await actionsClient.get({ - id: evaluatorConnectorId ?? experimentConnector.id, // fallback to the experiment connector if the evaluator connector is not found: - throwIfSystemAction: false, - })) ?? experimentConnector; + let evaluatorConnector: InferenceConnector; + try { + evaluatorConnector = await getInferenceConnectorById( + evaluatorConnectorId ?? experimentConnector.connectorId + ); + } catch { + evaluatorConnector = experimentConnector; + } - const evaluatorLlmType = getLlmType(evaluatorConnector.actionTypeId); - const experimentLlmType = getLlmType(experimentConnector.actionTypeId); + const evaluatorLlmType = getLlmType(evaluatorConnector.type); + const experimentLlmType = getLlmType(experimentConnector.type); logger.info( `The ${evaluatorConnector.name} (${evaluatorLlmType}) connector will judge output from the ${experimentConnector.name} (${experimentLlmType}) connector` @@ -55,9 +61,10 @@ export const getEvaluatorLlm = async ({ return new ActionsClientLlm({ actionsClient, - connectorId: evaluatorConnector.id, + connectorId: evaluatorConnector.connectorId, llmType: evaluatorLlmType, logger, + model: getConnectorDefaultModel(evaluatorConnector), temperature: 0, // zero temperature for evaluation timeout: connectorTimeout, traceOptions, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/index.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/index.test.ts index 6e2f550186eaa..b3a47a3ede1ca 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/index.test.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/index.test.ts @@ -5,12 +5,12 @@ * 2.0. */ import type { ActionsClient } from '@kbn/actions-plugin/server'; -import type { Connector } from '@kbn/actions-plugin/server/application/connector/types'; import { elasticsearchServiceMock } from '@kbn/core-elasticsearch-server-mocks'; import type { ActionsClientLlm } from '@kbn/langchain/server'; import { getLangSmithTracer } from '@kbn/langchain/server/tracers/langsmith'; import { loggerMock } from '@kbn/logging-mocks'; import type { LangChainTracer } from '@langchain/core/tracers/tracer_langchain'; +import type { InferenceConnector } from '@kbn/inference-common'; import { evaluateAttackDiscovery } from '.'; import type { DefaultAttackDiscoveryGraph } from '../graphs/default_attack_discovery_graph'; @@ -53,6 +53,7 @@ const connectorTimeout = 1000; const datasetName = 'test-dataset'; const evaluationId = 'test-evaluation-id'; const evaluatorConnectorId = 'test-evaluator-connector-id'; +const getInferenceConnectorById = jest.fn(); const langSmithApiKey = 'test-api-key'; const langSmithProject = 'test-lang-smith-project'; const logger = loggerMock.create(); @@ -80,7 +81,7 @@ const connectors = [ const projectName = 'test-lang-smith-project'; const graphs: Array<{ - connector: Connector; + connector: InferenceConnector; graph: DefaultAttackDiscoveryGraph; llmType: string | undefined; name: string; @@ -89,7 +90,7 @@ const graphs: Array<{ tracers: LangChainTracer[]; }; }> = connectors.map((connector) => { - const llmType = getLlmType(connector.actionTypeId); + const llmType = getLlmType(connector.type); const traceOptions = { projectName, @@ -137,6 +138,7 @@ describe('evaluateAttackDiscovery', () => { esClientInternalUser: mockEsClientInternalUser, evaluationId, evaluatorConnectorId, + getInferenceConnectorById, langSmithApiKey, langSmithProject, logger, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/index.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/index.ts index c38a39f750f10..95ac3ebe05bfa 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/index.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/index.ts @@ -6,7 +6,6 @@ */ import type { ActionsClient } from '@kbn/actions-plugin/server'; -import type { Connector } from '@kbn/actions-plugin/server/application/connector/types'; import type { ElasticsearchClient } from '@kbn/core-elasticsearch-server'; import type { Logger } from '@kbn/core/server'; import type { AnonymizationFieldResponse } from '@kbn/elastic-assistant-common/impl/schemas'; @@ -15,6 +14,8 @@ import { ActionsClientLlm } from '@kbn/langchain/server'; import { getLangSmithTracer } from '@kbn/langchain/server/tracers/langsmith'; import { asyncForEach } from '@kbn/std'; import type { PublicMethodsOf } from '@kbn/utility-types'; +import type { InferenceConnector } from '@kbn/inference-common'; +import { getConnectorDefaultModel } from '@kbn/inference-common'; import type { CombinedPrompts } from '../graphs/default_attack_discovery_graph/prompts'; import { DEFAULT_EVAL_ANONYMIZATION_FIELDS } from './constants'; @@ -24,7 +25,7 @@ import { getLlmType } from '../../../routes/utils'; import { runEvaluations } from './run_evaluations'; import { createOrUpdateEvaluationResults, EvaluationStatus } from '../../../routes/evaluate/utils'; -interface ConnectorWithPrompts extends Connector { +interface ConnectorWithPrompts extends InferenceConnector { prompts: CombinedPrompts; } export const evaluateAttackDiscovery = async ({ @@ -39,6 +40,7 @@ export const evaluateAttackDiscovery = async ({ esClientInternalUser, evaluationId, evaluatorConnectorId, + getInferenceConnectorById, langSmithApiKey, langSmithProject, logger, @@ -56,6 +58,7 @@ export const evaluateAttackDiscovery = async ({ esClientInternalUser: ElasticsearchClient; evaluationId: string; evaluatorConnectorId: string | undefined; + getInferenceConnectorById: (id: string) => Promise; langSmithApiKey: string | undefined; langSmithProject: string | undefined; logger: Logger; @@ -65,7 +68,7 @@ export const evaluateAttackDiscovery = async ({ await asyncForEach(attackDiscoveryGraphs, async ({ getDefaultAttackDiscoveryGraph }) => { // create a graph for every connector: const graphs: Array<{ - connector: Connector; + connector: InferenceConnector; graph: DefaultAttackDiscoveryGraph; llmType: string | undefined; name: string; @@ -74,7 +77,7 @@ export const evaluateAttackDiscovery = async ({ tracers: LangChainTracer[]; }; }> = connectors.map((connector) => { - const llmType = getLlmType(connector.actionTypeId); + const llmType = getLlmType(connector.type); const traceOptions = { projectName: langSmithProject, @@ -89,7 +92,7 @@ export const evaluateAttackDiscovery = async ({ const llm = new ActionsClientLlm({ actionsClient, - connectorId: connector.id, + connectorId: connector.connectorId, llmType, logger, temperature: 0, // zero temperature for attack discovery, because we want structured JSON output @@ -98,7 +101,7 @@ export const evaluateAttackDiscovery = async ({ telemetryMetadata: { pluginId: 'security_attack_discovery', }, - model: connector.config?.defaultModel, + model: getConnectorDefaultModel(connector), }); const graph = getDefaultAttackDiscoveryGraph({ @@ -126,6 +129,7 @@ export const evaluateAttackDiscovery = async ({ connectorTimeout, evaluatorConnectorId, datasetName, + getInferenceConnectorById, graphs, langSmithApiKey, logger, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/run_evaluations/index.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/run_evaluations/index.test.ts index 909c279218f1c..7ca142901d6c8 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/run_evaluations/index.test.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/run_evaluations/index.test.ts @@ -6,11 +6,11 @@ */ import type { ActionsClient } from '@kbn/actions-plugin/server'; -import type { Connector } from '@kbn/actions-plugin/server/application/connector/types'; import type { ActionsClientLlm } from '@kbn/langchain/server'; import { getLangSmithTracer } from '@kbn/langchain/server/tracers/langsmith'; import { loggerMock } from '@kbn/logging-mocks'; import type { LangChainTracer } from '@langchain/core/tracers/tracer_langchain'; +import type { InferenceConnector } from '@kbn/inference-common'; import { runEvaluations } from '.'; import { type DefaultAttackDiscoveryGraph } from '../../graphs/default_attack_discovery_graph'; @@ -51,6 +51,7 @@ const actionsClient = { const connectorTimeout = 1000; const datasetName = 'test-dataset'; const evaluatorConnectorId = 'test-evaluator-connector-id'; +const getInferenceConnectorById = jest.fn(); const langSmithApiKey = 'test-api-key'; const logger = loggerMock.create(); const connectors = [mockExperimentConnector]; @@ -58,7 +59,7 @@ const connectors = [mockExperimentConnector]; const projectName = 'test-lang-smith-project'; const graphs: Array<{ - connector: Connector; + connector: InferenceConnector; graph: DefaultAttackDiscoveryGraph; llmType: string | undefined; name: string; @@ -67,7 +68,7 @@ const graphs: Array<{ tracers: LangChainTracer[]; }; }> = connectors.map((connector) => { - const llmType = getLlmType(connector.actionTypeId); + const llmType = getLlmType(connector.type); const traceOptions = { projectName, @@ -102,6 +103,7 @@ describe('runEvaluations', () => { connectorTimeout, datasetName, evaluatorConnectorId, + getInferenceConnectorById, graphs, langSmithApiKey, logger, @@ -129,6 +131,7 @@ describe('runEvaluations', () => { connectorTimeout, datasetName, evaluatorConnectorId, + getInferenceConnectorById, graphs, langSmithApiKey, logger, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/run_evaluations/index.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/run_evaluations/index.ts index 1125d8c3cf29a..2039e198ddae1 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/run_evaluations/index.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/evaluation/run_evaluations/index.ts @@ -6,11 +6,11 @@ */ import type { ActionsClient } from '@kbn/actions-plugin/server'; -import type { Connector } from '@kbn/actions-plugin/server/application/connector/types'; import type { Logger } from '@kbn/core/server'; import type { LangChainTracer } from '@langchain/core/tracers/tracer_langchain'; import { asyncForEach } from '@kbn/std'; import type { PublicMethodsOf } from '@kbn/utility-types'; +import type { InferenceConnector } from '@kbn/inference-common'; import { Client } from 'langsmith'; import { evaluate } from 'langsmith/evaluation'; @@ -30,6 +30,7 @@ export const runEvaluations = async ({ connectorTimeout, evaluatorConnectorId, datasetName, + getInferenceConnectorById, graphs, langSmithApiKey, logger, @@ -38,8 +39,9 @@ export const runEvaluations = async ({ connectorTimeout: number; evaluatorConnectorId: string | undefined; datasetName: string; + getInferenceConnectorById: (id: string) => Promise; graphs: Array<{ - connector: Connector; + connector: InferenceConnector; graph: DefaultAttackDiscoveryGraph; llmType: string | undefined; name: string; @@ -83,6 +85,7 @@ export const runEvaluations = async ({ connectorTimeout, evaluatorConnectorId, experimentConnector: connector, + getInferenceConnectorById, langSmithApiKey, logger, }); diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/graphs/default_attack_discovery_graph/prompts/index.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/graphs/default_attack_discovery_graph/prompts/index.test.ts index 7ba240e7b7c78..b675a3ec6916f 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/graphs/default_attack_discovery_graph/prompts/index.test.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/graphs/default_attack_discovery_graph/prompts/index.test.ts @@ -5,9 +5,7 @@ * 2.0. */ -import type { ActionsClient } from '@kbn/actions-plugin/server'; import type { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server'; -import type { PublicMethodsOf } from '@kbn/utility-types'; import { getAttackDiscoveryPrompts } from '.'; import { getPromptsByGroupId, promptDictionary } from '../../../../prompt'; import { promptGroupId } from '../../../../prompt/local_prompt_object'; @@ -22,7 +20,6 @@ jest.mock('../../../../prompt', () => { const mockGetPromptsByGroupId = getPromptsByGroupId as jest.Mock; describe('getAttackDiscoveryPrompts', () => { - const actionsClient = {} as jest.Mocked>; const savedObjectsClient = {} as jest.Mocked; beforeEach(() => { @@ -54,7 +51,6 @@ describe('getAttackDiscoveryPrompts', () => { it('should return all prompts', async () => { const result = await getAttackDiscoveryPrompts({ - actionsClient, connectorId: 'test-connector-id', savedObjectsClient, model: '2', @@ -98,7 +94,6 @@ describe('getAttackDiscoveryPrompts', () => { mockGetPromptsByGroupId.mockResolvedValue([]); const result = await getAttackDiscoveryPrompts({ - actionsClient, connectorId: 'test-connector-id', savedObjectsClient, }); diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/graphs/default_attack_discovery_graph/prompts/index.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/graphs/default_attack_discovery_graph/prompts/index.ts index 8148db7d78555..1c0358adc8c8b 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/graphs/default_attack_discovery_graph/prompts/index.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/attack_discovery/graphs/default_attack_discovery_graph/prompts/index.ts @@ -5,10 +5,8 @@ * 2.0. */ -import type { PublicMethodsOf } from '@kbn/utility-types'; -import type { ActionsClient } from '@kbn/actions-plugin/server'; import type { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server'; -import type { Connector } from '@kbn/actions-plugin/server/application/connector/types'; +import type { InferenceConnector } from '@kbn/inference-common'; import { getPromptsByGroupId, promptDictionary } from '../../../../prompt'; import { promptGroupId } from '../../../../prompt/local_prompt_object'; @@ -30,23 +28,20 @@ export interface GenerationPrompts { export interface CombinedPrompts extends AttackDiscoveryPrompts, GenerationPrompts {} export const getAttackDiscoveryPrompts = async ({ - actionsClient, - connector, + getInferenceConnectorById, connectorId, model, provider, savedObjectsClient, }: { - actionsClient: PublicMethodsOf; - connector?: Connector; + getInferenceConnectorById?: (id: string) => Promise; connectorId: string; model?: string; provider?: string; savedObjectsClient: SavedObjectsClientContract; }): Promise => { const prompts = await getPromptsByGroupId({ - actionsClient, - connector, + getInferenceConnectorById, connectorId, // if in future oss has different prompt, add it as model here model, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/evaluation/helpers/get_evaluator_llm/index.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/evaluation/helpers/get_evaluator_llm/index.test.ts index 7026dc426a49d..353afdbfa1ac5 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/evaluation/helpers/get_evaluator_llm/index.test.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/evaluation/helpers/get_evaluator_llm/index.test.ts @@ -6,12 +6,12 @@ */ import type { ActionsClient } from '@kbn/actions-plugin/server'; -import type { Connector } from '@kbn/actions-plugin/server/application/connector/types'; import { ActionsClientLlm } from '@kbn/langchain/server'; import { loggerMock } from '@kbn/logging-mocks'; +import type { InferenceConnector } from '@kbn/inference-common'; +import { InferenceConnectorType } from '@kbn/inference-common'; import { getEvaluatorLlm } from '.'; -import { createMockConnector } from '@kbn/actions-plugin/server/application/connector/mocks'; jest.mock('@kbn/langchain/server', () => ({ ...jest.requireActual('@kbn/langchain/server'), @@ -35,22 +35,25 @@ jest.mock('../../../../../routes/utils', () => ({ const connectorTimeout = 1500; const evaluatorConnectorId = 'evaluator-connector-id'; -const evaluatorConnector: Connector = { - id: 'evaluator-connector-id', - actionTypeId: '.gen-ai', +const evaluatorConnector: InferenceConnector = { + connectorId: 'evaluator-connector-id', + type: InferenceConnectorType.OpenAI, name: 'OpenAI Evaluator', + config: {}, + capabilities: {}, + isInferenceEndpoint: false, isPreconfigured: false, - isSystemAction: false, - isDeprecated: false, -} as Connector; +}; -const experimentConnector: Connector = createMockConnector({ - id: 'experiment-connector-id', - actionTypeId: '.gemini', +const experimentConnector: InferenceConnector = { + connectorId: 'experiment-connector-id', + type: InferenceConnectorType.Gemini, name: 'Gemini Experiment', config: {}, + capabilities: {}, + isInferenceEndpoint: false, isPreconfigured: true, -}); +}; const logger = loggerMock.create(); @@ -61,77 +64,71 @@ describe('getEvaluatorLlm', () => { describe('evaluator connector resolution', () => { it('uses the provided evaluatorConnectorId if available', async () => { - const actionsClient = { - get: jest.fn(), - } as unknown as ActionsClient; + const actionsClient = {} as unknown as ActionsClient; + const getInferenceConnectorById = jest.fn().mockResolvedValue(evaluatorConnector); await getEvaluatorLlm({ actionsClient, connectorTimeout, evaluatorConnectorId, experimentConnector, + getInferenceConnectorById, langSmithApiKey: undefined, logger, }); - expect(actionsClient.get).toHaveBeenCalledWith({ - id: evaluatorConnectorId, - throwIfSystemAction: false, - }); + expect(getInferenceConnectorById).toHaveBeenCalledWith(evaluatorConnectorId); }); - it('falls back to experimentConnector.id if no evaluatorConnectorId is provided', async () => { - const actionsClient = { - get: jest.fn(), - } as unknown as ActionsClient; + it('falls back to experimentConnector.connectorId if no evaluatorConnectorId is provided', async () => { + const actionsClient = {} as unknown as ActionsClient; + const getInferenceConnectorById = jest.fn().mockResolvedValue(experimentConnector); await getEvaluatorLlm({ actionsClient, connectorTimeout, evaluatorConnectorId: undefined, experimentConnector, + getInferenceConnectorById, langSmithApiKey: undefined, logger, }); - expect(actionsClient.get).toHaveBeenCalledWith({ - id: experimentConnector.id, - throwIfSystemAction: false, - }); + expect(getInferenceConnectorById).toHaveBeenCalledWith(experimentConnector.connectorId); }); - it('uses the experimentConnector if get() returns null', async () => { - const actionsClient = { - get: jest.fn().mockResolvedValue(null), - } as unknown as ActionsClient; + it('uses the experimentConnector if getInferenceConnectorById throws', async () => { + const actionsClient = {} as unknown as ActionsClient; + const getInferenceConnectorById = jest.fn().mockRejectedValue(new Error('Not found')); await getEvaluatorLlm({ actionsClient, connectorTimeout, evaluatorConnectorId, experimentConnector, + getInferenceConnectorById, langSmithApiKey: undefined, logger, }); expect(ActionsClientLlm).toHaveBeenCalledWith( expect.objectContaining({ - connectorId: experimentConnector.id, + connectorId: experimentConnector.connectorId, }) ); }); }); it('logs a message with connector names and llm types', async () => { - const actionsClient = { - get: jest.fn().mockResolvedValue(evaluatorConnector), - } as unknown as ActionsClient; + const actionsClient = {} as unknown as ActionsClient; + const getInferenceConnectorById = jest.fn().mockResolvedValue(evaluatorConnector); await getEvaluatorLlm({ actionsClient, connectorTimeout, evaluatorConnectorId, experimentConnector, + getInferenceConnectorById, langSmithApiKey: undefined, logger, }); @@ -142,15 +139,15 @@ describe('getEvaluatorLlm', () => { }); it('passes expected traceOptions and config to ActionsClientLlm', async () => { - const actionsClient = { - get: jest.fn().mockResolvedValue(evaluatorConnector), - } as unknown as ActionsClient; + const actionsClient = {} as unknown as ActionsClient; + const getInferenceConnectorById = jest.fn().mockResolvedValue(evaluatorConnector); await getEvaluatorLlm({ actionsClient, connectorTimeout, evaluatorConnectorId, experimentConnector, + getInferenceConnectorById, langSmithApiKey: 'some-key', logger, }); diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/evaluation/helpers/get_evaluator_llm/index.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/evaluation/helpers/get_evaluator_llm/index.ts index dd7135f1fa1b8..325d64a6a5bf5 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/evaluation/helpers/get_evaluator_llm/index.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/evaluation/helpers/get_evaluator_llm/index.ts @@ -6,11 +6,12 @@ */ import type { ActionsClient } from '@kbn/actions-plugin/server'; -import type { Connector } from '@kbn/actions-plugin/server/application/connector/types'; import type { Logger } from '@kbn/core/server'; import { getLangSmithTracer } from '@kbn/langchain/server/tracers/langsmith'; import { ActionsClientLlm } from '@kbn/langchain/server'; import type { PublicMethodsOf } from '@kbn/utility-types'; +import type { InferenceConnector } from '@kbn/inference-common'; +import { getConnectorDefaultModel } from '@kbn/inference-common'; import { getLlmType } from '../../../../../routes/utils'; @@ -19,24 +20,29 @@ export const getEvaluatorLlm = async ({ connectorTimeout, evaluatorConnectorId, experimentConnector, + getInferenceConnectorById, langSmithApiKey, logger, }: { actionsClient: PublicMethodsOf; connectorTimeout: number; evaluatorConnectorId: string | undefined; - experimentConnector: Connector; + experimentConnector: InferenceConnector; + getInferenceConnectorById: (id: string) => Promise; langSmithApiKey: string | undefined; logger: Logger; }): Promise => { - const evaluatorConnector = - (await actionsClient.get({ - id: evaluatorConnectorId ?? experimentConnector.id, // fallback to the experiment connector if the evaluator connector is not found: - throwIfSystemAction: false, - })) ?? experimentConnector; + let evaluatorConnector: InferenceConnector; + try { + evaluatorConnector = await getInferenceConnectorById( + evaluatorConnectorId ?? experimentConnector.connectorId + ); + } catch { + evaluatorConnector = experimentConnector; + } - const evaluatorLlmType = getLlmType(evaluatorConnector.actionTypeId); - const experimentLlmType = getLlmType(experimentConnector.actionTypeId); + const evaluatorLlmType = getLlmType(evaluatorConnector.type); + const experimentLlmType = getLlmType(experimentConnector.type); logger.info( `The ${evaluatorConnector.name} (${evaluatorLlmType}) connector will judge output from the ${experimentConnector.name} (${experimentLlmType}) connector` @@ -55,9 +61,10 @@ export const getEvaluatorLlm = async ({ return new ActionsClientLlm({ actionsClient, - connectorId: evaluatorConnector.id, + connectorId: evaluatorConnector.connectorId, llmType: evaluatorLlmType, logger, + model: getConnectorDefaultModel(evaluatorConnector), temperature: 0, // zero temperature for evaluation timeout: connectorTimeout, traceOptions, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/evaluation/index.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/evaluation/index.test.ts index 6bb5ac55387bd..609b277bdd788 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/evaluation/index.test.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/evaluation/index.test.ts @@ -8,16 +8,17 @@ import type { Logger } from '@kbn/logging'; import type { ElasticsearchClient } from '@kbn/core-elasticsearch-server'; import type { PublicMethodsOf } from '@kbn/utility-types'; -import type { ActionsClient, Connector } from '@kbn/actions-plugin/server'; +import type { ActionsClient } from '@kbn/actions-plugin/server'; import { ActionsClientLlm } from '@kbn/langchain/server'; import { getLangSmithTracer } from '@kbn/langchain/server/tracers/langsmith'; import { savedObjectsClientMock } from '@kbn/core/server/mocks'; import { DefendInsightType } from '@kbn/elastic-assistant-common'; +import type { InferenceConnector } from '@kbn/inference-common'; +import { InferenceConnectorType } from '@kbn/inference-common'; import { getLlmType } from '../../../routes/utils'; import { runDefendInsightsEvaluations } from './run_evaluations'; import { evaluateDefendInsights } from '.'; -import { createMockConnector } from '@kbn/actions-plugin/server/application/connector/mocks'; jest.mock('./run_evaluations'); jest.mock('@kbn/langchain/server', () => ({ @@ -61,31 +62,27 @@ describe('evaluateDefendInsights', () => { }, ]; - const mockConnectors = [ - createMockConnector({ - id: '1', + const mockConnectors: InferenceConnector[] = [ + { + connectorId: '1', + type: InferenceConnectorType.OpenAI, name: 'Test Connector', - actionTypeId: '.test', - prompts: { - default: 'default', - refine: 'refine', - continue: 'continue', - group: 'group', - events: 'events', - eventsId: 'eventsId', - eventsEndpointId: 'eventsEndpointId', - eventsValue: 'eventsValue', - }, - } as unknown as Connector), + config: {}, + capabilities: {}, + isInferenceEndpoint: false, + isPreconfigured: false, + }, ]; const mockActionsClient = {} as unknown as PublicMethodsOf; const mockEsClient = {} as unknown as ElasticsearchClient; const mockSoClient = savedObjectsClientMock.create(); const mockEsClientInternalUser = {} as unknown as ElasticsearchClient; + const mockGetInferenceConnectorById = jest.fn(); await evaluateDefendInsights({ actionsClient: mockActionsClient, + getInferenceConnectorById: mockGetInferenceConnectorById, defendInsightsGraphs: mockGraphMetadata, anonymizationFields: [], connectors: mockConnectors, @@ -103,7 +100,7 @@ describe('evaluateDefendInsights', () => { size: 10, }); - expect(getLlmType).toHaveBeenCalledWith('.test'); + expect(getLlmType).toHaveBeenCalledWith(InferenceConnectorType.OpenAI); expect(getLangSmithTracer).toHaveBeenCalledWith({ apiKey: 'api-key', projectName: 'project-name', @@ -120,6 +117,7 @@ describe('evaluateDefendInsights', () => { projectName: 'project-name', tracers: ['mockTracer'], }, + model: undefined, }); expect(mockGetDefaultDefendInsightsGraph).toHaveBeenCalledWith({ diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/evaluation/index.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/evaluation/index.ts index 26bbeec4e200e..2e8a6aaced8d7 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/evaluation/index.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/evaluation/index.ts @@ -7,9 +7,10 @@ import type { LangChainTracer } from '@langchain/core/tracers/tracer_langchain'; import type { ActionsClient } from '@kbn/actions-plugin/server'; -import type { Connector } from '@kbn/actions-plugin/server/application/connector/types'; import type { PublicMethodsOf } from '@kbn/utility-types'; import type { ElasticsearchClient } from '@kbn/core-elasticsearch-server'; +import type { InferenceConnector } from '@kbn/inference-common'; +import { getConnectorDefaultModel } from '@kbn/inference-common'; import type { Logger } from '@kbn/logging'; import type { AnonymizationFieldResponse } from '@kbn/elastic-assistant-common/impl/schemas'; import type { SavedObjectsClientContract } from '@kbn/core/server'; @@ -28,6 +29,7 @@ import { runDefendInsightsEvaluations } from './run_evaluations'; export const evaluateDefendInsights = async ({ actionsClient, + getInferenceConnectorById, defendInsightsGraphs, anonymizationFields = DEFAULT_EVAL_ANONYMIZATION_FIELDS, // determines which fields are included in the alerts connectors, @@ -46,9 +48,10 @@ export const evaluateDefendInsights = async ({ size, }: { actionsClient: PublicMethodsOf; + getInferenceConnectorById: (id: string) => Promise; defendInsightsGraphs: DefendInsightsGraphMetadata[]; anonymizationFields?: AnonymizationFieldResponse[]; - connectors: Connector[]; + connectors: InferenceConnector[]; connectorTimeout: number; datasetName: string; esClient: ElasticsearchClient; @@ -68,7 +71,7 @@ export const evaluateDefendInsights = async ({ async ({ getDefaultDefendInsightsGraph, insightType }) => { // create a graph for every connector: const graphs: Array<{ - connector: Connector; + connector: InferenceConnector; graph: DefaultDefendInsightsGraph; llmType: string | undefined; name: string; @@ -78,12 +81,11 @@ export const evaluateDefendInsights = async ({ }; }> = await Promise.all( connectors.map(async (connector) => { - const llmType = getLlmType(connector.actionTypeId); + const llmType = getLlmType(connector.type); const prompts = await getDefendInsightsPrompt({ type: insightType, - actionsClient, - connectorId: connector.id, - connector, + getInferenceConnectorById, + connectorId: connector.connectorId, savedObjectsClient: soClient, }); @@ -100,13 +102,13 @@ export const evaluateDefendInsights = async ({ const llm = new ActionsClientLlm({ actionsClient, - connectorId: connector.id, + connectorId: connector.connectorId, llmType, logger, temperature: 0, // zero temperature for defend insights, because we want structured JSON output timeout: connectorTimeout, traceOptions, - model: connector.config?.defaultModel, + model: getConnectorDefaultModel(connector), }); const graph = getDefaultDefendInsightsGraph({ diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/evaluation/run_evaluations/index.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/evaluation/run_evaluations/index.test.ts index f30a37db40d21..116ba0bf97b58 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/evaluation/run_evaluations/index.test.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/evaluation/run_evaluations/index.test.ts @@ -5,16 +5,16 @@ * 2.0. */ -import type { Connector } from '@kbn/actions-plugin/server/application/connector/types'; import { loggerMock } from '@kbn/logging-mocks'; import type { LangChainTracer } from '@langchain/core/tracers/tracer_langchain'; import { getLangSmithTracer } from '@kbn/langchain/server/tracers/langsmith'; +import type { InferenceConnector } from '@kbn/inference-common'; +import { InferenceConnectorType } from '@kbn/inference-common'; import { runDefendInsightsEvaluations } from '.'; import type { DefaultDefendInsightsGraph } from '../../graphs/default_defend_insights_graph'; import { getLlmType } from '../../../../routes/utils'; import { DefendInsightType } from '@kbn/elastic-assistant-common'; -import { createMockConnector } from '@kbn/actions-plugin/server/application/connector/mocks'; jest.mock('langsmith/evaluation', () => ({ evaluate: jest.fn(async (predict: Function) => @@ -34,18 +34,20 @@ jest.mock('../helpers/get_graph_input_overrides', () => ({ getDefendInsightsGraphInputOverrides: jest.fn((input) => input.overrides ?? {}), })); -const mockExperimentConnector = createMockConnector({ +const mockExperimentConnector: InferenceConnector = { + type: InferenceConnectorType.Gemini, name: 'Gemini 1.5 Pro 002', - actionTypeId: '.gemini', + connectorId: 'gemini-1-5-pro-002', config: { apiUrl: 'https://example.com', defaultModel: 'gemini-1.5-pro-002', gcpRegion: 'test-region', gcpProjectID: 'test-project-id', }, - id: 'gemini-1-5-pro-002', + capabilities: {}, + isInferenceEndpoint: false, isPreconfigured: true, -}); +}; const datasetName = 'test-dataset'; const evaluatorConnectorId = 'test-evaluator-connector-id'; @@ -55,7 +57,7 @@ const connectors = [mockExperimentConnector]; const projectName = 'test-lang-smith-project'; const graphs: Array<{ - connector: Connector; + connector: InferenceConnector; graph: DefaultDefendInsightsGraph; llmType: string | undefined; name: string; @@ -64,7 +66,7 @@ const graphs: Array<{ tracers: LangChainTracer[]; }; }> = connectors.map((connector) => { - const llmType = getLlmType(connector.actionTypeId); + const llmType = getLlmType(connector.type); const traceOptions = { projectName, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/evaluation/run_evaluations/index.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/evaluation/run_evaluations/index.ts index 9a4f3e81135eb..cf030b0ec42c6 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/evaluation/run_evaluations/index.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/evaluation/run_evaluations/index.ts @@ -6,7 +6,7 @@ */ import type { LangChainTracer } from '@langchain/core/tracers/tracer_langchain'; -import type { Connector } from '@kbn/actions-plugin/server/application/connector/types'; +import type { InferenceConnector } from '@kbn/inference-common'; import type { Logger } from '@kbn/logging'; import type { DefendInsightType } from '@kbn/elastic-assistant-common'; import { Client } from 'langsmith'; @@ -32,7 +32,7 @@ export const runDefendInsightsEvaluations = async ({ evaluatorConnectorId: string | undefined; datasetName: string; graphs: Array<{ - connector: Connector; + connector: InferenceConnector; graph: DefaultDefendInsightsGraph; llmType: string | undefined; name: string; diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/graphs/default_defend_insights_graph/prompts/incompatible_antivirus.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/graphs/default_defend_insights_graph/prompts/incompatible_antivirus.test.ts index cfa0ecaffde36..e80b0b95d6204 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/graphs/default_defend_insights_graph/prompts/incompatible_antivirus.test.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/graphs/default_defend_insights_graph/prompts/incompatible_antivirus.test.ts @@ -5,9 +5,7 @@ * 2.0. */ -import type { ActionsClient } from '@kbn/actions-plugin/server'; import type { SavedObjectsClientContract } from '@kbn/core/server'; -import type { PublicMethodsOf } from '@kbn/utility-types'; import { promptDictionary, getPromptsByGroupId } from '../../../../prompt'; import { getIncompatibleAntivirusPrompt } from './incompatible_antivirus'; import { promptGroupId } from '../../../../prompt/local_prompt_object'; @@ -23,7 +21,6 @@ jest.mock('../../../../prompt', () => { const mockGetPromptsByGroupId = getPromptsByGroupId as jest.Mock; describe('getIncompatibleAntivirusPrompt', () => { - const actionsClient = {} as jest.Mocked>; const savedObjectsClient = {} as jest.Mocked; beforeEach(() => { @@ -66,7 +63,6 @@ describe('getIncompatibleAntivirusPrompt', () => { it('should return all prompts', async () => { const result = await getIncompatibleAntivirusPrompt({ - actionsClient, connectorId: 'test-connector-id', savedObjectsClient, model: '4', @@ -108,7 +104,6 @@ describe('getIncompatibleAntivirusPrompt', () => { mockGetPromptsByGroupId.mockResolvedValue([]); const result = await getIncompatibleAntivirusPrompt({ - actionsClient, connectorId: 'test-connector-id', savedObjectsClient, }); diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/graphs/default_defend_insights_graph/prompts/incompatible_antivirus.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/graphs/default_defend_insights_graph/prompts/incompatible_antivirus.ts index 9d4c5392fc911..259a043cfa598 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/graphs/default_defend_insights_graph/prompts/incompatible_antivirus.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/graphs/default_defend_insights_graph/prompts/incompatible_antivirus.ts @@ -5,33 +5,28 @@ * 2.0. */ -import type { PublicMethodsOf } from '@kbn/utility-types'; -import type { ActionsClient } from '@kbn/actions-plugin/server'; -import type { Connector } from '@kbn/actions-plugin/server/application/connector/types'; import type { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server'; +import type { InferenceConnector } from '@kbn/inference-common'; import type { DefendInsightsCombinedPrompts } from '.'; import { promptDictionary, getPromptsByGroupId } from '../../../../prompt'; import { promptGroupId } from '../../../../prompt/local_prompt_object'; export async function getIncompatibleAntivirusPrompt({ - actionsClient, - connector, + getInferenceConnectorById, connectorId, model, provider, savedObjectsClient, }: { - actionsClient: PublicMethodsOf; - connector?: Connector; + getInferenceConnectorById?: (id: string) => Promise; connectorId: string; model?: string; provider?: string; savedObjectsClient: SavedObjectsClientContract; }): Promise { const prompts = await getPromptsByGroupId({ - actionsClient, - connector, + getInferenceConnectorById, connectorId, model, provider, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/graphs/default_defend_insights_graph/prompts/index.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/graphs/default_defend_insights_graph/prompts/index.test.ts index 941e9ceb6c171..19de0907de487 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/graphs/default_defend_insights_graph/prompts/index.test.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/graphs/default_defend_insights_graph/prompts/index.test.ts @@ -5,9 +5,7 @@ * 2.0. */ -import type { ActionsClient } from '@kbn/actions-plugin/server'; import type { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server'; -import type { PublicMethodsOf } from '@kbn/utility-types'; import { DefendInsightType } from '@kbn/elastic-assistant-common'; import type { DefendInsightsCombinedPrompts } from '.'; @@ -25,7 +23,6 @@ jest.mock('./policy_response_failure', () => ({ describe('getDefendInsightsPrompt', () => { const mockArgs = { - actionsClient: {} as unknown as PublicMethodsOf, connector: undefined, connectorId: 'mock-connector-id', model: 'mock-model', diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/graphs/default_defend_insights_graph/prompts/index.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/graphs/default_defend_insights_graph/prompts/index.ts index 6e3798b8d403b..324c8eb695231 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/graphs/default_defend_insights_graph/prompts/index.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/graphs/default_defend_insights_graph/prompts/index.ts @@ -5,10 +5,8 @@ * 2.0. */ -import type { PublicMethodsOf } from '@kbn/utility-types'; -import type { ActionsClient } from '@kbn/actions-plugin/server'; -import type { Connector } from '@kbn/actions-plugin/server/application/connector/types'; import type { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server'; +import type { InferenceConnector } from '@kbn/inference-common'; import { DefendInsightType } from '@kbn/elastic-assistant-common'; import { InvalidDefendInsightTypeError } from '../../../errors'; @@ -39,8 +37,7 @@ export function getDefendInsightsPrompt({ ...args }: { type: DefendInsightType; - actionsClient: PublicMethodsOf; - connector?: Connector; + getInferenceConnectorById?: (id: string) => Promise; connectorId: string; model?: string; provider?: string; diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/graphs/default_defend_insights_graph/prompts/policy_response_failure.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/graphs/default_defend_insights_graph/prompts/policy_response_failure.test.ts index fad2369032219..5b30a907a2b6e 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/graphs/default_defend_insights_graph/prompts/policy_response_failure.test.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/graphs/default_defend_insights_graph/prompts/policy_response_failure.test.ts @@ -5,9 +5,7 @@ * 2.0. */ -import type { ActionsClient } from '@kbn/actions-plugin/server'; import type { SavedObjectsClientContract } from '@kbn/core/server'; -import type { PublicMethodsOf } from '@kbn/utility-types'; import { promptDictionary, getPromptsByGroupId } from '../../../../prompt'; import { getPolicyResponseFailurePrompt } from './policy_response_failure'; import { promptGroupId } from '../../../../prompt/local_prompt_object'; @@ -23,7 +21,6 @@ jest.mock('../../../../prompt', () => { const mockGetPromptsByGroupId = getPromptsByGroupId as jest.Mock; describe('getPolicyResponseFailurePrompt', () => { - const actionsClient = {} as jest.Mocked>; const savedObjectsClient = {} as jest.Mocked; beforeEach(() => { @@ -78,7 +75,6 @@ describe('getPolicyResponseFailurePrompt', () => { it('should return all prompts', async () => { const result = await getPolicyResponseFailurePrompt({ - actionsClient, connectorId: 'test-connector-id', savedObjectsClient, model: '4', @@ -126,7 +122,6 @@ describe('getPolicyResponseFailurePrompt', () => { mockGetPromptsByGroupId.mockResolvedValue([]); const result = await getPolicyResponseFailurePrompt({ - actionsClient, connectorId: 'test-connector-id', savedObjectsClient, }); diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/graphs/default_defend_insights_graph/prompts/policy_response_failure.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/graphs/default_defend_insights_graph/prompts/policy_response_failure.ts index ffba144845f77..c1359328ac132 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/graphs/default_defend_insights_graph/prompts/policy_response_failure.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/defend_insights/graphs/default_defend_insights_graph/prompts/policy_response_failure.ts @@ -5,33 +5,28 @@ * 2.0. */ -import type { ActionsClient } from '@kbn/actions-plugin/server'; -import type { Connector } from '@kbn/actions-plugin/server/application/connector/types'; -import type { PublicMethodsOf } from '@kbn/utility-types'; import type { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server'; +import type { InferenceConnector } from '@kbn/inference-common'; import type { DefendInsightsCombinedPrompts } from '.'; import { promptGroupId } from '../../../../prompt/local_prompt_object'; import { promptDictionary, getPromptsByGroupId } from '../../../../prompt'; export async function getPolicyResponseFailurePrompt({ - actionsClient, - connector, + getInferenceConnectorById, connectorId, model, provider, savedObjectsClient, }: { - actionsClient: PublicMethodsOf; - connector?: Connector; + getInferenceConnectorById?: (id: string) => Promise; connectorId: string; model?: string; provider?: string; savedObjectsClient: SavedObjectsClientContract; }): Promise { const prompts = await getPromptsByGroupId({ - actionsClient, - connector, + getInferenceConnectorById, connectorId, model, provider, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts index 13da7e492f077..b249f745abbe7 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts @@ -12,9 +12,8 @@ import type { Logger } from '@kbn/logging'; import type { BaseChatModel } from '@langchain/core/language_models/chat_models'; import type { ContentReferencesStore } from '@kbn/elastic-assistant-common'; -import type { PublicMethodsOf } from '@kbn/utility-types'; -import type { ActionsClient } from '@kbn/actions-plugin/server'; import type { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server'; +import type { InferenceConnector } from '@kbn/inference-common'; import { ToolNode } from '@langchain/langgraph/prebuilt'; import type { AgentState, NodeParamsBase } from './types'; @@ -26,7 +25,7 @@ import { AssistantStateAnnotation } from './state'; export const DEFAULT_ASSISTANT_GRAPH_ID = 'Default Security Assistant Graph'; export interface GetDefaultAssistantGraphParams { - actionsClient: PublicMethodsOf; + getInferenceConnectorById: (id: string) => Promise; createLlmInstance: () => Promise; logger: Logger; savedObjectsClient: SavedObjectsClientContract; @@ -39,7 +38,7 @@ export interface GetDefaultAssistantGraphParams { export type DefaultAssistantGraph = Awaited>; export const getDefaultAssistantGraph = async ({ - actionsClient, + getInferenceConnectorById, checkpointSaver, contentReferencesStore, createLlmInstance, @@ -52,7 +51,7 @@ export const getDefaultAssistantGraph = async ({ try { // Default node parameters const nodeParams: NodeParamsBase = { - actionsClient, + getInferenceConnectorById, logger, savedObjectsClient, contentReferencesStore, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index 03c4dd78b49ef..f68ad242c1a77 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -77,48 +77,49 @@ export const callAssistantGraph: AgentExecutor = async ({ * the state unintentionally. For this reason, only call createLlmInstance at runtime */ const createLlmInstance = async () => { - const connector = await actionsClient.get({ id: connectorId }); - const defaultModel = connector?.config?.defaultModel; - return !inferenceChatModelDisabled - ? inference.getChatModel({ - request, - connectorId, - chatModelOptions: { - model: request.body.model, - signal: abortSignal, - temperature: getDefaultArguments(llmType).temperature, - // prevents the agent from retrying on failure - // failure could be due to bad connector, we should deliver that result to the client asap - maxRetries: 0, - telemetryMetadata: { - pluginId: 'security_ai_assistant', - }, - // TODO add timeout to inference once resolved https://github.com/elastic/kibana/issues/221318 - // timeout, - }, - }) - : new llmClass({ - actionsClient, - connectorId, - llmType, - logger, - // possible client model override, - // let this be undefined otherwise so the connector handles the model - model: request.body.model ?? defaultModel, - // ensure this is defined because we default to it in the language_models - // This is where the LangSmith logs (Metadata > Invocation Params) are set - temperature: getDefaultArguments(llmType).temperature, + if (!inferenceChatModelDisabled) { + return inference.getChatModel({ + request, + connectorId, + chatModelOptions: { + model: request.body.model, signal: abortSignal, - streaming: isStream, + temperature: getDefaultArguments(llmType).temperature, // prevents the agent from retrying on failure // failure could be due to bad connector, we should deliver that result to the client asap maxRetries: 0, - convertSystemMessageToHumanContent: false, - timeout, telemetryMetadata: { pluginId: 'security_ai_assistant', }, - }); + // TODO add timeout to inference once resolved https://github.com/elastic/kibana/issues/221318 + // timeout, + }, + }); + } + const connector = await actionsClient.get({ id: connectorId }); + const defaultModel = connector?.config?.defaultModel; + return new llmClass({ + actionsClient, + connectorId, + llmType, + logger, + // possible client model override, + // let this be undefined otherwise so the connector handles the model + model: request.body.model ?? defaultModel, + // ensure this is defined because we default to it in the language_models + // This is where the LangSmith logs (Metadata > Invocation Params) are set + temperature: getDefaultArguments(llmType).temperature, + signal: abortSignal, + streaming: isStream, + // prevents the agent from retrying on failure + // failure could be due to bad connector, we should deliver that result to the client asap + maxRetries: 0, + convertSystemMessageToHumanContent: false, + timeout, + telemetryMetadata: { + pluginId: 'security_ai_assistant', + }, + }); }; const anonymizationFieldsRes = @@ -165,7 +166,7 @@ export const callAssistantGraph: AgentExecutor = async ({ let description: string | undefined; try { description = await getPrompt({ - actionsClient, + getInferenceConnectorById: (id) => inference.getConnectorById(id, request), connectorId, localPrompts: localToolPrompts, model: getModelOrOss(llmType, isOssModel, request.body.model), @@ -213,7 +214,7 @@ export const callAssistantGraph: AgentExecutor = async ({ } const defaultSystemPrompt = await localGetPrompt({ - actionsClient, + getInferenceConnectorById: (id) => inference.getConnectorById(id, request), connectorId, model: getModelOrOss(llmType, isOssModel, request.body.model), promptId: promptDictionary.systemPrompt, @@ -230,7 +231,7 @@ export const callAssistantGraph: AgentExecutor = async ({ // we need to pass it like this or streaming does not work for bedrock createLlmInstance, logger, - actionsClient, + getInferenceConnectorById: (id) => inference.getConnectorById(id, request), savedObjectsClient, tools, // some chat models (bedrock) require a signal to be passed on agent invoke rather than the signal passed to the chat model @@ -258,7 +259,7 @@ export const callAssistantGraph: AgentExecutor = async ({ screenContextTimezone: screenContext?.timeZone, uiSettingsDateFormatTimezone, }), - actionsClient, + getInferenceConnectorById: (id) => inference.getConnectorById(id, request), savedObjectsClient, connectorId, llmType, @@ -280,7 +281,7 @@ export const callAssistantGraph: AgentExecutor = async ({ void (async () => { const model = await createLlmInstance(); await generateChatTitle({ - actionsClient, + getInferenceConnectorById: (id) => inference.getConnectorById(id, request), contentReferencesStore, conversationsDataClient: dataClients?.conversationsDataClient, logger, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/generate_chat_title.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/generate_chat_title.ts index 608299d403300..5ea7134146e8a 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/generate_chat_title.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/generate_chat_title.ts @@ -47,7 +47,7 @@ export interface GenerateChatTitleParams extends NodeParamsBase { } export async function generateChatTitle({ - actionsClient, + getInferenceConnectorById, conversationsDataClient, logger, savedObjectsClient, @@ -77,7 +77,7 @@ export async function generateChatTitle({ const outputParser = new StringOutputParser(); const prompt = await getPrompt({ - actionsClient, + getInferenceConnectorById, connectorId: state.connectorId, promptId: promptDictionary.chatTitle, promptGroupId: promptGroupId.aiAssistant, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.test.ts index c2811027cce0c..1fc6128996c18 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.test.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.test.ts @@ -8,10 +8,8 @@ import { DEFAULT_ASSISTANT_GRAPH_PROMPT_TEMPLATE, chatPromptFactory } from './prompts'; import { AIMessage, HumanMessage, SystemMessage } from '@langchain/core/messages'; import { loggingSystemMock } from '@kbn/core/server/mocks'; -import type { ActionsClient } from '@kbn/actions-plugin/server'; import type { SavedObjectsClientContract } from '@kbn/core/server'; import type { AIAssistantKnowledgeBaseDataClient } from '../../../../ai_assistant_data_clients/knowledge_base'; -import type { PublicMethodsOf } from '@kbn/utility-types'; import { newContentReferencesStore } from '@kbn/elastic-assistant-common'; import { newContentReferencesStoreMock } from '@kbn/elastic-assistant-common/impl/content_references/content_references_store/__mocks__/content_references_store.mock'; @@ -34,7 +32,6 @@ describe('chatPromptFactory', () => { }, ]), } as unknown as AIAssistantKnowledgeBaseDataClient; - const mockActionsClient = {} as unknown as PublicMethodsOf; const mockSavedObjectsClient = {} as unknown as SavedObjectsClientContract; const baseInputs = { @@ -49,7 +46,7 @@ describe('chatPromptFactory', () => { ], logger: loggingSystemMock.createLogger(), formattedTime: '2023-10-01T00:00:00Z', - actionsClient: mockActionsClient, + getInferenceConnectorById: jest.fn(), savedObjectsClient: mockSavedObjectsClient, connectorId: 'test-connector-id', llmType: 'gemini', diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.ts index 221bd05b1a6f8..b93c1e2bf767d 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/prompts.ts @@ -10,9 +10,8 @@ import type { BaseMessage } from '@langchain/core/messages'; import type { ContentReferencesStore, DocumentEntry } from '@kbn/elastic-assistant-common'; import { enrichDocument } from '@kbn/elastic-assistant-common'; import type { Logger } from '@kbn/logging'; -import type { PublicMethodsOf } from '@kbn/utility-types'; import type { SavedObjectsClientContract } from '@kbn/core/server'; -import type { ActionsClient } from '@kbn/actions-plugin/server'; +import type { InferenceConnector } from '@kbn/inference-common'; import type { ChatPromptValueInterface } from '@langchain/core/prompt_values'; import { enrichConversation } from '../../utils/enrich_graph_input_messages'; import type { AIAssistantKnowledgeBaseDataClient } from '../../../../ai_assistant_data_clients/knowledge_base'; @@ -32,7 +31,7 @@ interface Inputs { conversationMessages: BaseMessage[]; logger: Logger; formattedTime: string; - actionsClient: PublicMethodsOf; + getInferenceConnectorById: (id: string) => Promise; savedObjectsClient: SavedObjectsClientContract; connectorId: string; llmType: string | undefined; @@ -79,7 +78,7 @@ export const chatPromptFactory = async ( }); const enrichedMessages = await enrichConversation({ - actionsClient: inputs.actionsClient, + getInferenceConnectorById: inputs.getInferenceConnectorById, savedObjectsClient: inputs.savedObjectsClient, connectorId: inputs.connectorId, llmType: inputs.llmType, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/types.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/types.ts index 37faa231ef33e..136f6eccb9f82 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/types.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/types.ts @@ -8,9 +8,8 @@ import type { BaseMessage } from '@langchain/core/messages'; import type { Logger } from '@kbn/logging'; import type { ContentReferencesStore } from '@kbn/elastic-assistant-common'; -import type { PublicMethodsOf } from '@kbn/utility-types'; -import type { ActionsClient } from '@kbn/actions-plugin/server'; import type { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server'; +import type { InferenceConnector } from '@kbn/inference-common'; import type { AssistantStateAnnotation } from './state'; export interface GraphInputs { @@ -31,7 +30,7 @@ export interface GraphInputs { export type AgentState = typeof AssistantStateAnnotation.State; export interface NodeParamsBase { - actionsClient: PublicMethodsOf; + getInferenceConnectorById: (id: string) => Promise; logger: Logger; savedObjectsClient: SavedObjectsClientContract; contentReferencesStore: ContentReferencesStore; diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/utils/enrich_graph_input_messages.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/utils/enrich_graph_input_messages.ts index 803117d884f49..7708bd927fa80 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/utils/enrich_graph_input_messages.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/utils/enrich_graph_input_messages.ts @@ -8,8 +8,7 @@ import type { BaseMessage } from '@langchain/core/messages'; import { HumanMessage } from '@langchain/core/messages'; import type { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server'; -import type { PublicMethodsOf } from '@kbn/utility-types'; -import type { ActionsClient } from '@kbn/actions-plugin/server'; +import type { InferenceConnector } from '@kbn/inference-common'; import { promptGroupId } from '../../prompt/local_prompt_object'; import { getPrompt, promptDictionary } from '../../prompt'; @@ -17,7 +16,7 @@ interface Params { llmType?: string; connectorId: string; savedObjectsClient: SavedObjectsClientContract; - actionsClient: PublicMethodsOf; + getInferenceConnectorById: (id: string) => Promise; messages: BaseMessage[]; } @@ -33,13 +32,16 @@ export const enrichConversation = (params: Params) => { * Prepends the user prompt to the last message if the last message is a human message. */ const getUserPrompt = ( - params: Pick + params: Pick< + Params, + 'getInferenceConnectorById' | 'savedObjectsClient' | 'connectorId' | 'llmType' + > ) => { return async (messages: BaseMessage[]): Promise => { const userPrompt = params.llmType === 'gemini' ? await getPrompt({ - actionsClient: params.actionsClient, + getInferenceConnectorById: params.getInferenceConnectorById, connectorId: params.connectorId, promptId: promptDictionary.userPrompt, promptGroupId: promptGroupId.aiAssistant, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/prompt/get_prompt.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/prompt/get_prompt.ts index ea58e116ca475..7170238c1140b 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/prompt/get_prompt.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/prompt/get_prompt.ts @@ -12,6 +12,8 @@ import { type PromptArray, type GetPromptsByGroupIdArgs, } from '@kbn/security-ai-prompts'; +import type { KibanaRequest } from '@kbn/core/server'; +import type { InferenceServerStart } from '@kbn/inference-plugin/server'; import { localPrompts } from './local_prompt_object'; export const getPromptsByGroupId = async ( @@ -23,3 +25,7 @@ export const getPromptsByGroupId = async ( export const getPrompt = async (args: Omit): Promise => { return _getPrompt({ ...args, localPrompts }); }; + +export const getInferenceConnectorById = + (inference: InferenceServerStart, request: KibanaRequest) => (id: string) => + inference.getConnectorById(id, request); diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/prompt/index.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/prompt/index.ts index c27ef7803414b..1a39ab0ef40b1 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/prompt/index.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/prompt/index.ts @@ -5,5 +5,5 @@ * 2.0. */ -export { getPrompt, getPromptsByGroupId } from './get_prompt'; +export { getPrompt, getPromptsByGroupId, getInferenceConnectorById } from './get_prompt'; export { promptDictionary } from './local_prompt_object'; diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/alert_summary/find_route.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/alert_summary/find_route.ts index bb8c4977fa04e..a3433a3f2c7fb 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/alert_summary/find_route.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/alert_summary/find_route.ts @@ -16,7 +16,7 @@ import type { FindAlertSummaryResponse } from '@kbn/elastic-assistant-common/imp import { FindAlertSummaryRequestQuery } from '@kbn/elastic-assistant-common/impl/schemas'; import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; import _ from 'lodash'; -import { getPrompt, promptDictionary } from '../../lib/prompt'; +import { getPrompt, getInferenceConnectorById, promptDictionary } from '../../lib/prompt'; import type { ElasticAssistantPluginRouter } from '../../types'; import { buildResponse } from '../utils'; import type { EsAlertSummarySchema } from '../../ai_assistant_data_clients/alert_summary/types'; @@ -60,8 +60,6 @@ export const findAlertSummaryRoute = (router: ElasticAssistantPluginRouter, logg return checkResponse.response; } const dataClient = await ctx.elasticAssistant.getAlertSummaryDataClient(); - const actions = ctx.elasticAssistant.actions; - const actionsClient = await actions.getActionsClientWithRequest(request); const savedObjectsClient = ctx.elasticAssistant.savedObjectsClient; const result = await dataClient?.findDocuments({ perPage: query.per_page, @@ -72,7 +70,10 @@ export const findAlertSummaryRoute = (router: ElasticAssistantPluginRouter, logg fields: query.fields?.map((f) => _.snakeCase(f)), }); const prompt = await getPrompt({ - actionsClient, + getInferenceConnectorById: getInferenceConnectorById( + ctx.elasticAssistant.inference, + request + ), connectorId: query.connector_id, promptId: promptDictionary.alertSummary, promptGroupId: promptGroupId.aiForSoc, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/attack_discovery/public/post/helpers/invoke_attack_discovery_graph/index.tsx b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/attack_discovery/public/post/helpers/invoke_attack_discovery_graph/index.tsx index 51b06d88bc85d..52a4e0ae2bd61 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/attack_discovery/public/post/helpers/invoke_attack_discovery_graph/index.tsx +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/attack_discovery/public/post/helpers/invoke_attack_discovery_graph/index.tsx @@ -100,7 +100,6 @@ export const invokeAttackDiscoveryGraph = async ({ } const attackDiscoveryPrompts = await getAttackDiscoveryPrompts({ - actionsClient, connectorId: apiConfig.connectorId, // if in future oss has different prompt, add it as model here model, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts index 9121a032de431..322d1cb61e476 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts @@ -20,6 +20,7 @@ import { getFindAnonymizationFieldsResultWithSingleHit, } from '../../__mocks__/response'; import { defaultAssistantFeatures } from '@kbn/elastic-assistant-common'; +import { InferenceConnectorType } from '@kbn/inference-common'; import { chatCompleteRoute } from './chat_complete_route'; import { licensingMock } from '@kbn/licensing-plugin/server/mocks'; import { @@ -72,6 +73,22 @@ const mockContext = { getRegisteredFeatures: jest.fn(() => defaultAssistantFeatures), logger: loggingSystemMock.createLogger(), telemetry: { ...coreMock.createSetup().analytics, reportEvent }, + inference: { + getConnectorById: jest.fn().mockImplementation((id: string) => { + if (id === 'mock-connector-id') { + return Promise.resolve({ + connectorId: 'mock-connector-id', + type: InferenceConnectorType.OpenAI, + name: 'mock connector', + config: {}, + capabilities: {}, + isInferenceEndpoint: false, + isPreconfigured: false, + }); + } + return Promise.resolve(undefined); + }), + }, llmTasks: { retrieveDocumentationAvailable: jest.fn(), retrieveDocumentation: jest.fn() }, getCurrentUser: () => ({ username: 'user', @@ -403,7 +420,7 @@ describe('chatCompleteRoute', () => { expect(reportEvent).toHaveBeenCalledWith(INVOKE_ASSISTANT_ERROR_EVENT.eventType, { errorMessage: 'simulated error', errorLocation: 'chatCompleteRoute', - actionTypeId: '.gen-ai', + actionTypeId: '.inference', model: 'gpt-4', assistantStreamingEnabled: false, isEnabledKnowledgeBase: false, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts index cd61aece0e798..a088ede212d20 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts @@ -22,6 +22,7 @@ import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/ import { getRequestAbortedSignal } from '@kbn/data-plugin/server'; import { defaultInferenceEndpoints } from '@kbn/inference-common'; import { v4 as uuidv4 } from 'uuid'; +import { getInferenceConnectorById } from '../../lib/prompt'; import { INVOKE_ASSISTANT_ERROR_EVENT } from '../../lib/telemetry/event_based_telemetry'; import type { ElasticAssistantPluginRouter } from '../../types'; import { buildResponse } from '../../lib/build_response'; @@ -124,10 +125,12 @@ export const chatCompleteRoute = ( // get the actions plugin start contract from the request context: const actions = ctx.elasticAssistant.actions; const actionsClient = await actions.getActionsClientWithRequest(request); - const connectors = await actionsClient.getBulk({ ids: [connectorId] }); - const connector = connectors.length > 0 ? connectors[0] : undefined; - actionTypeId = connector?.actionTypeId ?? '.gen-ai'; - const isOssModel = isOpenSourceModel(connector); + const inferenceConnector = await getInferenceConnectorById( + inference, + request + )(connectorId).catch(() => undefined); + actionTypeId = inferenceConnector?.type ?? '.inference'; + const isOssModel = isOpenSourceModel(inferenceConnector); const savedObjectsClient = ctx.elasticAssistant.savedObjectsClient; // replacements diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/defend_insights/helpers.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/defend_insights/helpers.ts index bdf8fe6f609d3..1338041ee3083 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/defend_insights/helpers.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/defend_insights/helpers.ts @@ -27,6 +27,7 @@ import type { AnonymizationFieldResponse } from '@kbn/elastic-assistant-common/i import type { ActionsClient } from '@kbn/actions-plugin/server'; import type { Moment } from 'moment'; import type { PublicMethodsOf } from '@kbn/utility-types'; +import type { InferenceConnector } from '@kbn/inference-common'; import moment from 'moment'; import { ActionsClientLlm } from '@kbn/langchain/server'; import { getLangSmithTracer } from '@kbn/langchain/server/tracers/langsmith'; @@ -421,6 +422,7 @@ export const invokeDefendInsightsGraph = async ({ insightType, endpointIds, actionsClient, + getInferenceConnectorById, anonymizationFields, apiConfig, connectorTimeout, @@ -439,6 +441,7 @@ export const invokeDefendInsightsGraph = async ({ insightType: DefendInsightType; endpointIds: string[]; actionsClient: PublicMethodsOf; + getInferenceConnectorById: (id: string) => Promise; anonymizationFields: AnonymizationFieldResponse[]; apiConfig: ApiConfig; connectorTimeout: number; @@ -495,7 +498,7 @@ export const invokeDefendInsightsGraph = async ({ const defendInsightsPrompts = await getDefendInsightsPrompt({ type: insightType, - actionsClient, + getInferenceConnectorById, connectorId: apiConfig.connectorId, model, provider: llmType, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/defend_insights/post_defend_insights.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/defend_insights/post_defend_insights.ts index 8ad7c75ca27ca..48965d158a291 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/defend_insights/post_defend_insights.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/defend_insights/post_defend_insights.ts @@ -147,6 +147,8 @@ export const postDefendInsightsRoute = (router: IRouter + assistantContext.inference.getConnectorById(id, request), anonymizationFields, apiConfig, connectorTimeout: CONNECTOR_TIMEOUT, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts index 46b0a6a7092a0..ea0d89d621c4e 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts @@ -31,7 +31,7 @@ import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/ import { getDefaultArguments } from '@kbn/langchain/server'; import type { StructuredTool } from '@langchain/core/tools'; import { omit } from 'lodash/fp'; -import { defaultInferenceEndpoints } from '@kbn/inference-common'; +import { defaultInferenceEndpoints, getConnectorDefaultModel } from '@kbn/inference-common'; import { HumanMessage } from '@langchain/core/messages'; import { evaluateDefendInsights } from '../../lib/defend_insights/evaluation'; import { localToolPrompts, promptGroupId as toolsGroupId } from '../../lib/prompt/tool_prompts'; @@ -42,7 +42,11 @@ import { DEFAULT_ASSISTANT_GRAPH_PROMPT_TEMPLATE, chatPromptFactory, } from '../../lib/langchain/graphs/default_assistant_graph/prompts'; -import { getPrompt as localGetPrompt, promptDictionary } from '../../lib/prompt'; +import { + getPrompt as localGetPrompt, + getInferenceConnectorById, + promptDictionary, +} from '../../lib/prompt'; import { buildResponse } from '../../lib/build_response'; import type { AssistantDataClients } from '../../lib/langchain/executors/types'; import type { AssistantToolParams, ElasticAssistantRequestHandlerContext } from '../../types'; @@ -203,10 +207,9 @@ export const postEvaluateRoute = ( // Actions const actionsClient = await actions.getActionsClientWithRequest(request); - const connectors = await actionsClient.getBulk({ - ids: connectorIds, - throwIfSystemAction: false, - }); + const connectors = await Promise.all( + connectorIds.map(getInferenceConnectorById(inference, request)) + ); // Fetch any tools registered to the security assistant const assistantTools = assistantContext.getRegisteredTools(DEFAULT_PLUGIN_NAME); @@ -228,6 +231,7 @@ export const postEvaluateRoute = ( try { void evaluateDefendInsights({ actionsClient, + getInferenceConnectorById: getInferenceConnectorById(inference, request), defendInsightsGraphs, connectors, connectorTimeout: RESPONSE_TIMEOUT, @@ -257,9 +261,8 @@ export const postEvaluateRoute = ( const connectorsWithPrompts = await Promise.all( connectors.map(async (connector) => { const prompts = await getAttackDiscoveryPrompts({ - actionsClient, - connectorId: connector.id, - connector, + getInferenceConnectorById: getInferenceConnectorById(inference, request), + connectorId: connector.connectorId, savedObjectsClient, }); return { @@ -283,6 +286,7 @@ export const postEvaluateRoute = ( esClientInternalUser, evaluationId, evaluatorConnectorId, + getInferenceConnectorById: getInferenceConnectorById(inference, request), langSmithApiKey, langSmithProject, logger, @@ -308,14 +312,14 @@ export const postEvaluateRoute = ( contentReferencesStore: ContentReferencesStore; }> = await Promise.all( connectors.map(async (connector) => { - const llmType = getLlmType(connector.actionTypeId); + const llmType = getLlmType(connector.type); const isOssModel = isOpenSourceModel(connector); const llmClass = getLlmClass(llmType); const createLlmInstance = async () => !inferenceChatModelDisabled ? inference.getChatModel({ request, - connectorId: connector.id, + connectorId: connector.connectorId, chatModelOptions: { signal: abortSignal, temperature: getDefaultArguments(llmType).temperature, @@ -331,10 +335,10 @@ export const postEvaluateRoute = ( }) : new llmClass({ actionsClient, - connectorId: connector.id, + connectorId: connector.connectorId, llmType, logger, - model: connector.config?.defaultModel, + model: getConnectorDefaultModel(connector), temperature: getDefaultArguments(llmType).temperature, signal: abortSignal, streaming: false, @@ -399,7 +403,7 @@ export const postEvaluateRoute = ( replacements, contentReferencesStore, inference, - connectorId: connector.id, + connectorId: connector.connectorId, size, telemetry: ctx.elasticAssistant.telemetry, ...(productDocsAvailable ? { llmTasks: ctx.elasticAssistant.llmTasks } : {}), @@ -412,9 +416,8 @@ export const postEvaluateRoute = ( let description: string | undefined; try { description = await getPrompt({ - actionsClient, - connector, - connectorId: connector.id, + getInferenceConnectorById: getInferenceConnectorById(inference, request), + connectorId: connector.connectorId, model: getModelOrOss(llmType, isOssModel), localPrompts: localToolPrompts, promptId: tool.name, @@ -437,7 +440,7 @@ export const postEvaluateRoute = ( ).filter((e) => e != null) as StructuredTool[]; return { - connectorId: connector.id, + connectorId: connector.connectorId, name: `${runName} - ${connector.name}`, llmType, isOssModel, @@ -445,8 +448,8 @@ export const postEvaluateRoute = ( graph: await getDefaultAssistantGraph({ contentReferencesStore, createLlmInstance, + getInferenceConnectorById: getInferenceConnectorById(inference, request), logger, - actionsClient, savedObjectsClient, tools, checkpointSaver: await assistantContext.getCheckpointSaver(), @@ -464,7 +467,7 @@ export const postEvaluateRoute = ( logger.debug(`input:\n ${JSON.stringify(evaluationInput, null, 2)}`); const defaultSystemPrompt = await localGetPrompt({ - actionsClient, + getInferenceConnectorById: getInferenceConnectorById(inference, request), connectorId, model: getModelOrOss(llmType, isOssModel), promptId: promptDictionary.systemPrompt, @@ -487,7 +490,7 @@ export const postEvaluateRoute = ( ), screenContextTimezone: 'UTC', }), - actionsClient, + getInferenceConnectorById: getInferenceConnectorById(inference, request), savedObjectsClient, connectorId, llmType, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/post_actions_connector_execute.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/post_actions_connector_execute.test.ts index 50bef939f8b53..bb6817e9ed15e 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/post_actions_connector_execute.test.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/post_actions_connector_execute.test.ts @@ -24,7 +24,7 @@ import type { ExecuteConnectorRequestBody } from '@kbn/elastic-assistant-common' import { defaultAssistantFeatures } from '@kbn/elastic-assistant-common'; import { licensingMock } from '@kbn/licensing-plugin/server/mocks'; import { appendAssistantMessageToConversation, langChainExecute } from './helpers'; -import { getPrompt } from '../lib/prompt'; +import { getPrompt, getInferenceConnectorById } from '../lib/prompt'; import { defaultInferenceEndpoints } from '@kbn/inference-common'; import expect from 'expect'; import { createMockConnector } from '@kbn/actions-plugin/server/application/connector/mocks'; @@ -36,6 +36,7 @@ jest.mock('../lib/build_response', () => ({ })); jest.mock('../lib/prompt'); const mockGetPrompt = getPrompt as jest.Mock; +const mockGetInferenceConnectorById = getInferenceConnectorById as jest.Mock; const mockStream = jest.fn().mockImplementation(() => new PassThrough()); const mockLangChainExecute = langChainExecute as jest.Mock; @@ -61,6 +62,9 @@ const mockContext = { actions: { getActionsClientWithRequest: jest.fn().mockResolvedValue(actionsClient), }, + inference: { + getConnectorById: jest.fn().mockResolvedValue(undefined), + }, llmTasks: { retrieveDocumentationAvailable: jest.fn(), retrieveDocumentation: jest.fn() }, getRegisteredTools: jest.fn(() => []), getRegisteredFeatures: jest.fn(() => defaultAssistantFeatures), @@ -188,6 +192,7 @@ describe('postActionsConnectorExecuteRoute', () => { }, }), ]); + mockGetInferenceConnectorById.mockReturnValue(() => Promise.resolve(undefined)); }); it('returns the expected response', async () => { diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index 5ad0b199f8c08..feb9629805180 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -25,7 +25,7 @@ import { } from '@kbn/elastic-assistant-common'; import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; import { defaultInferenceEndpoints } from '@kbn/inference-common'; -import { getPrompt } from '../lib/prompt'; +import { getPrompt, getInferenceConnectorById } from '../lib/prompt'; import { INVOKE_ASSISTANT_ERROR_EVENT } from '../lib/telemetry/event_based_telemetry'; import { buildResponse } from '../lib/build_response'; import type { ElasticAssistantRequestHandlerContext } from '../types'; @@ -138,9 +138,11 @@ export const postActionsConnectorExecuteRoute = ( inferenceId: defaultInferenceEndpoints.ELSER, })) ?? false; const actionsClient = await actions.getActionsClientWithRequest(request); - const connectors = await actionsClient.getBulk({ ids: [connectorId] }); - const connector = connectors.length > 0 ? connectors[0] : undefined; - const isOssModel = isOpenSourceModel(connector); + const inferenceConnector = await getInferenceConnectorById( + inference, + request + )(connectorId).catch(() => undefined); + const isOssModel = isOpenSourceModel(inferenceConnector); const conversationsDataClient = await assistantContext.getAIAssistantConversationsDataClient({ @@ -206,7 +208,7 @@ export const postActionsConnectorExecuteRoute = ( } if (promptIds) { const additionalSystemPrompt = await getPrompt({ - actionsClient, + getInferenceConnectorById: getInferenceConnectorById(inference, request), connectorId, // promptIds is promptId and promptGroupId ...promptIds, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/security_ai_prompts/find_prompts.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/security_ai_prompts/find_prompts.ts index ae0fd5d1d497b..e316529d63843 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/security_ai_prompts/find_prompts.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/security_ai_prompts/find_prompts.ts @@ -15,7 +15,7 @@ import { FindSecurityAIPromptsResponse, } from '@kbn/elastic-assistant-common'; import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; -import { getPromptsByGroupId } from '../../lib/prompt'; +import { getPromptsByGroupId, getInferenceConnectorById } from '../../lib/prompt'; import type { ElasticAssistantPluginRouter } from '../../types'; import { buildResponse } from '../utils'; import { performChecks } from '../helpers'; @@ -64,12 +64,13 @@ export const findSecurityAIPromptsRoute = (router: ElasticAssistantPluginRouter, if (!checkResponse.isSuccess) { return checkResponse.response; } - const actions = ctx.elasticAssistant.actions; - const actionsClient = await actions.getActionsClientWithRequest(request); const savedObjectsClient = ctx.elasticAssistant.savedObjectsClient; const prompts = await getPromptsByGroupId({ - actionsClient, + getInferenceConnectorById: getInferenceConnectorById( + ctx.elasticAssistant.inference, + request + ), connectorId: query.connector_id, promptGroupId: query.prompt_group_id, promptIds: query.prompt_ids, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/utils.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/utils.test.ts index 0a33d8b8ed9c7..78414f43ece6f 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/utils.test.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/utils.test.ts @@ -5,9 +5,18 @@ * 2.0. */ -import type { Connector } from '@kbn/actions-plugin/server/application/connector/types'; import { isOpenSourceModel } from './utils'; import { OPENAI_CHAT_URL, OpenAiProviderType } from '@kbn/connector-schemas/openai/constants'; +import { InferenceConnectorType } from '@kbn/inference-common'; +import type { InferenceConnector } from '@kbn/inference-common'; + +const baseConnector: Omit = { + name: 'test', + connectorId: 'test-id', + capabilities: {}, + isInferenceEndpoint: false, + isPreconfigured: false, +}; describe('Utils', () => { describe('isOpenSourceModel', () => { @@ -17,60 +26,120 @@ describe('Utils', () => { }); it('should return `false` when connector is a Bedrock', async () => { - const connector = { actionTypeId: '.bedrock' } as Connector; + const connector = { ...baseConnector, type: InferenceConnectorType.Bedrock, config: {} }; const isOpenModel = isOpenSourceModel(connector); expect(isOpenModel).toEqual(false); }); it('should return `false` when connector is a Gemini', async () => { - const connector = { actionTypeId: '.gemini' } as Connector; + const connector = { ...baseConnector, type: InferenceConnectorType.Gemini, config: {} }; const isOpenModel = isOpenSourceModel(connector); expect(isOpenModel).toEqual(false); }); it('should return `false` when connector is a OpenAI and API url is not specified', async () => { - const connector = { - actionTypeId: '.gen-ai', - } as unknown as Connector; + const connector = { ...baseConnector, type: InferenceConnectorType.OpenAI, config: {} }; const isOpenModel = isOpenSourceModel(connector); expect(isOpenModel).toEqual(false); }); it('should return `false` when connector is a OpenAI and OpenAI API url is specified', async () => { const connector = { - actionTypeId: '.gen-ai', + ...baseConnector, + type: InferenceConnectorType.OpenAI, config: { apiUrl: OPENAI_CHAT_URL }, - } as unknown as Connector; + }; const isOpenModel = isOpenSourceModel(connector); expect(isOpenModel).toEqual(false); }); it('should return `false` when connector is a AzureOpenAI', async () => { const connector = { - actionTypeId: '.gen-ai', + ...baseConnector, + type: InferenceConnectorType.OpenAI, config: { apiProvider: OpenAiProviderType.AzureAi }, - } as unknown as Connector; + }; const isOpenModel = isOpenSourceModel(connector); expect(isOpenModel).toEqual(false); }); it('should return `true` when connector is a OpenAI and non-OpenAI API url is specified', async () => { const connector = { - actionTypeId: '.gen-ai', + ...baseConnector, + type: InferenceConnectorType.OpenAI, config: { apiUrl: 'https://elastic.llm.com/llama/chat/completions' }, - } as unknown as Connector; + }; const isOpenModel = isOpenSourceModel(connector); expect(isOpenModel).toEqual(true); }); it('should return `true` when apiProvider of OpenAiProviderType.Other is specified', async () => { const connector = { - actionTypeId: '.gen-ai', + ...baseConnector, + type: InferenceConnectorType.OpenAI, + config: { apiUrl: OPENAI_CHAT_URL, apiProvider: OpenAiProviderType.Other }, + }; + const isOpenModel = isOpenSourceModel(connector); + expect(isOpenModel).toEqual(true); + }); + + it('should return `false` when connector is a .inference type with non-openai provider', async () => { + const connector = { + ...baseConnector, + type: InferenceConnectorType.Inference, + config: { provider: 'bedrock', providerConfig: { model_id: 'claude-3' } }, + }; + const isOpenModel = isOpenSourceModel(connector); + expect(isOpenModel).toEqual(false); + }); + + it('should return `false` when connector is a .inference type with openai provider but no custom URL', async () => { + const connector = { + ...baseConnector, + type: InferenceConnectorType.Inference, + config: { provider: 'openai', providerConfig: { model_id: 'gpt-4o' } }, + }; + const isOpenModel = isOpenSourceModel(connector); + expect(isOpenModel).toEqual(false); + }); + + it('should return `true` when connector is a .inference type with openai provider and a custom URL', async () => { + const connector = { + ...baseConnector, + type: InferenceConnectorType.Inference, + config: { + provider: 'openai', + providerConfig: { model_id: 'llama3', url: 'https://my-ollama.internal/v1' }, + }, + }; + const isOpenModel = isOpenSourceModel(connector); + expect(isOpenModel).toEqual(true); + }); + + it('should return `false` when connector is a .inference type with openai service (native endpoint) pointing to api.openai.com', async () => { + const connector = { + ...baseConnector, + type: InferenceConnectorType.Inference, + isInferenceEndpoint: true, + config: { + service: 'openai', + providerConfig: { model_id: 'gpt-4o', url: 'https://api.openai.com/v1' }, + }, + }; + const isOpenModel = isOpenSourceModel(connector); + expect(isOpenModel).toEqual(false); + }); + + it('should return `true` when connector is a .inference native endpoint with openai service and a custom URL', async () => { + const connector = { + ...baseConnector, + type: InferenceConnectorType.Inference, + isInferenceEndpoint: true, config: { - apiUrl: OPENAI_CHAT_URL, - apiProvider: OpenAiProviderType.Other, + service: 'openai', + providerConfig: { model_id: 'llama3', url: 'https://my-ollama.internal/v1' }, }, - } as unknown as Connector; + }; const isOpenModel = isOpenSourceModel(connector); expect(isOpenModel).toEqual(true); }); diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/utils.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/utils.ts index bb744ff8f8d41..a9dc7ad69a62d 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/utils.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/utils.ts @@ -18,8 +18,9 @@ import { ActionsClientChatBedrockConverse, ActionsClientChatVertexAI, } from '@kbn/langchain/server'; -import type { Connector } from '@kbn/actions-plugin/server/application/connector/types'; import { OPENAI_CHAT_URL, OpenAiProviderType } from '@kbn/connector-schemas/openai/constants'; +import { InferenceConnectorType } from '@kbn/inference-common'; +import type { InferenceConnector } from '@kbn/inference-common'; import { CustomHttpRequestError } from './custom_http_request_error'; export interface BulkError { @@ -201,31 +202,34 @@ export const getLlmClass = (llmType?: string) => { } }; -export const isOpenSourceModel = (connector?: Connector): boolean => { +export const isOpenSourceModel = (connector?: InferenceConnector): boolean => { if (connector == null) { return false; } - const llmType = getLlmType(connector.actionTypeId); - const isOpenAiType = llmType === 'openai'; - - if (!isOpenAiType) { - return false; - } - const connectorApiProvider = connector.config?.apiProvider - ? (connector.config?.apiProvider as OpenAiProviderType) - : undefined; - if (connectorApiProvider === OpenAiProviderType.Other) { - return true; + if (connector.type === InferenceConnectorType.OpenAI) { + const connectorApiProvider = connector.config?.apiProvider as OpenAiProviderType | undefined; + if (connectorApiProvider === OpenAiProviderType.Other) { + return true; + } + const connectorApiUrl = connector.config?.apiUrl as string | undefined; + return ( + !!connectorApiUrl && + connectorApiUrl !== OPENAI_CHAT_URL && + connectorApiProvider !== OpenAiProviderType.AzureAi + ); } - const connectorApiUrl = connector.config?.apiUrl - ? (connector.config.apiUrl as string) - : undefined; + if (connector.type === InferenceConnectorType.Inference) { + const provider: string | undefined = connector.config?.provider; + const service: string | undefined = connector.config?.service; + if (provider !== 'openai' && service !== 'openai') { + return false; + } + // A custom URL that doesn't point to api.openai.com indicates a self-hosted/OSS model + const url: string | undefined = connector.config?.providerConfig?.url; + return !!url && !url.includes('api.openai.com'); + } - return ( - !!connectorApiUrl && - connectorApiUrl !== OPENAI_CHAT_URL && - connectorApiProvider !== OpenAiProviderType.AzureAi - ); + return false; }; diff --git a/x-pack/solutions/security/plugins/security_solution/moon.yml b/x-pack/solutions/security/plugins/security_solution/moon.yml index 0986e8c237304..1718cefdded8c 100644 --- a/x-pack/solutions/security/plugins/security_solution/moon.yml +++ b/x-pack/solutions/security/plugins/security_solution/moon.yml @@ -76,6 +76,7 @@ dependsOn: - '@kbn/kibana-react-plugin' - '@kbn/ecs-data-quality-dashboard' - '@kbn/elastic-assistant' + - '@kbn/inference-connectors' - '@kbn/elastic-assistant-plugin' - '@kbn/data-views-plugin' - '@kbn/datemath' diff --git a/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/index.test.tsx b/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/index.test.tsx index a3ad98f2f894a..082ea18bce441 100644 --- a/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/index.test.tsx +++ b/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/index.test.tsx @@ -27,7 +27,7 @@ import { UpsellingProvider } from '../../common/components/upselling_provider'; import { mockFindAnonymizationFieldsResponse } from './mock/mock_find_anonymization_fields_response'; import { ATTACK_DISCOVERY_PAGE_TITLE } from './page_title/translations'; import { useAttackDiscovery } from './use_attack_discovery'; -import { useLoadConnectors } from '@kbn/elastic-assistant/impl/connectorland/use_load_connectors'; +import { useLoadConnectors } from '@kbn/inference-connectors'; import { SECURITY_UI_SHOW_PRIVILEGE } from '@kbn/security-solution-features/constants'; import { CALLOUT_TEST_DATA_ID } from './moving_attacks_callout'; import { useMovingAttacksCallout } from './moving_attacks_callout/use_moving_attacks_callout'; @@ -69,7 +69,7 @@ jest.mock( }) ); -jest.mock('@kbn/elastic-assistant/impl/connectorland/use_load_connectors', () => ({ +jest.mock('@kbn/inference-connectors', () => ({ useLoadConnectors: jest.fn(() => ({ isFetched: true, data: mockConnectors, diff --git a/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/index.tsx b/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/index.tsx index 4de9597b8fa4c..3b54dac45467f 100644 --- a/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/index.tsx +++ b/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/index.tsx @@ -19,8 +19,8 @@ import { QUERY_LOCAL_STORAGE_KEY, START_LOCAL_STORAGE_KEY, useAssistantContext, - useLoadConnectors, } from '@kbn/elastic-assistant'; +import { useLoadConnectors } from '@kbn/inference-connectors'; import type { Filter, Query } from '@kbn/es-query'; import React, { useCallback, useEffect, useMemo, useState } from 'react'; import useLocalStorage from 'react-use/lib/useLocalStorage'; @@ -58,10 +58,10 @@ const AttackDiscoveryPageComponent: React.FC = () => { services: { uiSettings, settings }, } = useKibana(); - const { http, inferenceEnabled } = useAssistantContext(); + const { http } = useAssistantContext(); const { data: aiConnectors } = useLoadConnectors({ http, - inferenceEnabled, + featureId: 'attack_discovery', settings, }); diff --git a/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/settings_flyout/schedule/create_flyout/index.test.tsx b/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/settings_flyout/schedule/create_flyout/index.test.tsx index cddac181f3f79..51516114ce368 100644 --- a/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/settings_flyout/schedule/create_flyout/index.test.tsx +++ b/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/settings_flyout/schedule/create_flyout/index.test.tsx @@ -8,7 +8,7 @@ import React from 'react'; import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'; import { triggersActionsUiMock } from '@kbn/triggers-actions-ui-plugin/public/mocks'; -import { useLoadConnectors } from '@kbn/elastic-assistant/impl/connectorland/use_load_connectors'; +import { useLoadConnectors } from '@kbn/inference-connectors'; import { CreateFlyout } from '.'; import * as i18n from './translations'; @@ -18,7 +18,7 @@ import { TestProviders } from '../../../../../common/mock/test_providers'; import { useSourcererDataView } from '../../../../../sourcerer/containers'; import { useCreateAttackDiscoverySchedule } from '../logic/use_create_schedule'; -jest.mock('@kbn/elastic-assistant/impl/connectorland/use_load_connectors'); +jest.mock('@kbn/inference-connectors'); jest.mock('../logic/use_create_schedule'); jest.mock('../../../../../common/lib/kibana'); jest.mock('../../../../../sourcerer/containers'); diff --git a/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/settings_flyout/schedule/create_flyout/index.tsx b/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/settings_flyout/schedule/create_flyout/index.tsx index e27ee679acd8d..945757ae9a1b9 100644 --- a/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/settings_flyout/schedule/create_flyout/index.tsx +++ b/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/settings_flyout/schedule/create_flyout/index.tsx @@ -16,7 +16,8 @@ import { keys, useGeneratedHtmlId, } from '@elastic/eui'; -import { useAssistantContext, useLoadConnectors } from '@kbn/elastic-assistant'; +import { useAssistantContext } from '@kbn/elastic-assistant'; +import { useLoadConnectors } from '@kbn/inference-connectors'; import React, { useCallback, useState } from 'react'; import { PageScope } from '../../../../../data_view_manager/constants'; @@ -61,6 +62,7 @@ export const CreateFlyout: React.FC = React.memo(({ onClose }) => { const { alertsIndexPattern, http, settings } = useAssistantContext(); const { data: aiConnectors, isLoading: isLoadingConnectors } = useLoadConnectors({ http, + featureId: 'attack_discovery', settings, }); diff --git a/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/settings_flyout/schedule/details_flyout/index.test.tsx b/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/settings_flyout/schedule/details_flyout/index.test.tsx index aff9a3e4971ac..b0d7412dbea86 100644 --- a/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/settings_flyout/schedule/details_flyout/index.test.tsx +++ b/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/settings_flyout/schedule/details_flyout/index.test.tsx @@ -8,7 +8,7 @@ import React from 'react'; import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'; import { triggersActionsUiMock } from '@kbn/triggers-actions-ui-plugin/public/mocks'; -import { useLoadConnectors } from '@kbn/elastic-assistant/impl/connectorland/use_load_connectors'; +import { useLoadConnectors } from '@kbn/inference-connectors'; import { DetailsFlyout } from '.'; @@ -21,7 +21,7 @@ import { mockAttackDiscoverySchedule } from '../../../mock/mock_attack_discovery import { ATTACK_DISCOVERY_FEATURE_ID } from '../../../../../../common/constants'; import { waitForEuiToolTipVisible } from '@elastic/eui/lib/test/rtl'; -jest.mock('@kbn/elastic-assistant/impl/connectorland/use_load_connectors'); +jest.mock('@kbn/inference-connectors'); jest.mock('../logic/use_update_schedule'); jest.mock('../logic/use_get_schedule'); jest.mock('../../../../../common/lib/kibana'); diff --git a/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/settings_flyout/schedule/details_flyout/index.tsx b/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/settings_flyout/schedule/details_flyout/index.tsx index e244bc26b0522..b0f3b63e7dd1b 100644 --- a/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/settings_flyout/schedule/details_flyout/index.tsx +++ b/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/settings_flyout/schedule/details_flyout/index.tsx @@ -22,7 +22,8 @@ import { } from '@elastic/eui'; import { css } from '@emotion/react'; import type { RuleAction } from '@kbn/alerting-types'; -import { useAssistantContext, useLoadConnectors } from '@kbn/elastic-assistant'; +import { useAssistantContext } from '@kbn/elastic-assistant'; +import { useLoadConnectors } from '@kbn/inference-connectors'; import { DEFAULT_END, DEFAULT_START } from '@kbn/elastic-assistant-common'; import type { Filter } from '@kbn/es-query'; @@ -77,6 +78,7 @@ export const DetailsFlyout: React.FC = React.memo(({ scheduleId, onClose const { alertsIndexPattern, http, settings } = useAssistantContext(); const { data: aiConnectors, isLoading: isLoadingConnectors } = useLoadConnectors({ http, + featureId: 'attack_discovery', settings, }); const { data: { schedule } = { schedule: undefined }, isLoading: isLoadingSchedule } = diff --git a/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/use_attack_discovery/index.test.tsx b/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/use_attack_discovery/index.test.tsx index b6e860c197ee3..0f8e1eebc4a88 100644 --- a/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/use_attack_discovery/index.test.tsx +++ b/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/use_attack_discovery/index.test.tsx @@ -49,6 +49,9 @@ jest.mock('@kbn/elastic-assistant', () => ({ latestAlerts: 20, }, }), +})); + +jest.mock('@kbn/inference-connectors', () => ({ useLoadConnectors: jest.fn(() => ({ isFetched: true, data: mockConnectors, diff --git a/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/use_attack_discovery/index.tsx b/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/use_attack_discovery/index.tsx index d95fa2d28a6fd..8954a360aa6a1 100644 --- a/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/use_attack_discovery/index.tsx +++ b/x-pack/solutions/security/plugins/security_solution/public/attack_discovery/pages/use_attack_discovery/index.tsx @@ -5,7 +5,8 @@ * 2.0. */ -import { useAssistantContext, useLoadConnectors } from '@kbn/elastic-assistant'; +import { useAssistantContext } from '@kbn/elastic-assistant'; +import { useLoadConnectors } from '@kbn/inference-connectors'; import { API_VERSIONS, ATTACK_DISCOVERY_GENERATE, @@ -59,6 +60,7 @@ export const useAttackDiscovery = ({ const { data: aiConnectors } = useLoadConnectors({ http, + featureId: 'attack_discovery', settings, }); diff --git a/x-pack/solutions/security/plugins/security_solution/public/detections/components/attacks/content.tsx b/x-pack/solutions/security/plugins/security_solution/public/detections/components/attacks/content.tsx index 55e3ffab147e0..b19912e14e1ae 100644 --- a/x-pack/solutions/security/plugins/security_solution/public/detections/components/attacks/content.tsx +++ b/x-pack/solutions/security/plugins/security_solution/public/detections/components/attacks/content.tsx @@ -19,7 +19,8 @@ import { noop } from 'lodash/fp'; import type { DataView } from '@kbn/data-views-plugin/common'; import { isEqual } from 'lodash'; -import { useAssistantContext, useLoadConnectors } from '@kbn/elastic-assistant'; +import { useAssistantContext } from '@kbn/elastic-assistant'; +import { useLoadConnectors } from '@kbn/inference-connectors'; import type { Filter } from '@kbn/es-query'; import type { FilterGroupHandler } from '@kbn/alerts-ui-shared'; import { dataTableSelectors, tableDefaults, TableId } from '@kbn/securitysolution-data-table'; @@ -84,10 +85,10 @@ export const AttacksPageContent = React.memo(({ dataView }: AttacksPageContentPr } = useKibana(); const { euiTheme } = useEuiTheme(); - const { http, inferenceEnabled } = useAssistantContext(); + const { http } = useAssistantContext(); const { data: aiConnectors } = useLoadConnectors({ http, - inferenceEnabled, + featureId: 'attack_discovery', settings, }); const { from } = useGlobalTime(); diff --git a/x-pack/solutions/security/plugins/security_solution/public/management/pages/endpoint_hosts/view/details/components/insights/workflow_insights_scan.tsx b/x-pack/solutions/security/plugins/security_solution/public/management/pages/endpoint_hosts/view/details/components/insights/workflow_insights_scan.tsx index 75a473431065d..349f512be6368 100644 --- a/x-pack/solutions/security/plugins/security_solution/public/management/pages/endpoint_hosts/view/details/components/insights/workflow_insights_scan.tsx +++ b/x-pack/solutions/security/plugins/security_solution/public/management/pages/endpoint_hosts/view/details/components/insights/workflow_insights_scan.tsx @@ -25,9 +25,9 @@ import { DEFEND_INSIGHTS_STORAGE_KEY, ConnectorSelectorInline, DEFAULT_ASSISTANT_NAMESPACE, - useLoadConnectors, AssistantSpaceIdProvider, } from '@kbn/elastic-assistant'; +import { useLoadConnectors } from '@kbn/inference-connectors'; import { FormattedMessage } from '@kbn/i18n-react'; import { useUserPrivileges } from '../../../../../../../common/components/user_privileges'; @@ -62,6 +62,7 @@ export const WorkflowInsightsScanSection = ({ const { http, settings, docLinks } = useKibana().services; const { data: aiConnectors } = useLoadConnectors({ http, + featureId: 'defend_insights', settings, }); const { canWriteWorkflowInsights } = useUserPrivileges().endpointPrivileges; diff --git a/x-pack/solutions/security/plugins/security_solution/public/management/pages/endpoint_hosts/view/index.test.tsx b/x-pack/solutions/security/plugins/security_solution/public/management/pages/endpoint_hosts/view/index.test.tsx index 9a89df1d6b2f5..acf42b309e40b 100644 --- a/x-pack/solutions/security/plugins/security_solution/public/management/pages/endpoint_hosts/view/index.test.tsx +++ b/x-pack/solutions/security/plugins/security_solution/public/management/pages/endpoint_hosts/view/index.test.tsx @@ -161,6 +161,9 @@ const timepickerRanges = [ }, ]; +jest.mock('@kbn/inference-connectors', () => ({ + useLoadConnectors: jest.fn().mockReturnValue({ isLoading: false, data: [] }), +})); jest.mock('../../../../common/lib/kibana'); jest.mock('../../../../common/hooks/use_license'); jest.mock('../../../hooks/endpoint/use_get_endpoint_details'); diff --git a/x-pack/solutions/security/plugins/security_solution/server/lib/entity_analytics/entity_details/routes/entity_details_highlight.ts b/x-pack/solutions/security/plugins/security_solution/server/lib/entity_analytics/entity_details/routes/entity_details_highlight.ts index 5331a534f6a2c..27ef2c1548ea1 100644 --- a/x-pack/solutions/security/plugins/security_solution/server/lib/entity_analytics/entity_details/routes/entity_details_highlight.ts +++ b/x-pack/solutions/security/plugins/security_solution/server/lib/entity_analytics/entity_details/routes/entity_details_highlight.ts @@ -64,9 +64,8 @@ export const entityDetailsHighlightsRoute = ( const fromDate = request.body.from; const toDate = request.body.to; - const [coreStart] = await getStartServices(); + const [coreStart, { inference }] = await getStartServices(); const securitySolution = await context.securitySolution; - const actions = await context.actions; const esClient = coreStart.elasticsearch.client.asInternalUser; const spaceId = securitySolution.getSpaceId(); @@ -120,7 +119,7 @@ export const entityDetailsHighlightsRoute = ( ); const prompt = await getPrompt({ - actionsClient: actions.getActionsClient(), + getInferenceConnectorById: (id) => inference.getConnectorById(id, request), connectorId, promptId: promptDictionary.entityDetailsHighlights, promptGroupId: promptGroupId.aiForEntityAnalytics, diff --git a/x-pack/solutions/security/plugins/security_solution/tsconfig.json b/x-pack/solutions/security/plugins/security_solution/tsconfig.json index e2ee132f13a45..9162cb11b3b0d 100644 --- a/x-pack/solutions/security/plugins/security_solution/tsconfig.json +++ b/x-pack/solutions/security/plugins/security_solution/tsconfig.json @@ -78,6 +78,7 @@ "@kbn/kibana-react-plugin", "@kbn/ecs-data-quality-dashboard", "@kbn/elastic-assistant", + "@kbn/inference-connectors", "@kbn/elastic-assistant-plugin", "@kbn/data-views-plugin", "@kbn/datemath", diff --git a/x-pack/solutions/workplaceai/plugins/workplace_ai_app/kibana.jsonc b/x-pack/solutions/workplaceai/plugins/workplace_ai_app/kibana.jsonc index e3df9a36d1a0a..69764cae02708 100644 --- a/x-pack/solutions/workplaceai/plugins/workplace_ai_app/kibana.jsonc +++ b/x-pack/solutions/workplaceai/plugins/workplace_ai_app/kibana.jsonc @@ -15,6 +15,6 @@ "configPath": ["xpack", "workplaceAIApp"], "requiredPlugins": ["inference", "actions", "features", "dataCatalog", "spaces", "triggersActionsUi", "workflowsExtensions"], "optionalPlugins": ["cloud", "share"], - "requiredBundles": ["kibanaReact", "stackConnectors"] + "requiredBundles": ["kibanaReact"] } } diff --git a/x-pack/solutions/workplaceai/plugins/workplace_ai_app/moon.yml b/x-pack/solutions/workplaceai/plugins/workplace_ai_app/moon.yml index 60140f6fa2791..a1ab34eaf1219 100644 --- a/x-pack/solutions/workplaceai/plugins/workplace_ai_app/moon.yml +++ b/x-pack/solutions/workplaceai/plugins/workplace_ai_app/moon.yml @@ -42,7 +42,7 @@ dependsOn: - '@kbn/share-plugin' - '@kbn/spaces-plugin' - '@kbn/triggers-actions-ui-plugin' - - '@kbn/elastic-assistant' + - '@kbn/inference-connectors' - '@kbn/ai-assistant-connector-selector-action' - '@kbn/core-application-browser' - '@kbn/workflows-extensions' diff --git a/x-pack/solutions/workplaceai/plugins/workplace_ai_app/public/application/components/connector_selector/connector_selector.tsx b/x-pack/solutions/workplaceai/plugins/workplace_ai_app/public/application/components/connector_selector/connector_selector.tsx index f0f028ebbc8fb..9747b154c82dd 100644 --- a/x-pack/solutions/workplaceai/plugins/workplace_ai_app/public/application/components/connector_selector/connector_selector.tsx +++ b/x-pack/solutions/workplaceai/plugins/workplace_ai_app/public/application/components/connector_selector/connector_selector.tsx @@ -12,7 +12,7 @@ import { ConnectorSelectable, type ConnectorSelectableComponentProps, } from '@kbn/ai-assistant-connector-selector-action'; -import { useLoadConnectors } from '@kbn/elastic-assistant'; +import { useLoadConnectors } from '@kbn/inference-connectors'; import { STACK_CONNECTORS_MANAGEMENT_ID } from '../../../../common'; import { useKibana } from '../../hooks/use_kibana'; import { useNavigateToApp } from '../../hooks/use_navigate_to_app'; @@ -49,11 +49,11 @@ export const ConnectorSelector: React.FC = ({ const { data: dataConnectors, isLoading } = useLoadConnectors({ http, + featureId: 'workplace_ai', settings: { client: uiSettings, globalClient: uiSettings, }, - inferenceEnabled: true, }); const connectors = useMemo(() => dataConnectors ?? [], [dataConnectors]); diff --git a/x-pack/solutions/workplaceai/plugins/workplace_ai_app/tsconfig.json b/x-pack/solutions/workplaceai/plugins/workplace_ai_app/tsconfig.json index bea6539a9c6c4..b32e55e8060b6 100644 --- a/x-pack/solutions/workplaceai/plugins/workplace_ai_app/tsconfig.json +++ b/x-pack/solutions/workplaceai/plugins/workplace_ai_app/tsconfig.json @@ -37,7 +37,7 @@ "@kbn/share-plugin", "@kbn/spaces-plugin", "@kbn/triggers-actions-ui-plugin", - "@kbn/elastic-assistant", + "@kbn/inference-connectors", "@kbn/ai-assistant-connector-selector-action", "@kbn/core-application-browser", "@kbn/workflows-extensions", diff --git a/yarn.lock b/yarn.lock index 9df3148d73b4a..d94d0cc360b55 100644 --- a/yarn.lock +++ b/yarn.lock @@ -7332,6 +7332,10 @@ version "0.0.0" uid "" +"@kbn/inference-connectors@link:x-pack/platform/packages/shared/kbn-inference-connectors": + version "0.0.0" + uid "" + "@kbn/inference-endpoint-plugin@link:x-pack/platform/plugins/shared/inference_endpoint": version "0.0.0" uid ""