From f618b713afdfa2a1596ea8c258a47e6e70becb1a Mon Sep 17 00:00:00 2001
From: ItzCrazyKns
Date: Thu, 2 May 2024 12:14:26 +0530
Subject: [PATCH] feat(chatModels): load model from localstorage
---
README.md | 10 +++---
package.json | 2 +-
sample.config.toml | 2 --
src/config.ts | 17 +++-------
src/routes/config.ts | 9 ------
src/routes/images.ts | 13 ++++----
src/routes/index.ts | 2 ++
src/routes/models.ts | 18 +++++++++++
src/routes/videos.ts | 13 ++++----
src/websocket/connectionManager.ts | 16 +++++++---
src/websocket/websocketServer.ts | 4 +--
ui/components/ChatWindow.tsx | 36 ++++++++++++++++++---
ui/components/SearchImages.tsx | 6 ++++
ui/components/SearchVideos.tsx | 6 ++++
ui/components/SettingsDialog.tsx | 51 ++++++++++++++++--------------
ui/package.json | 2 +-
16 files changed, 126 insertions(+), 81 deletions(-)
create mode 100644 src/routes/models.ts
diff --git a/README.md b/README.md
index 50e0e1d0..bb7171ba 100644
--- a/README.md
+++ b/README.md
@@ -59,13 +59,11 @@ There are mainly 2 ways of installing Perplexica - With Docker, Without Docker.
4. Rename the `sample.config.toml` file to `config.toml`. For Docker setups, you need only fill in the following fields:
- - `CHAT_MODEL`: The name of the LLM to use. Like `llama3:latest` (using Ollama), `gpt-3.5-turbo` (using OpenAI), etc.
- - `CHAT_MODEL_PROVIDER`: The chat model provider, either `openai` or `ollama`. Depending upon which provider you use you would have to fill in the following fields:
+ - `OPENAI`: Your OpenAI API key. **You only need to fill this if you wish to use OpenAI's models**.
+ - `OLLAMA`: Your Ollama API URL. You should enter it as `http://host.docker.internal:PORT_NUMBER`. If you installed Ollama on port 11434, use `http://host.docker.internal:11434`. For other ports, adjust accordingly. **You need to fill this if you wish to use Ollama's models instead of OpenAI's**.
+ - `GROQ`: Your Groq API key. **You only need to fill this if you wish to use Groq's hosted models**
- - `OPENAI`: Your OpenAI API key. **You only need to fill this if you wish to use OpenAI's models**.
- - `OLLAMA`: Your Ollama API URL. You should enter it as `http://host.docker.internal:PORT_NUMBER`. If you installed Ollama on port 11434, use `http://host.docker.internal:11434`. For other ports, adjust accordingly. **You need to fill this if you wish to use Ollama's models instead of OpenAI's**.
-
- **Note**: You can change these and use different models after running Perplexica as well from the settings page.
+ **Note**: You can change these after starting Perplexica from the settings dialog.
- `SIMILARITY_MEASURE`: The similarity measure to use (This is filled by default; you can leave it as is if you are unsure about it.)
diff --git a/package.json b/package.json
index a4b91e1a..94345696 100644
--- a/package.json
+++ b/package.json
@@ -1,6 +1,6 @@
{
"name": "perplexica-backend",
- "version": "1.0.0",
+ "version": "1.1.0",
"license": "MIT",
"author": "ItzCrazyKns",
"scripts": {
diff --git a/sample.config.toml b/sample.config.toml
index e2838269..7bc8880f 100644
--- a/sample.config.toml
+++ b/sample.config.toml
@@ -1,8 +1,6 @@
[GENERAL]
PORT = 3001 # Port to run the server on
SIMILARITY_MEASURE = "cosine" # "cosine" or "dot"
-CHAT_MODEL_PROVIDER = "openai" # "openai" or "ollama" or "groq"
-CHAT_MODEL = "gpt-3.5-turbo" # Name of the model to use
[API_KEYS]
OPENAI = "" # OpenAI API key - sk-1234567890abcdef1234567890abcdef
diff --git a/src/config.ts b/src/config.ts
index 25dcbf4d..7c0c7f14 100644
--- a/src/config.ts
+++ b/src/config.ts
@@ -8,8 +8,6 @@ interface Config {
GENERAL: {
PORT: number;
SIMILARITY_MEASURE: string;
- CHAT_MODEL_PROVIDER: string;
- CHAT_MODEL: string;
};
API_KEYS: {
OPENAI: string;
@@ -35,11 +33,6 @@ export const getPort = () => loadConfig().GENERAL.PORT;
export const getSimilarityMeasure = () =>
loadConfig().GENERAL.SIMILARITY_MEASURE;
-export const getChatModelProvider = () =>
- loadConfig().GENERAL.CHAT_MODEL_PROVIDER;
-
-export const getChatModel = () => loadConfig().GENERAL.CHAT_MODEL;
-
export const getOpenaiApiKey = () => loadConfig().API_KEYS.OPENAI;
export const getGroqApiKey = () => loadConfig().API_KEYS.GROQ;
@@ -52,21 +45,19 @@ export const updateConfig = (config: RecursivePartial) => {
const currentConfig = loadConfig();
for (const key in currentConfig) {
- /* if (currentConfig[key] && !config[key]) {
- config[key] = currentConfig[key];
- } */
+ if (!config[key]) config[key] = {};
- if (currentConfig[key] && typeof currentConfig[key] === 'object') {
+ if (typeof currentConfig[key] === 'object' && currentConfig[key] !== null) {
for (const nestedKey in currentConfig[key]) {
if (
- currentConfig[key][nestedKey] &&
!config[key][nestedKey] &&
+ currentConfig[key][nestedKey] &&
config[key][nestedKey] !== ''
) {
config[key][nestedKey] = currentConfig[key][nestedKey];
}
}
- } else if (currentConfig[key] && !config[key] && config[key] !== '') {
+ } else if (currentConfig[key] && config[key] !== '') {
config[key] = currentConfig[key];
}
}
diff --git a/src/routes/config.ts b/src/routes/config.ts
index 4d22ec56..9518c5f0 100644
--- a/src/routes/config.ts
+++ b/src/routes/config.ts
@@ -1,8 +1,6 @@
import express from 'express';
import { getAvailableProviders } from '../lib/providers';
import {
- getChatModel,
- getChatModelProvider,
getGroqApiKey,
getOllamaApiEndpoint,
getOpenaiApiKey,
@@ -26,9 +24,6 @@ router.get('/', async (_, res) => {
config['providers'][provider] = Object.keys(providers[provider]);
}
- config['selectedProvider'] = getChatModelProvider();
- config['selectedChatModel'] = getChatModel();
-
config['openeaiApiKey'] = getOpenaiApiKey();
config['ollamaApiUrl'] = getOllamaApiEndpoint();
config['groqApiKey'] = getGroqApiKey();
@@ -40,10 +35,6 @@ router.post('/', async (req, res) => {
const config = req.body;
const updatedConfig = {
- GENERAL: {
- CHAT_MODEL_PROVIDER: config.selectedProvider,
- CHAT_MODEL: config.selectedChatModel,
- },
API_KEYS: {
OPENAI: config.openeaiApiKey,
GROQ: config.groqApiKey,
diff --git a/src/routes/images.ts b/src/routes/images.ts
index 066a3ee4..39066897 100644
--- a/src/routes/images.ts
+++ b/src/routes/images.ts
@@ -2,7 +2,6 @@ import express from 'express';
import handleImageSearch from '../agents/imageSearchAgent';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { getAvailableProviders } from '../lib/providers';
-import { getChatModel, getChatModelProvider } from '../config';
import { HumanMessage, AIMessage } from '@langchain/core/messages';
import logger from '../utils/logger';
@@ -10,7 +9,7 @@ const router = express.Router();
router.post('/', async (req, res) => {
try {
- let { query, chat_history } = req.body;
+ let { query, chat_history, chat_model_provider, chat_model } = req.body;
chat_history = chat_history.map((msg: any) => {
if (msg.role === 'user') {
@@ -20,14 +19,14 @@ router.post('/', async (req, res) => {
}
});
- const models = await getAvailableProviders();
- const provider = getChatModelProvider();
- const chatModel = getChatModel();
+ const chatModels = await getAvailableProviders();
+ const provider = chat_model_provider || Object.keys(chatModels)[0];
+ const chatModel = chat_model || Object.keys(chatModels[provider])[0];
let llm: BaseChatModel | undefined;
- if (models[provider] && models[provider][chatModel]) {
- llm = models[provider][chatModel] as BaseChatModel | undefined;
+ if (chatModels[provider] && chatModels[provider][chatModel]) {
+ llm = chatModels[provider][chatModel] as BaseChatModel | undefined;
}
if (!llm) {
diff --git a/src/routes/index.ts b/src/routes/index.ts
index bcfc3d36..04390cd3 100644
--- a/src/routes/index.ts
+++ b/src/routes/index.ts
@@ -2,11 +2,13 @@ import express from 'express';
import imagesRouter from './images';
import videosRouter from './videos';
import configRouter from './config';
+import modelsRouter from './models';
const router = express.Router();
router.use('/images', imagesRouter);
router.use('/videos', videosRouter);
router.use('/config', configRouter);
+router.use('/models', modelsRouter);
export default router;
diff --git a/src/routes/models.ts b/src/routes/models.ts
new file mode 100644
index 00000000..f2332f4b
--- /dev/null
+++ b/src/routes/models.ts
@@ -0,0 +1,18 @@
+import express from 'express';
+import logger from '../utils/logger';
+import { getAvailableProviders } from '../lib/providers';
+
+const router = express.Router();
+
+router.get('/', async (req, res) => {
+ try {
+ const providers = await getAvailableProviders();
+
+ res.status(200).json({ providers });
+ } catch (err) {
+ res.status(500).json({ message: 'An error has occurred.' });
+ logger.error(err.message);
+ }
+});
+
+export default router;
diff --git a/src/routes/videos.ts b/src/routes/videos.ts
index bfd5fa84..fecd8745 100644
--- a/src/routes/videos.ts
+++ b/src/routes/videos.ts
@@ -1,7 +1,6 @@
import express from 'express';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { getAvailableProviders } from '../lib/providers';
-import { getChatModel, getChatModelProvider } from '../config';
import { HumanMessage, AIMessage } from '@langchain/core/messages';
import logger from '../utils/logger';
import handleVideoSearch from '../agents/videoSearchAgent';
@@ -10,7 +9,7 @@ const router = express.Router();
router.post('/', async (req, res) => {
try {
- let { query, chat_history } = req.body;
+ let { query, chat_history, chat_model_provider, chat_model } = req.body;
chat_history = chat_history.map((msg: any) => {
if (msg.role === 'user') {
@@ -20,14 +19,14 @@ router.post('/', async (req, res) => {
}
});
- const models = await getAvailableProviders();
- const provider = getChatModelProvider();
- const chatModel = getChatModel();
+ const chatModels = await getAvailableProviders();
+ const provider = chat_model_provider || Object.keys(chatModels)[0];
+ const chatModel = chat_model || Object.keys(chatModels[provider])[0];
let llm: BaseChatModel | undefined;
- if (models[provider] && models[provider][chatModel]) {
- llm = models[provider][chatModel] as BaseChatModel | undefined;
+ if (chatModels[provider] && chatModels[provider][chatModel]) {
+ llm = chatModels[provider][chatModel] as BaseChatModel | undefined;
}
if (!llm) {
diff --git a/src/websocket/connectionManager.ts b/src/websocket/connectionManager.ts
index afaaf443..c2f37980 100644
--- a/src/websocket/connectionManager.ts
+++ b/src/websocket/connectionManager.ts
@@ -1,15 +1,23 @@
import { WebSocket } from 'ws';
import { handleMessage } from './messageHandler';
-import { getChatModel, getChatModelProvider } from '../config';
import { getAvailableProviders } from '../lib/providers';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import type { Embeddings } from '@langchain/core/embeddings';
+import type { IncomingMessage } from 'http';
import logger from '../utils/logger';
-export const handleConnection = async (ws: WebSocket) => {
+export const handleConnection = async (
+ ws: WebSocket,
+ request: IncomingMessage,
+) => {
+ const searchParams = new URL(request.url, `http://${request.headers.host}`)
+ .searchParams;
+
const models = await getAvailableProviders();
- const provider = getChatModelProvider();
- const chatModel = getChatModel();
+ const provider =
+ searchParams.get('chatModelProvider') || Object.keys(models)[0];
+ const chatModel =
+ searchParams.get('chatModel') || Object.keys(models[provider])[0];
let llm: BaseChatModel | undefined;
let embeddings: Embeddings | undefined;
diff --git a/src/websocket/websocketServer.ts b/src/websocket/websocketServer.ts
index bc84f527..3ab0b519 100644
--- a/src/websocket/websocketServer.ts
+++ b/src/websocket/websocketServer.ts
@@ -10,9 +10,7 @@ export const initServer = (
const port = getPort();
const wss = new WebSocketServer({ server });
- wss.on('connection', (ws) => {
- handleConnection(ws);
- });
+ wss.on('connection', handleConnection);
logger.info(`WebSocket server started on port ${port}`);
};
diff --git a/ui/components/ChatWindow.tsx b/ui/components/ChatWindow.tsx
index 4c138ffd..68a2ba0d 100644
--- a/ui/components/ChatWindow.tsx
+++ b/ui/components/ChatWindow.tsx
@@ -19,14 +19,42 @@ const useSocket = (url: string) => {
useEffect(() => {
if (!ws) {
- const ws = new WebSocket(url);
- ws.onopen = () => {
- console.log('[DEBUG] open');
- setWs(ws);
+ const connectWs = async () => {
+ let chatModel = localStorage.getItem('chatModel');
+ let chatModelProvider = localStorage.getItem('chatModelProvider');
+
+ if (!chatModel || !chatModelProvider) {
+ const chatModelProviders = await fetch(
+ `${process.env.NEXT_PUBLIC_API_URL}/models`,
+ ).then(async (res) => (await res.json())['providers']);
+
+ if (
+ !chatModelProviders ||
+ Object.keys(chatModelProviders).length === 0
+ )
+ return console.error('No chat models available');
+
+ chatModelProvider = Object.keys(chatModelProviders)[0];
+ chatModel = Object.keys(chatModelProviders[chatModelProvider])[0];
+
+ localStorage.setItem('chatModel', chatModel!);
+ localStorage.setItem('chatModelProvider', chatModelProvider);
+ }
+
+ const ws = new WebSocket(
+ `${url}?chatModel=${chatModel}&chatModelProvider=${chatModelProvider}`,
+ );
+ ws.onopen = () => {
+ console.log('[DEBUG] open');
+ setWs(ws);
+ };
};
+
+ connectWs();
}
return () => {
+ 1;
ws?.close();
console.log('[DEBUG] closed');
};
diff --git a/ui/components/SearchImages.tsx b/ui/components/SearchImages.tsx
index 137571c4..aa70c96d 100644
--- a/ui/components/SearchImages.tsx
+++ b/ui/components/SearchImages.tsx
@@ -29,6 +29,10 @@ const SearchImages = ({
)}
- {config.selectedProvider && (
+ {selectedChatModelProvider && (
Chat Model