Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ export const getDefaultBody = (config?: Config) => {
return config.defaultModel ? DEFAULT_BODY_OTHER(config.defaultModel) : DEFAULT_BODY;
}
if (config?.apiProvider === OpenAiProviderType.AzureAi) {
// update sample data if AzureAi
return DEFAULT_BODY_AZURE;
// update sample data if AzureAi; include model when defaultModel is configured (e.g. for APIM endpoints)
return config.defaultModel ? DEFAULT_BODY_OTHER(config.defaultModel) : DEFAULT_BODY_AZURE;
}
// default to OpenAiProviderType.OpenAi sample data
return DEFAULT_BODY;
Expand Down Expand Up @@ -155,6 +155,18 @@ export const azureAiConfig: ConfigFieldSchema[] = [
/>
),
},
{
id: 'defaultModel',
label: i18n.DEFAULT_MODEL_LABEL,
isRequired: false,
labelAppend: OptionalFieldLabel,
helpText: (
<FormattedMessage
defaultMessage="Required for Azure API Management (APIM) or proxy endpoints that do not infer the model from the deployment URL."
id="xpack.stackConnectors.components.genAi.azureAiDefaultModel"
/>
),
},
contextWindowLengthField,
temperatureField,
];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,22 @@ describe('Gen AI Params Fields renders', () => {
);
});
});

