Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/app/(general)/_components/sidebar/chats/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ export const NavChats = () => {

const allChats = chats?.pages.flatMap((page) => page.items) ?? [];

if (isLoading || !chats || state === "collapsed") return null;
const isInitialLoading = isLoading && !chats;

if (isInitialLoading || state === "collapsed") return null;

const starredChats = allChats.filter((chat) => chat.starred);
const regularChats = allChats.filter((chat) => !chat.starred);
Expand Down
13 changes: 12 additions & 1 deletion src/app/(general)/_components/sidebar/chats/item.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -154,5 +154,16 @@ const PureChatItem: React.FC<Props> = ({

export const ChatItem = memo(PureChatItem, (prevProps, nextProps) => {
if (prevProps.isActive !== nextProps.isActive) return false;
return true;

const prevChat = prevProps.chat;
const nextChat = nextProps.chat;

return (
prevChat.id === nextChat.id &&
prevChat.title === nextChat.title &&
prevChat.starred === nextChat.starred &&
prevChat.visibility === nextChat.visibility &&
prevChat.parentChatId === nextChat.parentChatId &&
prevChat.workbenchId === nextChat.workbenchId
);
});
238 changes: 229 additions & 9 deletions src/app/(general)/_contexts/chat-context.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
useEffect,
useState,
useCallback,
useMemo,
} from "react";

import { useChat } from "@ai-sdk/react";
Expand Down Expand Up @@ -37,7 +38,7 @@ import type { ClientToolkit } from "@/toolkits/types";
import type { z } from "zod";
import type { SelectedToolkit } from "@/components/toolkit/types";
import type { Toolkits } from "@/toolkits/toolkits/shared";
import type { Workbench } from "@prisma/client";
import type { Chat, Workbench } from "@prisma/client";
import type { PersistedToolkit } from "@/lib/cookies/types";
import type { ImageModel } from "@/ai/image/types";
import type { LanguageModel } from "@/ai/language/types";
Expand Down Expand Up @@ -175,9 +176,92 @@ export function ChatProvider({

return [];
});
const [hasInvalidated, setHasInvalidated] = useState(false);
const [hasSyncedInitialChat, setHasSyncedInitialChat] = useState(false);
const [hasInsertedPlaceholder, setHasInsertedPlaceholder] = useState(false);
const [streamStopped, setStreamStopped] = useState(false);

const chatListInput = useMemo(
() =>
({
limit: 10,
workbenchId: workbench?.id ?? null,
}) as const,
[workbench?.id],
);

const syncChatList = useCallback(
(attempt = 0) => {
void (async () => {
try {
const latestChat = await utils.chats.getChat.fetch(id);
if (!latestChat) return;

const { workbench: _ignoredWorkbench, ...chatWithoutWorkbench } =
latestChat;
const normalizedChat = chatWithoutWorkbench as Chat;

utils.chats.getChats.setInfiniteData(chatListInput, (cache) => {
if (!cache) {
return {
pageParams: [null],
pages: [
{
items: [normalizedChat],
hasMore: false,
nextCursor: undefined,
},
],
};
}

return {
...cache,
pages: cache.pages.map((page, index) => {
if (index === 0) {
const existingIndex = page.items.findIndex(
(item) => item.id === normalizedChat.id,
);

if (existingIndex !== -1) {
const nextItems = [...page.items];
nextItems[existingIndex] = normalizedChat;
return {
...page,
items: nextItems,
};
}

const filteredItems = page.items.filter(
(item) => item.id !== normalizedChat.id,
);
const nextItems = [
normalizedChat,
...filteredItems,
].slice(0, chatListInput.limit);
return {
...page,
items: nextItems,
};
}
return page;
}),
};
});

setHasInsertedPlaceholder(false);
setHasSyncedInitialChat(true);

if (normalizedChat.title === "New Chat" && attempt < 5) {
setTimeout(() => syncChatList(attempt + 1), 1000);
}
} catch (error) {
console.error("Failed to sync chat list:", error);
}
})();
},
[chatListInput, id, utils],
);

// Wrapper functions that also save to cookies
const setSelectedChatModel = (model: LanguageModel) => {
setSelectedChatModelState(model);
Expand Down Expand Up @@ -213,7 +297,7 @@ export function ChatProvider({
handleSubmit: originalHandleSubmit,
input,
setInput,
append,
append: originalAppend,
status,
stop,
reload,
Expand Down Expand Up @@ -249,14 +333,31 @@ export function ChatProvider({
onFinish: () => {
setStreamStopped(false);
void utils.messages.getMessagesForChat.invalidate({ chatId: id });
if (initialMessages.length === 0 && !hasInvalidated) {
setHasInvalidated(true);
void utils.chats.getChats.invalidate({
workbenchId: workbench?.id,
});
if (initialMessages.length === 0 && !hasSyncedInitialChat) {
syncChatList();
}
},
onError: (error) => {
if (
initialMessages.length === 0 &&
hasInsertedPlaceholder &&
!hasSyncedInitialChat
) {
utils.chats.getChats.setInfiniteData(chatListInput, (cache) => {
if (!cache) return cache;

return {
...cache,
pages: cache.pages.map((page) => ({
...page,
items: page.items.filter((chat) => chat.id !== id),
})),
};
});
setHasInsertedPlaceholder(false);
setHasSyncedInitialChat(false);
}

if (error instanceof ChatSDKError) {
toast.error(error.message);
} else {
Expand All @@ -282,15 +383,134 @@ export function ChatProvider({
onStreamError,
});

const promoteChatInCache = useCallback(() => {
utils.chats.getChats.setInfiniteData(chatListInput, (cache) => {
if (!cache) {
return cache;
}

const pages = cache.pages.map((page) => ({
...page,
items: [...page.items],
}));

let promotedChat: Chat | undefined;

for (const page of pages) {
const existingIndex = page.items.findIndex((chat) => chat.id === id);
if (existingIndex !== -1) {
const [chat] = page.items.splice(existingIndex, 1);
if (!promotedChat) {
promotedChat = chat;
}
}
}

if (!promotedChat) {
return cache;
}

let carry: Chat | undefined = promotedChat;

for (let pageIndex = 0; pageIndex < pages.length && carry; pageIndex += 1) {
const page = pages[pageIndex];
page.items.unshift(carry);

if (page.items.length > chatListInput.limit) {
carry = page.items.pop();
} else {
carry = undefined;
}
}

if (carry) {
const lastPage = pages[pages.length - 1];
if (lastPage) {
lastPage.items.push(carry);
}
}

return {
...cache,
pages,
};
});
}, [chatListInput, id, utils]);

const handleSubmit: UseChatHelpers["handleSubmit"] = (
event,
chatRequestOptions,
) => {
// Reset stream stopped flag when submitting new message
setStreamStopped(false);

promoteChatInCache();

if (initialMessages.length === 0 && !hasInsertedPlaceholder) {
setHasInsertedPlaceholder(true);
const placeholderChat: Chat = {
id,
title: "New Chat",
userId: "",
visibility: initialVisibilityType,
parentChatId: null,
workbenchId: workbench?.id ?? null,
starred: false,
createdAt: new Date(),
};

utils.chats.getChats.setInfiniteData(chatListInput, (cache) => {
if (!cache) {
return {
pageParams: [null],
pages: [
{
items: [placeholderChat],
hasMore: false,
nextCursor: undefined,
},
],
};
}

const alreadyExists = cache.pages.some((page) =>
page.items.some((chat) => chat.id === id),
);

if (alreadyExists) {
return cache;
}

return {
...cache,
pages: cache.pages.map((page, index) => {
if (index === 0) {
const nextItems = [placeholderChat, ...page.items].slice(
0,
chatListInput.limit,
);
return {
...page,
items: nextItems,
};
}
return page;
}),
};
});
}

originalHandleSubmit(event, chatRequestOptions);
};

const appendWithPromotion = useCallback<UseChatHelpers["append"]>(
async (message, options) => {
promoteChatInCache();
await originalAppend(message, options);
},
[originalAppend, promoteChatInCache],
);

useEffect(() => {
if (
selectedChatModel?.capabilities?.includes(
Expand Down Expand Up @@ -319,7 +539,7 @@ export function ChatProvider({
handleSubmit,
stop,
reload,
append,
append: appendWithPromotion,
imageGenerationModel,
setImageGenerationModel,
toolkits,
Expand Down