Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for users to specify custom request settings, model and optionally provider specific #14535

Merged
merged 12 commits into from
Nov 28, 2024
Merged
41 changes: 40 additions & 1 deletion packages/ai-core/src/browser/ai-core-preferences.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { interfaces } from '@theia/core/shared/inversify';
export const AI_CORE_PREFERENCES_TITLE = '✨ AI Features [Experimental]';
export const PREFERENCE_NAME_ENABLE_EXPERIMENTAL = 'ai-features.AiEnable.enableAI';
export const PREFERENCE_NAME_PROMPT_TEMPLATES = 'ai-features.promptTemplates.promptTemplatesFolder';
export const PREFERENCE_NAME_REQUEST_SETTINGS = 'ai-features.modelSettings.requestSettings';

export const aiCorePreferenceSchema: PreferenceSchema = {
type: 'object',
Expand Down Expand Up @@ -55,13 +56,51 @@ export const aiCorePreferenceSchema: PreferenceSchema = {
canSelectMany: false
}
},

},
[PREFERENCE_NAME_REQUEST_SETTINGS]: {
title: 'Custom Request Settings',
markdownDescription: 'Allows specifying custom request settings for multiple models.\n\
Each object represents the configuration for a specific model. The `modelId` field specifies the model ID, `requestSettings` defines model-specific settings.\n\
The `providerId` field is optional and allows you to apply the settings to a specific provider. If not set, the settings will be applied to all providers.\n\
Example providerIds: huggingface, openai, ollama, llamafile.\n\
Refer to [our documentation](https://theia-ide.org/docs/user_ai/#custom-request-settings) for more information.',
type: 'array',
items: {
type: 'object',
properties: {
modelId: {
type: 'string',
description: 'The model id'
},
requestSettings: {
type: 'object',
additionalProperties: true,
description: 'Settings for the specific model ID.',
},
providerId: {
type: 'string',
description: 'The (optional) provider id to apply the settings to. If not set, the settings will be applied to all providers.',
},
},
},
default: [],
}
}
};
export interface AICoreConfiguration {
[PREFERENCE_NAME_ENABLE_EXPERIMENTAL]: boolean | undefined;
[PREFERENCE_NAME_PROMPT_TEMPLATES]: string | undefined;
[PREFERENCE_NAME_REQUEST_SETTINGS]: Array<{
modelId: string;
requestSettings?: { [key: string]: unknown };
providerId?: string;
}> | undefined;
}

export interface RequestSetting {
modelId: string;
requestSettings?: { [key: string]: unknown };
providerId?: string;
}

export const AICorePreferences = Symbol('AICorePreferences');
Expand Down
5 changes: 5 additions & 0 deletions packages/ai-core/src/common/language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ export interface LanguageModelMetaData {
readonly family?: string;
readonly maxInputTokens?: number;
readonly maxOutputTokens?: number;
/**
* Default request settings for the language model. These settings can be set by a user preferences.
* Settings in a request will override these default settings.
*/
readonly defaultRequestSettings?: { [key: string]: unknown };
}

export namespace LanguageModelMetaData {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ import { FrontendApplicationContribution, PreferenceService } from '@theia/core/
import { inject, injectable } from '@theia/core/shared/inversify';
import { HuggingFaceLanguageModelsManager, HuggingFaceModelDescription } from '../common';
import { API_KEY_PREF, MODELS_PREF } from './huggingface-preferences';
import { PREFERENCE_NAME_REQUEST_SETTINGS, RequestSetting } from '@theia/ai-core/lib/browser/ai-core-preferences';

const HUGGINGFACE_PROVIDER_ID = 'huggingface';
@injectable()
export class HuggingFaceFrontendApplicationContribution implements FrontendApplicationContribution {

Expand All @@ -36,31 +38,58 @@ export class HuggingFaceFrontendApplicationContribution implements FrontendAppli
this.manager.setApiKey(apiKey);

const models = this.preferenceService.get<string[]>(MODELS_PREF, []);
this.manager.createOrUpdateLanguageModels(...models.map(createHuggingFaceModelDescription));
const requestSettings = this.preferenceService.get<RequestSetting[]>(PREFERENCE_NAME_REQUEST_SETTINGS, []);
this.manager.createOrUpdateLanguageModels(...models.map(modelId => this.createHuggingFaceModelDescription(modelId, requestSettings)));
this.prevModels = [...models];

this.preferenceService.onPreferenceChanged(event => {
if (event.preferenceName === API_KEY_PREF) {
this.manager.setApiKey(event.newValue);
} else if (event.preferenceName === MODELS_PREF) {
const oldModels = new Set(this.prevModels);
const newModels = new Set(event.newValue as string[]);

const modelsToRemove = [...oldModels].filter(model => !newModels.has(model));
const modelsToAdd = [...newModels].filter(model => !oldModels.has(model));

this.manager.removeLanguageModels(...modelsToRemove.map(model => `huggingface/${model}`));
this.manager.createOrUpdateLanguageModels(...modelsToAdd.map(createHuggingFaceModelDescription));
this.prevModels = [...event.newValue];
this.handleModelChanges(event.newValue as string[]);
} else if (event.preferenceName === PREFERENCE_NAME_REQUEST_SETTINGS) {
this.handleRequestSettingsChanges(event.newValue as RequestSetting[]);
}
});
});
}
}

