diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts index 8cd2f0fd801d0..d129ae58ca658 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/chat/chat_complete_route.test.ts @@ -44,6 +44,7 @@ jest.mock('../helpers', () => { }; }); const mockAppendAssistantMessageToConversation = appendAssistantMessageToConversation as jest.Mock; +const mockCreateConversationWithUserInput = createConversationWithUserInput as jest.Mock; const mockLangChainExecute = langChainExecute as jest.Mock; const mockStream = jest.fn().mockImplementation(() => new PassThrough()); @@ -150,7 +151,7 @@ describe('chatCompleteRoute', () => { jest.clearAllMocks(); mockAppendAssistantMessageToConversation.mockResolvedValue(true); license.hasAtLeast.mockReturnValue(true); - (createConversationWithUserInput as jest.Mock).mockResolvedValue({ id: 'something' }); + mockCreateConversationWithUserInput.mockResolvedValue({ id: 'something' }); mockLangChainExecute.mockImplementation( async ({ connectorId, @@ -166,12 +167,14 @@ describe('chatCompleteRoute', () => { ) => Promise; }) => { if (!isStream && connectorId === 'mock-connector-id') { + onLlmResponse('Non-streamed test reply.', {}, false).catch(() => {}); return { connector_id: 'mock-connector-id', data: mockActionResponse, status: 'ok', }; } else if (isStream && connectorId === 'mock-connector-id') { + onLlmResponse('Streamed test reply.', {}, false).catch(() => {}); return mockStream; } else { onLlmResponse('simulated error', {}, true).catch(() => {}); @@ -399,4 +402,141 @@ describe('chatCompleteRoute', () => { mockGetElser ); }); + + it('should add assistant reply to existing conversation when `persist=true`', async () => { + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + await handler( + mockContext, + { + ...mockRequest, + body: { + ...mockRequest.body, + conversationId: existingConversation.id, + }, + }, + mockResponse + ); + expect(mockAppendAssistantMessageToConversation).toHaveBeenCalledWith( + expect.objectContaining({ + messageContent: 'Non-streamed test reply.', + isError: false, + }) + ); + expect(mockCreateConversationWithUserInput).toHaveBeenCalledTimes(0); + }), + }; + }), + }, + }; + + chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('should not add assistant reply to existing conversation when `persist=false`', async () => { + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + await handler( + mockContext, + { + ...mockRequest, + body: { + ...mockRequest.body, + conversationId: existingConversation.id, + persist: false, + }, + }, + mockResponse + ); + expect(mockAppendAssistantMessageToConversation).toHaveBeenCalledTimes(0); + expect(mockCreateConversationWithUserInput).toHaveBeenCalledTimes(0); + }), + }; + }), + }, + }; + + chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('should add assistant reply to new conversation when `persist=true`', async () => { + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + await handler( + mockContext, + { + ...mockRequest, + body: { + ...mockRequest.body, + conversationId: undefined, + persist: true, + }, + }, + mockResponse + ); + expect(mockAppendAssistantMessageToConversation).toHaveBeenCalledWith( + expect.objectContaining({ + messageContent: 'Non-streamed test reply.', + isError: false, + }) + ); + expect(mockCreateConversationWithUserInput).toHaveBeenCalledTimes(1); + }), + }; + }), + }, + }; + + chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); + + it('should not create a new conversation when `persist=false`', async () => { + const mockRouter = { + versioned: { + post: jest.fn().mockImplementation(() => { + return { + addVersion: jest.fn().mockImplementation(async (_, handler) => { + await handler( + mockContext, + { + ...mockRequest, + body: { + ...mockRequest.body, + conversationId: undefined, + persist: false, + }, + }, + mockResponse + ); + expect(mockAppendAssistantMessageToConversation).toHaveBeenCalledTimes(0); + expect(mockCreateConversationWithUserInput).toHaveBeenCalledTimes(0); + }), + }; + }), + }, + }; + + chatCompleteRoute( + mockRouter as unknown as IRouter, + mockGetElser + ); + }); }); diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts index 726ea224692ee..cc39ce04add17 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/chat/chat_complete_route.ts @@ -93,7 +93,7 @@ export const chatCompleteRoute = ( await ctx.elasticAssistant.getAIAssistantAnonymizationFieldsDataClient(); let messages; - const conversationId = request.body.conversationId; + const existingConversationId = request.body.conversationId; const connectorId = request.body.connectorId; let latestReplacements: Replacements = {}; @@ -159,11 +159,10 @@ export const chatCompleteRoute = ( }); let newConversation: ConversationResponse | undefined | null; - if (conversationsDataClient && !conversationId && request.body.persist) { + if (conversationsDataClient && !existingConversationId && request.body.persist) { newConversation = await createConversationWithUserInput({ actionTypeId, connectorId, - conversationId, conversationsDataClient, promptId: request.body.promptId, replacements: latestReplacements, @@ -178,6 +177,11 @@ export const chatCompleteRoute = ( })); } + // Do not persist conversation messages if `persist = false` + const conversationId = request.body.persist + ? existingConversationId ?? newConversation?.id + : undefined; + const contentReferencesStore = newContentReferencesStore(); const onLlmResponse = async ( @@ -185,11 +189,11 @@ export const chatCompleteRoute = ( traceData: Message['traceData'] = {}, isError = false ): Promise => { - if (newConversation?.id && conversationsDataClient) { + if (conversationId && conversationsDataClient) { const contentReferences = pruneContentReferences(content, contentReferencesStore); await appendAssistantMessageToConversation({ - conversationId: newConversation?.id, + conversationId, conversationsDataClient, messageContent: content, replacements: latestReplacements, @@ -207,7 +211,7 @@ export const chatCompleteRoute = ( actionTypeId, connectorId, isOssModel, - conversationId: conversationId ?? newConversation?.id, + conversationId, context: ctx, getElser, logger,