Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for users to specify custom request settings, model and optionally provider specific #14535

Merged
merged 12 commits into from
Nov 28, 2024
Merged
41 changes: 40 additions & 1 deletion packages/ai-core/src/browser/ai-core-preferences.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { interfaces } from '@theia/core/shared/inversify';
export const AI_CORE_PREFERENCES_TITLE = '✨ AI Features [Experimental]';
export const PREFERENCE_NAME_ENABLE_EXPERIMENTAL = 'ai-features.AiEnable.enableAI';
export const PREFERENCE_NAME_PROMPT_TEMPLATES = 'ai-features.promptTemplates.promptTemplatesFolder';
export const PREFERENCE_NAME_REQUEST_SETTINGS = 'ai-features.modelSettings.requestSettings';

export const aiCorePreferenceSchema: PreferenceSchema = {
type: 'object',
Expand Down Expand Up @@ -55,13 +56,51 @@ export const aiCorePreferenceSchema: PreferenceSchema = {
canSelectMany: false
}
},

},
[PREFERENCE_NAME_REQUEST_SETTINGS]: {
title: 'Custom Request Settings',
markdownDescription: 'Allows specifying custom request settings for multiple models.\n\
Each object represents the configuration for a specific model. The `modelId` field specifies the model ID, `requestSettings` defines model-specific settings.\n\
The `providerId` field is optional and allows you to apply the settings to a specific provider. If not set, the settings will be applied to all providers.\n\
Example providerIds: huggingface, openai, ollama, llamafile.\n\
Refer to [our documentation](https://theia-ide.org/docs/user_ai/#custom-request-settings) for more information.',
type: 'array',
items: {
type: 'object',
properties: {
modelId: {
type: 'string',
description: 'The model id'
},
requestSettings: {
type: 'object',
additionalProperties: true,
description: 'Settings for the specific model ID.',
},
providerId: {
type: 'string',
description: 'The (optional) provider id to apply the settings to. If not set, the settings will be applied to all providers.',
},
},
},
default: [],
}
}
};
export interface AICoreConfiguration {
[PREFERENCE_NAME_ENABLE_EXPERIMENTAL]: boolean | undefined;
[PREFERENCE_NAME_PROMPT_TEMPLATES]: string | undefined;
[PREFERENCE_NAME_REQUEST_SETTINGS]: Array<{
modelId: string;
requestSettings?: { [key: string]: unknown };
providerId?: string;
}> | undefined;
}

export interface RequestSetting {
modelId: string;
requestSettings?: { [key: string]: unknown };
providerId?: string;
}

export const AICorePreferences = Symbol('AICorePreferences');
Expand Down
5 changes: 5 additions & 0 deletions packages/ai-core/src/common/language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ export interface LanguageModelMetaData {
readonly family?: string;
readonly maxInputTokens?: number;
readonly maxOutputTokens?: number;
/**
* Default request settings for the language model. These settings can be set by a user preferences.
* Settings in a request will override these default settings.
*/
readonly defaultRequestSettings?: { [key: string]: unknown };
}

export namespace LanguageModelMetaData {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ import { FrontendApplicationContribution, PreferenceService } from '@theia/core/
import { inject, injectable } from '@theia/core/shared/inversify';
import { HuggingFaceLanguageModelsManager, HuggingFaceModelDescription } from '../common';
import { API_KEY_PREF, MODELS_PREF } from './huggingface-preferences';
import { PREFERENCE_NAME_REQUEST_SETTINGS, RequestSetting } from '@theia/ai-core/lib/browser/ai-core-preferences';

const HUGGINGFACE_PROVIDER_ID = 'huggingface';
@injectable()
export class HuggingFaceFrontendApplicationContribution implements FrontendApplicationContribution {

Expand All @@ -36,31 +38,58 @@ export class HuggingFaceFrontendApplicationContribution implements FrontendAppli
this.manager.setApiKey(apiKey);

const models = this.preferenceService.get<string[]>(MODELS_PREF, []);
this.manager.createOrUpdateLanguageModels(...models.map(createHuggingFaceModelDescription));
const requestSettings = this.preferenceService.get<RequestSetting[]>(PREFERENCE_NAME_REQUEST_SETTINGS, []);
this.manager.createOrUpdateLanguageModels(...models.map(modelId => this.createHuggingFaceModelDescription(modelId, requestSettings)));
this.prevModels = [...models];

this.preferenceService.onPreferenceChanged(event => {
if (event.preferenceName === API_KEY_PREF) {
this.manager.setApiKey(event.newValue);
} else if (event.preferenceName === MODELS_PREF) {
const oldModels = new Set(this.prevModels);
const newModels = new Set(event.newValue as string[]);

const modelsToRemove = [...oldModels].filter(model => !newModels.has(model));
const modelsToAdd = [...newModels].filter(model => !oldModels.has(model));

this.manager.removeLanguageModels(...modelsToRemove.map(model => `huggingface/${model}`));
this.manager.createOrUpdateLanguageModels(...modelsToAdd.map(createHuggingFaceModelDescription));
this.prevModels = [...event.newValue];
this.handleModelChanges(event.newValue as string[]);
} else if (event.preferenceName === PREFERENCE_NAME_REQUEST_SETTINGS) {
this.handleRequestSettingsChanges(event.newValue as RequestSetting[]);
}
});
});
}
}

