diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/chat_send/use_chat_send.test.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/chat_send/use_chat_send.test.tsx index a9266f14ad69a..55245cc307a14 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/chat_send/use_chat_send.test.tsx +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/chat_send/use_chat_send.test.tsx @@ -145,4 +145,34 @@ describe('use chat send', () => { }); }); }); + it('retries getConversation up to 5 times if title is empty, and stops when title is found', async () => { + const promptText = 'test prompt'; + const getConversationMock = jest.fn(); + // First 3 calls return empty title, 4th returns non-empty + getConversationMock + .mockResolvedValueOnce({ title: '' }) + .mockResolvedValueOnce({ title: '' }) + .mockResolvedValueOnce({ title: '' }) + .mockResolvedValueOnce({ title: 'Final Title' }); + (useConversation as jest.Mock).mockReturnValue({ + removeLastMessage, + clearConversation, + getConversation: getConversationMock, + createConversation: jest.fn(), + }); + const { result } = renderHook( + () => + useChatSend({ + ...testProps, + currentConversation: { ...emptyWelcomeConvo, id: 'convo-id', title: '' }, + }), + { wrapper: TestProviders } + ); + await act(async () => { + await result.current.handleChatSend(promptText); + }); + // Should call getConversation 4 times (until non-empty title) + expect(getConversationMock).toHaveBeenCalledTimes(4); + expect(getConversationMock).toHaveBeenLastCalledWith('convo-id'); + }); }); diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/chat_send/use_chat_send.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/chat_send/use_chat_send.tsx index c4f575d5a3cfb..e2ae89f0b1d30 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/chat_send/use_chat_send.tsx +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/chat_send/use_chat_send.tsx @@ -61,12 +61,15 @@ export const useChatSend = ({ const { setLastConversation } = useAssistantLastConversation({ spaceId }); const [userPrompt, setUserPrompt] = useState(null); - const { isLoading, sendMessage, abortStream } = useSendMessage(); + const { sendMessage, abortStream } = useSendMessage(); const { clearConversation, createConversation, getConversation, removeLastMessage } = useConversation(); const { data: kbStatus } = useKnowledgeBaseStatus({ http, enabled: isAssistantEnabled }); const isSetupComplete = kbStatus?.elser_exists && kbStatus?.security_labs_exists; + // Local loading state that persists until the entire message flow is complete + const [isLoadingChatSend, setIsLoadingChatSend] = useState(false); + // Handles sending latest user prompt to API const handleSendMessage = useCallback( async (promptText: string) => { @@ -81,79 +84,102 @@ export const useChatSend = ({ ); return; } - const apiConfig = currentConversation.apiConfig; - let newConvo; - if (currentConversation.id === '') { - // create conversation with empty title, GENERATE_CHAT_TITLE graph step will properly title - newConvo = await createConversation(currentConversation); - if (newConvo?.id) { - setLastConversation({ - id: newConvo.id, - }); + + setIsLoadingChatSend(true); + + try { + const apiConfig = currentConversation.apiConfig; + let newConvo; + if (currentConversation.id === '') { + // create conversation with empty title, GENERATE_CHAT_TITLE graph step will properly title + newConvo = await createConversation(currentConversation); + if (newConvo?.id) { + setLastConversation({ + id: newConvo.id, + }); + } } + const convo: Conversation = { ...currentConversation, ...(newConvo ?? {}) }; + const userMessage = getCombinedMessage({ + currentReplacements: convo.replacements, + promptText, + selectedPromptContexts, + }); + + const baseReplacements: Replacements = userMessage.replacements ?? convo.replacements; + + const selectedPromptContextsReplacements = Object.values( + selectedPromptContexts + ).reduce((acc, context) => ({ ...acc, ...context.replacements }), {}); + + const replacements: Replacements = { + ...baseReplacements, + ...selectedPromptContextsReplacements, + }; + const updatedMessages = [...convo.messages, userMessage].map((m) => ({ + ...m, + content: m.content ?? '', + })); + setCurrentConversation({ + ...convo, + replacements, + messages: updatedMessages, + }); + + // Reset prompt context selection and preview before sending: + setSelectedPromptContexts({}); + + const rawResponse = await sendMessage({ + apiConfig, + http, + message: userMessage.content ?? '', + conversationId: convo.id, + replacements, + }); + + assistantTelemetry?.reportAssistantMessageSent({ + role: userMessage.role, + actionTypeId: apiConfig.actionTypeId, + model: apiConfig.model, + provider: apiConfig.provider, + isEnabledKnowledgeBase: isSetupComplete ?? false, + }); + + const responseMessage: ClientMessage = getMessageFromRawResponse(rawResponse); + if (convo.title === '') { + // Retry getConversation up to 5 times if title is empty + let retryCount = 0; + const maxRetries = 5; + while (retryCount < maxRetries) { + const conversation = await getConversation(convo.id); + convo.title = conversation?.title ?? ''; + + if (convo.title !== '') { + break; // Title found, exit retry loop + } + + retryCount++; + if (retryCount < maxRetries) { + // Wait 1 second before next retry + await new Promise((resolve) => setTimeout(resolve, 1000)); + } + } + } + setCurrentConversation({ + ...convo, + replacements, + messages: [...updatedMessages, responseMessage], + }); + assistantTelemetry?.reportAssistantMessageSent({ + role: responseMessage.role, + actionTypeId: apiConfig.actionTypeId, + model: apiConfig.model, + provider: apiConfig.provider, + isEnabledKnowledgeBase: isSetupComplete ?? false, + }); + } finally { + setIsLoadingChatSend(false); } - const convo: Conversation = { ...currentConversation, ...(newConvo ?? {}) }; - const userMessage = getCombinedMessage({ - currentReplacements: convo.replacements, - promptText, - selectedPromptContexts, - }); - - const baseReplacements: Replacements = userMessage.replacements ?? convo.replacements; - - const selectedPromptContextsReplacements = Object.values( - selectedPromptContexts - ).reduce((acc, context) => ({ ...acc, ...context.replacements }), {}); - - const replacements: Replacements = { - ...baseReplacements, - ...selectedPromptContextsReplacements, - }; - const updatedMessages = [...convo.messages, userMessage].map((m) => ({ - ...m, - content: m.content ?? '', - })); - setCurrentConversation({ - ...convo, - replacements, - messages: updatedMessages, - }); - - // Reset prompt context selection and preview before sending: - setSelectedPromptContexts({}); - - const rawResponse = await sendMessage({ - apiConfig, - http, - message: userMessage.content ?? '', - conversationId: convo.id, - replacements, - }); - - assistantTelemetry?.reportAssistantMessageSent({ - role: userMessage.role, - actionTypeId: apiConfig.actionTypeId, - model: apiConfig.model, - provider: apiConfig.provider, - isEnabledKnowledgeBase: isSetupComplete ?? false, - }); - - const responseMessage: ClientMessage = getMessageFromRawResponse(rawResponse); - if (convo.title === '') { - convo.title = (await getConversation(convo.id))?.title ?? ''; - } - setCurrentConversation({ - ...convo, - replacements, - messages: [...updatedMessages, responseMessage], - }); - assistantTelemetry?.reportAssistantMessageSent({ - role: responseMessage.role, - actionTypeId: apiConfig.actionTypeId, - model: apiConfig.model, - provider: apiConfig.provider, - isEnabledKnowledgeBase: isSetupComplete ?? false, - }); }, [ assistantTelemetry, @@ -241,7 +267,7 @@ export const useChatSend = ({ handleChatSend, abortStream, handleRegenerateResponse, - isLoading, + isLoading: isLoadingChatSend, userPrompt, setUserPrompt, }; diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/__mocks__/conversations_schema.mock.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/__mocks__/conversations_schema.mock.ts index b1accf5d74d6f..d101a78eaf7dc 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/__mocks__/conversations_schema.mock.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/__mocks__/conversations_schema.mock.ts @@ -244,3 +244,35 @@ export const getEsCreateConversationSchemaMock = ( namespace: 'default', ...rest, }); + +export const getEsConversationSchemaMock = ( + rest?: Partial +): EsConversationSchema => ({ + '@timestamp': '2020-04-20T15:25:31.830Z', + created_at: '2020-04-20T15:25:31.830Z', + title: 'title-1', + updated_at: '2020-04-20T15:25:31.830Z', + messages: [], + id: '1', + namespace: 'default', + exclude_from_last_conversation_storage: false, + api_config: { + action_type_id: '.gen-ai', + connector_id: 'c1', + default_system_prompt_id: 'prompt-1', + model: 'test', + provider: 'Azure OpenAI', + }, + summary: { + content: 'test', + }, + category: 'assistant', + users: [ + { + id: '1111', + name: 'elastic', + }, + ], + replacements: undefined, + ...rest, +}); diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/ai_assistant_data_clients/conversations/append_conversation_messages.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/ai_assistant_data_clients/conversations/append_conversation_messages.test.ts new file mode 100644 index 0000000000000..ec7d8e6a91e20 --- /dev/null +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/ai_assistant_data_clients/conversations/append_conversation_messages.test.ts @@ -0,0 +1,416 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { loggingSystemMock } from '@kbn/core-logging-server-mocks'; +import type { ConversationResponse, Message } from '@kbn/elastic-assistant-common'; +import type { + DocumentsDataWriter, + WriterBulkResponse, +} from '../../lib/data_stream/documents_data_writer'; +import { + appendConversationMessages, + transformToUpdateScheme, +} from './append_conversation_messages'; +import type { EsConversationSchema } from './types'; +import { authenticatedUser } from '../../__mocks__/user'; +import { + getConversationMock, + getAppendConversationMessagesSchemaMock, + getQueryConversationParams, + getEsConversationSchemaMock, +} from '../../__mocks__/conversations_schema.mock'; +import { transformESToConversations } from './transforms'; +import { getUpdateScript } from './helpers'; + +jest.mock('./transforms', () => ({ + transformESToConversations: jest.fn(), +})); + +jest.mock('./helpers', () => ({ + getUpdateScript: jest.fn(), +})); + +const mockUser = authenticatedUser; + +// Reusable mock helpers to keep tests DRY +const createMockDataWriter = (): jest.Mocked => + ({ + bulk: jest.fn(), + getFilterByUser: jest.fn(), + getFilterByConversationUser: jest.fn(), + } as unknown as jest.Mocked); +// Use existing mocks to keep tests DRY +const createMockConversation = ( + overrides?: Partial +): ConversationResponse => { + const baseParams = getQueryConversationParams(false); + return getConversationMock({ + ...baseParams, + ...overrides, + }); +}; + +const createMockMessage = (overrides?: Partial): Message => { + const baseMessage = getAppendConversationMessagesSchemaMock().messages[0]; + return { + ...baseMessage, + ...overrides, + }; +}; + +// Use shared mock from conversations_schema.mock.ts +const createMockEsConversation = (): EsConversationSchema => getEsConversationSchemaMock(); + +const createSuccessfulBulkResponse = ( + docsUpdated: EsConversationSchema[] +): WriterBulkResponse => ({ + errors: [], + docs_created: [], + docs_deleted: [], + docs_updated: docsUpdated, + took: 10, +}); + +const createErrorBulkResponse = (): WriterBulkResponse => ({ + errors: [ + { + message: 'Document update failed', + document: { id: '1' }, + }, + ], + docs_created: [], + docs_deleted: [], + docs_updated: [], + took: 5, +}); + +describe('appendConversationMessages', () => { + let logger: ReturnType; + let dataWriter: jest.Mocked; + let existingConversation: ConversationResponse; + let newMessages: Message[]; + + // Test helper functions to reduce repetition + const setupSuccessfulTest = () => { + const mockEsConversation = createMockEsConversation(); + const bulkResponse = createSuccessfulBulkResponse([mockEsConversation]); + const expectedConversation = createMockConversation(); + + dataWriter.bulk.mockResolvedValue(bulkResponse); + (transformESToConversations as jest.Mock).mockReturnValue([expectedConversation]); + + return { mockEsConversation, bulkResponse, expectedConversation }; + }; + + const callAppendConversationMessages = async (messages: Message[] = newMessages) => { + return appendConversationMessages({ + dataWriter, + logger, + existingConversation, + messages, + authenticatedUser: mockUser, + }); + }; + + const expectBulkCallWithMessages = (expectedMessages: Message[]) => { + expect(dataWriter.bulk).toHaveBeenCalledWith( + expect.objectContaining({ + documentsToUpdate: expect.arrayContaining([ + expect.objectContaining({ + id: existingConversation.id, + updated_at: '2024-01-01T01:00:00.000Z', + messages: expect.arrayContaining( + expectedMessages.map((msg) => + expect.objectContaining({ + content: msg.content, + role: msg.role, + '@timestamp': msg.timestamp, + }) + ) + ), + }), + ]), + getUpdateScript: expect.any(Function), + authenticatedUser: mockUser, + }) + ); + }; + + beforeEach(() => { + jest.clearAllMocks(); + logger = loggingSystemMock.createLogger(); + dataWriter = createMockDataWriter(); + existingConversation = createMockConversation(); + newMessages = [createMockMessage()]; + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + beforeAll(() => { + jest.useFakeTimers(); + jest.setSystemTime(new Date('2024-01-01T01:00:00.000Z')); + }); + + afterAll(() => { + jest.useRealTimers(); + }); + + it('returns updated conversation when bulk operation succeeds', async () => { + setupSuccessfulTest(); + const expectedResult = createMockConversation({ + messages: [...existingConversation.messages!, ...newMessages], + updatedAt: '2024-01-01T01:00:00.000Z', + }); + (transformESToConversations as jest.Mock).mockReturnValue([expectedResult]); + + const result = await callAppendConversationMessages(); + + expect(result).toEqual(expectedResult); + }); + + it('calls dataWriter.bulk with correct parameters', async () => { + setupSuccessfulTest(); + (getUpdateScript as jest.Mock).mockReturnValue({ script: { source: 'test' } }); + + await callAppendConversationMessages(); + + expectBulkCallWithMessages([...existingConversation.messages!, ...newMessages]); + }); + + it('returns null when bulk operation has errors', async () => { + dataWriter.bulk.mockResolvedValue(createErrorBulkResponse()); + + const result = await callAppendConversationMessages(); + + expect(result).toBeNull(); + expect(logger.error).toHaveBeenCalledWith( + 'Error appending conversation messages: Document update failed for conversation by ID: 04128c15-0d1b-4716-a4c5-46997ac7f3bd' + ); + }); + + it('handles empty messages array', async () => { + const { expectedConversation } = setupSuccessfulTest(); + + const result = await callAppendConversationMessages([]); + + expect(result).toEqual(expectedConversation); + expectBulkCallWithMessages(existingConversation.messages!); + }); + + it('handles conversation without existing messages', async () => { + const conversationWithoutMessages = createMockConversation({ messages: undefined }); + const { expectedConversation } = setupSuccessfulTest(); + + const result = await appendConversationMessages({ + dataWriter, + logger, + existingConversation: conversationWithoutMessages, + messages: newMessages, + authenticatedUser: mockUser, + }); + + expect(result).toEqual(expectedConversation); + expectBulkCallWithMessages(newMessages); + }); + + it('preserves all existing conversation fields in transformation', async () => { + const conversationWithAllFields = createMockConversation({ + title: 'Complex Conversation', + apiConfig: { + actionTypeId: '.custom-ai', + connectorId: 'custom-connector', + defaultSystemPromptId: 'custom-prompt', + model: 'custom-model', + provider: 'OpenAI', + }, + excludeFromLastConversationStorage: true, + replacements: { + key1: 'value1', + key2: 'value2', + }, + }); + + const mockEsConversation = createMockEsConversation(); + const bulkResponse = createSuccessfulBulkResponse([mockEsConversation]); + const expectedConversation = createMockConversation(); + + dataWriter.bulk.mockResolvedValue(bulkResponse); + (transformESToConversations as jest.Mock).mockReturnValue([expectedConversation]); + + await appendConversationMessages({ + dataWriter, + logger, + existingConversation: conversationWithAllFields, + messages: newMessages, + authenticatedUser: mockUser, + }); + + expect(dataWriter.bulk).toHaveBeenCalledWith( + expect.objectContaining({ + documentsToUpdate: expect.arrayContaining([ + expect.objectContaining({ + title: 'Complex Conversation', + api_config: { + action_type_id: '.custom-ai', + connector_id: 'custom-connector', + default_system_prompt_id: 'custom-prompt', + model: 'custom-model', + provider: 'OpenAI', + }, + exclude_from_last_conversation_storage: true, + replacements: [ + { uuid: 'key1', value: 'value1' }, + { uuid: 'key2', value: 'value2' }, + ], + }), + ]), + }) + ); + }); + + it('handles message with metadata and trace data', async () => { + const messageWithMetadata = createMockMessage({ + metadata: { + contentReferences: { + 'ref-1': { + id: 'ref-1', + type: 'KnowledgeBaseEntry', + knowledgeBaseEntryId: 'kb-1', + knowledgeBaseEntryName: 'Reference 1', + }, + }, + }, + traceData: { + traceId: 'trace-123', + transactionId: 'transaction-456', + }, + }); + + setupSuccessfulTest(); + + await callAppendConversationMessages([messageWithMetadata]); + + expect(dataWriter.bulk).toHaveBeenCalledWith( + expect.objectContaining({ + documentsToUpdate: expect.arrayContaining([ + expect.objectContaining({ + messages: expect.arrayContaining([ + expect.objectContaining({ + metadata: { + content_references: { + 'ref-1': { + id: 'ref-1', + type: 'KnowledgeBaseEntry', + knowledgeBaseEntryId: 'kb-1', + knowledgeBaseEntryName: 'Reference 1', + }, + }, + }, + trace_data: { + trace_id: 'trace-123', + transaction_id: 'transaction-456', + }, + }), + ]), + }), + ]), + }) + ); + }); +}); + +describe('transformToUpdateScheme', () => { + let existingConversation: ConversationResponse; + + beforeEach(() => { + existingConversation = createMockConversation(); + }); + + const testCases = [ + { + name: 'transforms conversation to update schema correctly', + conversation: () => existingConversation, + messages: () => [createMockMessage()], + expectedFields: (conv: ConversationResponse, msgs: Message[]) => ({ + id: conv.id, + title: conv.title, + api_config: { + action_type_id: conv.apiConfig!.actionTypeId, + connector_id: conv.apiConfig!.connectorId, + default_system_prompt_id: conv.apiConfig!.defaultSystemPromptId, + model: conv.apiConfig!.model, + provider: conv.apiConfig!.provider, + }, + exclude_from_last_conversation_storage: conv.excludeFromLastConversationStorage, + messages: expect.arrayContaining([ + expect.objectContaining({ + '@timestamp': msgs[0].timestamp, + content: msgs[0].content, + role: msgs[0].role, + trace_data: expect.objectContaining({ + trace_id: '1', + transaction_id: '2', + }), + }), + ]), + }), + }, + { + name: 'handles conversation without optional fields', + conversation: () => + createMockConversation({ + title: undefined, + apiConfig: undefined, + excludeFromLastConversationStorage: undefined, + replacements: undefined, + }), + messages: () => [createMockMessage()], + expectedFields: (conv: ConversationResponse) => ({ + id: conv.id, + messages: expect.any(Array), + }), + shouldNotHaveFields: [ + 'title', + 'api_config', + 'exclude_from_last_conversation_storage', + 'replacements', + ], + }, + { + name: 'handles empty messages array', + conversation: () => existingConversation, + messages: () => [], + expectedFields: (conv: ConversationResponse) => ({ + id: conv.id, + messages: [], + }), + }, + ]; + + testCases.forEach(({ name, conversation, messages, expectedFields, shouldNotHaveFields }) => { + it(name, () => { + const conv = conversation(); + const msgs = messages(); + const updatedAt = '2024-01-01T01:00:00.000Z'; + + const result = transformToUpdateScheme(updatedAt, msgs, conv); + + expect(result).toMatchObject({ + updated_at: updatedAt, + ...expectedFields(conv, msgs), + }); + + if (shouldNotHaveFields) { + shouldNotHaveFields.forEach((field) => { + expect(result).not.toHaveProperty(field); + }); + } + }); + }); +}); diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/ai_assistant_data_clients/conversations/append_conversation_messages.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/ai_assistant_data_clients/conversations/append_conversation_messages.ts index 87671b97a7be9..f5c096665b322 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/ai_assistant_data_clients/conversations/append_conversation_messages.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/ai_assistant_data_clients/conversations/append_conversation_messages.ts @@ -5,128 +5,102 @@ * 2.0. */ -import { ElasticsearchClient, Logger } from '@kbn/core/server'; +import type { AuthenticatedUser, Logger } from '@kbn/core/server'; -import { ConversationResponse, Message } from '@kbn/elastic-assistant-common'; -import { getConversation } from './get_conversation'; +import type { ConversationResponse, Message } from '@kbn/elastic-assistant-common'; +import type { + DocumentsDataWriter, + BulkOperationError, +} from '../../lib/data_stream/documents_data_writer'; +import { transformESToConversations } from './transforms'; +import type { EsConversationSchema } from './types'; +import type { UpdateConversationSchema } from './update_conversation'; +import { getUpdateScript } from './helpers'; export interface AppendConversationMessagesParams { - esClient: ElasticsearchClient; + dataWriter: DocumentsDataWriter; logger: Logger; - conversationIndex: string; existingConversation: ConversationResponse; messages: Message[]; + authenticatedUser?: AuthenticatedUser; } export const appendConversationMessages = async ({ - esClient, + dataWriter, logger, - conversationIndex, existingConversation, messages, + authenticatedUser, }: AppendConversationMessagesParams): Promise => { const updatedAt = new Date().toISOString(); + const params = transformToUpdateScheme( + updatedAt, + [...(existingConversation.messages ?? []), ...messages], + existingConversation + ); - const params = transformToUpdateScheme(updatedAt, [ - ...(existingConversation.messages ?? []), - ...messages, - ]); - - const maxRetries = 3; - let attempt = 0; - let response; - while (attempt < maxRetries) { - try { - response = await esClient.updateByQuery({ - conflicts: 'proceed', - index: conversationIndex, - query: { - ids: { - values: [existingConversation.id ?? ''], - }, - }, - refresh: true, - script: { - lang: 'painless', - params: { - ...params, - }, - source: ` - if (params.assignEmpty == true || params.containsKey('messages')) { - def messages = []; - for (message in params.messages) { - def newMessage = [:]; - newMessage['@timestamp'] = message['@timestamp']; - newMessage.content = message.content; - newMessage.is_error = message.is_error; - newMessage.reader = message.reader; - newMessage.role = message.role; - if (message.trace_data != null) { - newMessage.trace_data = message.trace_data; - } - if (message.metadata != null) { - newMessage.metadata = message.metadata; - } - messages.add(newMessage); - } - ctx._source.messages = messages; - } - ctx._source.updated_at = params.updated_at; - `, - }, - }); - if ( - (response?.updated && response?.updated > 0) || - (response?.failures && response?.failures.length > 0) - ) { - break; - } - if ( - response?.version_conflicts && - response?.version_conflicts > 0 && - response?.updated === 0 - ) { - attempt++; - if (attempt < maxRetries) { - logger.warn( - `Version conflict detected, retrying appendConversationMessages (attempt ${ - attempt + 1 - }) for conversation ID: ${existingConversation.id}` - ); - await new Promise((resolve) => setTimeout(resolve, 100 * attempt)); // Exponential backoff - } - } else { - break; - } - } catch (err) { - logger.error( - `Error appending conversation messages: ${err} for conversation by ID: ${existingConversation.id}` - ); - throw err; - } - } + const { errors, docs_updated: docsUpdated } = await dataWriter.bulk< + UpdateConversationSchema, + never + >({ + documentsToUpdate: [params], + getUpdateScript: (document: UpdateConversationSchema) => + getUpdateScript({ conversation: document }), + authenticatedUser, + }); - if (response && response?.failures && response?.failures.length > 0) { + if (errors && errors.length > 0) { logger.error( - `Error appending conversation messages: ${response?.failures.map( - (f) => f.id + `Error appending conversation messages: ${errors.map( + (err: BulkOperationError) => err.message )} for conversation by ID: ${existingConversation.id}` ); return null; } - const updatedConversation = await getConversation({ - esClient, - conversationIndex, - id: existingConversation.id, - logger, - }); + const updatedConversation = transformESToConversations( + docsUpdated as EsConversationSchema[] + )?.[0]; + return updatedConversation; }; -export const transformToUpdateScheme = (updatedAt: string, messages: Message[]) => { +export const transformToUpdateScheme = ( + updatedAt: string, + messages: Message[], + existingConversation: ConversationResponse +) => { return { + id: existingConversation.id, updated_at: updatedAt, + // Preserve all existing conversation fields + ...(existingConversation.title ? { title: existingConversation.title } : {}), + ...(existingConversation.apiConfig + ? { + api_config: { + action_type_id: existingConversation.apiConfig.actionTypeId, + connector_id: existingConversation.apiConfig.connectorId, + default_system_prompt_id: existingConversation.apiConfig.defaultSystemPromptId, + model: existingConversation.apiConfig.model, + provider: existingConversation.apiConfig.provider, + }, + } + : {}), + ...(existingConversation.excludeFromLastConversationStorage != null + ? { + exclude_from_last_conversation_storage: + existingConversation.excludeFromLastConversationStorage, + } + : {}), + ...(existingConversation.replacements + ? { + replacements: Object.keys(existingConversation.replacements).map((key) => ({ + uuid: key, + value: existingConversation.replacements?.[key] ?? '', + })), + } + : {}), + // Update messages with the new combined list messages: messages?.map((message) => ({ '@timestamp': message.timestamp, content: message.content, diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/ai_assistant_data_clients/conversations/get_conversation.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/ai_assistant_data_clients/conversations/get_conversation.test.ts index f3a3af6050a42..6e49dd7f72d6f 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/ai_assistant_data_clients/conversations/get_conversation.test.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/ai_assistant_data_clients/conversations/get_conversation.test.ts @@ -12,7 +12,8 @@ import { estypes } from '@elastic/elasticsearch'; import { EsConversationSchema } from './types'; import { authenticatedUser } from '../../__mocks__/user'; import { loggingSystemMock } from '@kbn/core-logging-server-mocks'; -import { ConversationResponse } from '@kbn/elastic-assistant-common'; +import type { ConversationResponse } from '@kbn/elastic-assistant-common'; +import { getEsConversationSchemaMock } from '../../__mocks__/conversations_schema.mock'; export const getConversationResponseMock = (): ConversationResponse => ({ createdAt: '2020-04-20T15:25:31.830Z', @@ -30,6 +31,9 @@ export const getConversationResponseMock = (): ConversationResponse => ({ model: 'test', provider: 'Azure OpenAI', }, + summary: { + content: 'test', + }, category: 'assistant', users: [ { @@ -56,34 +60,7 @@ export const getSearchConversationMock = (): estypes.SearchResponse => { - const esClient = await this.options.elasticsearchClientPromise; + const dataWriter = await this.getWriter(); return appendConversationMessages({ - esClient, + dataWriter, logger: this.options.logger, - conversationIndex: this.indexTemplateAndPattern.alias, existingConversation, messages, + authenticatedUser: authenticatedUser ?? this.options.currentUser ?? undefined, }); }; diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/ai_assistant_data_clients/conversations/transforms.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/ai_assistant_data_clients/conversations/transforms.ts index 05eaf2da96388..3fba0b0e07357 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/ai_assistant_data_clients/conversations/transforms.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/ai_assistant_data_clients/conversations/transforms.ts @@ -28,6 +28,7 @@ export const transformESToConversation = ( })) ?? [], title: conversationSchema.title, category: conversationSchema.category, + summary: conversationSchema.summary, ...(conversationSchema.api_config ? { apiConfig: { diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/ai_assistant_data_clients/find.test.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/ai_assistant_data_clients/find.test.ts index a9ad152770040..1e850fd5b90dc 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/ai_assistant_data_clients/find.test.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/ai_assistant_data_clients/find.test.ts @@ -10,7 +10,8 @@ import { elasticsearchClientMock } from '@kbn/core-elasticsearch-client-server-m import { estypes } from '@elastic/elasticsearch'; import { loggingSystemMock } from '@kbn/core-logging-server-mocks'; import { findDocuments } from './find'; -import { EsConversationSchema } from './conversations/types'; +import type { EsConversationSchema } from './conversations/types'; +import { getEsConversationSchemaMock } from '../__mocks__/conversations_schema.mock'; export const getSearchConversationMock = (): estypes.SearchResponse => ({ _scroll_id: '123', @@ -26,34 +27,7 @@ export const getSearchConversationMock = (): estypes.SearchResponse { errors: BulkOperationError[]; docs_created: string[]; docs_deleted: string[]; - docs_updated: unknown[]; + docs_updated: TUpdated[]; took: number; } @@ -44,7 +44,7 @@ interface BulkParams { export interface DocumentsDataWriter { bulk: ( params: BulkParams - ) => Promise; + ) => Promise>; } interface DocumentsDataWriterOptions { @@ -60,14 +60,20 @@ export class DocumentsDataWriter implements DocumentsDataWriter { public bulk = async ( params: BulkParams - ) => { + ): Promise> => { try { if ( !params.documentsToCreate?.length && !params.documentsToUpdate?.length && !params.documentsToDelete?.length ) { - return { errors: [], docs_created: [], docs_deleted: [], docs_updated: [], took: 0 }; + return { + errors: [], + docs_created: [], + docs_deleted: [], + docs_updated: [], + took: 0, + } as WriterBulkResponse; } const { errors, items, took } = await this.options.esClient.bulk( @@ -91,9 +97,9 @@ export class DocumentsDataWriter implements DocumentsDataWriter { .map((item) => item.delete?._id), docs_updated: items .filter((item) => item.update?.status === 201 || item.update?.status === 200) - .map((item) => item.update?.get?._source), + .map((item) => item.update?.get?._source) as TUpdateParams[], took, - } as WriterBulkResponse; + } as WriterBulkResponse; } catch (e) { this.options.logger.error(`Error bulk actions for documents: ${e.message}`); return { @@ -109,7 +115,7 @@ export class DocumentsDataWriter implements DocumentsDataWriter { docs_deleted: [], docs_updated: [], took: 0, - } as WriterBulkResponse; + } as WriterBulkResponse; } }; diff --git a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/user_conversations/append_conversation_messages_route.ts b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/user_conversations/append_conversation_messages_route.ts index ffbaed812cfab..f9d7c6e6a2bd6 100644 --- a/x-pack/solutions/security/plugins/elastic_assistant/server/routes/user_conversations/append_conversation_messages_route.ts +++ b/x-pack/solutions/security/plugins/elastic_assistant/server/routes/user_conversations/append_conversation_messages_route.ts @@ -63,10 +63,10 @@ export const appendConversationMessageRoute = (router: ElasticAssistantPluginRou statusCode: 404, }); } - const conversation = await dataClient?.appendConversationMessages({ existingConversation, messages: request.body.messages, + authenticatedUser, }); if (conversation == null) { return assistantResponse.error({