Skip to content

Commit

Permalink
feat(chatModels): load model from localstorage
Browse files Browse the repository at this point in the history
  • Loading branch information
ItzCrazyKns committed May 2, 2024
1 parent ed9ff3c commit f618b71
Show file tree
Hide file tree
Showing 16 changed files with 126 additions and 81 deletions.
10 changes: 4 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,11 @@ There are mainly 2 ways of installing Perplexica - With Docker, Without Docker.

4. Rename the `sample.config.toml` file to `config.toml`. For Docker setups, you need only fill in the following fields:

- `CHAT_MODEL`: The name of the LLM to use. Like `llama3:latest` (using Ollama), `gpt-3.5-turbo` (using OpenAI), etc.
- `CHAT_MODEL_PROVIDER`: The chat model provider, either `openai` or `ollama`. Depending upon which provider you use you would have to fill in the following fields:
- `OPENAI`: Your OpenAI API key. **You only need to fill this if you wish to use OpenAI's models**.
- `OLLAMA`: Your Ollama API URL. You should enter it as `http://host.docker.internal:PORT_NUMBER`. If you installed Ollama on port 11434, use `http://host.docker.internal:11434`. For other ports, adjust accordingly. **You need to fill this if you wish to use Ollama's models instead of OpenAI's**.
- `GROQ`: Your Groq API key. **You only need to fill this if you wish to use Groq's hosted models**

- `OPENAI`: Your OpenAI API key. **You only need to fill this if you wish to use OpenAI's models**.
- `OLLAMA`: Your Ollama API URL. You should enter it as `http://host.docker.internal:PORT_NUMBER`. If you installed Ollama on port 11434, use `http://host.docker.internal:11434`. For other ports, adjust accordingly. **You need to fill this if you wish to use Ollama's models instead of OpenAI's**.

**Note**: You can change these and use different models after running Perplexica as well from the settings page.
**Note**: You can change these after starting Perplexica from the settings dialog.