function createHuggingFaceModelDescription(modelId: string): HuggingFaceModelDescription {
return {
id: `huggingface/${modelId}`,
model: modelId
};
protected handleModelChanges(newModels: string[]): void {
const oldModels = new Set(this.prevModels);
const updatedModels = new Set(newModels);

const modelsToRemove = [...oldModels].filter(model => !updatedModels.has(model));
const modelsToAdd = [...updatedModels].filter(model => !oldModels.has(model));

this.manager.removeLanguageModels(...modelsToRemove.map(model => `${HUGGINGFACE_PROVIDER_ID}/${model}`));
const requestSettings = this.preferenceService.get<RequestSetting[]>(PREFERENCE_NAME_REQUEST_SETTINGS, []);
this.manager.createOrUpdateLanguageModels(...modelsToAdd.map(modelId => this.createHuggingFaceModelDescription(modelId, requestSettings)));
this.prevModels = newModels;
}

protected handleRequestSettingsChanges(newSettings: RequestSetting[]): void {
const models = this.preferenceService.get<string[]>(MODELS_PREF, []);
this.manager.createOrUpdateLanguageModels(...models.map(modelId => this.createHuggingFaceModelDescription(modelId, newSettings)));
}

protected createHuggingFaceModelDescription(
modelId: string,
requestSettings: RequestSetting[]
): HuggingFaceModelDescription {
const id = `${HUGGINGFACE_PROVIDER_ID}/${modelId}`;
const matchingSettings = requestSettings.filter(
setting => (!setting.providerId || setting.providerId === HUGGINGFACE_PROVIDER_ID) && setting.modelId === modelId
);
if (matchingSettings.length > 1) {
console.warn(
`Multiple entries found for modelId "${modelId}". Using the first match and ignoring the rest.`
);
}
const modelRequestSetting = matchingSettings[0];
return {
id: id,
model: modelId,
defaultRequestSettings: modelRequestSetting?.requestSettings
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ export interface HuggingFaceModelDescription {
* The model ID as used by the Hugging Face API.
*/
model: string;
/**
* Default request settings for the Hugging Face model.
*/
defaultRequestSettings?: { [key: string]: unknown };
}

export interface HuggingFaceLanguageModelsManager {
Expand Down
30 changes: 21 additions & 9 deletions packages/ai-hugging-face/src/node/huggingface-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,18 @@ export class HuggingFaceModel implements LanguageModel {
* @param model the model id as it is used by the Hugging Face API
* @param apiKey function to retrieve the API key for Hugging Face
*/
constructor(public readonly id: string, public model: string, public apiKey: () => string | undefined) {
}
constructor(
public readonly id: string,
public model: string,
public apiKey: () => string | undefined,
public readonly name?: string,
public readonly vendor?: string,
public readonly version?: string,
public readonly family?: string,
public readonly maxInputTokens?: number,
public readonly maxOutputTokens?: number,
public defaultRequestSettings?: Record<string, unknown>
) { }

async request(request: LanguageModelRequest, cancellationToken?: CancellationToken): Promise<LanguageModelResponse> {
const hfInference = this.initializeHfInference();
Expand All @@ -67,15 +77,16 @@ export class HuggingFaceModel implements LanguageModel {
}
}

protected getDefaultSettings(): Record<string, unknown> {
return {
max_new_tokens: 2024,
stop: ['<|endoftext|>', '<eos>']
};
protected getSettings(request: LanguageModelRequest): Record<string, unknown> {
const settings = request.settings ? request.settings : this.defaultRequestSettings;
if (!settings) {
return {};
}
return settings;
}

protected async handleNonStreamingRequest(hfInference: HfInference, request: LanguageModelRequest): Promise<LanguageModelTextResponse> {
const settings = request.settings || this.getDefaultSettings();
const settings = this.getSettings(request);

const response = await hfInference.textGeneration({
model: this.model,
Expand Down Expand Up @@ -104,7 +115,8 @@ export class HuggingFaceModel implements LanguageModel {
request: LanguageModelRequest,
cancellationToken?: CancellationToken
): Promise<LanguageModelResponse> {
const settings = request.settings || this.getDefaultSettings();

const settings = this.getSettings(request);

const stream = hfInference.textGenerationStream({
model: this.model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,22 @@ export class HuggingFaceLanguageModelsManagerImpl implements HuggingFaceLanguage
}
model.model = modelDescription.model;
model.apiKey = apiKeyProvider;
model.defaultRequestSettings = modelDescription.defaultRequestSettings;
} else {
this.languageModelRegistry.addLanguageModels([new HuggingFaceModel(modelDescription.id, modelDescription.model, apiKeyProvider)]);
this.languageModelRegistry.addLanguageModels([
new HuggingFaceModel(
modelDescription.id,
modelDescription.model,
apiKeyProvider,
undefined,
undefined,
undefined,
undefined,
undefined,
undefined,
modelDescription.defaultRequestSettings
)
]);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ import { FrontendApplicationContribution, PreferenceService } from '@theia/core/
import { inject, injectable } from '@theia/core/shared/inversify';
import { LlamafileEntry, LlamafileManager } from '../common/llamafile-manager';
import { PREFERENCE_LLAMAFILE } from './llamafile-preferences';
import { PREFERENCE_NAME_REQUEST_SETTINGS, RequestSetting } from '@theia/ai-core/lib/browser/ai-core-preferences';

const LLAMAFILE_PROVIDER_ID = 'llamafile';
@injectable()
export class LlamafileFrontendApplicationContribution implements FrontendApplicationContribution {

Expand All @@ -33,27 +35,63 @@ export class LlamafileFrontendApplicationContribution implements FrontendApplica
onStart(): void {
this.preferenceService.ready.then(() => {
const llamafiles = this.preferenceService.get<LlamafileEntry[]>(PREFERENCE_LLAMAFILE, []);
JonasHelming marked this conversation as resolved.
Show resolved Hide resolved
this.llamafileManager.addLanguageModels(llamafiles);
llamafiles.forEach(model => this._knownLlamaFiles.set(model.name, model));
const requestSettings = this.preferenceService.get<RequestSetting[]>(PREFERENCE_NAME_REQUEST_SETTINGS, []);
const modelsWithSettings = llamafiles.map(model => this.applyRequestSettingsToLlamaFile(model, requestSettings));

this.llamafileManager.addLanguageModels(modelsWithSettings);
modelsWithSettings.forEach(model => this._knownLlamaFiles.set(model.name, model));

this.preferenceService.onPreferenceChanged(event => {
if (event.preferenceName === PREFERENCE_LLAMAFILE) {
// only new models which are actual LLamaFileEntries
const newModels = event.newValue.filter((llamafileEntry: unknown) => LlamafileEntry.is(llamafileEntry)) as LlamafileEntry[];
this.handleLlamaFilePreferenceChange(event.newValue as LlamafileEntry[]);
} else if (event.preferenceName === PREFERENCE_NAME_REQUEST_SETTINGS) {
this.handleRequestSettingsChange(event.newValue as RequestSetting[]);
}
});
});
}

const llamafilesToAdd = newModels.filter(llamafile =>
!this._knownLlamaFiles.has(llamafile.name) || !LlamafileEntry.equals(this._knownLlamaFiles.get(llamafile.name)!, llamafile));
protected handleLlamaFilePreferenceChange(newModels: LlamafileEntry[]): void {
const requestSettings = this.preferenceService.get<RequestSetting[]>(PREFERENCE_NAME_REQUEST_SETTINGS, []);

const llamafileIdsToRemove = [...this._knownLlamaFiles.values()].filter(llamafile =>
!newModels.find(a => LlamafileEntry.equals(a, llamafile))).map(a => a.name);
// Models to add or update
const llamafilesToAdd = newModels.filter(llamafile =>
!this._knownLlamaFiles.has(llamafile.name) ||
!LlamafileEntry.equals(this._knownLlamaFiles.get(llamafile.name)!, llamafile));

this.llamafileManager.removeLanguageModels(llamafileIdsToRemove);
llamafileIdsToRemove.forEach(model => this._knownLlamaFiles.delete(model));
// Models to remove
const llamafileIdsToRemove = [...this._knownLlamaFiles.values()].filter(llamafile =>
!newModels.find(newModel => LlamafileEntry.equals(newModel, llamafile)))
.map(llamafile => llamafile.name);

this.llamafileManager.addLanguageModels(llamafilesToAdd);
llamafilesToAdd.forEach(model => this._knownLlamaFiles.set(model.name, model));
}
});
});
// Update the manager
this.llamafileManager.removeLanguageModels(llamafileIdsToRemove);
llamafileIdsToRemove.forEach(id => this._knownLlamaFiles.delete(id));

const modelsWithSettings = llamafilesToAdd.map(model => this.applyRequestSettingsToLlamaFile(model, requestSettings));
this.llamafileManager.addLanguageModels(modelsWithSettings);
modelsWithSettings.forEach(model => this._knownLlamaFiles.set(model.name, model));
}

protected handleRequestSettingsChange(newSettings: RequestSetting[]): void {
const llamafiles = this.preferenceService.get<LlamafileEntry[]>(PREFERENCE_LLAMAFILE, []);
JonasHelming marked this conversation as resolved.
Show resolved Hide resolved
const modelsWithSettings = llamafiles.map(model => this.applyRequestSettingsToLlamaFile(model, newSettings));

this.llamafileManager.addLanguageModels(modelsWithSettings);
modelsWithSettings.forEach(model => this._knownLlamaFiles.set(model.name, model));
}

protected applyRequestSettingsToLlamaFile(model: LlamafileEntry, requestSettings: RequestSetting[]): LlamafileEntry {
const matchingSettings = requestSettings.filter(
setting => (!setting.providerId || setting.providerId === LLAMAFILE_PROVIDER_ID) && setting.modelId === model.name
);
if (matchingSettings.length > 1) {
console.warn(`Multiple entries found for model "${model.name}". Using the first match.`);
}

return {
...model,
defaultRequestSettings: matchingSettings[0]?.requestSettings
};
}
}
32 changes: 26 additions & 6 deletions packages/ai-llamafile/src/common/llamafile-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,37 @@ export class LlamafileLanguageModel implements LanguageModel {
readonly providerId = 'llamafile';
readonly vendor: string = 'Mozilla';

constructor(readonly name: string, readonly uri: string, readonly port: number) {
}
/**
* @param name the unique name for this language model. It will be used to identify the model in the UI.
* @param uri the URI pointing to the Llamafile model location.
* @param port the port on which the Llamafile model server operates.
* @param defaultRequestSettings optional default settings for requests made using this model.
*/
constructor(
public readonly name: string,
public readonly uri: string,
public readonly port: number,
public defaultRequestSettings?: { [key: string]: unknown }
) { }

get id(): string {
return this.name;
}
protected getSettings(request: LanguageModelRequest): Record<string, unknown> {
const settings = request.settings ? request.settings : this.defaultRequestSettings;
if (!settings) {
return {
n_predict: 200,
stream: true,
stop: ['</s>', 'Llama:', 'User:', '<|eot_id|>'],
cache_prompt: true,
};
}
return settings;
}

async request(request: LanguageModelRequest, cancellationToken?: CancellationToken): Promise<LanguageModelResponse> {
const settings = this.getSettings(request);
try {
let prompt = request.messages.map(message => {
switch (message.actor) {
Expand All @@ -48,10 +71,7 @@ export class LlamafileLanguageModel implements LanguageModel {
},
body: JSON.stringify({
prompt: prompt,
n_predict: 200,
stream: true,
stop: ['</s>', 'Llama:', 'User:', '<|eot_id|>'],
cache_prompt: true,
...settings
}),
});

Expand Down
Loading
Loading