Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ jest.mock('../helpers', () => {
};
});
const mockAppendAssistantMessageToConversation = appendAssistantMessageToConversation as jest.Mock;
const mockCreateConversationWithUserInput = createConversationWithUserInput as jest.Mock;

const mockLangChainExecute = langChainExecute as jest.Mock;
const mockStream = jest.fn().mockImplementation(() => new PassThrough());
Expand Down Expand Up @@ -150,7 +151,7 @@ describe('chatCompleteRoute', () => {
jest.clearAllMocks();
mockAppendAssistantMessageToConversation.mockResolvedValue(true);
license.hasAtLeast.mockReturnValue(true);
(createConversationWithUserInput as jest.Mock).mockResolvedValue({ id: 'something' });
mockCreateConversationWithUserInput.mockResolvedValue({ id: 'something' });
mockLangChainExecute.mockImplementation(
async ({
connectorId,
Expand All @@ -166,12 +167,14 @@ describe('chatCompleteRoute', () => {
) => Promise<void>;
}) => {
if (!isStream && connectorId === 'mock-connector-id') {
onLlmResponse('Non-streamed test reply.', {}, false).catch(() => {});
return {
connector_id: 'mock-connector-id',
data: mockActionResponse,
status: 'ok',
};
} else if (isStream && connectorId === 'mock-connector-id') {
onLlmResponse('Streamed test reply.', {}, false).catch(() => {});
return mockStream;
} else {
onLlmResponse('simulated error', {}, true).catch(() => {});
Expand Down Expand Up @@ -399,4 +402,141 @@ describe('chatCompleteRoute', () => {
mockGetElser
);
});

it('should add assistant reply to existing conversation when `persist=true`', async () => {
const mockRouter = {
versioned: {
post: jest.fn().mockImplementation(() => {
return {
addVersion: jest.fn().mockImplementation(async (_, handler) => {
await handler(
mockContext,
{
...mockRequest,
body: {
...mockRequest.body,
conversationId: existingConversation.id,
},
},
mockResponse
);
expect(mockAppendAssistantMessageToConversation).toHaveBeenCalledWith(
expect.objectContaining({
messageContent: 'Non-streamed test reply.',
isError: false,
})
);
expect(mockCreateConversationWithUserInput).toHaveBeenCalledTimes(0);
}),
};
}),
},
};

chatCompleteRoute(
mockRouter as unknown as IRouter<ElasticAssistantRequestHandlerContext>,
mockGetElser
);
});

it('should not add assistant reply to existing conversation when `persist=false`', async () => {
const mockRouter = {
versioned: {
post: jest.fn().mockImplementation(() => {
return {
addVersion: jest.fn().mockImplementation(async (_, handler) => {
await handler(
mockContext,
{
...mockRequest,
body: {
...mockRequest.body,
conversationId: existingConversation.id,
persist: false,
},
},
mockResponse
);
expect(mockAppendAssistantMessageToConversation).toHaveBeenCalledTimes(0);
expect(mockCreateConversationWithUserInput).toHaveBeenCalledTimes(0);
}),
};
}),
},
};

chatCompleteRoute(
mockRouter as unknown as IRouter<ElasticAssistantRequestHandlerContext>,
mockGetElser
);
});

it('should add assistant reply to new conversation when `persist=true`', async () => {
const mockRouter = {
versioned: {
post: jest.fn().mockImplementation(() => {
return {
addVersion: jest.fn().mockImplementation(async (_, handler) => {
await handler(
mockContext,
{
...mockRequest,
body: {
...mockRequest.body,
conversationId: undefined,
persist: true,
},
},
mockResponse
);
expect(mockAppendAssistantMessageToConversation).toHaveBeenCalledWith(
expect.objectContaining({
messageContent: 'Non-streamed test reply.',
isError: false,
})
);
expect(mockCreateConversationWithUserInput).toHaveBeenCalledTimes(1);
}),
};
}),
},
};

chatCompleteRoute(
mockRouter as unknown as IRouter<ElasticAssistantRequestHandlerContext>,
mockGetElser
);
});

it('should not create a new conversation when `persist=false`', async () => {
const mockRouter = {
versioned: {
post: jest.fn().mockImplementation(() => {
return {
addVersion: jest.fn().mockImplementation(async (_, handler) => {
await handler(
mockContext,
{
...mockRequest,
body: {
...mockRequest.body,
conversationId: undefined,
persist: false,
},
},
mockResponse
);
expect(mockAppendAssistantMessageToConversation).toHaveBeenCalledTimes(0);
expect(mockCreateConversationWithUserInput).toHaveBeenCalledTimes(0);
}),
};
}),
},
};

chatCompleteRoute(
mockRouter as unknown as IRouter<ElasticAssistantRequestHandlerContext>,
mockGetElser
);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ export const chatCompleteRoute = (
await ctx.elasticAssistant.getAIAssistantAnonymizationFieldsDataClient();

let messages;
const conversationId = request.body.conversationId;
const existingConversationId = request.body.conversationId;
const connectorId = request.body.connectorId;

let latestReplacements: Replacements = {};
Expand Down Expand Up @@ -159,11 +159,10 @@ export const chatCompleteRoute = (
});

let newConversation: ConversationResponse | undefined | null;
if (conversationsDataClient && !conversationId && request.body.persist) {
if (conversationsDataClient && !existingConversationId && request.body.persist) {
newConversation = await createConversationWithUserInput({
actionTypeId,
connectorId,
conversationId,
conversationsDataClient,
promptId: request.body.promptId,
replacements: latestReplacements,
Expand All @@ -178,18 +177,23 @@ export const chatCompleteRoute = (
}));
}

// Do not persist conversation messages if `persist = false`
const conversationId = request.body.persist
? existingConversationId ?? newConversation?.id
: undefined;

const contentReferencesStore = newContentReferencesStore();

const onLlmResponse = async (
content: string,
traceData: Message['traceData'] = {},
isError = false
): Promise<void> => {
if (newConversation?.id && conversationsDataClient) {
if (conversationId && conversationsDataClient) {
const contentReferences = pruneContentReferences(content, contentReferencesStore);

await appendAssistantMessageToConversation({
conversationId: newConversation?.id,
conversationId,
conversationsDataClient,
messageContent: content,
replacements: latestReplacements,
Expand All @@ -207,7 +211,7 @@ export const chatCompleteRoute = (
actionTypeId,
connectorId,
isOssModel,
conversationId: conversationId ?? newConversation?.id,
conversationId,
context: ctx,
getElser,
logger,
Expand Down