diff --git a/package-lock.json b/package-lock.json index ab717191..b4be8594 100644 --- a/package-lock.json +++ b/package-lock.json @@ -17,6 +17,7 @@ "@clack/prompts": "^0.6.1", "@dqbd/tiktoken": "^1.0.2", "@google/generative-ai": "^0.11.4", + "@mistralai/mistralai": "^1.3.5", "@octokit/webhooks-schemas": "^6.11.0", "@octokit/webhooks-types": "^6.11.0", "axios": "^1.3.4", @@ -28,7 +29,8 @@ "ini": "^3.0.1", "inquirer": "^9.1.4", "openai": "^4.57.0", - "punycode": "^2.3.1" + "punycode": "^2.3.1", + "zod": "^3.23.8" }, "bin": { "oco": "out/cli.cjs", @@ -1785,6 +1787,16 @@ "node": ">= 0.4" } }, + "node_modules/@mistralai/mistralai": { + "version": "1.3.5", + "resolved": "https://registry.npmjs.org/@mistralai/mistralai/-/mistralai-1.3.5.tgz", + "integrity": "sha512-yC91oJ5ScEPqbXmv3mJTwTFgu/ZtsYoOPOhaVXSsy6x4zXTqTI57yEC1flC9uiA8GpG/yhpn2BBUXF95+U9Blw==", + "peerDependencies": { + "react": "^18 || ^19", + "react-dom": "^18 || ^19", + "zod": ">= 3" + } + }, "node_modules/@nodelib/fs.scandir": { "version": "2.1.5", "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", @@ -6477,8 +6489,7 @@ "node_modules/js-tokens": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", - "integrity": "sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==", - "dev": true + "integrity": "sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==" }, "node_modules/js-yaml": { "version": "4.1.0", @@ -6678,6 +6689,19 @@ "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==" }, + "node_modules/loose-envify": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/loose-envify/-/loose-envify-1.4.0.tgz", + "integrity": "sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==", + "license": "MIT", + "peer": true, + "dependencies": { + "js-tokens": "^3.0.0 || ^4.0.0" + }, + "bin": { + "loose-envify": "cli.js" + } + }, "node_modules/lru-cache": { "version": "5.1.1", "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-5.1.1.tgz", @@ -7407,6 +7431,33 @@ } ] }, + "node_modules/react": { + "version": "18.3.1", + "resolved": "https://registry.npmjs.org/react/-/react-18.3.1.tgz", + "integrity": "sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==", + "license": "MIT", + "peer": true, + "dependencies": { + "loose-envify": "^1.1.0" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/react-dom": { + "version": "18.3.1", + "resolved": "https://registry.npmjs.org/react-dom/-/react-dom-18.3.1.tgz", + "integrity": "sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw==", + "license": "MIT", + "peer": true, + "dependencies": { + "loose-envify": "^1.1.0", + "scheduler": "^0.23.2" + }, + "peerDependencies": { + "react": "^18.3.1" + } + }, "node_modules/react-is": { "version": "18.2.0", "resolved": "https://registry.npmjs.org/react-is/-/react-is-18.2.0.tgz", @@ -7632,6 +7683,16 @@ "resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz", "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==" }, + "node_modules/scheduler": { + "version": "0.23.2", + "resolved": "https://registry.npmjs.org/scheduler/-/scheduler-0.23.2.tgz", + "integrity": "sha512-UOShsPwz7NrMUqhR6t0hWjFduvOzbtv7toDH1/hIrfRNIDBnnBWd0CwJTGvTpngVlmwGCdP9/Zl/tVrDqcuYzQ==", + "license": "MIT", + "peer": true, + "dependencies": { + "loose-envify": "^1.1.0" + } + }, "node_modules/semver": { "version": "7.6.0", "resolved": "https://registry.npmjs.org/semver/-/semver-7.6.0.tgz", @@ -8446,6 +8507,15 @@ "funding": { "url": "https://github.com/sponsors/sindresorhus" } + }, + "node_modules/zod": { + "version": "3.23.8", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.23.8.tgz", + "integrity": "sha512-XBx9AXhXktjUqnepgTiE5flcKIYWi/rme0Eaj+5Y0lftuGBq+jyRu/md4WnuxqgP1ubdpNCsYEYPxrzVHD8d6g==", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/colinhacks" + } } } } diff --git a/package.json b/package.json index 8557e7fe..9942732f 100644 --- a/package.json +++ b/package.json @@ -88,6 +88,7 @@ "@clack/prompts": "^0.6.1", "@dqbd/tiktoken": "^1.0.2", "@google/generative-ai": "^0.11.4", + "@mistralai/mistralai": "^1.3.5", "@octokit/webhooks-schemas": "^6.11.0", "@octokit/webhooks-types": "^6.11.0", "axios": "^1.3.4", @@ -99,6 +100,7 @@ "ini": "^3.0.1", "inquirer": "^9.1.4", "openai": "^4.57.0", - "punycode": "^2.3.1" + "punycode": "^2.3.1", + "zod": "^3.23.8" } } diff --git a/src/commands/config.ts b/src/commands/config.ts index eccb799e..5381a8d1 100644 --- a/src/commands/config.ts +++ b/src/commands/config.ts @@ -86,6 +86,48 @@ export const MODEL_LIST = { 'llama-3.1-70b-versatile', // Llama 3.1 70B (Preview) 'gemma-7b-it', // Gemma 7B 'gemma2-9b-it' // Gemma 2 9B + ], + + mistral: [ + 'ministral-3b-2410', + 'ministral-3b-latest', + 'ministral-8b-2410', + 'ministral-8b-latest', + 'open-mistral-7b', + 'mistral-tiny', + 'mistral-tiny-2312', + 'open-mistral-nemo', + 'open-mistral-nemo-2407', + 'mistral-tiny-2407', + 'mistral-tiny-latest', + 'open-mixtral-8x7b', + 'mistral-small', + 'mistral-small-2312', + 'open-mixtral-8x22b', + 'open-mixtral-8x22b-2404', + 'mistral-small-2402', + 'mistral-small-2409', + 'mistral-small-latest', + 'mistral-medium-2312', + 'mistral-medium', + 'mistral-medium-latest', + 'mistral-large-2402', + 'mistral-large-2407', + 'mistral-large-2411', + 'mistral-large-latest', + 'pixtral-large-2411', + 'pixtral-large-latest', + 'codestral-2405', + 'codestral-latest', + 'codestral-mamba-2407', + 'open-codestral-mamba', + 'codestral-mamba-latest', + 'pixtral-12b-2409', + 'pixtral-12b', + 'pixtral-12b-latest', + 'mistral-embed', + 'mistral-moderation-2411', + 'mistral-moderation-latest', ] }; @@ -101,6 +143,8 @@ const getDefaultModel = (provider: string | undefined): string => { return MODEL_LIST.gemini[0]; case 'groq': return MODEL_LIST.groq[0]; + case 'mistral': + return MODEL_LIST.mistral[0]; default: return MODEL_LIST.openai[0]; } @@ -257,14 +301,15 @@ export const configValidators = { CONFIG_KEYS.OCO_AI_PROVIDER, [ 'openai', + 'mistral', 'anthropic', 'gemini', 'azure', 'test', 'flowise', 'groq' - ].includes(value) || value.startsWith('ollama') || value.startsWith('mlx'), - `${value} is not supported yet, use 'ollama', 'mlx', anthropic', 'azure', 'gemini', 'flowise' or 'openai' (default)` + ].includes(value) || value.startsWith('ollama'), + `${value} is not supported yet, use 'ollama', 'mlx', 'anthropic', 'azure', 'gemini', 'flowise', 'mistral' or 'openai' (default)` ); return value; @@ -310,6 +355,7 @@ export enum OCO_AI_PROVIDER_ENUM { TEST = 'test', FLOWISE = 'flowise', GROQ = 'groq', + MISTRAL = 'mistral', MLX = 'mlx' } diff --git a/src/engine/Engine.ts b/src/engine/Engine.ts index 6a6f0238..19562271 100644 --- a/src/engine/Engine.ts +++ b/src/engine/Engine.ts @@ -3,6 +3,7 @@ import { OpenAIClient as AzureOpenAIClient } from '@azure/openai'; import { GoogleGenerativeAI as GeminiClient } from '@google/generative-ai'; import { AxiosInstance as RawAxiosClient } from 'axios'; import { OpenAI as OpenAIClient } from 'openai'; +import { Mistral as MistralClient } from '@mistralai/mistralai'; export interface AiEngineConfig { apiKey: string; @@ -17,7 +18,8 @@ type Client = | AzureOpenAIClient | AnthropicClient | RawAxiosClient - | GeminiClient; + | GeminiClient + | MistralClient; export interface AiEngine { config: AiEngineConfig; diff --git a/src/engine/mistral.ts b/src/engine/mistral.ts new file mode 100644 index 00000000..ce480f2e --- /dev/null +++ b/src/engine/mistral.ts @@ -0,0 +1,82 @@ +import axios from 'axios'; +import { Mistral } from '@mistralai/mistralai'; +import { OpenAI } from 'openai'; +import { GenerateCommitMessageErrorEnum } from '../generateCommitMessageFromGitDiff'; +import { tokenCount } from '../utils/tokenCount'; +import { AiEngine, AiEngineConfig } from './Engine'; +import { + AssistantMessage as MistralAssistantMessage, + SystemMessage as MistralSystemMessage, + ToolMessage as MistralToolMessage, + UserMessage as MistralUserMessage +} from '@mistralai/mistralai/models/components'; + +export interface MistralAiConfig extends AiEngineConfig {} +export type MistralCompletionMessageParam = Array< +| (MistralSystemMessage & { role: "system" }) +| (MistralUserMessage & { role: "user" }) +| (MistralAssistantMessage & { role: "assistant" }) +| (MistralToolMessage & { role: "tool" }) +> + +export class MistralAiEngine implements AiEngine { + config: MistralAiConfig; + client: Mistral; + + constructor(config: MistralAiConfig) { + this.config = config; + + if (!config.baseURL) { + this.client = new Mistral({ apiKey: config.apiKey }); + } else { + this.client = new Mistral({ apiKey: config.apiKey, serverURL: config.baseURL }); + } + } + + public generateCommitMessage = async ( + messages: Array + ): Promise => { + const params = { + model: this.config.model, + messages: messages as MistralCompletionMessageParam, + topP: 0.1, + maxTokens: this.config.maxTokensOutput + }; + + try { + const REQUEST_TOKENS = messages + .map((msg) => tokenCount(msg.content as string) + 4) + .reduce((a, b) => a + b, 0); + + if ( + REQUEST_TOKENS > + this.config.maxTokensInput - this.config.maxTokensOutput + ) + throw new Error(GenerateCommitMessageErrorEnum.tooMuchTokens); + + const completion = await this.client.chat.complete(params); + + if (!completion.choices) + throw Error('No completion choice available.') + + const message = completion.choices[0].message; + + if (!message || !message.content) + throw Error('No completion choice available.') + + return message.content as string; + } catch (error) { + const err = error as Error; + if ( + axios.isAxiosError<{ error?: { message: string } }>(error) && + error.response?.status === 401 + ) { + const mistralError = error.response.data.error; + + if (mistralError) throw new Error(mistralError.message); + } + + throw err; + } + }; +} diff --git a/src/utils/engine.ts b/src/utils/engine.ts index eba0d20d..481a9f9b 100644 --- a/src/utils/engine.ts +++ b/src/utils/engine.ts @@ -6,6 +6,7 @@ import { FlowiseEngine } from '../engine/flowise'; import { GeminiEngine } from '../engine/gemini'; import { OllamaEngine } from '../engine/ollama'; import { OpenAiEngine } from '../engine/openAi'; +import { MistralAiEngine } from '../engine/mistral'; import { TestAi, TestMockType } from '../engine/testAi'; import { GroqEngine } from '../engine/groq'; import { MLXEngine } from '../engine/mlx'; @@ -44,6 +45,9 @@ export function getEngine(): AiEngine { case OCO_AI_PROVIDER_ENUM.GROQ: return new GroqEngine(DEFAULT_CONFIG); + case OCO_AI_PROVIDER_ENUM.MISTRAL: + return new MistralAiEngine(DEFAULT_CONFIG); + case OCO_AI_PROVIDER_ENUM.MLX: return new MLXEngine(DEFAULT_CONFIG);