diff --git a/x-pack/platform/plugins/shared/stack_connectors/public/connector_types/openai/constants.tsx b/x-pack/platform/plugins/shared/stack_connectors/public/connector_types/openai/constants.tsx
index 2e42c28031713..a2bae6d87bb75 100644
--- a/x-pack/platform/plugins/shared/stack_connectors/public/connector_types/openai/constants.tsx
+++ b/x-pack/platform/plugins/shared/stack_connectors/public/connector_types/openai/constants.tsx
@@ -52,8 +52,8 @@ export const getDefaultBody = (config?: Config) => {
return config.defaultModel ? DEFAULT_BODY_OTHER(config.defaultModel) : DEFAULT_BODY;
}
if (config?.apiProvider === OpenAiProviderType.AzureAi) {
- // update sample data if AzureAi
- return DEFAULT_BODY_AZURE;
+ // update sample data if AzureAi; include model when defaultModel is configured (e.g. for APIM endpoints)
+ return config.defaultModel ? DEFAULT_BODY_OTHER(config.defaultModel) : DEFAULT_BODY_AZURE;
}
// default to OpenAiProviderType.OpenAi sample data
return DEFAULT_BODY;
@@ -155,6 +155,18 @@ export const azureAiConfig: ConfigFieldSchema[] = [
/>
),
},
+ {
+ id: 'defaultModel',
+ label: i18n.DEFAULT_MODEL_LABEL,
+ isRequired: false,
+ labelAppend: OptionalFieldLabel,
+ helpText: (
+
+ ),
+ },
contextWindowLengthField,
temperatureField,
];
diff --git a/x-pack/platform/plugins/shared/stack_connectors/public/connector_types/openai/params.test.tsx b/x-pack/platform/plugins/shared/stack_connectors/public/connector_types/openai/params.test.tsx
index 1b3d32d0b8518..ee5e75f462cf0 100644
--- a/x-pack/platform/plugins/shared/stack_connectors/public/connector_types/openai/params.test.tsx
+++ b/x-pack/platform/plugins/shared/stack_connectors/public/connector_types/openai/params.test.tsx
@@ -132,3 +132,22 @@ describe('Gen AI Params Fields renders', () => {
);
});
});
+
+describe('getDefaultBody', () => {
+ it('returns Azure body without model when Azure config has no defaultModel', () => {
+ const body = getDefaultBody({
+ apiProvider: OpenAiProviderType.AzureAi,
+ apiUrl: 'https://my-resource.openai.azure.com',
+ });
+ expect(body).not.toContain('"model"');
+ });
+
+ it('returns body with model when Azure config has defaultModel', () => {
+ const body = getDefaultBody({
+ apiProvider: OpenAiProviderType.AzureAi,
+ apiUrl: 'https://my-resource.openai.azure.com',
+ defaultModel: 'gpt-4o',
+ });
+ expect(body).toContain('"model": "gpt-4o"');
+ });
+});
diff --git a/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/azure_openai_utils.test.ts b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/azure_openai_utils.test.ts
index 07feee5cc517b..642c813d38d46 100644
--- a/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/azure_openai_utils.test.ts
+++ b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/azure_openai_utils.test.ts
@@ -87,6 +87,40 @@ describe('Azure Open AI Utils', () => {
const sanitizedBodyString = sanitizeRequest('https://randostring.ai', bodyString);
expect(sanitizedBodyString).toEqual(bodyString);
});
+
+ it('injects defaultModel into body when body has no model and defaultModel is provided', () => {
+ const body = {
+ messages: [{ role: 'user', content: 'This is a test' }],
+ };
+ [chatUrl, completionUrl, completionExtensionsUrl].forEach((url: string) => {
+ const sanitizedBodyString = sanitizeRequest(url, JSON.stringify(body), 'gpt-4o');
+ const parsed = JSON.parse(sanitizedBodyString);
+ expect(parsed.model).toEqual('gpt-4o');
+ });
+ });
+
+ it('preserves existing model in body when defaultModel is provided', () => {
+ const body = {
+ model: 'gpt-4-turbo',
+ messages: [{ role: 'user', content: 'This is a test' }],
+ };
+ [chatUrl, completionUrl, completionExtensionsUrl].forEach((url: string) => {
+ const sanitizedBodyString = sanitizeRequest(url, JSON.stringify(body), 'gpt-4o');
+ const parsed = JSON.parse(sanitizedBodyString);
+ expect(parsed.model).toEqual('gpt-4-turbo');
+ });
+ });
+
+ it('does not inject model when defaultModel is not provided', () => {
+ const body = {
+ messages: [{ role: 'user', content: 'This is a test' }],
+ };
+ [chatUrl, completionUrl, completionExtensionsUrl].forEach((url: string) => {
+ const sanitizedBodyString = sanitizeRequest(url, JSON.stringify(body));
+ const parsed = JSON.parse(sanitizedBodyString);
+ expect(parsed.model).toBeUndefined();
+ });
+ });
});
describe('getRequestWithStreamOption', () => {
@@ -180,6 +214,50 @@ describe('Azure Open AI Utils', () => {
);
expect(sanitizedBodyString).toEqual(bodyString);
});
+
+ it('injects defaultModel into body when body has no model and defaultModel is provided', () => {
+ const body = {
+ messages: [{ role: 'user', content: 'This is a test' }],
+ };
+ [chatUrl, completionUrl, completionExtensionsUrl].forEach((url: string) => {
+ const sanitizedBodyString = getRequestWithStreamOption(
+ url,
+ JSON.stringify(body),
+ false,
+ 'gpt-4o'
+ );
+ const parsed = JSON.parse(sanitizedBodyString);
+ expect(parsed.model).toEqual('gpt-4o');
+ });
+ });
+
+ it('preserves existing model in body when defaultModel is provided', () => {
+ const body = {
+ model: 'gpt-4-turbo',
+ messages: [{ role: 'user', content: 'This is a test' }],
+ };
+ [chatUrl, completionUrl, completionExtensionsUrl].forEach((url: string) => {
+ const sanitizedBodyString = getRequestWithStreamOption(
+ url,
+ JSON.stringify(body),
+ false,
+ 'gpt-4o'
+ );
+ const parsed = JSON.parse(sanitizedBodyString);
+ expect(parsed.model).toEqual('gpt-4-turbo');
+ });
+ });
+
+ it('does not inject model when defaultModel is not provided', () => {
+ const body = {
+ messages: [{ role: 'user', content: 'This is a test' }],
+ };
+ [chatUrl, completionUrl, completionExtensionsUrl].forEach((url: string) => {
+ const sanitizedBodyString = getRequestWithStreamOption(url, JSON.stringify(body), false);
+ const parsed = JSON.parse(sanitizedBodyString);
+ expect(parsed.model).toBeUndefined();
+ });
+ });
});
describe('transformApiUrlToRegex', () => {
diff --git a/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/azure_openai_utils.ts b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/azure_openai_utils.ts
index df49dade74ec6..15eeacc839ced 100644
--- a/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/azure_openai_utils.ts
+++ b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/azure_openai_utils.ts
@@ -24,9 +24,13 @@ const APIS_ALLOWING_STREAMING = new Set([
*
* The stream parameter is only accepted in the Chat API, the Completion API
* and the Completions Extensions API
+ *
+ * When defaultModel is provided and the body does not already specify a model,
+ * it is injected into the body. This is required for Azure API Management (APIM)
+ * or proxy endpoints that do not infer the model from the deployment URL.
*/
-export const sanitizeRequest = (url: string, body: string): string => {
- return getRequestWithStreamOption(url, body, false);
+export const sanitizeRequest = (url: string, body: string, defaultModel?: string): string => {
+ return getRequestWithStreamOption(url, body, false, defaultModel);
};
/**
@@ -34,8 +38,17 @@ export const sanitizeRequest = (url: string, body: string): string => {
*
* The stream parameter is only accepted in the Chat API, the Completion API
* and the Completions Extensions API
+ *
+ * When defaultModel is provided and the body does not already specify a model,
+ * it is injected into the body. This is required for Azure API Management (APIM)
+ * or proxy endpoints that do not infer the model from the deployment URL.
*/
-export const getRequestWithStreamOption = (url: string, body: string, stream: boolean): string => {
+export const getRequestWithStreamOption = (
+ url: string,
+ body: string,
+ stream: boolean,
+ defaultModel?: string
+): string => {
if (
!Array.from(APIS_ALLOWING_STREAMING)
.map((apiUrl: string) => transformApiUrlToRegex(apiUrl))
@@ -53,6 +66,9 @@ export const getRequestWithStreamOption = (url: string, body: string, stream: bo
include_usage: true,
};
}
+ if (defaultModel && !jsonBody.model) {
+ jsonBody.model = defaultModel;
+ }
}
return JSON.stringify(jsonBody);
diff --git a/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/utils.test.ts b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/utils.test.ts
index 1cf6c117cf6b2..cdc653b45bc97 100644
--- a/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/utils.test.ts
+++ b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/utils.test.ts
@@ -65,7 +65,7 @@ describe('Utils', () => {
it('calls azure_openai_utils sanitizeRequest when provider is AzureAi', () => {
sanitizeRequest(OpenAiProviderType.AzureAi, azureAiUrl, bodyString);
- expect(mockAzureAiSanitizeRequest).toHaveBeenCalledWith(azureAiUrl, bodyString);
+ expect(mockAzureAiSanitizeRequest).toHaveBeenCalledWith(azureAiUrl, bodyString, undefined);
expect(mockOpenAiSanitizeRequest).not.toHaveBeenCalled();
expect(mockOtherOpenAiSanitizeRequest).not.toHaveBeenCalled();
});
@@ -130,7 +130,8 @@ describe('Utils', () => {
expect(mockAzureAiGetRequestWithStreamOption).toHaveBeenCalledWith(
azureAiUrl,
bodyString,
- true
+ true,
+ undefined
);
expect(mockOpenAiGetRequestWithStreamOption).not.toHaveBeenCalled();
expect(mockOtherOpenAiGetRequestWithStreamOption).not.toHaveBeenCalled();
diff --git a/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/utils.ts b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/utils.ts
index c0e74a86a6356..909cfec2d6c49 100644
--- a/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/utils.ts
+++ b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/utils.ts
@@ -31,7 +31,7 @@ export const sanitizeRequest = (
case OpenAiProviderType.OpenAi:
return openAiSanitizeRequest(url, body, defaultModel!);
case OpenAiProviderType.AzureAi:
- return azureAiSanitizeRequest(url, body);
+ return azureAiSanitizeRequest(url, body, defaultModel);
case OpenAiProviderType.Other:
return otherOpenAiSanitizeRequest(body);
default:
@@ -47,13 +47,6 @@ export function getRequestWithStreamOption(
defaultModel: string
): string;
-export function getRequestWithStreamOption(
- provider: OpenAiProviderType.AzureAi | OpenAiProviderType.Other,
- url: string,
- body: string,
- stream: boolean
-): string;
-
export function getRequestWithStreamOption(
provider: OpenAiProviderType,
url: string,
@@ -73,7 +66,7 @@ export function getRequestWithStreamOption(
case OpenAiProviderType.OpenAi:
return openAiGetRequestWithStreamOption(url, body, stream, defaultModel!);
case OpenAiProviderType.AzureAi:
- return azureAiGetRequestWithStreamOption(url, body, stream);
+ return azureAiGetRequestWithStreamOption(url, body, stream, defaultModel);
case OpenAiProviderType.Other:
return otherOpenAiGetRequestWithStreamOption(body, stream, defaultModel);
default:
diff --git a/x-pack/platform/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/openai.ts b/x-pack/platform/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/openai.ts
index 9fe9dae10b754..50657c3ce547e 100644
--- a/x-pack/platform/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/openai.ts
+++ b/x-pack/platform/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/openai.ts
@@ -383,6 +383,47 @@ export default function genAiTest({ getService }: FtrProviderContext) {
expect(body.status).to.equal('ok');
expect(body.connector_id).to.equal(createdAction.id);
});
+
+ it('should return 200 when creating an Azure connector with defaultModel', async () => {
+ const { body: createdAction } = await supertest
+ .post('/api/actions/connector')
+ .set('kbn-xsrf', 'foo')
+ .send({
+ name,
+ connector_type_id: connectorTypeId,
+ config: {
+ apiProvider: 'Azure OpenAI',
+ apiUrl: config.apiUrl,
+ defaultModel: 'gpt-4o',
+ },
+ secrets,
+ })
+ .expect(200);
+
+ expect(createdAction).to.have.property('id');
+ expect(createdAction.config.apiProvider).to.equal('Azure OpenAI');
+ expect(createdAction.config.defaultModel).to.equal('gpt-4o');
+ });
+
+ it('should return 200 when creating an Azure connector without defaultModel', async () => {
+ const { body: createdAction } = await supertest
+ .post('/api/actions/connector')
+ .set('kbn-xsrf', 'foo')
+ .send({
+ name,
+ connector_type_id: connectorTypeId,
+ config: {
+ apiProvider: 'Azure OpenAI',
+ apiUrl: config.apiUrl,
+ },
+ secrets,
+ })
+ .expect(200);
+
+ expect(createdAction).to.have.property('id');
+ expect(createdAction.config.apiProvider).to.equal('Azure OpenAI');
+ expect(createdAction.config).to.not.have.property('defaultModel');
+ });
});
describe('executor', () => {