- `SIMILARITY_MEASURE`: The similarity measure to use (This is filled by default; you can leave it as is if you are unsure about it.)

Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "perplexica-backend",
"version": "1.0.0",
"version": "1.1.0",
"license": "MIT",
"author": "ItzCrazyKns",
"scripts": {
Expand Down
2 changes: 0 additions & 2 deletions sample.config.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
[GENERAL]
PORT = 3001 # Port to run the server on
SIMILARITY_MEASURE = "cosine" # "cosine" or "dot"
CHAT_MODEL_PROVIDER = "openai" # "openai" or "ollama" or "groq"
CHAT_MODEL = "gpt-3.5-turbo" # Name of the model to use

[API_KEYS]
OPENAI = "" # OpenAI API key - sk-1234567890abcdef1234567890abcdef
Expand Down
17 changes: 4 additions & 13 deletions src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ interface Config {
GENERAL: {
PORT: number;
SIMILARITY_MEASURE: string;
CHAT_MODEL_PROVIDER: string;
CHAT_MODEL: string;
};
API_KEYS: {
OPENAI: string;
Expand All @@ -35,11 +33,6 @@ export const getPort = () => loadConfig().GENERAL.PORT;
export const getSimilarityMeasure = () =>
loadConfig().GENERAL.SIMILARITY_MEASURE;

export const getChatModelProvider = () =>
loadConfig().GENERAL.CHAT_MODEL_PROVIDER;

export const getChatModel = () => loadConfig().GENERAL.CHAT_MODEL;

export const getOpenaiApiKey = () => loadConfig().API_KEYS.OPENAI;

export const getGroqApiKey = () => loadConfig().API_KEYS.GROQ;
Expand All @@ -52,21 +45,19 @@ export const updateConfig = (config: RecursivePartial<Config>) => {
const currentConfig = loadConfig();

for (const key in currentConfig) {
/* if (currentConfig[key] && !config[key]) {
config[key] = currentConfig[key];
} */
if (!config[key]) config[key] = {};

if (currentConfig[key] && typeof currentConfig[key] === 'object') {
if (typeof currentConfig[key] === 'object' && currentConfig[key] !== null) {
for (const nestedKey in currentConfig[key]) {
if (
currentConfig[key][nestedKey] &&
!config[key][nestedKey] &&
currentConfig[key][nestedKey] &&
config[key][nestedKey] !== ''
) {
config[key][nestedKey] = currentConfig[key][nestedKey];
}
}
} else if (currentConfig[key] && !config[key] && config[key] !== '') {
} else if (currentConfig[key] && config[key] !== '') {
config[key] = currentConfig[key];
}
}
Expand Down
9 changes: 0 additions & 9 deletions src/routes/config.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import express from 'express';
import { getAvailableProviders } from '../lib/providers';
import {
getChatModel,
getChatModelProvider,
getGroqApiKey,
getOllamaApiEndpoint,
getOpenaiApiKey,
Expand All @@ -26,9 +24,6 @@ router.get('/', async (_, res) => {
config['providers'][provider] = Object.keys(providers[provider]);
}

config['selectedProvider'] = getChatModelProvider();
config['selectedChatModel'] = getChatModel();

config['openeaiApiKey'] = getOpenaiApiKey();
config['ollamaApiUrl'] = getOllamaApiEndpoint();
config['groqApiKey'] = getGroqApiKey();
Expand All @@ -40,10 +35,6 @@ router.post('/', async (req, res) => {
const config = req.body;

const updatedConfig = {
GENERAL: {
CHAT_MODEL_PROVIDER: config.selectedProvider,
CHAT_MODEL: config.selectedChatModel,
},
API_KEYS: {
OPENAI: config.openeaiApiKey,
GROQ: config.groqApiKey,
Expand Down
13 changes: 6 additions & 7 deletions src/routes/images.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@ import express from 'express';
import handleImageSearch from '../agents/imageSearchAgent';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { getAvailableProviders } from '../lib/providers';
import { getChatModel, getChatModelProvider } from '../config';
import { HumanMessage, AIMessage } from '@langchain/core/messages';
import logger from '../utils/logger';

const router = express.Router();

router.post('/', async (req, res) => {
try {
let { query, chat_history } = req.body;
let { query, chat_history, chat_model_provider, chat_model } = req.body;

chat_history = chat_history.map((msg: any) => {
if (msg.role === 'user') {
Expand All @@ -20,14 +19,14 @@ router.post('/', async (req, res) => {
}
});

const models = await getAvailableProviders();
const provider = getChatModelProvider();
const chatModel = getChatModel();
const chatModels = await getAvailableProviders();
const provider = chat_model_provider || Object.keys(chatModels)[0];
const chatModel = chat_model || Object.keys(chatModels[provider])[0];

let llm: BaseChatModel | undefined;

if (models[provider] && models[provider][chatModel]) {
llm = models[provider][chatModel] as BaseChatModel | undefined;
if (chatModels[provider] && chatModels[provider][chatModel]) {
llm = chatModels[provider][chatModel] as BaseChatModel | undefined;
}

if (!llm) {
Expand Down
2 changes: 2 additions & 0 deletions src/routes/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ import express from 'express';
import imagesRouter from './images';
import videosRouter from './videos';
import configRouter from './config';
import modelsRouter from './models';

const router = express.Router();

router.use('/images', imagesRouter);
router.use('/videos', videosRouter);
router.use('/config', configRouter);
router.use('/models', modelsRouter);

export default router;
18 changes: 18 additions & 0 deletions src/routes/models.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import express from 'express';
import logger from '../utils/logger';
import { getAvailableProviders } from '../lib/providers';

const router = express.Router();

router.get('/', async (req, res) => {
try {
const providers = await getAvailableProviders();

res.status(200).json({ providers });
} catch (err) {
res.status(500).json({ message: 'An error has occurred.' });
logger.error(err.message);
}
});

export default router;
13 changes: 6 additions & 7 deletions src/routes/videos.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import express from 'express';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { getAvailableProviders } from '../lib/providers';
import { getChatModel, getChatModelProvider } from '../config';
import { HumanMessage, AIMessage } from '@langchain/core/messages';
import logger from '../utils/logger';
import handleVideoSearch from '../agents/videoSearchAgent';
Expand All @@ -10,7 +9,7 @@ const router = express.Router();

router.post('/', async (req, res) => {
try {
let { query, chat_history } = req.body;
let { query, chat_history, chat_model_provider, chat_model } = req.body;

chat_history = chat_history.map((msg: any) => {
if (msg.role === 'user') {
Expand All @@ -20,14 +19,14 @@ router.post('/', async (req, res) => {
}
});

const models = await getAvailableProviders();
const provider = getChatModelProvider();
const chatModel = getChatModel();
const chatModels = await getAvailableProviders();
const provider = chat_model_provider || Object.keys(chatModels)[0];
const chatModel = chat_model || Object.keys(chatModels[provider])[0];

let llm: BaseChatModel | undefined;

if (models[provider] && models[provider][chatModel]) {
llm = models[provider][chatModel] as BaseChatModel | undefined;
if (chatModels[provider] && chatModels[provider][chatModel]) {
llm = chatModels[provider][chatModel] as BaseChatModel | undefined;
}

if (!llm) {
Expand Down
16 changes: 12 additions & 4 deletions src/websocket/connectionManager.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
import { WebSocket } from 'ws';
import { handleMessage } from './messageHandler';
import { getChatModel, getChatModelProvider } from '../config';
import { getAvailableProviders } from '../lib/providers';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import type { Embeddings } from '@langchain/core/embeddings';
import type { IncomingMessage } from 'http';
import logger from '../utils/logger';

export const handleConnection = async (ws: WebSocket) => {
export const handleConnection = async (
ws: WebSocket,
request: IncomingMessage,
) => {
const searchParams = new URL(request.url, `http://${request.headers.host}`)
.searchParams;

const models = await getAvailableProviders();
const provider = getChatModelProvider();
const chatModel = getChatModel();
const provider =
searchParams.get('chatModelProvider') || Object.keys(models)[0];
const chatModel =
searchParams.get('chatModel') || Object.keys(models[provider])[0];

let llm: BaseChatModel | undefined;
let embeddings: Embeddings | undefined;
Expand Down
4 changes: 1 addition & 3 deletions src/websocket/websocketServer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ export const initServer = (
const port = getPort();
const wss = new WebSocketServer({ server });

wss.on('connection', (ws) => {
handleConnection(ws);
});
wss.on('connection', handleConnection);

logger.info(`WebSocket server started on port ${port}`);
};
36 changes: 32 additions & 4 deletions ui/components/ChatWindow.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,42 @@ const useSocket = (url: string) => {

useEffect(() => {
if (!ws) {
const ws = new WebSocket(url);
ws.onopen = () => {
console.log('[DEBUG] open');
setWs(ws);
const connectWs = async () => {
let chatModel = localStorage.getItem('chatModel');
let chatModelProvider = localStorage.getItem('chatModelProvider');

if (!chatModel || !chatModelProvider) {
const chatModelProviders = await fetch(
`${process.env.NEXT_PUBLIC_API_URL}/models`,
).then(async (res) => (await res.json())['providers']);

if (
!chatModelProviders ||
Object.keys(chatModelProviders).length === 0
)
return console.error('No chat models available');

chatModelProvider = Object.keys(chatModelProviders)[0];
chatModel = Object.keys(chatModelProviders[chatModelProvider])[0];

localStorage.setItem('chatModel', chatModel!);
localStorage.setItem('chatModelProvider', chatModelProvider);
}

const ws = new WebSocket(
`${url}?chatModel=${chatModel}&chatModelProvider=${chatModelProvider}`,
);
ws.onopen = () => {
console.log('[DEBUG] open');
setWs(ws);
};
};

connectWs();
}

return () => {
1;
ws?.close();
console.log('[DEBUG] closed');
};
Expand Down
6 changes: 6 additions & 0 deletions ui/components/SearchImages.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ const SearchImages = ({
<button
onClick={async () => {
setLoading(true);

const chatModelProvider = localStorage.getItem('chatModelProvider');
const chatModel = localStorage.getItem('chatModel');

const res = await fetch(
`${process.env.NEXT_PUBLIC_API_URL}/images`,
{
Expand All @@ -39,6 +43,8 @@ const SearchImages = ({
body: JSON.stringify({
query: query,
chat_history: chat_history,
chat_model_provider: chatModelProvider,
chat_model: chatModel,
}),
},
);
Expand Down
6 changes: 6 additions & 0 deletions ui/components/SearchVideos.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ const Searchvideos = ({
<button
onClick={async () => {
setLoading(true);

const chatModelProvider = localStorage.getItem('chatModelProvider');
const chatModel = localStorage.getItem('chatModel');

const res = await fetch(
`${process.env.NEXT_PUBLIC_API_URL}/videos`,
{
Expand All @@ -52,6 +56,8 @@ const Searchvideos = ({
body: JSON.stringify({
query: query,
chat_history: chat_history,
chat_model_provider: chatModelProvider,
chat_model: chatModel,
}),
},
);
Expand Down
Loading

0 comments on commit f618b71

Please sign in to comment.