diff --git a/api_docs/kbn_elastic_assistant_common.devdocs.json b/api_docs/kbn_elastic_assistant_common.devdocs.json index 4712a39d62552..2944f7c87b3af 100644 --- a/api_docs/kbn_elastic_assistant_common.devdocs.json +++ b/api_docs/kbn_elastic_assistant_common.devdocs.json @@ -6182,15 +6182,15 @@ }, { "parentPluginId": "@kbn/elastic-assistant-common", - "id": "def-common.INFERENCE_CHAT_MODEL_ENABLED_FEATURE_FLAG", + "id": "def-common.INFERENCE_CHAT_MODEL_DISABLED_FEATURE_FLAG", "type": "string", "tags": [], - "label": "INFERENCE_CHAT_MODEL_ENABLED_FEATURE_FLAG", + "label": "INFERENCE_CHAT_MODEL_DISABLED_FEATURE_FLAG", "description": [ - "\nThis feature flag enables the InferenceChatModel feature.\n\nIt may be overridden via the following setting in `kibana.yml` or `kibana.dev.yml`:\n```\nfeature_flags.overrides:\n securitySolution.inferenceChatModelEnabled: true\n```" + "\nThis feature flag disables the InferenceChatModel feature.\n\nIt may be overridden via the following setting in `kibana.yml` or `kibana.dev.yml`:\n```\nfeature_flags.overrides:\n securitySolution.inferenceChatModelDisabled: true\n```" ], "signature": [ - "\"securitySolution.inferenceChatModelEnabled\"" + "\"securitySolution.inferenceChatModelDisabled\"" ], "path": "x-pack/platform/packages/shared/kbn-elastic-assistant-common/constants.ts", "deprecated": false, @@ -12173,4 +12173,4 @@ } ] } -} \ No newline at end of file +} diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant-common/constants.ts b/x-pack/platform/packages/shared/kbn-elastic-assistant-common/constants.ts index 55ef349a897da..eb2dd1065815a 100755 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant-common/constants.ts +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant-common/constants.ts @@ -128,13 +128,13 @@ export const ATTACK_DISCOVERY_ALERTS_COMMON_INDEX_PREFIX = '.alerts-security.attack.discovery.alerts' as const; /** - * This feature flag enables the InferenceChatModel feature. + * This feature flag disables the InferenceChatModel feature. * * It may be overridden via the following setting in `kibana.yml` or `kibana.dev.yml`: * ``` * feature_flags.overrides: - * securitySolution.inferenceChatModelEnabled: true + * securitySolution.inferenceChatModelDisabled: true * ``` */ -export const INFERENCE_CHAT_MODEL_ENABLED_FEATURE_FLAG = - 'securitySolution.inferenceChatModelEnabled' as const; +export const INFERENCE_CHAT_MODEL_DISABLED_FEATURE_FLAG = + 'securitySolution.inferenceChatModelDisabled' as const; diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/executors/types.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/executors/types.ts index 1088dcbff6e6b..02adb6753f15c 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/executors/types.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/executors/types.ts @@ -61,7 +61,7 @@ export interface AgentExecutorParams { llmType?: string; isOssModel?: boolean; inference: InferenceServerStart; - inferenceChatModelEnabled?: boolean; + inferenceChatModelDisabled?: boolean; logger: Logger; onNewReplacements?: (newReplacements: Replacements) => void; replacements: Replacements; diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/agentRunnable.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/agentRunnable.ts index 0f4bcd3fef1af..a6d54ec173fd7 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/agentRunnable.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/agentRunnable.ts @@ -27,7 +27,7 @@ export const agentRunnableFactory = async ({ llm, llmType, tools, - inferenceChatModelEnabled, + inferenceChatModelDisabled, isOpenAI, isStream, prompt, @@ -37,7 +37,7 @@ export const agentRunnableFactory = async ({ | ActionsClientChatVertexAI | ActionsClientChatOpenAI | InferenceChatModel; - inferenceChatModelEnabled: boolean; + inferenceChatModelDisabled: boolean; isOpenAI: boolean; llmType: string | undefined; tools: StructuredToolInterface[] | ToolDefinition[]; @@ -51,7 +51,7 @@ export const agentRunnableFactory = async ({ prompt, } as const; - if (!inferenceChatModelEnabled && (isOpenAI || llmType === 'inference')) { + if (inferenceChatModelDisabled && (isOpenAI || llmType === 'inference')) { return createOpenAIToolsAgent(params); } diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.test.ts index fc4ceee290981..1802198340be8 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.test.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.test.ts @@ -64,6 +64,7 @@ describe('streamGraph', () => { logger: mockLogger, onLlmResponse: mockOnLlmResponse, request: mockRequest, + inferenceChatModelDisabled: true, isEnabledKnowledgeBase: false, telemetry: { reportEvent: jest.fn(), @@ -80,154 +81,156 @@ describe('streamGraph', () => { }); }); describe('OpenAI Function Agent streaming', () => { - it('should execute the graph in streaming mode - OpenAI + isOssModel = false', async () => { - mockStreamEvents.mockReturnValue({ - async *[Symbol.asyncIterator]() { - yield { - event: 'on_llm_stream', - data: { chunk: { message: { content: 'content' } } }, - tags: [AGENT_NODE_TAG], - }; - yield { - event: 'on_llm_end', - data: { - output: { - generations: [ - [{ generationInfo: { finish_reason: 'stop' }, text: 'final message' }], - ], + describe('Inference Chat Model Disabled', () => { + it('should execute the graph in streaming mode - OpenAI + isOssModel = false', async () => { + mockStreamEvents.mockReturnValue({ + async *[Symbol.asyncIterator]() { + yield { + event: 'on_llm_stream', + data: { chunk: { message: { content: 'content' } } }, + tags: [AGENT_NODE_TAG], + }; + yield { + event: 'on_llm_end', + data: { + output: { + generations: [ + [{ generationInfo: { finish_reason: 'stop' }, text: 'final message' }], + ], + }, }, - }, - tags: [AGENT_NODE_TAG], - }; - }, - }); + tags: [AGENT_NODE_TAG], + }; + }, + }); - const response = await streamGraph(requestArgs); + const response = await streamGraph(requestArgs); - expect(response).toBe(mockResponseWithHeaders); - await waitFor(() => { - expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' }); - expect(mockOnLlmResponse).toHaveBeenCalledWith( - 'final message', - { transactionId: 'transactionId', traceId: 'traceId' }, - false - ); + expect(response).toBe(mockResponseWithHeaders); + await waitFor(() => { + expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' }); + expect(mockOnLlmResponse).toHaveBeenCalledWith( + 'final message', + { transactionId: 'transactionId', traceId: 'traceId' }, + false + ); + }); }); - }); - it('on_llm_end events with finish_reason != stop should not end the stream', async () => { - mockStreamEvents.mockReturnValue({ - async *[Symbol.asyncIterator]() { - yield { - event: 'on_llm_stream', - data: { chunk: { message: { content: 'content' } } }, - tags: [AGENT_NODE_TAG], - }; - yield { - event: 'on_llm_end', - data: { - output: { - generations: [[{ generationInfo: { finish_reason: 'function_call' }, text: '' }]], + it('on_llm_end events with finish_reason != stop should not end the stream', async () => { + mockStreamEvents.mockReturnValue({ + async *[Symbol.asyncIterator]() { + yield { + event: 'on_llm_stream', + data: { chunk: { message: { content: 'content' } } }, + tags: [AGENT_NODE_TAG], + }; + yield { + event: 'on_llm_end', + data: { + output: { + generations: [[{ generationInfo: { finish_reason: 'function_call' }, text: '' }]], + }, }, - }, - tags: [AGENT_NODE_TAG], - }; - }, - }); + tags: [AGENT_NODE_TAG], + }; + }, + }); - const response = await streamGraph(requestArgs); + const response = await streamGraph(requestArgs); - expect(response).toBe(mockResponseWithHeaders); - await waitFor(() => { - expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' }); - expect(mockOnLlmResponse).not.toHaveBeenCalled(); + expect(response).toBe(mockResponseWithHeaders); + await waitFor(() => { + expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' }); + expect(mockOnLlmResponse).not.toHaveBeenCalled(); + }); }); - }); - it('on_llm_end events without a finish_reason should end the stream', async () => { - mockStreamEvents.mockReturnValue({ - async *[Symbol.asyncIterator]() { - yield { - event: 'on_llm_stream', - data: { chunk: { message: { content: 'content' } } }, - tags: [AGENT_NODE_TAG], - }; - yield { - event: 'on_llm_end', - data: { - output: { - generations: [[{ generationInfo: {}, text: 'final message' }]], + it('on_llm_end events without a finish_reason should end the stream', async () => { + mockStreamEvents.mockReturnValue({ + async *[Symbol.asyncIterator]() { + yield { + event: 'on_llm_stream', + data: { chunk: { message: { content: 'content' } } }, + tags: [AGENT_NODE_TAG], + }; + yield { + event: 'on_llm_end', + data: { + output: { + generations: [[{ generationInfo: {}, text: 'final message' }]], + }, }, - }, - tags: [AGENT_NODE_TAG], - }; - }, - }); + tags: [AGENT_NODE_TAG], + }; + }, + }); - const response = await streamGraph(requestArgs); + const response = await streamGraph(requestArgs); - expect(response).toBe(mockResponseWithHeaders); - await waitFor(() => { - expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' }); - expect(mockOnLlmResponse).toHaveBeenCalledWith( - 'final message', - { transactionId: 'transactionId', traceId: 'traceId' }, - false - ); + expect(response).toBe(mockResponseWithHeaders); + await waitFor(() => { + expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' }); + expect(mockOnLlmResponse).toHaveBeenCalledWith( + 'final message', + { transactionId: 'transactionId', traceId: 'traceId' }, + false + ); + }); }); - }); - it('on_llm_end events is called with chunks if there is no final text value', async () => { - mockStreamEvents.mockReturnValue({ - async *[Symbol.asyncIterator]() { - yield { - event: 'on_llm_stream', - data: { chunk: { message: { content: 'content' } } }, - tags: [AGENT_NODE_TAG], - }; - yield { - event: 'on_llm_end', - data: { - output: { - generations: [[{ generationInfo: {}, text: '' }]], + it('on_llm_end events is called with chunks if there is no final text value', async () => { + mockStreamEvents.mockReturnValue({ + async *[Symbol.asyncIterator]() { + yield { + event: 'on_llm_stream', + data: { chunk: { message: { content: 'content' } } }, + tags: [AGENT_NODE_TAG], + }; + yield { + event: 'on_llm_end', + data: { + output: { + generations: [[{ generationInfo: {}, text: '' }]], + }, }, - }, - tags: [AGENT_NODE_TAG], - }; - }, - }); + tags: [AGENT_NODE_TAG], + }; + }, + }); - const response = await streamGraph(requestArgs); + const response = await streamGraph(requestArgs); - expect(response).toBe(mockResponseWithHeaders); - await waitFor(() => { - expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' }); - expect(mockOnLlmResponse).toHaveBeenCalledWith( - 'content', - { transactionId: 'transactionId', traceId: 'traceId' }, - false - ); - }); - }); - it('on_llm_end does not call handleStreamEnd if generations is undefined', async () => { - mockStreamEvents.mockReturnValue({ - async *[Symbol.asyncIterator]() { - yield { - event: 'on_llm_stream', - data: { chunk: { message: { content: 'content' } } }, - tags: [AGENT_NODE_TAG], - }; - yield { - event: 'on_llm_end', - data: {}, - tags: [AGENT_NODE_TAG], - }; - }, + expect(response).toBe(mockResponseWithHeaders); + await waitFor(() => { + expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' }); + expect(mockOnLlmResponse).toHaveBeenCalledWith( + 'content', + { transactionId: 'transactionId', traceId: 'traceId' }, + false + ); + }); }); + it('on_llm_end does not call handleStreamEnd if generations is undefined', async () => { + mockStreamEvents.mockReturnValue({ + async *[Symbol.asyncIterator]() { + yield { + event: 'on_llm_stream', + data: { chunk: { message: { content: 'content' } } }, + tags: [AGENT_NODE_TAG], + }; + yield { + event: 'on_llm_end', + data: {}, + tags: [AGENT_NODE_TAG], + }; + }, + }); - const response = await streamGraph(requestArgs); + const response = await streamGraph(requestArgs); - expect(response).toBe(mockResponseWithHeaders); - await waitFor(() => { - expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' }); - expect(mockOnLlmResponse).not.toHaveBeenCalled(); + expect(response).toBe(mockResponseWithHeaders); + await waitFor(() => { + expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' }); + expect(mockOnLlmResponse).not.toHaveBeenCalled(); + }); }); }); }); @@ -337,5 +340,16 @@ describe('streamGraph', () => { }); await expectConditions(response, true); }); + it('should execute the graph in streaming mode - OpenAI + inferenceChatModelDisabled = false', async () => { + const mockAssistantGraphAsyncIterator = { + streamEvents: () => mockAsyncIterator, + } as unknown as DefaultAssistantGraph; + const response = await streamGraph({ + ...requestArgs, + inferenceChatModelDisabled: false, + assistantGraph: mockAssistantGraphAsyncIterator, + }); + await expectConditions(response); + }); }); }); diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts index 97dc0f3ac4fd4..9788f3df23a71 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/helpers.ts @@ -26,7 +26,7 @@ interface StreamGraphParams { apmTracer: APMTracer; assistantGraph: DefaultAssistantGraph; inputs: GraphInputs; - inferenceChatModelEnabled?: boolean; + inferenceChatModelDisabled?: boolean; isEnabledKnowledgeBase: boolean; logger: Logger; onLlmResponse?: OnLlmResponse; @@ -54,7 +54,7 @@ export const streamGraph = async ({ apmTracer, assistantGraph, inputs, - inferenceChatModelEnabled = false, + inferenceChatModelDisabled = false, isEnabledKnowledgeBase, logger, onLlmResponse, @@ -108,7 +108,7 @@ export const streamGraph = async ({ // Stream is from tool calling agent or structured chat agent if ( - inferenceChatModelEnabled || + !inferenceChatModelDisabled || inputs.isOssModel || inputs?.llmType === 'bedrock' || inputs?.llmType === 'gemini' diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.test.ts index 6fb9f505bfb6f..df3ee818b250c 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.test.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.test.ts @@ -67,6 +67,7 @@ describe('callAssistantGraph', () => { saved_objects: [], }); const mockLogger = loggerMock.create(); + const getChatModel = jest.fn(); const defaultParams = { actionsClient: actionsClientMock.create(), alertsIndexPattern: 'test-pattern', @@ -75,7 +76,10 @@ describe('callAssistantGraph', () => { conversationId: 'test-conversation', dataClients: mockDataClients, esClient: elasticsearchClientMock.createScopedClusterClient().asCurrentUser, - inference: {}, + inference: { + getChatModel, + }, + inferenceChatModelDisabled: true, langChainMessages: [{ content: 'test message' }], llmTasks: { retrieveDocumentationAvailable: jest.fn(), retrieveDocumentation: jest.fn() }, llmType: 'openai', @@ -121,217 +125,286 @@ describe('callAssistantGraph', () => { ); getPromptMock.mockResolvedValue('prompt'); }); - - it('calls invokeGraph with correct parameters for non-streaming', async () => { - const result = await callAssistantGraph(defaultParams); - - expect(invokeGraph).toHaveBeenCalledWith( - expect.objectContaining({ - inputs: expect.objectContaining({ - input: 'test message', - }), - }) - ); - expect(result.body).toEqual({ - connector_id: 'test-connector', - data: 'test-output', - trace_data: {}, - replacements: [], - status: 'ok', - conversationId: 'new-conversation-id', + describe('inferenceChatModelDisabled = true', () => { + it('calls invokeGraph with correct parameters for non-streaming', async () => { + const result = await callAssistantGraph(defaultParams); + + expect(invokeGraph).toHaveBeenCalledWith( + expect.objectContaining({ + inputs: expect.objectContaining({ + input: 'test message', + }), + }) + ); + expect(result.body).toEqual({ + connector_id: 'test-connector', + data: 'test-output', + trace_data: {}, + replacements: [], + status: 'ok', + conversationId: 'new-conversation-id', + }); }); - }); - it('calls streamGraph with correct parameters for streaming', async () => { - const params = { ...defaultParams, isStream: true }; - await callAssistantGraph(params); + it('calls streamGraph with correct parameters for streaming', async () => { + const params = { ...defaultParams, isStream: true }; + await callAssistantGraph(params); - expect(streamGraph).toHaveBeenCalledWith( - expect.objectContaining({ - inputs: expect.objectContaining({ - input: 'test message', - }), - }) - ); - }); + expect(streamGraph).toHaveBeenCalledWith( + expect.objectContaining({ + inputs: expect.objectContaining({ + input: 'test message', + }), + }) + ); + }); - it('calls getDefaultAssistantGraph without signal for openai', async () => { - await callAssistantGraph(defaultParams); - expect(getDefaultAssistantGraphMock.mock.calls[0][0]).not.toHaveProperty('signal'); - }); + it('calls getDefaultAssistantGraph without signal for openai', async () => { + await callAssistantGraph(defaultParams); + expect(getDefaultAssistantGraphMock.mock.calls[0][0]).not.toHaveProperty('signal'); + }); - it('calls getDefaultAssistantGraph with signal for bedrock', async () => { - await callAssistantGraph({ ...defaultParams, llmType: 'bedrock' }); - expect(getDefaultAssistantGraphMock.mock.calls[0][0]).toHaveProperty('signal'); - }); + it('calls getDefaultAssistantGraph with signal for bedrock', async () => { + await callAssistantGraph({ ...defaultParams, llmType: 'bedrock' }); + expect(getDefaultAssistantGraphMock.mock.calls[0][0]).toHaveProperty('signal'); + }); - it('handles error when anonymizationFieldsDataClient.findDocuments fails', async () => { - (mockDataClients?.anonymizationFieldsDataClient?.findDocuments as jest.Mock).mockRejectedValue( - new Error('test error') - ); + it('handles error when anonymizationFieldsDataClient.findDocuments fails', async () => { + ( + mockDataClients?.anonymizationFieldsDataClient?.findDocuments as jest.Mock + ).mockRejectedValue(new Error('test error')); - await expect(callAssistantGraph(defaultParams)).rejects.toThrow('test error'); - }); + await expect(callAssistantGraph(defaultParams)).rejects.toThrow('test error'); + }); - it('handles error when kbDataClient.isInferenceEndpointExists fails', async () => { - (mockDataClients?.kbDataClient?.isInferenceEndpointExists as jest.Mock).mockRejectedValue( - new Error('test error') - ); + it('handles error when kbDataClient.isInferenceEndpointExists fails', async () => { + (mockDataClients?.kbDataClient?.isInferenceEndpointExists as jest.Mock).mockRejectedValue( + new Error('test error') + ); - await expect(callAssistantGraph(defaultParams)).rejects.toThrow('test error'); - }); + await expect(callAssistantGraph(defaultParams)).rejects.toThrow('test error'); + }); - it('returns correct response when no conversationId is returned', async () => { - (invokeGraph as jest.Mock).mockResolvedValue({ output: 'test-output', traceData: {} }); + it('returns correct response when no conversationId is returned', async () => { + (invokeGraph as jest.Mock).mockResolvedValue({ output: 'test-output', traceData: {} }); - const result = await callAssistantGraph(defaultParams); + const result = await callAssistantGraph(defaultParams); - expect(result.body).toEqual({ - connector_id: 'test-connector', - data: 'test-output', - trace_data: {}, - replacements: [], - status: 'ok', + expect(result.body).toEqual({ + connector_id: 'test-connector', + data: 'test-output', + trace_data: {}, + replacements: [], + status: 'ok', + }); }); - }); - it('calls getPrompt for each tool and the default system prompt', async () => { - const params = { - ...defaultParams, - assistantTools: [ - { ...mockTool, name: 'test-tool' }, - { ...mockTool, name: 'test-tool2' }, - ], - }; - await callAssistantGraph(params); - - expect(getPromptMock).toHaveBeenCalledTimes(3); - expect(getPromptMock).toHaveBeenCalledWith( - expect.objectContaining({ - model: 'test-model', - provider: 'openai', - promptId: 'test-tool', - promptGroupId: toolsGroupId, - }) - ); - expect(getPromptMock).toHaveBeenCalledWith( - expect.objectContaining({ - model: 'test-model', - provider: 'openai', - promptId: 'test-tool2', - promptGroupId: toolsGroupId, - }) - ); - expect(getPromptMock).toHaveBeenCalledWith( - expect.objectContaining({ - model: 'test-model', - provider: 'openai', - promptId: promptDictionary.systemPrompt, - promptGroupId: promptGroupId.aiAssistant, - }) - ); + it('calls getPrompt for each tool and the default system prompt', async () => { + const params = { + ...defaultParams, + assistantTools: [ + { ...mockTool, name: 'test-tool' }, + { ...mockTool, name: 'test-tool2' }, + ], + }; + await callAssistantGraph(params); - expect(getTool).toHaveBeenCalledWith( - expect.objectContaining({ - description: 'prompt', - }) - ); - }); + expect(getPromptMock).toHaveBeenCalledTimes(3); + expect(getPromptMock).toHaveBeenCalledWith( + expect.objectContaining({ + model: 'test-model', + provider: 'openai', + promptId: 'test-tool', + promptGroupId: toolsGroupId, + }) + ); + expect(getPromptMock).toHaveBeenCalledWith( + expect.objectContaining({ + model: 'test-model', + provider: 'openai', + promptId: 'test-tool2', + promptGroupId: toolsGroupId, + }) + ); + expect(getPromptMock).toHaveBeenCalledWith( + expect.objectContaining({ + model: 'test-model', + provider: 'openai', + promptId: promptDictionary.systemPrompt, + promptGroupId: promptGroupId.aiAssistant, + }) + ); + + expect(getTool).toHaveBeenCalledWith( + expect.objectContaining({ + description: 'prompt', + }) + ); + }); - it('Passes only Elastic tools, not custom, to Telemetry tracer', async () => { - await callAssistantGraph({ - ...defaultParams, - assistantTools: [ - { ...mockTool, name: 'test-tool', getTool: getTool.mockReturnValue({ name: 'test-tool' }) }, + it('Passes only Elastic tools, not custom, to Telemetry tracer', async () => { + await callAssistantGraph({ + ...defaultParams, + assistantTools: [ + { + ...mockTool, + name: 'test-tool', + getTool: getTool.mockReturnValue({ name: 'test-tool' }), + }, + { + ...mockTool, + name: 'test-tool2', + getTool: getTool.mockReturnValue({ name: 'test-tool2' }), + }, + ], + }); + expect(telemetryTracerMock).toHaveBeenCalledWith( { - ...mockTool, - name: 'test-tool2', - getTool: getTool.mockReturnValue({ name: 'test-tool2' }), + elasticTools: ['test-tool2', 'test-tool2'], + telemetry: {}, + telemetryParams: {}, }, - ], + { + ...mockLogger, + context: ['defaultAssistantGraph'], + } + ); }); - expect(telemetryTracerMock).toHaveBeenCalledWith( - { - elasticTools: ['test-tool2', 'test-tool2'], - telemetry: {}, - telemetryParams: {}, - }, - { - ...mockLogger, - context: ['defaultAssistantGraph'], - } - ); - }); - describe('agentRunnable', () => { - it('creates OpenAIToolsAgent for openai llmType', async () => { - const params = { ...defaultParams, llmType: 'openai' }; - await callAssistantGraph(params); + describe('agentRunnable', () => { + it('creates OpenAIToolsAgent for openai llmType', async () => { + const params = { ...defaultParams, llmType: 'openai' }; + await callAssistantGraph(params); - expect(createOpenAIToolsAgent).toHaveBeenCalled(); - expect(createToolCallingAgent).not.toHaveBeenCalled(); - }); + expect(createOpenAIToolsAgent).toHaveBeenCalled(); + expect(createToolCallingAgent).not.toHaveBeenCalled(); + }); - it('creates OpenAIToolsAgent for inference llmType', async () => { - defaultParams.actionsClient.get = jest.fn().mockResolvedValue({ - config: { - provider: 'elastic', - providerConfig: { model_id: 'rainbow-sprinkles' }, - }, + it('creates OpenAIToolsAgent for inference llmType', async () => { + defaultParams.actionsClient.get = jest.fn().mockResolvedValue({ + config: { + provider: 'elastic', + providerConfig: { model_id: 'rainbow-sprinkles' }, + }, + }); + const params = { ...defaultParams, llmType: 'inference' }; + await callAssistantGraph(params); + + expect(createOpenAIToolsAgent).toHaveBeenCalled(); + expect(createToolCallingAgent).not.toHaveBeenCalled(); }); - const params = { ...defaultParams, llmType: 'inference' }; - await callAssistantGraph(params); - expect(createOpenAIToolsAgent).toHaveBeenCalled(); - expect(createToolCallingAgent).not.toHaveBeenCalled(); - }); + it('creates ToolCallingAgent for bedrock llmType', async () => { + const params = { ...defaultParams, llmType: 'bedrock' }; + await callAssistantGraph(params); - it('creates ToolCallingAgent for bedrock llmType', async () => { - const params = { ...defaultParams, llmType: 'bedrock' }; - await callAssistantGraph(params); + expect(createToolCallingAgent).toHaveBeenCalled(); + expect(createOpenAIToolsAgent).not.toHaveBeenCalled(); + }); - expect(createToolCallingAgent).toHaveBeenCalled(); - expect(createOpenAIToolsAgent).not.toHaveBeenCalled(); - }); + it('creates ToolCallingAgent for gemini llmType', async () => { + const params = { + ...defaultParams, + request: { + body: { model: 'gemini-1.5-flash' }, + } as unknown as AgentExecutorParams['request'], + llmType: 'gemini', + }; + await callAssistantGraph(params); + + expect(createToolCallingAgent).toHaveBeenCalled(); + expect(createOpenAIToolsAgent).not.toHaveBeenCalled(); + }); - it('creates ToolCallingAgent for gemini llmType', async () => { - const params = { - ...defaultParams, - request: { - body: { model: 'gemini-1.5-flash' }, - } as unknown as AgentExecutorParams['request'], - llmType: 'gemini', - }; - await callAssistantGraph(params); + it('creates ToolCallingAgent for oss model', async () => { + const params = { ...defaultParams, llmType: 'openai', isOssModel: true }; + await callAssistantGraph(params); - expect(createToolCallingAgent).toHaveBeenCalled(); - expect(createOpenAIToolsAgent).not.toHaveBeenCalled(); - }); + expect(createOpenAIToolsAgent).not.toHaveBeenCalled(); + expect(createToolCallingAgent).toHaveBeenCalled(); + }); + it('does not calls resolveProviderAndModel when llmType === openai', async () => { + const params = { ...defaultParams, llmType: 'openai' }; + await callAssistantGraph(params); - it('creates ToolCallingAgent for oss model', async () => { - const params = { ...defaultParams, llmType: 'openai', isOssModel: true }; - await callAssistantGraph(params); + expect(resolveProviderAndModelMock).not.toHaveBeenCalled(); + }); + it('calls resolveProviderAndModel when llmType === inference', async () => { + const params = { ...defaultParams, llmType: 'inference' }; + await callAssistantGraph(params); - expect(createOpenAIToolsAgent).not.toHaveBeenCalled(); - expect(createToolCallingAgent).toHaveBeenCalled(); - }); - it('does not calls resolveProviderAndModel when llmType === openai', async () => { - const params = { ...defaultParams, llmType: 'openai' }; - await callAssistantGraph(params); + expect(resolveProviderAndModelMock).toHaveBeenCalled(); + }); + it('calls resolveProviderAndModel when llmType === undefined', async () => { + const params = { ...defaultParams, llmType: undefined }; + await callAssistantGraph(params); - expect(resolveProviderAndModelMock).not.toHaveBeenCalled(); + expect(resolveProviderAndModelMock).toHaveBeenCalled(); + }); + }); + }); + describe('inferenceChatModelDisabled = false', () => { + const newDefaultParams = { + ...defaultParams, + inferenceChatModelDisabled: false, + }; + it('calls invokeGraph with correct parameters for non-streaming', async () => { + const result = await callAssistantGraph(newDefaultParams); + + expect(invokeGraph).toHaveBeenCalledWith( + expect.objectContaining({ + inputs: expect.objectContaining({ + input: 'test message', + }), + }) + ); + expect(result.body).toEqual({ + connector_id: 'test-connector', + data: 'test-output', + trace_data: {}, + replacements: [], + status: 'ok', + conversationId: 'new-conversation-id', + }); + expect(getChatModel).toHaveBeenCalled(); }); - it('calls resolveProviderAndModel when llmType === inference', async () => { - const params = { ...defaultParams, llmType: 'inference' }; + + it('calls streamGraph with correct parameters for streaming', async () => { + const params = { ...newDefaultParams, isStream: true }; await callAssistantGraph(params); - expect(resolveProviderAndModelMock).toHaveBeenCalled(); + expect(streamGraph).toHaveBeenCalledWith( + expect.objectContaining({ + inputs: expect.objectContaining({ + input: 'test message', + }), + }) + ); + expect(getChatModel).toHaveBeenCalled(); }); - it('calls resolveProviderAndModel when llmType === undefined', async () => { - const params = { ...defaultParams, llmType: undefined }; - await callAssistantGraph(params); - expect(resolveProviderAndModelMock).toHaveBeenCalled(); + describe('agentRunnable', () => { + it('creates ToolCallingAgent for openai llmType', async () => { + const params = { ...newDefaultParams, llmType: 'openai' }; + await callAssistantGraph(params); + + expect(createToolCallingAgent).toHaveBeenCalled(); + expect(createOpenAIToolsAgent).not.toHaveBeenCalled(); + }); + + it('creates ToolCallingAgent for inference llmType', async () => { + newDefaultParams.actionsClient.get = jest.fn().mockResolvedValue({ + config: { + provider: 'elastic', + providerConfig: { model_id: 'rainbow-sprinkles' }, + }, + }); + const params = { ...newDefaultParams, llmType: 'inference' }; + await callAssistantGraph(params); + + expect(createToolCallingAgent).toHaveBeenCalled(); + expect(createOpenAIToolsAgent).not.toHaveBeenCalled(); + }); }); }); }); diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index d9db6541bb9ca..523b9e59892bb 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -42,7 +42,7 @@ export const callAssistantGraph: AgentExecutor = async ({ dataClients, esClient, inference, - inferenceChatModelEnabled = false, + inferenceChatModelDisabled = false, langChainMessages, llmTasks, llmType, @@ -76,7 +76,7 @@ export const callAssistantGraph: AgentExecutor = async ({ * the state unintentionally. For this reason, only call createLlmInstance at runtime */ const createLlmInstance = async () => - inferenceChatModelEnabled + !inferenceChatModelDisabled ? inference.getChatModel({ request, connectorId, @@ -231,7 +231,7 @@ export const callAssistantGraph: AgentExecutor = async ({ llm, llmType, tools, - inferenceChatModelEnabled, + inferenceChatModelDisabled, isOpenAI, isStream, prompt: chatPromptTemplate, @@ -306,7 +306,7 @@ export const callAssistantGraph: AgentExecutor = async ({ apmTracer, assistantGraph, inputs, - inferenceChatModelEnabled, + inferenceChatModelDisabled, isEnabledKnowledgeBase: telemetryParams?.isEnabledKnowledgeBase ?? false, logger, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts index e5634d9e4a870..f54da4384ac33 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/evaluate/post_evaluate.ts @@ -23,7 +23,7 @@ import { PostEvaluateBody, PostEvaluateResponse, DefendInsightType, - INFERENCE_CHAT_MODEL_ENABLED_FEATURE_FLAG, + INFERENCE_CHAT_MODEL_DISABLED_FEATURE_FLAG, } from '@kbn/elastic-assistant-common'; import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; import { getDefaultArguments } from '@kbn/langchain/server'; @@ -179,8 +179,8 @@ export const postEvaluateRoute = ( (await ctx.elasticAssistant.llmTasks.retrieveDocumentationAvailable()) ?? false; const { featureFlags } = await context.core; - const inferenceChatModelEnabled = await featureFlags.getBooleanValue( - INFERENCE_CHAT_MODEL_ENABLED_FEATURE_FLAG, + const inferenceChatModelDisabled = await featureFlags.getBooleanValue( + INFERENCE_CHAT_MODEL_DISABLED_FEATURE_FLAG, false ); @@ -321,7 +321,7 @@ export const postEvaluateRoute = ( const isOpenAI = llmType === 'openai' && !isOssModel; const llmClass = getLlmClass(llmType); const createLlmInstance = async () => - inferenceChatModelEnabled + !inferenceChatModelDisabled ? inference.getChatModel({ request, connectorId: connector.id, @@ -467,7 +467,7 @@ export const postEvaluateRoute = ( const agentRunnable = await agentRunnableFactory({ llm: chatModel, llmType, - inferenceChatModelEnabled, + inferenceChatModelDisabled, isOpenAI, tools, isStream: false, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/helpers.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/helpers.ts index 3cbae2fc199cf..74236297c73f6 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/helpers.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/helpers.ts @@ -235,7 +235,7 @@ export interface LangChainExecuteParams { contentReferencesStore: ContentReferencesStore; llmTasks?: LlmTasksPluginStart; inference: InferenceServerStart; - inferenceChatModelEnabled?: boolean; + inferenceChatModelDisabled?: boolean; isOssModel?: boolean; conversationId?: string; context: AwaitedProperties< @@ -266,7 +266,7 @@ export const langChainExecute = async ({ actionTypeId, connectorId, contentReferencesStore, - inferenceChatModelEnabled, + inferenceChatModelDisabled, isOssModel, context, actionsClient, @@ -335,7 +335,7 @@ export const langChainExecute = async ({ esClient, llmTasks, inference, - inferenceChatModelEnabled, + inferenceChatModelDisabled, isStream, llmType: getLlmType(actionTypeId), isOssModel, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index 1c7bbca8e1ca7..c0716d59d9878 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -19,7 +19,7 @@ import { pruneContentReferences, ExecuteConnectorRequestQuery, POST_ACTIONS_CONNECTOR_EXECUTE, - INFERENCE_CHAT_MODEL_ENABLED_FEATURE_FLAG, + INFERENCE_CHAT_MODEL_DISABLED_FEATURE_FLAG, } from '@kbn/elastic-assistant-common'; import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common'; import { getPrompt } from '../lib/prompt'; @@ -82,9 +82,9 @@ export const postActionsConnectorExecuteRoute = ( let onLlmResponse; const coreContext = await context.core; - const inferenceChatModelEnabled = + const inferenceChatModelDisabled = (await coreContext?.featureFlags?.getBooleanValue( - INFERENCE_CHAT_MODEL_ENABLED_FEATURE_FLAG, + INFERENCE_CHAT_MODEL_DISABLED_FEATURE_FLAG, false )) ?? false; @@ -200,7 +200,7 @@ export const postActionsConnectorExecuteRoute = ( connectorId, contentReferencesStore, isOssModel, - inferenceChatModelEnabled, + inferenceChatModelDisabled, conversationId, context: ctx, logger,