function createHuggingFaceModelDescription(modelId: string): HuggingFaceModelDescription {
return {
id: `huggingface/${modelId}`,
model: modelId
};
protected handleModelChanges(newModels: string[]): void {
const oldModels = new Set(this.prevModels);
const updatedModels = new Set(newModels);

const modelsToRemove = [...oldModels].filter(model => !updatedModels.has(model));
const modelsToAdd = [...updatedModels].filter(model => !oldModels.has(model));

this.manager.removeLanguageModels(...modelsToRemove.map(model => `${HUGGINGFACE_PROVIDER_ID}/${model}`));
const requestSettings = this.preferenceService.get<RequestSetting[]>(PREFERENCE_NAME_REQUEST_SETTINGS, []);
this.manager.createOrUpdateLanguageModels(...modelsToAdd.map(modelId => this.createHuggingFaceModelDescription(modelId, requestSettings)));
this.prevModels = newModels;
}

protected handleRequestSettingsChanges(newSettings: RequestSetting[]): void {
const models = this.preferenceService.get<string[]>(MODELS_PREF, []);
this.manager.createOrUpdateLanguageModels(...models.map(modelId => this.createHuggingFaceModelDescription(modelId, newSettings)));
}

protected createHuggingFaceModelDescription(
modelId: string,
requestSettings: RequestSetting[]
): HuggingFaceModelDescription {
const id = `${HUGGINGFACE_PROVIDER_ID}/${modelId}`;
const matchingSettings = requestSettings.filter(
setting => (!setting.providerId || setting.providerId === HUGGINGFACE_PROVIDER_ID) && setting.modelId === modelId
);
if (matchingSettings.length > 1) {
console.warn(
`Multiple entries found for modelId "${modelId}". Using the first match and ignoring the rest.`
);
}
const modelRequestSetting = matchingSettings[0];
return {
id: id,
model: modelId,
defaultRequestSettings: modelRequestSetting?.requestSettings
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ export interface HuggingFaceModelDescription {
* The model ID as used by the Hugging Face API.
*/
model: string;
/**
* Default request settings for the Hugging Face model.
*/
defaultRequestSettings?: { [key: string]: unknown };
}

export interface HuggingFaceLanguageModelsManager {
Expand Down
30 changes: 21 additions & 9 deletions packages/ai-hugging-face/src/node/huggingface-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,18 @@ export class HuggingFaceModel implements LanguageModel {
* @param model the model id as it is used by the Hugging Face API
* @param apiKey function to retrieve the API key for Hugging Face
*/
constructor(public readonly id: string, public model: string, public apiKey: () => string | undefined) {
}
constructor(
public readonly id: string,
public model: string,
public apiKey: () => string | undefined,
public readonly name?: string,
public readonly vendor?: string,
public readonly version?: string,
public readonly family?: string,
public readonly maxInputTokens?: number,
public readonly maxOutputTokens?: number,
public defaultRequestSettings?: Record<string, unknown>
) { }

async request(request: LanguageModelRequest, cancellationToken?: CancellationToken): Promise<LanguageModelResponse> {
const hfInference = this.initializeHfInference();
Expand All @@ -67,15 +77,16 @@ export class HuggingFaceModel implements LanguageModel {
}
}

protected getDefaultSettings(): Record<string, unknown> {
return {
max_new_tokens: 2024,
stop: ['<|endoftext|>', '<eos>']
};
protected getSettings(request: LanguageModelRequest): Record<string, unknown> {
const settings = request.settings ? request.settings : this.defaultRequestSettings;
if (!settings) {
return {};
}
return settings;
}

protected async handleNonStreamingRequest(hfInference: HfInference, request: LanguageModelRequest): Promise<LanguageModelTextResponse> {
const settings = request.settings || this.getDefaultSettings();
const settings = this.getSettings(request);

const response = await hfInference.textGeneration({
model: this.model,
Expand Down Expand Up @@ -104,7 +115,8 @@ export class HuggingFaceModel implements LanguageModel {
request: LanguageModelRequest,
cancellationToken?: CancellationToken
): Promise<LanguageModelResponse> {
const settings = request.settings || this.getDefaultSettings();

const settings = this.getSettings(request);

const stream = hfInference.textGenerationStream({
model: this.model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,22 @@ export class HuggingFaceLanguageModelsManagerImpl implements HuggingFaceLanguage
}
model.model = modelDescription.model;
model.apiKey = apiKeyProvider;
model.defaultRequestSettings = modelDescription.defaultRequestSettings;
} else {
this.languageModelRegistry.addLanguageModels([new HuggingFaceModel(modelDescription.id, modelDescription.model, apiKeyProvider)]);
this.languageModelRegistry.addLanguageModels([
new HuggingFaceModel(
modelDescription.id,
modelDescription.model,
apiKeyProvider,
undefined,
undefined,
undefined,
undefined,
undefined,
undefined,
modelDescription.defaultRequestSettings
)
]);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ import { AICommandHandlerFactory } from '@theia/ai-core/lib/browser/ai-command-h
import { CommandContribution, CommandRegistry, MessageService } from '@theia/core';
import { PreferenceService, QuickInputService } from '@theia/core/lib/browser';
import { inject, injectable } from '@theia/core/shared/inversify';
import { LlamafileEntry, LlamafileManager } from '../common/llamafile-manager';
import { LlamafileManager } from '../common/llamafile-manager';
import { PREFERENCE_LLAMAFILE } from './llamafile-preferences';
import { LlamafileEntry } from './llamafile-frontend-application-contribution';

export const StartLlamafileCommand = {
id: 'llamafile.start',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

import { FrontendApplicationContribution, PreferenceService } from '@theia/core/lib/browser';
import { inject, injectable } from '@theia/core/shared/inversify';
import { LlamafileEntry, LlamafileManager } from '../common/llamafile-manager';
import { LlamafileManager, LlamafileModelDescription } from '../common/llamafile-manager';
import { PREFERENCE_LLAMAFILE } from './llamafile-preferences';
import { PREFERENCE_NAME_REQUEST_SETTINGS, RequestSetting } from '@theia/ai-core/lib/browser/ai-core-preferences';

const LLAMAFILE_PROVIDER_ID = 'llamafile';
@injectable()
export class LlamafileFrontendApplicationContribution implements FrontendApplicationContribution {

Expand All @@ -33,27 +35,92 @@ export class LlamafileFrontendApplicationContribution implements FrontendApplica
onStart(): void {
this.preferenceService.ready.then(() => {
const llamafiles = this.preferenceService.get<LlamafileEntry[]>(PREFERENCE_LLAMAFILE, []);
JonasHelming marked this conversation as resolved.
Show resolved Hide resolved
this.llamafileManager.addLanguageModels(llamafiles);
llamafiles.forEach(model => this._knownLlamaFiles.set(model.name, model));
const validLlamafiles = llamafiles.filter(LlamafileEntry.is);

const LlamafileModelDescriptions = this.getLLamaFileModelDescriptions(validLlamafiles);

this.llamafileManager.addLanguageModels(LlamafileModelDescriptions);
validLlamafiles.forEach(model => this._knownLlamaFiles.set(model.name, model));

this.preferenceService.onPreferenceChanged(event => {
if (event.preferenceName === PREFERENCE_LLAMAFILE) {
// only new models which are actual LLamaFileEntries
const newModels = event.newValue.filter((llamafileEntry: unknown) => LlamafileEntry.is(llamafileEntry)) as LlamafileEntry[];
this.handleLlamaFilePreferenceChange(newModels);
} else if (event.preferenceName === PREFERENCE_NAME_REQUEST_SETTINGS) {
this.handleRequestSettingsChange(event.newValue as RequestSetting[]);
}
});
});
}

const llamafilesToAdd = newModels.filter(llamafile =>
!this._knownLlamaFiles.has(llamafile.name) || !LlamafileEntry.equals(this._knownLlamaFiles.get(llamafile.name)!, llamafile));
protected getLLamaFileModelDescriptions(llamafiles: LlamafileEntry[]): LlamafileModelDescription[] {
const requestSettings = this.preferenceService.get<RequestSetting[]>(PREFERENCE_NAME_REQUEST_SETTINGS, []);
return llamafiles.map(llamafile => {
const matchingSettings = requestSettings.filter(
setting =>
(!setting.providerId || setting.providerId === LLAMAFILE_PROVIDER_ID) &&
setting.modelId === llamafile.name
);
if (matchingSettings.length > 1) {
console.warn(`Multiple entries found for model "${llamafile.name}". Using the first match.`);
}
return {
name: llamafile.name,
uri: llamafile.uri,
port: llamafile.port,
defaultRequestSettings: matchingSettings[0]?.requestSettings
};
});
}

const llamafileIdsToRemove = [...this._knownLlamaFiles.values()].filter(llamafile =>
!newModels.find(a => LlamafileEntry.equals(a, llamafile))).map(a => a.name);
protected handleLlamaFilePreferenceChange(newModels: LlamafileEntry[]): void {
const llamafilesToAdd = newModels.filter(llamafile =>
!this._knownLlamaFiles.has(llamafile.name) ||
!LlamafileEntry.equals(this._knownLlamaFiles.get(llamafile.name)!, llamafile));

this.llamafileManager.removeLanguageModels(llamafileIdsToRemove);
llamafileIdsToRemove.forEach(model => this._knownLlamaFiles.delete(model));
const llamafileIdsToRemove = [...this._knownLlamaFiles.values()].filter(llamafile =>
!newModels.find(newModel => LlamafileEntry.equals(newModel, llamafile)))
.map(llamafile => llamafile.name);

this.llamafileManager.addLanguageModels(llamafilesToAdd);
llamafilesToAdd.forEach(model => this._knownLlamaFiles.set(model.name, model));
}
});
this.llamafileManager.removeLanguageModels(llamafileIdsToRemove);
llamafileIdsToRemove.forEach(id => this._knownLlamaFiles.delete(id));

this.llamafileManager.addLanguageModels(this.getLLamaFileModelDescriptions(llamafilesToAdd));
llamafilesToAdd.forEach(model => this._knownLlamaFiles.set(model.name, model));
}

protected handleRequestSettingsChange(newSettings: RequestSetting[]): void {
const llamafiles = Array.from(this._knownLlamaFiles.values());
const llamafileModelDescriptions = this.getLLamaFileModelDescriptions(llamafiles);
llamafileModelDescriptions.forEach(llamafileModelDescription => {
this.llamafileManager.updateRequestSettings(llamafileModelDescription.name, llamafileModelDescription.defaultRequestSettings);
});
}
}

export interface LlamafileEntry {
name: string;
uri: string;
port: number;
}

namespace LlamafileEntry {
export function equals(a: LlamafileEntry, b: LlamafileEntry): boolean {
return (
a.name === b.name &&
a.uri === b.uri &&
a.port === b.port
);
}

export function is(entry: unknown): entry is LlamafileEntry {
return (
typeof entry === 'object' &&
// eslint-disable-next-line no-null/no-null
entry !== null &&
'name' in entry && typeof (entry as LlamafileEntry).name === 'string' &&
'uri' in entry && typeof (entry as LlamafileEntry).uri === 'string' &&
'port' in entry && typeof (entry as LlamafileEntry).port === 'number'
);
}
}
Loading
Loading