diff --git a/x-pack/platform/plugins/shared/stack_connectors/public/connector_types/openai/constants.tsx b/x-pack/platform/plugins/shared/stack_connectors/public/connector_types/openai/constants.tsx index 2e42c28031713..a2bae6d87bb75 100644 --- a/x-pack/platform/plugins/shared/stack_connectors/public/connector_types/openai/constants.tsx +++ b/x-pack/platform/plugins/shared/stack_connectors/public/connector_types/openai/constants.tsx @@ -52,8 +52,8 @@ export const getDefaultBody = (config?: Config) => { return config.defaultModel ? DEFAULT_BODY_OTHER(config.defaultModel) : DEFAULT_BODY; } if (config?.apiProvider === OpenAiProviderType.AzureAi) { - // update sample data if AzureAi - return DEFAULT_BODY_AZURE; + // update sample data if AzureAi; include model when defaultModel is configured (e.g. for APIM endpoints) + return config.defaultModel ? DEFAULT_BODY_OTHER(config.defaultModel) : DEFAULT_BODY_AZURE; } // default to OpenAiProviderType.OpenAi sample data return DEFAULT_BODY; @@ -155,6 +155,18 @@ export const azureAiConfig: ConfigFieldSchema[] = [ /> ), }, + { + id: 'defaultModel', + label: i18n.DEFAULT_MODEL_LABEL, + isRequired: false, + labelAppend: OptionalFieldLabel, + helpText: ( + + ), + }, contextWindowLengthField, temperatureField, ]; diff --git a/x-pack/platform/plugins/shared/stack_connectors/public/connector_types/openai/params.test.tsx b/x-pack/platform/plugins/shared/stack_connectors/public/connector_types/openai/params.test.tsx index 1b3d32d0b8518..ee5e75f462cf0 100644 --- a/x-pack/platform/plugins/shared/stack_connectors/public/connector_types/openai/params.test.tsx +++ b/x-pack/platform/plugins/shared/stack_connectors/public/connector_types/openai/params.test.tsx @@ -132,3 +132,22 @@ describe('Gen AI Params Fields renders', () => { ); }); }); + +describe('getDefaultBody', () => { + it('returns Azure body without model when Azure config has no defaultModel', () => { + const body = getDefaultBody({ + apiProvider: OpenAiProviderType.AzureAi, + apiUrl: 'https://my-resource.openai.azure.com', + }); + expect(body).not.toContain('"model"'); + }); + + it('returns body with model when Azure config has defaultModel', () => { + const body = getDefaultBody({ + apiProvider: OpenAiProviderType.AzureAi, + apiUrl: 'https://my-resource.openai.azure.com', + defaultModel: 'gpt-4o', + }); + expect(body).toContain('"model": "gpt-4o"'); + }); +}); diff --git a/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/azure_openai_utils.test.ts b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/azure_openai_utils.test.ts index 07feee5cc517b..642c813d38d46 100644 --- a/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/azure_openai_utils.test.ts +++ b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/azure_openai_utils.test.ts @@ -87,6 +87,40 @@ describe('Azure Open AI Utils', () => { const sanitizedBodyString = sanitizeRequest('https://randostring.ai', bodyString); expect(sanitizedBodyString).toEqual(bodyString); }); + + it('injects defaultModel into body when body has no model and defaultModel is provided', () => { + const body = { + messages: [{ role: 'user', content: 'This is a test' }], + }; + [chatUrl, completionUrl, completionExtensionsUrl].forEach((url: string) => { + const sanitizedBodyString = sanitizeRequest(url, JSON.stringify(body), 'gpt-4o'); + const parsed = JSON.parse(sanitizedBodyString); + expect(parsed.model).toEqual('gpt-4o'); + }); + }); + + it('preserves existing model in body when defaultModel is provided', () => { + const body = { + model: 'gpt-4-turbo', + messages: [{ role: 'user', content: 'This is a test' }], + }; + [chatUrl, completionUrl, completionExtensionsUrl].forEach((url: string) => { + const sanitizedBodyString = sanitizeRequest(url, JSON.stringify(body), 'gpt-4o'); + const parsed = JSON.parse(sanitizedBodyString); + expect(parsed.model).toEqual('gpt-4-turbo'); + }); + }); + + it('does not inject model when defaultModel is not provided', () => { + const body = { + messages: [{ role: 'user', content: 'This is a test' }], + }; + [chatUrl, completionUrl, completionExtensionsUrl].forEach((url: string) => { + const sanitizedBodyString = sanitizeRequest(url, JSON.stringify(body)); + const parsed = JSON.parse(sanitizedBodyString); + expect(parsed.model).toBeUndefined(); + }); + }); }); describe('getRequestWithStreamOption', () => { @@ -180,6 +214,50 @@ describe('Azure Open AI Utils', () => { ); expect(sanitizedBodyString).toEqual(bodyString); }); + + it('injects defaultModel into body when body has no model and defaultModel is provided', () => { + const body = { + messages: [{ role: 'user', content: 'This is a test' }], + }; + [chatUrl, completionUrl, completionExtensionsUrl].forEach((url: string) => { + const sanitizedBodyString = getRequestWithStreamOption( + url, + JSON.stringify(body), + false, + 'gpt-4o' + ); + const parsed = JSON.parse(sanitizedBodyString); + expect(parsed.model).toEqual('gpt-4o'); + }); + }); + + it('preserves existing model in body when defaultModel is provided', () => { + const body = { + model: 'gpt-4-turbo', + messages: [{ role: 'user', content: 'This is a test' }], + }; + [chatUrl, completionUrl, completionExtensionsUrl].forEach((url: string) => { + const sanitizedBodyString = getRequestWithStreamOption( + url, + JSON.stringify(body), + false, + 'gpt-4o' + ); + const parsed = JSON.parse(sanitizedBodyString); + expect(parsed.model).toEqual('gpt-4-turbo'); + }); + }); + + it('does not inject model when defaultModel is not provided', () => { + const body = { + messages: [{ role: 'user', content: 'This is a test' }], + }; + [chatUrl, completionUrl, completionExtensionsUrl].forEach((url: string) => { + const sanitizedBodyString = getRequestWithStreamOption(url, JSON.stringify(body), false); + const parsed = JSON.parse(sanitizedBodyString); + expect(parsed.model).toBeUndefined(); + }); + }); }); describe('transformApiUrlToRegex', () => { diff --git a/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/azure_openai_utils.ts b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/azure_openai_utils.ts index df49dade74ec6..15eeacc839ced 100644 --- a/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/azure_openai_utils.ts +++ b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/azure_openai_utils.ts @@ -24,9 +24,13 @@ const APIS_ALLOWING_STREAMING = new Set([ * * The stream parameter is only accepted in the Chat API, the Completion API * and the Completions Extensions API + * + * When defaultModel is provided and the body does not already specify a model, + * it is injected into the body. This is required for Azure API Management (APIM) + * or proxy endpoints that do not infer the model from the deployment URL. */ -export const sanitizeRequest = (url: string, body: string): string => { - return getRequestWithStreamOption(url, body, false); +export const sanitizeRequest = (url: string, body: string, defaultModel?: string): string => { + return getRequestWithStreamOption(url, body, false, defaultModel); }; /** @@ -34,8 +38,17 @@ export const sanitizeRequest = (url: string, body: string): string => { * * The stream parameter is only accepted in the Chat API, the Completion API * and the Completions Extensions API + * + * When defaultModel is provided and the body does not already specify a model, + * it is injected into the body. This is required for Azure API Management (APIM) + * or proxy endpoints that do not infer the model from the deployment URL. */ -export const getRequestWithStreamOption = (url: string, body: string, stream: boolean): string => { +export const getRequestWithStreamOption = ( + url: string, + body: string, + stream: boolean, + defaultModel?: string +): string => { if ( !Array.from(APIS_ALLOWING_STREAMING) .map((apiUrl: string) => transformApiUrlToRegex(apiUrl)) @@ -53,6 +66,9 @@ export const getRequestWithStreamOption = (url: string, body: string, stream: bo include_usage: true, }; } + if (defaultModel && !jsonBody.model) { + jsonBody.model = defaultModel; + } } return JSON.stringify(jsonBody); diff --git a/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/utils.test.ts b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/utils.test.ts index 1cf6c117cf6b2..cdc653b45bc97 100644 --- a/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/utils.test.ts +++ b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/utils.test.ts @@ -65,7 +65,7 @@ describe('Utils', () => { it('calls azure_openai_utils sanitizeRequest when provider is AzureAi', () => { sanitizeRequest(OpenAiProviderType.AzureAi, azureAiUrl, bodyString); - expect(mockAzureAiSanitizeRequest).toHaveBeenCalledWith(azureAiUrl, bodyString); + expect(mockAzureAiSanitizeRequest).toHaveBeenCalledWith(azureAiUrl, bodyString, undefined); expect(mockOpenAiSanitizeRequest).not.toHaveBeenCalled(); expect(mockOtherOpenAiSanitizeRequest).not.toHaveBeenCalled(); }); @@ -130,7 +130,8 @@ describe('Utils', () => { expect(mockAzureAiGetRequestWithStreamOption).toHaveBeenCalledWith( azureAiUrl, bodyString, - true + true, + undefined ); expect(mockOpenAiGetRequestWithStreamOption).not.toHaveBeenCalled(); expect(mockOtherOpenAiGetRequestWithStreamOption).not.toHaveBeenCalled(); diff --git a/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/utils.ts b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/utils.ts index c0e74a86a6356..909cfec2d6c49 100644 --- a/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/utils.ts +++ b/x-pack/platform/plugins/shared/stack_connectors/server/connector_types/openai/lib/utils.ts @@ -31,7 +31,7 @@ export const sanitizeRequest = ( case OpenAiProviderType.OpenAi: return openAiSanitizeRequest(url, body, defaultModel!); case OpenAiProviderType.AzureAi: - return azureAiSanitizeRequest(url, body); + return azureAiSanitizeRequest(url, body, defaultModel); case OpenAiProviderType.Other: return otherOpenAiSanitizeRequest(body); default: @@ -47,13 +47,6 @@ export function getRequestWithStreamOption( defaultModel: string ): string; -export function getRequestWithStreamOption( - provider: OpenAiProviderType.AzureAi | OpenAiProviderType.Other, - url: string, - body: string, - stream: boolean -): string; - export function getRequestWithStreamOption( provider: OpenAiProviderType, url: string, @@ -73,7 +66,7 @@ export function getRequestWithStreamOption( case OpenAiProviderType.OpenAi: return openAiGetRequestWithStreamOption(url, body, stream, defaultModel!); case OpenAiProviderType.AzureAi: - return azureAiGetRequestWithStreamOption(url, body, stream); + return azureAiGetRequestWithStreamOption(url, body, stream, defaultModel); case OpenAiProviderType.Other: return otherOpenAiGetRequestWithStreamOption(body, stream, defaultModel); default: diff --git a/x-pack/platform/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/openai.ts b/x-pack/platform/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/openai.ts index 9fe9dae10b754..50657c3ce547e 100644 --- a/x-pack/platform/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/openai.ts +++ b/x-pack/platform/test/alerting_api_integration/security_and_spaces/group2/tests/actions/connector_types/openai.ts @@ -383,6 +383,47 @@ export default function genAiTest({ getService }: FtrProviderContext) { expect(body.status).to.equal('ok'); expect(body.connector_id).to.equal(createdAction.id); }); + + it('should return 200 when creating an Azure connector with defaultModel', async () => { + const { body: createdAction } = await supertest + .post('/api/actions/connector') + .set('kbn-xsrf', 'foo') + .send({ + name, + connector_type_id: connectorTypeId, + config: { + apiProvider: 'Azure OpenAI', + apiUrl: config.apiUrl, + defaultModel: 'gpt-4o', + }, + secrets, + }) + .expect(200); + + expect(createdAction).to.have.property('id'); + expect(createdAction.config.apiProvider).to.equal('Azure OpenAI'); + expect(createdAction.config.defaultModel).to.equal('gpt-4o'); + }); + + it('should return 200 when creating an Azure connector without defaultModel', async () => { + const { body: createdAction } = await supertest + .post('/api/actions/connector') + .set('kbn-xsrf', 'foo') + .send({ + name, + connector_type_id: connectorTypeId, + config: { + apiProvider: 'Azure OpenAI', + apiUrl: config.apiUrl, + }, + secrets, + }) + .expect(200); + + expect(createdAction).to.have.property('id'); + expect(createdAction.config.apiProvider).to.equal('Azure OpenAI'); + expect(createdAction.config).to.not.have.property('defaultModel'); + }); }); describe('executor', () => {