Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,34 @@ describe('use chat send', () => {
});
});
});
it('retries getConversation up to 5 times if title is empty, and stops when title is found', async () => {
const promptText = 'test prompt';
const getConversationMock = jest.fn();
// First 3 calls return empty title, 4th returns non-empty
getConversationMock
.mockResolvedValueOnce({ title: '' })
.mockResolvedValueOnce({ title: '' })
.mockResolvedValueOnce({ title: '' })
.mockResolvedValueOnce({ title: 'Final Title' });
(useConversation as jest.Mock).mockReturnValue({
removeLastMessage,
clearConversation,
getConversation: getConversationMock,
createConversation: jest.fn(),
});
const { result } = renderHook(
() =>
useChatSend({
...testProps,
currentConversation: { ...emptyWelcomeConvo, id: 'convo-id', title: '' },
}),
{ wrapper: TestProviders }
);
await act(async () => {
await result.current.handleChatSend(promptText);
});
// Should call getConversation 4 times (until non-empty title)
expect(getConversationMock).toHaveBeenCalledTimes(4);
expect(getConversationMock).toHaveBeenLastCalledWith('convo-id');
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,15 @@ export const useChatSend = ({
const { setLastConversation } = useAssistantLastConversation({ spaceId });
const [userPrompt, setUserPrompt] = useState<string | null>(null);

const { isLoading, sendMessage, abortStream } = useSendMessage();
const { sendMessage, abortStream } = useSendMessage();
const { clearConversation, createConversation, getConversation, removeLastMessage } =
useConversation();
const { data: kbStatus } = useKnowledgeBaseStatus({ http, enabled: isAssistantEnabled });
const isSetupComplete = kbStatus?.elser_exists && kbStatus?.security_labs_exists;

// Local loading state that persists until the entire message flow is complete
const [isLoadingChatSend, setIsLoadingChatSend] = useState(false);

// Handles sending latest user prompt to API
const handleSendMessage = useCallback(
async (promptText: string) => {
Expand All @@ -81,79 +84,102 @@ export const useChatSend = ({
);
return;
}
const apiConfig = currentConversation.apiConfig;
let newConvo;
if (currentConversation.id === '') {
// create conversation with empty title, GENERATE_CHAT_TITLE graph step will properly title
newConvo = await createConversation(currentConversation);
if (newConvo?.id) {
setLastConversation({
id: newConvo.id,
});

setIsLoadingChatSend(true);

try {
const apiConfig = currentConversation.apiConfig;
let newConvo;
if (currentConversation.id === '') {
// create conversation with empty title, GENERATE_CHAT_TITLE graph step will properly title
newConvo = await createConversation(currentConversation);
if (newConvo?.id) {
setLastConversation({
id: newConvo.id,
});
}
}
const convo: Conversation = { ...currentConversation, ...(newConvo ?? {}) };
const userMessage = getCombinedMessage({
currentReplacements: convo.replacements,
promptText,
selectedPromptContexts,
});

const baseReplacements: Replacements = userMessage.replacements ?? convo.replacements;

const selectedPromptContextsReplacements = Object.values(
selectedPromptContexts
).reduce<Replacements>((acc, context) => ({ ...acc, ...context.replacements }), {});

const replacements: Replacements = {
...baseReplacements,
...selectedPromptContextsReplacements,
};
const updatedMessages = [...convo.messages, userMessage].map((m) => ({
...m,
content: m.content ?? '',
}));
setCurrentConversation({
...convo,
replacements,
messages: updatedMessages,
});

// Reset prompt context selection and preview before sending:
setSelectedPromptContexts({});

const rawResponse = await sendMessage({
apiConfig,
http,
message: userMessage.content ?? '',
conversationId: convo.id,
replacements,
});

assistantTelemetry?.reportAssistantMessageSent({
role: userMessage.role,
actionTypeId: apiConfig.actionTypeId,
model: apiConfig.model,
provider: apiConfig.provider,
isEnabledKnowledgeBase: isSetupComplete ?? false,
});

const responseMessage: ClientMessage = getMessageFromRawResponse(rawResponse);
if (convo.title === '') {
// Retry getConversation up to 5 times if title is empty
let retryCount = 0;
const maxRetries = 5;
while (retryCount < maxRetries) {
const conversation = await getConversation(convo.id);
convo.title = conversation?.title ?? '';

if (convo.title !== '') {
break; // Title found, exit retry loop
}

retryCount++;
if (retryCount < maxRetries) {
// Wait 1 second before next retry
await new Promise((resolve) => setTimeout(resolve, 1000));
}
}
}
setCurrentConversation({
...convo,
replacements,
messages: [...updatedMessages, responseMessage],
});
assistantTelemetry?.reportAssistantMessageSent({
role: responseMessage.role,
actionTypeId: apiConfig.actionTypeId,
model: apiConfig.model,
provider: apiConfig.provider,
isEnabledKnowledgeBase: isSetupComplete ?? false,
});
} finally {
setIsLoadingChatSend(false);
}
const convo: Conversation = { ...currentConversation, ...(newConvo ?? {}) };
const userMessage = getCombinedMessage({
currentReplacements: convo.replacements,
promptText,
selectedPromptContexts,
});

const baseReplacements: Replacements = userMessage.replacements ?? convo.replacements;

const selectedPromptContextsReplacements = Object.values(
selectedPromptContexts
).reduce<Replacements>((acc, context) => ({ ...acc, ...context.replacements }), {});

const replacements: Replacements = {
...baseReplacements,
...selectedPromptContextsReplacements,
};
const updatedMessages = [...convo.messages, userMessage].map((m) => ({
...m,
content: m.content ?? '',
}));
setCurrentConversation({
...convo,
replacements,
messages: updatedMessages,
});

// Reset prompt context selection and preview before sending:
setSelectedPromptContexts({});

const rawResponse = await sendMessage({
apiConfig,
http,
message: userMessage.content ?? '',
conversationId: convo.id,
replacements,
});

assistantTelemetry?.reportAssistantMessageSent({
role: userMessage.role,
actionTypeId: apiConfig.actionTypeId,
model: apiConfig.model,
provider: apiConfig.provider,
isEnabledKnowledgeBase: isSetupComplete ?? false,
});

const responseMessage: ClientMessage = getMessageFromRawResponse(rawResponse);
if (convo.title === '') {
convo.title = (await getConversation(convo.id))?.title ?? '';
}
setCurrentConversation({
...convo,
replacements,
messages: [...updatedMessages, responseMessage],
});
assistantTelemetry?.reportAssistantMessageSent({
role: responseMessage.role,
actionTypeId: apiConfig.actionTypeId,
model: apiConfig.model,
provider: apiConfig.provider,
isEnabledKnowledgeBase: isSetupComplete ?? false,
});
},
[
assistantTelemetry,
Expand Down Expand Up @@ -241,7 +267,7 @@ export const useChatSend = ({
handleChatSend,
abortStream,
handleRegenerateResponse,
isLoading,
isLoading: isLoadingChatSend,
userPrompt,
setUserPrompt,
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,35 @@ export const getEsCreateConversationSchemaMock = (
namespace: 'default',
...rest,
});

export const getEsConversationSchemaMock = (
rest?: Partial<EsConversationSchema>
): EsConversationSchema => ({
'@timestamp': '2020-04-20T15:25:31.830Z',
created_at: '2020-04-20T15:25:31.830Z',
title: 'title-1',
updated_at: '2020-04-20T15:25:31.830Z',
messages: [],
id: '1',
namespace: 'default',
exclude_from_last_conversation_storage: false,
api_config: {
action_type_id: '.gen-ai',
connector_id: 'c1',
default_system_prompt_id: 'prompt-1',
model: 'test',
provider: 'Azure OpenAI',
},
summary: {
content: 'test',
},
category: 'assistant',
users: [
{
id: '1111',
name: 'elastic',
},
],
replacements: undefined,
...rest,
});
Loading