Skip to content

Commit

Permalink
feat: add an ability to set default model through an env config (#70)
Browse files Browse the repository at this point in the history
* feat: use zod enums for models list

* feat: use env to set default LLM models

* feat: make models input optional
  • Loading branch information
samhwang authored Apr 29, 2024
1 parent f300def commit f61823d
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 20 deletions.
3 changes: 2 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# AI SERVER
# AI SERVER CONFIGS
AI_SERVER_URL="http://localhost:4000"
DEFAULT_MODEL="tinydolphin"

# BOT CONFIGS
CLIENT_ID=CLIENT_ID_GOES_HERE
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ dist

# Env
.env*
!.env.sample
!.env.example
18 changes: 13 additions & 5 deletions src/autocompletes/select-model/autocomplete.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import { SUPPORTED_MODELS_MAP } from '../../llm/utils';
import type { AutocompleteHandler } from '../builder';

function get25ptions(models: typeof SUPPORTED_MODELS_MAP) {
return models.slice(0, 25);
}

export const selectModelAutocomplete: AutocompleteHandler = async (interaction) => {
const searchTerm = interaction.options.getString('model', true).trim().toLowerCase();
let searchTerm = interaction.options.getString('model', false);
if (!searchTerm) {
const options = get25ptions(SUPPORTED_MODELS_MAP);
interaction.respond(options);
return;
}

const options = SUPPORTED_MODELS_MAP.filter((model) => {
if (!searchTerm) return true;
return model.name.includes(searchTerm);
}).slice(0, 25);
searchTerm = searchTerm.trim().toLowerCase();
const filtered = SUPPORTED_MODELS_MAP.filter((model) => model.name.includes(searchTerm));
const options = get25ptions(filtered);
interaction.respond(options);
};
17 changes: 9 additions & 8 deletions src/llm/utils.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import { Result } from 'oxide.ts';
import { z } from 'zod';
import { logger } from '../utils/logger';
import { getClient } from './client';

const DISCORD_MESSAGE_MAX_CHARACTERS = 2000;
const QUESTION_CUT_OFF_LENGTH = 150;
const RESERVED_LENGTH = 50; // for other additional strings. E.g. number `(1/4)`, `Q: `, `A: `, etc.

const SUPPORTED_MODELS = ['gpt-3.5-turbo', 'gpt-4', 'phi', 'phi3', 'tinydolphin', 'mistral', 'mixtral', 'llama3', 'llama3-70b'] as const;
type SupportedModel = (typeof SUPPORTED_MODELS)[number];
export const SUPPORTED_MODELS_MAP = SUPPORTED_MODELS.map((model) => ({
export const SupportedModel = z.enum(['gpt-3.5-turbo', 'gpt-4', 'phi', 'phi3', 'tinydolphin', 'mistral', 'mixtral', 'llama3', 'llama3-70b']);
export type SupportedModel = z.infer<typeof SupportedModel>;
export const SUPPORTED_MODELS_MAP = SupportedModel.options.map((model) => ({
name: model,
value: model,
}));
Expand Down Expand Up @@ -84,10 +85,10 @@ function addNumber(chunks: string[]): string[] {
return output;
}

export function findModel(model: string): SupportedModel {
const hasModel = SUPPORTED_MODELS.find((option) => option.toLowerCase() === model);
if (!hasModel) {
throw new Error(`Model ${model} is not supported. Supported models are: ${SUPPORTED_MODELS.join(', ')}`);
export function findModel(input: string): SupportedModel {
const model = SupportedModel.safeParse(input);
if (!model.success) {
throw new Error(`Model ${input} is not supported. Supported models are: ${SupportedModel.options.join(', ')}`);
}
return model as SupportedModel;
return model.data;
}
4 changes: 2 additions & 2 deletions src/slash-commands/ask-llm/command.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ import type { SlashCommand, SlashCommandHandler } from '../builder';
const data = new SlashCommandBuilder()
.setName('ask')
.setDescription('Ask an LLM to answer anything')
.addStringOption((option) => option.setName('model').setDescription('Choose an LLM model').setRequired(true).setAutocomplete(true))
.addStringOption((option) => option.setName('model').setDescription('Choose an LLM model').setRequired(false).setAutocomplete(true))
.addStringOption((option) => option.setName('question').setDescription('Enter your prompt').setRequired(true).setMinLength(10));

export const execute: SlashCommandHandler = async (interaction) => {
await interaction.deferReply();
const model = interaction.options.getString('model', true).trim().toLowerCase();
const model = (interaction.options.getString('model', false) || process.env.DEFAULT_MODEL).trim().toLowerCase();
const question = interaction.options.getString('question', true).trim();
logger.info(`[ask]: Asking ${model} model with prompt: ${question}`);

Expand Down
4 changes: 2 additions & 2 deletions src/slash-commands/review-resume/command.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ import { getOutputFileName } from './utils';
const data = new SlashCommandBuilder()
.setName('review-resume')
.setDescription('Review a resume from a generic PDF URL or Google Drive')
.addStringOption((option) => option.setName('model').setDescription('Choose an LLM model').setRequired(true).setAutocomplete(true))
.addStringOption((option) => option.setName('model').setDescription('Choose an LLM model').setRequired(false).setAutocomplete(true))
.addStringOption((option) => option.setName('url').setDescription('PDF or Google Drive URL').setRequired(true));

export const execute: SlashCommandHandler = async (interaction) => {
await interaction.deferReply();
const model = interaction.options.getString('model', true).trim().toLowerCase();
const model = (interaction.options.getString('model', false) || process.env.DEFAULT_MODEL).trim().toLowerCase();
const url = interaction.options.getString('url', true);
logger.info(`[review-resume]: Reviewing resume from URL: ${url}`);

Expand Down
4 changes: 3 additions & 1 deletion src/utils/load-env.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dotenv from 'dotenv';
import { z } from 'zod';
import { SupportedModel } from '../llm/utils';
import { logger } from './logger';

const configSchema = z.object({
Expand All @@ -10,8 +11,9 @@ const configSchema = z.object({
CLIENT_ID: z.string(),
GUILD_ID: z.string().optional(),

// AI Server URL
// AI Server config
AI_SERVER_URL: z.string().url(),
DEFAULT_MODEL: SupportedModel,
});
type ConfigSchema = z.infer<typeof configSchema>;

Expand Down

0 comments on commit f61823d

Please sign in to comment.