describe('getDefaultBody', () => {
it('returns Azure body without model when Azure config has no defaultModel', () => {
const body = getDefaultBody({
apiProvider: OpenAiProviderType.AzureAi,
apiUrl: 'https://my-resource.openai.azure.com',
});
expect(body).not.toContain('"model"');
});

it('returns body with model when Azure config has defaultModel', () => {
const body = getDefaultBody({
apiProvider: OpenAiProviderType.AzureAi,
apiUrl: 'https://my-resource.openai.azure.com',
defaultModel: 'gpt-4o',
});
expect(body).toContain('"model": "gpt-4o"');
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,40 @@ describe('Azure Open AI Utils', () => {
const sanitizedBodyString = sanitizeRequest('https://randostring.ai', bodyString);
expect(sanitizedBodyString).toEqual(bodyString);
});

it('injects defaultModel into body when body has no model and defaultModel is provided', () => {
const body = {
messages: [{ role: 'user', content: 'This is a test' }],
};
[chatUrl, completionUrl, completionExtensionsUrl].forEach((url: string) => {
const sanitizedBodyString = sanitizeRequest(url, JSON.stringify(body), 'gpt-4o');
const parsed = JSON.parse(sanitizedBodyString);
expect(parsed.model).toEqual('gpt-4o');
});
});

it('preserves existing model in body when defaultModel is provided', () => {
const body = {
model: 'gpt-4-turbo',
messages: [{ role: 'user', content: 'This is a test' }],
};
[chatUrl, completionUrl, completionExtensionsUrl].forEach((url: string) => {
const sanitizedBodyString = sanitizeRequest(url, JSON.stringify(body), 'gpt-4o');
const parsed = JSON.parse(sanitizedBodyString);
expect(parsed.model).toEqual('gpt-4-turbo');
});
});

it('does not inject model when defaultModel is not provided', () => {
const body = {
messages: [{ role: 'user', content: 'This is a test' }],
};
[chatUrl, completionUrl, completionExtensionsUrl].forEach((url: string) => {
const sanitizedBodyString = sanitizeRequest(url, JSON.stringify(body));
const parsed = JSON.parse(sanitizedBodyString);
expect(parsed.model).toBeUndefined();
});
});
});

describe('getRequestWithStreamOption', () => {
Expand Down Expand Up @@ -180,6 +214,50 @@ describe('Azure Open AI Utils', () => {
);
expect(sanitizedBodyString).toEqual(bodyString);
});

it('injects defaultModel into body when body has no model and defaultModel is provided', () => {
const body = {
messages: [{ role: 'user', content: 'This is a test' }],
};
[chatUrl, completionUrl, completionExtensionsUrl].forEach((url: string) => {
const sanitizedBodyString = getRequestWithStreamOption(
url,
JSON.stringify(body),
false,
'gpt-4o'
);
const parsed = JSON.parse(sanitizedBodyString);
expect(parsed.model).toEqual('gpt-4o');
});
});

it('preserves existing model in body when defaultModel is provided', () => {
const body = {
model: 'gpt-4-turbo',
messages: [{ role: 'user', content: 'This is a test' }],
};
[chatUrl, completionUrl, completionExtensionsUrl].forEach((url: string) => {
const sanitizedBodyString = getRequestWithStreamOption(
url,
JSON.stringify(body),
false,
'gpt-4o'
);
const parsed = JSON.parse(sanitizedBodyString);
expect(parsed.model).toEqual('gpt-4-turbo');
});
});

it('does not inject model when defaultModel is not provided', () => {
const body = {
messages: [{ role: 'user', content: 'This is a test' }],
};
[chatUrl, completionUrl, completionExtensionsUrl].forEach((url: string) => {
const sanitizedBodyString = getRequestWithStreamOption(url, JSON.stringify(body), false);
const parsed = JSON.parse(sanitizedBodyString);
expect(parsed.model).toBeUndefined();
});
});
});

describe('transformApiUrlToRegex', () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,31 @@ const APIS_ALLOWING_STREAMING = new Set<string>([
*
* The stream parameter is only accepted in the Chat API, the Completion API
* and the Completions Extensions API
*
* When defaultModel is provided and the body does not already specify a model,
* it is injected into the body. This is required for Azure API Management (APIM)
* or proxy endpoints that do not infer the model from the deployment URL.
*/
export const sanitizeRequest = (url: string, body: string): string => {
return getRequestWithStreamOption(url, body, false);
export const sanitizeRequest = (url: string, body: string, defaultModel?: string): string => {
return getRequestWithStreamOption(url, body, false, defaultModel);
};

/**
* Intercepts the Azure Open AI request body to set the stream parameter
*
* The stream parameter is only accepted in the Chat API, the Completion API
* and the Completions Extensions API
*
* When defaultModel is provided and the body does not already specify a model,
* it is injected into the body. This is required for Azure API Management (APIM)
* or proxy endpoints that do not infer the model from the deployment URL.
*/
export const getRequestWithStreamOption = (url: string, body: string, stream: boolean): string => {
export const getRequestWithStreamOption = (
url: string,
body: string,
stream: boolean,
defaultModel?: string
): string => {
if (
!Array.from(APIS_ALLOWING_STREAMING)
.map((apiUrl: string) => transformApiUrlToRegex(apiUrl))
Expand All @@ -53,6 +66,9 @@ export const getRequestWithStreamOption = (url: string, body: string, stream: bo
include_usage: true,
};
}
if (defaultModel && !jsonBody.model) {
jsonBody.model = defaultModel;
}
}

return JSON.stringify(jsonBody);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ describe('Utils', () => {

it('calls azure_openai_utils sanitizeRequest when provider is AzureAi', () => {
sanitizeRequest(OpenAiProviderType.AzureAi, azureAiUrl, bodyString);
expect(mockAzureAiSanitizeRequest).toHaveBeenCalledWith(azureAiUrl, bodyString);
expect(mockAzureAiSanitizeRequest).toHaveBeenCalledWith(azureAiUrl, bodyString, undefined);
expect(mockOpenAiSanitizeRequest).not.toHaveBeenCalled();
expect(mockOtherOpenAiSanitizeRequest).not.toHaveBeenCalled();
});
Expand Down Expand Up @@ -130,7 +130,8 @@ describe('Utils', () => {
expect(mockAzureAiGetRequestWithStreamOption).toHaveBeenCalledWith(
azureAiUrl,
bodyString,
true
true,
undefined
);
expect(mockOpenAiGetRequestWithStreamOption).not.toHaveBeenCalled();
expect(mockOtherOpenAiGetRequestWithStreamOption).not.toHaveBeenCalled();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export const sanitizeRequest = (
case OpenAiProviderType.OpenAi:
return openAiSanitizeRequest(url, body, defaultModel!);
case OpenAiProviderType.AzureAi:
return azureAiSanitizeRequest(url, body);
return azureAiSanitizeRequest(url, body, defaultModel);
case OpenAiProviderType.Other:
return otherOpenAiSanitizeRequest(body);
default:
Expand All @@ -47,13 +47,6 @@ export function getRequestWithStreamOption(
defaultModel: string
): string;

export function getRequestWithStreamOption(
provider: OpenAiProviderType.AzureAi | OpenAiProviderType.Other,
url: string,
body: string,
stream: boolean
): string;

export function getRequestWithStreamOption(
provider: OpenAiProviderType,
url: string,
Expand All @@ -73,7 +66,7 @@ export function getRequestWithStreamOption(
case OpenAiProviderType.OpenAi:
return openAiGetRequestWithStreamOption(url, body, stream, defaultModel!);
case OpenAiProviderType.AzureAi:
return azureAiGetRequestWithStreamOption(url, body, stream);
return azureAiGetRequestWithStreamOption(url, body, stream, defaultModel);
case OpenAiProviderType.Other:
return otherOpenAiGetRequestWithStreamOption(body, stream, defaultModel);
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,47 @@ export default function genAiTest({ getService }: FtrProviderContext) {
expect(body.status).to.equal('ok');
expect(body.connector_id).to.equal(createdAction.id);
});

it('should return 200 when creating an Azure connector with defaultModel', async () => {
const { body: createdAction } = await supertest
.post('/api/actions/connector')
.set('kbn-xsrf', 'foo')
.send({
name,
connector_type_id: connectorTypeId,
config: {
apiProvider: 'Azure OpenAI',
apiUrl: config.apiUrl,
defaultModel: 'gpt-4o',
},
secrets,
})
.expect(200);

expect(createdAction).to.have.property('id');
expect(createdAction.config.apiProvider).to.equal('Azure OpenAI');
expect(createdAction.config.defaultModel).to.equal('gpt-4o');
});

it('should return 200 when creating an Azure connector without defaultModel', async () => {
const { body: createdAction } = await supertest
.post('/api/actions/connector')
.set('kbn-xsrf', 'foo')
.send({
name,
connector_type_id: connectorTypeId,
config: {
apiProvider: 'Azure OpenAI',
apiUrl: config.apiUrl,
},
secrets,
})
.expect(200);

expect(createdAction).to.have.property('id');
expect(createdAction.config.apiProvider).to.equal('Azure OpenAI');
expect(createdAction.config).to.not.have.property('defaultModel');
});
});

describe('executor', () => {
Expand Down
Loading