diff --git a/docs/changelog/142969.yaml b/docs/changelog/142969.yaml new file mode 100644 index 0000000000000..6fa1958b62488 --- /dev/null +++ b/docs/changelog/142969.yaml @@ -0,0 +1,5 @@ +area: Inference +issues: [] +pr: 142969 +summary: "[Inference API] Add custom headers for Azure OpenAI Service" +type: enhancement diff --git a/server/src/main/resources/transport/definitions/referable/inference_azure_openai_task_settings_headers.csv b/server/src/main/resources/transport/definitions/referable/inference_azure_openai_task_settings_headers.csv new file mode 100644 index 0000000000000..c988fdc9fe987 --- /dev/null +++ b/server/src/main/resources/transport/definitions/referable/inference_azure_openai_task_settings_headers.csv @@ -0,0 +1 @@ +9304000 diff --git a/server/src/main/resources/transport/upper_bounds/9.4.csv b/server/src/main/resources/transport/upper_bounds/9.4.csv index c60446d500473..3d08c5fa7867e 100644 --- a/server/src/main/resources/transport/upper_bounds/9.4.csv +++ b/server/src/main/resources/transport/upper_bounds/9.4.csv @@ -1 +1 @@ -query_dsl_boxplot_exponential_histogram_support,9303000 +inference_azure_openai_task_settings_headers,9304000 diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UpdateInferenceModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UpdateInferenceModelAction.java index 56ee6d0306a40..dcbdfbfb2ed02 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UpdateInferenceModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UpdateInferenceModelAction.java @@ -131,8 +131,8 @@ public Settings getContentAsSettings() { if (unvalidatedMap.containsKey(SERVICE_SETTINGS)) { if (unvalidatedMap.get(SERVICE_SETTINGS) instanceof Map tempMap) { for (Map.Entry entry : (tempMap).entrySet()) { - if (entry.getKey() instanceof String key && entry.getValue() instanceof Object value) { - serviceSettings.put(key, value); + if (entry.getKey() instanceof String key) { + serviceSettings.put(key, entry.getValue()); } else { throw new ElasticsearchStatusException( "Failed to parse update request [{}]", @@ -154,8 +154,8 @@ public Settings getContentAsSettings() { if (unvalidatedMap.containsKey(TASK_SETTINGS)) { if (unvalidatedMap.get(TASK_SETTINGS) instanceof Map tempMap) { for (Map.Entry entry : (tempMap).entrySet()) { - if (entry.getKey() instanceof String key && entry.getValue() instanceof Object value) { - taskSettings.put(key, value); + if (entry.getKey() instanceof String key) { + taskSettings.put(key, entry.getValue()); } else { throw new ElasticsearchStatusException( "Failed to parse update request [{}]", diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/parser/Headers.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/parser/Headers.java new file mode 100644 index 0000000000000..dd9a2798003ac --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/parser/Headers.java @@ -0,0 +1,133 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.common.parser; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentFragment; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeNullValues; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues; + +public record Headers(StatefulValue> mapValue) implements ToXContentFragment, Writeable { + + // public for testing + public static final String HEADERS_FIELD = "headers"; + // public for testing + public static final Headers UNDEFINED_INSTANCE = new Headers(StatefulValue.undefined()); + public static final Headers NULL_INSTANCE = new Headers(StatefulValue.nullInstance()); + + /** + * Sentinel passed by the parser when the headers field is present with value null. + */ + public static final Object PARSER_NULL_SENTINEL = new HashMap<>(); + + private static final ParseField HEADERS = new ParseField(HEADERS_FIELD); + + public static void initParser(ConstructingObjectParser parser) { + parser.declareObjectOrNull(optionalConstructorArg(), (p, c) -> { + var parsedMap = p.map(); + if (parsedMap == null || parsedMap == PARSER_NULL_SENTINEL) { + return parsedMap; + } + + var validationException = new ValidationException(); + + return doValidation(parsedMap, validationException); + }, PARSER_NULL_SENTINEL, HEADERS); + } + + private static Map doValidation(Map map, ValidationException validationException) { + removeNullValues(map); + + var stringHeaders = validateMapStringValues(map, HEADERS.getPreferredName(), validationException, false, Map.of()); + + validationException.throwIfValidationErrorsExist(); + + return stringHeaders; + } + + @SuppressWarnings("unchecked") + public static Headers create(Object arg, String path) { + // We will get null here if the headers field was not present in the json + if (arg == null) { + return UNDEFINED_INSTANCE; + } + + if (arg == PARSER_NULL_SENTINEL) { + return NULL_INSTANCE; + } + + var validationException = new ValidationException(); + + if (arg instanceof Map == false) { + validationException.addValidationError(ObjectParserUtils.invalidTypeErrorMsg(HEADERS_FIELD, path, arg, "Map")); + throw validationException; + } + + // It's not likely that this create method would be called with invalid values since they should be validated during parsing but + // we'll do it just in case this method is used elsewhere + var stringsMap = doValidation((Map) arg, validationException); + + if (stringsMap.isEmpty()) { + // If a user specifies "headers": {} we'll assume they don't want any headers. If this in the context of an update API, + // this is the same as if they did "headers": null which means to remove all existing headers. + return NULL_INSTANCE; + } + + return new Headers(StatefulValue.of(stringsMap)); + } + + public Headers { + Objects.requireNonNull(mapValue); + } + + public Headers(StreamInput in) throws IOException { + this(StatefulValue.read(in, input -> input.readImmutableMap(StreamInput::readString, StreamInput::readString))); + } + + public boolean isEmpty() { + return mapValue.isPresent() == false || mapValue.get().isEmpty(); + } + + public boolean isPresent() { + return mapValue.isPresent(); + } + + public boolean isNull() { + return mapValue.isNull(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (isEmpty() == false) { + builder.field(HEADERS.getPreferredName(), mapValue.get()); + } + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + StatefulValue.write( + out, + mapValue, + (streamOutput, v) -> streamOutput.writeMap(v, StreamOutput::writeString, StreamOutput::writeString) + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/parser/StatefulValue.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/parser/StatefulValue.java new file mode 100644 index 0000000000000..60479cac33973 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/parser/StatefulValue.java @@ -0,0 +1,126 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.common.parser; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; +import java.util.NoSuchElementException; +import java.util.Objects; + +/** + * This class holds a value of type {@param T} that can be in one of three states: undefined, null, or defined with a non-null value. + * It provides methods to check the state and retrieve the value if present. + *

+ * Undefined means that the value is not defined aka it was absent in the input + * Null means that the value is defined but explicitly set to null + * Present means that the value is defined and not null + * @param the type of the value + */ +public final class StatefulValue { + + static final NoSuchElementException NO_VALUE_PRESENT = new NoSuchElementException("No value present"); + + private static final StatefulValue UNDEFINED_INSTANCE = new StatefulValue<>(null, false); + private static final StatefulValue NULL_INSTANCE = new StatefulValue<>(null, true); + + public static StatefulValue undefined() { + @SuppressWarnings("unchecked") + var absent = (StatefulValue) UNDEFINED_INSTANCE; + return absent; + } + + public static StatefulValue nullInstance() { + @SuppressWarnings("unchecked") + var nullInstance = (StatefulValue) NULL_INSTANCE; + return nullInstance; + } + + public static StatefulValue of(T value) { + return new StatefulValue<>(Objects.requireNonNull(value), true); + } + + public static StatefulValue read(StreamInput in, Writeable.Reader reader) throws IOException { + var isDefined = in.readBoolean(); + if (isDefined == false) { + return undefined(); + } + + var isNull = in.readBoolean(); + if (isNull) { + return nullInstance(); + } + + var value = reader.read(in); + return of(value); + } + + public static void write(StreamOutput out, StatefulValue statefulValue, Writeable.Writer writer) throws IOException { + out.writeBoolean(statefulValue.isDefined); + if (statefulValue.isDefined) { + out.writeBoolean(statefulValue.isNull()); + if (statefulValue.isPresent()) { + writer.write(out, statefulValue.value); + } + } + } + + private final T value; + private final boolean isDefined; + + private StatefulValue(T value, boolean isDefined) { + this.value = value; + this.isDefined = isDefined; + } + + /** + * Returns true if the value is not defined, meaning it is absent. + */ + public boolean isUndefined() { + return isDefined == false; + } + + /** + * Returns true if the value is defined and explicitly set to null. + */ + public boolean isNull() { + return isDefined && value == null; + } + + /** + * Returns true if the value is defined and not null. + */ + public boolean isPresent() { + return isDefined && value != null; + } + + public T get() { + if (isPresent() == false) { + throw NO_VALUE_PRESENT; + } + return value; + } + + public T orElse(T other) { + return isPresent() ? value : other; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + StatefulValue statefulValue = (StatefulValue) o; + return Objects.equals(value, statefulValue.value) && isDefined == statefulValue.isDefined; + } + + @Override + public int hashCode() { + return Objects.hash(value, isDefined); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index f5325d39fe8ca..8105f16ef7b2e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -35,6 +35,7 @@ import java.util.EnumSet; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.function.Function; @@ -641,7 +642,7 @@ public static void validateMapValues( settingName, entry.getKey(), entry.getValue(), - entry.getValue(), + getTypeAsString(entry.getValue()), String.join(", ", validTypesAsStrings) ); } @@ -657,6 +658,29 @@ public static void validateMapValues( } } + private static String getTypeAsString(@Nullable Object value) { + if (value == null) { + return "null"; + } + + var simpleName = value.getClass().getSimpleName(); + var lowerCaseSimpleName = simpleName.toLowerCase(Locale.ROOT); + + // maps may be represented as HashMap, LinkedHashMap, Map1, etc. Lists may be ArrayList, LinkedList, etc. + // Sets may be HashSet, LinkedHashSet, etc. We want to simplify these to Map, List, and Set in the error messages. + if (lowerCaseSimpleName.contains("map")) { + return "Map"; + } else if (lowerCaseSimpleName.contains("list")) { + return "Array"; + } else if (lowerCaseSimpleName.contains("set")) { + return "Set"; + } else if (lowerCaseSimpleName.contains("array")) { + return "Array"; + } else { + return simpleName; + } + } + public static Map convertMapStringsToSecureString( Map map, String settingName, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsRequestTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsRequestTaskSettings.java index 8c9fd22a7cdf7..4697ebdcdc292 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsRequestTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsRequestTaskSettings.java @@ -10,7 +10,6 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsRequestTaskSettings; import java.util.Map; @@ -31,7 +30,7 @@ public record AzureAiStudioEmbeddingsRequestTaskSettings(@Nullable String user) * does not throw an error. * * @param map the settings received from a request - * @return a {@link AzureOpenAiEmbeddingsRequestTaskSettings} + * @return a {@link AzureAiStudioEmbeddingsRequestTaskSettings} */ public static AzureAiStudioEmbeddingsRequestTaskSettings fromMap(Map map) { if (map.isEmpty()) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiModel.java index c42f4d1650428..95e5284b4647b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiModel.java @@ -96,6 +96,11 @@ public AzureOpenAiRateLimitServiceSettings rateLimitServiceSettings() { return rateLimitServiceSettings; } + @Override + public AzureOpenAiSecretSettings getSecretSettings() { + return (AzureOpenAiSecretSettings) super.getSecretSettings(); + } + @Override public RateLimitSettings rateLimitSettings() { return rateLimitServiceSettings.rateLimitSettings(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index 4d06b6ff2fcb5..19ae92ae653c8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -141,7 +141,9 @@ public void parseRequestConfig( throwIfNotEmptyMap(config, NAME); throwIfNotEmptyMap(serviceSettingsMap, NAME); - throwIfNotEmptyMap(taskSettingsMap, NAME); + // The new approach is to leverage a ConstructingObjectParser to parse the task settings, this does not mutate the original map + // so we don't need to check if it's empty after parsing. The ConstructingObjectParser will throw an exception if there are any + // unrecognized fields in the task settings parsedModelListener.onResponse(model); } catch (Exception e) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiTaskSettings.java new file mode 100644 index 0000000000000..22057f89969ec --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiTaskSettings.java @@ -0,0 +1,280 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.inference.InferenceUtils; +import org.elasticsearch.xpack.inference.common.parser.Headers; +import org.elasticsearch.xpack.inference.common.parser.StatefulValue; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; +import static org.elasticsearch.xpack.inference.common.parser.Headers.UNDEFINED_INSTANCE; + +/** + * Base class for Azure OpenAI task settings. Holds optional user and optional + * custom HTTP headers via {@link Headers}. + */ +public abstract class AzureOpenAiTaskSettings> implements TaskSettings { + + protected static final TransportVersion INFERENCE_AZURE_OPENAI_TASK_SETTINGS_HEADERS = TransportVersion.fromName( + "inference_azure_openai_task_settings_headers" + ); + + // Default for testing + protected record CommonSettings(StatefulValue user, Headers headers) { + public CommonSettings { + Objects.requireNonNull(user); + Objects.requireNonNull(headers); + } + + public boolean isEmpty() { + // user is empty if it is not present or if it is empty + // (although the parser should prevent an empty string by throwing a validation exception) + return user.orElse("").isEmpty() && headers().isEmpty(); + } + } + + /** + * Sentinel for parser: when "user" field is present with value null. + */ + private static final Object USER_PARSER_NULL_SENTINEL = new Object(); + + private static final ConstructingObjectParser STORAGE_PARSER = createParser(true); + private static final ConstructingObjectParser REQUEST_PARSER = createParser(false); + + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser constructingObjectParser = new ConstructingObjectParser<>( + "azure_openai_task_settings_parser", + ignoreUnknownFields, + args -> createSettings(args[0], args[1]) + ); + + constructingObjectParser.declareField( + optionalConstructorArg(), + p -> p.currentToken() == XContentParser.Token.VALUE_NULL ? USER_PARSER_NULL_SENTINEL : p.text(), + new ParseField(AzureOpenAiServiceFields.USER), + ObjectParser.ValueType.STRING_OR_NULL + ); + + Headers.initParser(constructingObjectParser); + + return constructingObjectParser; + } + + private static CommonSettings createSettings(Object userArg, Object headersArg) { + StatefulValue user; + if (userArg == null) { + user = StatefulValue.undefined(); + } else if (userArg == USER_PARSER_NULL_SENTINEL) { + user = StatefulValue.nullInstance(); + } else { + user = StatefulValue.of((String) userArg); + } + + Headers headers = headersArg instanceof Headers + ? (Headers) headersArg + : Headers.create(headersArg, ModelConfigurations.TASK_SETTINGS); + return new CommonSettings(user, headers); + } + + protected abstract static class Factory { + private T emptyInstance; + + protected abstract T create(CommonSettings commonSettings); + + protected abstract T createEmptyInstance(); + + public T emptySettings() { + // Ideally we'd be able to pass the empty instance in via the Factory constructor, but since the empty instance relies on the + // factory to be created, we have to lazily create it here. The empty instance will call the AzureOpenAiTaskSettings + // constructor with the factory. If we don't do it this way we end up getting an NPE in the constructor because the factory + // hasn't finished initialization yet. + if (emptyInstance == null) { + emptyInstance = createEmptyInstance(); + } + return emptyInstance; + } + } + + protected static > T parseSettingsFromMap( + Map map, + ConfigurationParseContext configurationParseContext, + Factory factory + ) { + if (map.isEmpty()) { + return factory.emptySettings(); + } + + try { + try ( + var xContent = XContentBuilder.builder(JsonXContent.jsonXContent).map(map); + var parser = JsonXContent.jsonXContent.createParser(XContentParserConfiguration.EMPTY, Strings.toString(xContent)) + ) { + CommonSettings parsed; + + if (configurationParseContext == ConfigurationParseContext.REQUEST) { + parsed = REQUEST_PARSER.parse(parser, null); + validateParsedRequest(parsed); + } else { + parsed = STORAGE_PARSER.parse(parser, null); + } + + return factory.create(parsed); + } + } catch (IOException e) { + throw new IllegalArgumentException("Failed to parse Azure OpenAI task settings", e); + } + } + + private static void validateParsedRequest(CommonSettings parsed) { + if (parsed.user().isPresent() && parsed.user().get().isEmpty()) { + var validationException = new ValidationException(); + validationException.addValidationError( + InferenceUtils.mustBeNonEmptyString(AzureOpenAiServiceFields.USER, ModelConfigurations.TASK_SETTINGS) + ); + throw validationException; + } + } + + private final CommonSettings taskSettings; + private final Factory factory; + + protected AzureOpenAiTaskSettings(@Nullable String user, @Nullable Headers headers, Factory factory) { + this(createSettings(user, headers), factory); + } + + protected AzureOpenAiTaskSettings(CommonSettings taskSettings, Factory factory) { + this.taskSettings = Objects.requireNonNull(taskSettings); + this.factory = Objects.requireNonNull(factory); + } + + protected AzureOpenAiTaskSettings(StreamInput in, Factory factory) throws IOException { + this(readTaskSettingsFromStream(in), factory); + } + + private static CommonSettings readTaskSettingsFromStream(StreamInput in) throws IOException { + if (in.getTransportVersion().supports(INFERENCE_AZURE_OPENAI_TASK_SETTINGS_HEADERS)) { + var user = StatefulValue.read(in, StreamInput::readString); + return new CommonSettings(user, new Headers(in)); + } else { + var user = StatefulValue.undefined(); + var userString = in.readOptionalString(); + if (Strings.isNullOrEmpty(userString) == false) { + user = StatefulValue.of(userString); + } + + return new CommonSettings(user, UNDEFINED_INSTANCE); + } + } + + public StatefulValue user() { + return taskSettings.user(); + } + + public Headers headers() { + return taskSettings.headers(); + } + + @Override + public boolean isEmpty() { + return taskSettings.isEmpty(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + var user = taskSettings.user(); + + if (user.isPresent() && user.get().isEmpty() == false) { + builder.field(AzureOpenAiServiceFields.USER, user.get()); + } + + taskSettings.headers().toXContent(builder, params); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AzureOpenAiTaskSettings that = (AzureOpenAiTaskSettings) o; + return Objects.equals(taskSettings, that.taskSettings); + } + + @Override + public int hashCode() { + return Objects.hash(taskSettings); + } + + @Override + public T updatedTaskSettings(Map newSettings) { + var updated = parseSettingsFromMap(new HashMap<>(newSettings), ConfigurationParseContext.REQUEST, factory); + + var userToUse = taskSettings.user(); + if (updated.user().isPresent()) { + userToUse = updated.user(); + } else if (updated.user().isNull()) { + userToUse = StatefulValue.undefined(); + } + + var headersToUse = taskSettings.headers(); + if (updated.headers().isPresent()) { + headersToUse = updated.headers(); + } else if (updated.headers().isNull()) { + headersToUse = Headers.UNDEFINED_INSTANCE; + } + + if (userToUse.isUndefined() && headersToUse.mapValue().isUndefined()) { + return factory.emptySettings(); + } + + return factory.create(new CommonSettings(userToUse, headersToUse)); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + assert false : "should never be called when supportsVersion is used"; + return INFERENCE_AZURE_OPENAI_TASK_SETTINGS_HEADERS; + } + + @Override + public boolean supportsVersion(TransportVersion version) { + return INFERENCE_AZURE_OPENAI_TASK_SETTINGS_HEADERS.supports(version); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (out.getTransportVersion().supports(INFERENCE_AZURE_OPENAI_TASK_SETTINGS_HEADERS)) { + StatefulValue.write(out, taskSettings.user(), StreamOutput::writeString); + taskSettings.headers().writeTo(out); + } else { + out.writeOptionalString(user().orElse(null)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModel.java index 25136f01dd809..fb2144191a9f1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModel.java @@ -29,8 +29,7 @@ public static AzureOpenAiCompletionModel of(AzureOpenAiCompletionModel model, Ma return model; } - var requestTaskSettings = AzureOpenAiCompletionRequestTaskSettings.fromMap(taskSettings); - return new AzureOpenAiCompletionModel(model, AzureOpenAiCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); + return new AzureOpenAiCompletionModel(model, model.getTaskSettings().updatedTaskSettings(taskSettings)); } public AzureOpenAiCompletionModel( @@ -47,7 +46,7 @@ public AzureOpenAiCompletionModel( taskType, service, AzureOpenAiCompletionServiceSettings.fromMap(serviceSettings, context), - AzureOpenAiCompletionTaskSettings.fromMap(taskSettings), + AzureOpenAiCompletionTaskSettings.fromMap(taskSettings, context), AzureOpenAiSecretSettings.fromMap(secrets) ); } @@ -91,11 +90,6 @@ public AzureOpenAiCompletionTaskSettings getTaskSettings() { return (AzureOpenAiCompletionTaskSettings) super.getTaskSettings(); } - @Override - public AzureOpenAiSecretSettings getSecretSettings() { - return (AzureOpenAiSecretSettings) super.getSecretSettings(); - } - @Override public ExecutableAction accept(AzureOpenAiActionVisitor creator, Map taskSettings) { return creator.create(this, taskSettings); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionRequestTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionRequestTaskSettings.java deleted file mode 100644 index 5dd42bb1b911f..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionRequestTaskSettings.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.azureopenai.completion; - -import org.elasticsearch.common.ValidationException; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.ModelConfigurations; - -import java.util.Map; - -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; -import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.USER; - -public record AzureOpenAiCompletionRequestTaskSettings(@Nullable String user) { - - public static final AzureOpenAiCompletionRequestTaskSettings EMPTY_SETTINGS = new AzureOpenAiCompletionRequestTaskSettings(null); - - public static AzureOpenAiCompletionRequestTaskSettings fromMap(Map map) { - if (map.isEmpty()) { - return AzureOpenAiCompletionRequestTaskSettings.EMPTY_SETTINGS; - } - - ValidationException validationException = new ValidationException(); - - String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException); - - if (validationException.validationErrors().isEmpty() == false) { - throw validationException; - } - - return new AzureOpenAiCompletionRequestTaskSettings(user); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettings.java index 4f48f3b1ff21f..9567cbaca8b13 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettings.java @@ -7,112 +7,56 @@ package org.elasticsearch.xpack.inference.services.azureopenai.completion; -import org.elasticsearch.TransportVersion; -import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.TaskSettings; -import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.common.parser.Headers; +import org.elasticsearch.xpack.inference.common.parser.StatefulValue; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiTaskSettings; import java.io.IOException; -import java.util.HashMap; import java.util.Map; -import java.util.Objects; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; - -public class AzureOpenAiCompletionTaskSettings implements TaskSettings { +public class AzureOpenAiCompletionTaskSettings extends AzureOpenAiTaskSettings { public static final String NAME = "azure_openai_completion_task_settings"; - public static final String USER = "user"; - - public static AzureOpenAiCompletionTaskSettings fromMap(Map map) { - ValidationException validationException = new ValidationException(); - - String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException); - - if (validationException.validationErrors().isEmpty() == false) { - throw validationException; + private static final AzureOpenAiTaskSettings.Factory FACTORY = new Factory<>() { + @Override + public AzureOpenAiCompletionTaskSettings create(CommonSettings commonSettings) { + return new AzureOpenAiCompletionTaskSettings(commonSettings); } - return new AzureOpenAiCompletionTaskSettings(user); - } - - private final String user; + @Override + protected AzureOpenAiCompletionTaskSettings createEmptyInstance() { + return new AzureOpenAiCompletionTaskSettings(); + } + }; - public static AzureOpenAiCompletionTaskSettings of( - AzureOpenAiCompletionTaskSettings originalSettings, - AzureOpenAiCompletionRequestTaskSettings requestSettings - ) { - var userToUse = requestSettings.user() == null ? originalSettings.user : requestSettings.user(); - return new AzureOpenAiCompletionTaskSettings(userToUse); - } + public static final AzureOpenAiCompletionTaskSettings EMPTY = FACTORY.emptySettings(); - public AzureOpenAiCompletionTaskSettings(@Nullable String user) { - this.user = user; + public static AzureOpenAiCompletionTaskSettings fromMap(Map map, ConfigurationParseContext context) { + return AzureOpenAiTaskSettings.parseSettingsFromMap(map, context, FACTORY); } - public AzureOpenAiCompletionTaskSettings(StreamInput in) throws IOException { - this.user = in.readOptionalString(); + private AzureOpenAiCompletionTaskSettings() { + super(null, null, FACTORY); } - @Override - public boolean isEmpty() { - return user == null; + private AzureOpenAiCompletionTaskSettings(CommonSettings commonSettings) { + super(commonSettings, FACTORY); } - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - { - if (user != null) { - builder.field(USER, user); - } - } - builder.endObject(); - return builder; + // Default for testing + AzureOpenAiCompletionTaskSettings(StatefulValue user, Headers headers) { + this(new CommonSettings(user, headers)); } - public String user() { - return user; + public AzureOpenAiCompletionTaskSettings(StreamInput in) throws IOException { + super(in, FACTORY); } @Override public String getWriteableName() { return NAME; } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersion.minimumCompatible(); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalString(user); - } - - @Override - public boolean equals(Object object) { - if (this == object) return true; - if (object == null || getClass() != object.getClass()) return false; - AzureOpenAiCompletionTaskSettings that = (AzureOpenAiCompletionTaskSettings) object; - return Objects.equals(user, that.user); - } - - @Override - public int hashCode() { - return Objects.hash(user); - } - - @Override - public TaskSettings updatedTaskSettings(Map newSettings) { - AzureOpenAiCompletionRequestTaskSettings updatedSettings = AzureOpenAiCompletionRequestTaskSettings.fromMap( - new HashMap<>(newSettings) - ); - return of(this, updatedSettings); - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModel.java index ab87d1a8bb222..0454dd84daa04 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModel.java @@ -30,8 +30,7 @@ public static AzureOpenAiEmbeddingsModel of(AzureOpenAiEmbeddingsModel model, Ma return model; } - var requestTaskSettings = AzureOpenAiEmbeddingsRequestTaskSettings.fromMap(taskSettings); - return new AzureOpenAiEmbeddingsModel(model, AzureOpenAiEmbeddingsTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); + return new AzureOpenAiEmbeddingsModel(model, model.getTaskSettings().updatedTaskSettings(taskSettings)); } public AzureOpenAiEmbeddingsModel( @@ -49,7 +48,7 @@ public AzureOpenAiEmbeddingsModel( taskType, service, AzureOpenAiEmbeddingsServiceSettings.fromMap(serviceSettings, context), - AzureOpenAiEmbeddingsTaskSettings.fromMap(taskSettings), + AzureOpenAiEmbeddingsTaskSettings.fromMap(taskSettings, context), chunkingSettings, AzureOpenAiSecretSettings.fromMap(secrets) ); @@ -98,11 +97,6 @@ public AzureOpenAiEmbeddingsTaskSettings getTaskSettings() { return (AzureOpenAiEmbeddingsTaskSettings) super.getTaskSettings(); } - @Override - public AzureOpenAiSecretSettings getSecretSettings() { - return (AzureOpenAiSecretSettings) super.getSecretSettings(); - } - @Override public ExecutableAction accept(AzureOpenAiActionVisitor creator, Map taskSettings) { return creator.create(this, taskSettings); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTaskSettings.java deleted file mode 100644 index ffb8c844ac89f..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTaskSettings.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.azureopenai.embeddings; - -import org.elasticsearch.common.ValidationException; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.ModelConfigurations; - -import java.util.Map; - -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; -import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.USER; - -/** - * This class handles extracting Azure OpenAI task settings from a request. The difference between this class and - * {@link AzureOpenAiEmbeddingsTaskSettings} is that this class considers all fields as optional. It will not throw an error if a field - * is missing. This allows overriding persistent task settings. - * @param user a unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse - */ -public record AzureOpenAiEmbeddingsRequestTaskSettings(@Nullable String user) { - - public static final AzureOpenAiEmbeddingsRequestTaskSettings EMPTY_SETTINGS = new AzureOpenAiEmbeddingsRequestTaskSettings(null); - - /** - * Extracts the task settings from a map. All settings are considered optional and the absence of a setting - * does not throw an error. - * - * @param map the settings received from a request - * @return a {@link AzureOpenAiEmbeddingsRequestTaskSettings} - */ - public static AzureOpenAiEmbeddingsRequestTaskSettings fromMap(Map map) { - if (map.isEmpty()) { - return AzureOpenAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS; - } - - ValidationException validationException = new ValidationException(); - - String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException); - - if (validationException.validationErrors().isEmpty() == false) { - throw validationException; - } - - return new AzureOpenAiEmbeddingsRequestTaskSettings(user); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettings.java index b9e2113ad7171..fa9dfa6be4ba9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettings.java @@ -7,122 +7,62 @@ package org.elasticsearch.xpack.inference.services.azureopenai.embeddings; -import org.elasticsearch.TransportVersion; -import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.TaskSettings; -import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.common.parser.Headers; +import org.elasticsearch.xpack.inference.common.parser.StatefulValue; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiTaskSettings; import java.io.IOException; -import java.util.HashMap; import java.util.Map; -import java.util.Objects; - -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; -import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.USER; /** - * Defines the task settings for the openai service. + * Defines the task settings for the Azure OpenAI embeddings service. *

- * User is an optional unique identifier representing the end-user, which can help OpenAI to monitor and detect abuse - * see the openai docs for more details + * User is an optional unique identifier representing the end-user, which can help OpenAI to monitor and detect abuse. + * Headers are optional custom HTTP headers to send with the request. */ -public class AzureOpenAiEmbeddingsTaskSettings implements TaskSettings { +public class AzureOpenAiEmbeddingsTaskSettings extends AzureOpenAiTaskSettings { public static final String NAME = "azure_openai_embeddings_task_settings"; - public static AzureOpenAiEmbeddingsTaskSettings fromMap(Map map) { - ValidationException validationException = new ValidationException(); + private static final AzureOpenAiTaskSettings.Factory FACTORY = new Factory<>() { + @Override + public AzureOpenAiEmbeddingsTaskSettings create(CommonSettings commonSettings) { + return new AzureOpenAiEmbeddingsTaskSettings(commonSettings); + } - String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException); - if (validationException.validationErrors().isEmpty() == false) { - throw validationException; + @Override + protected AzureOpenAiEmbeddingsTaskSettings createEmptyInstance() { + return new AzureOpenAiEmbeddingsTaskSettings(); } + }; - return new AzureOpenAiEmbeddingsTaskSettings(user); - } + public static final AzureOpenAiEmbeddingsTaskSettings EMPTY = FACTORY.emptySettings(); - /** - * Creates a new {@link AzureOpenAiEmbeddingsTaskSettings} object by overriding the values in originalSettings with the ones - * passed in via requestSettings if the fields are not null. - * - * @param originalSettings the original {@link AzureOpenAiEmbeddingsTaskSettings} from the inference entity configuration from storage - * @param requestSettings the {@link AzureOpenAiEmbeddingsTaskSettings} from the request - * @return a new {@link AzureOpenAiEmbeddingsTaskSettings} - */ - public static AzureOpenAiEmbeddingsTaskSettings of( - AzureOpenAiEmbeddingsTaskSettings originalSettings, - AzureOpenAiEmbeddingsRequestTaskSettings requestSettings - ) { - var userToUse = requestSettings.user() == null ? originalSettings.user : requestSettings.user(); - return new AzureOpenAiEmbeddingsTaskSettings(userToUse); + public static AzureOpenAiEmbeddingsTaskSettings fromMap(Map map, ConfigurationParseContext context) { + return AzureOpenAiTaskSettings.parseSettingsFromMap(map, context, FACTORY); } - private final String user; - - public AzureOpenAiEmbeddingsTaskSettings(@Nullable String user) { - this.user = user; + private AzureOpenAiEmbeddingsTaskSettings() { + super(null, null, FACTORY); } - public AzureOpenAiEmbeddingsTaskSettings(StreamInput in) throws IOException { - this.user = in.readOptionalString(); + private AzureOpenAiEmbeddingsTaskSettings(CommonSettings commonSettings) { + super(commonSettings, FACTORY); } - @Override - public boolean isEmpty() { - return user == null || user.isEmpty(); + // Default for testing + AzureOpenAiEmbeddingsTaskSettings(StatefulValue user, Headers headers) { + this(new CommonSettings(user, headers)); } - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - if (user != null) { - builder.field(USER, user); - } - builder.endObject(); - return builder; - } - - public String user() { - return user; + public AzureOpenAiEmbeddingsTaskSettings(StreamInput in) throws IOException { + super(in, FACTORY); } @Override public String getWriteableName() { return NAME; } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersion.minimumCompatible(); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalString(user); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - AzureOpenAiEmbeddingsTaskSettings that = (AzureOpenAiEmbeddingsTaskSettings) o; - return Objects.equals(user, that.user); - } - - @Override - public int hashCode() { - return Objects.hash(user); - } - - @Override - public TaskSettings updatedTaskSettings(Map newSettings) { - AzureOpenAiEmbeddingsRequestTaskSettings requestSettings = AzureOpenAiEmbeddingsRequestTaskSettings.fromMap( - new HashMap<>(newSettings) - ); - return of(this, requestSettings); - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/request/AzureOpenAiChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/request/AzureOpenAiChatCompletionRequest.java index 49b3a0c0e7d61..b14d55351951e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/request/AzureOpenAiChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/request/AzureOpenAiChatCompletionRequest.java @@ -7,66 +7,28 @@ package org.elasticsearch.xpack.inference.services.azureopenai.request; -import org.apache.http.client.methods.HttpPost; -import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; -import org.elasticsearch.xpack.inference.external.request.HttpRequest; -import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModel; -import java.net.URI; -import java.nio.charset.StandardCharsets; import java.util.Objects; -public class AzureOpenAiChatCompletionRequest implements AzureOpenAiRequest { +public class AzureOpenAiChatCompletionRequest extends AzureOpenAiRequest { private final UnifiedChatInput chatInput; - private final AzureOpenAiCompletionModel model; - public AzureOpenAiChatCompletionRequest(UnifiedChatInput chatInput, AzureOpenAiCompletionModel model) { + super(Objects.requireNonNull(model), model.getTaskSettings(), createRequestEntity(Objects.requireNonNull(chatInput), model)); this.chatInput = chatInput; - this.model = Objects.requireNonNull(model); } - @Override - public HttpRequest createHttpRequest() { - var httpPost = new HttpPost(getURI()); - var requestEntity = Strings.toString(new AzureOpenAiChatCompletionRequestEntity(chatInput, model.getTaskSettings().user())); - - ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8)); - httpPost.setEntity(byteEntity); - - AzureOpenAiRequest.decorateWithAuthHeader(httpPost, model.getSecretSettings()); - - return new HttpRequest(httpPost, getInferenceEntityId()); - } - - @Override - public URI getURI() { - return model.getUri(); - } - - @Override - public String getInferenceEntityId() { - return model.getInferenceEntityId(); + private static String createRequestEntity(UnifiedChatInput chatInput, AzureOpenAiCompletionModel model) { + var user = model.getTaskSettings().user().orElse(null); + return Strings.toString(new AzureOpenAiChatCompletionRequestEntity(chatInput, user)); } @Override public boolean isStreaming() { return chatInput.stream(); } - - @Override - public Request truncate() { - // No truncation for Azure OpenAI completion - return this; - } - - @Override - public boolean[] getTruncationInfo() { - // No truncation for Azure OpenAI completion - return null; - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/request/AzureOpenAiCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/request/AzureOpenAiCompletionRequest.java index 254993f6d9ef9..fa2b45d1de052 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/request/AzureOpenAiCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/request/AzureOpenAiCompletionRequest.java @@ -7,72 +7,28 @@ package org.elasticsearch.xpack.inference.services.azureopenai.request; -import org.apache.http.client.methods.HttpPost; -import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; -import org.elasticsearch.xpack.inference.external.request.HttpRequest; -import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModel; -import java.net.URI; -import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Objects; -public class AzureOpenAiCompletionRequest implements AzureOpenAiRequest { - - private final List input; - - private final URI uri; - - private final AzureOpenAiCompletionModel model; +public class AzureOpenAiCompletionRequest extends AzureOpenAiRequest { private final boolean stream; public AzureOpenAiCompletionRequest(List input, AzureOpenAiCompletionModel model, boolean stream) { - this.input = input; - this.model = Objects.requireNonNull(model); - this.uri = model.getUri(); + super(Objects.requireNonNull(model), model.getTaskSettings(), createRequestEntity(Objects.requireNonNull(input), model, stream)); this.stream = stream; } - @Override - public HttpRequest createHttpRequest() { - var httpPost = new HttpPost(uri); - var requestEntity = Strings.toString(new AzureOpenAiCompletionRequestEntity(input, model.getTaskSettings().user(), isStreaming())); - - ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8)); - httpPost.setEntity(byteEntity); - - AzureOpenAiRequest.decorateWithAuthHeader(httpPost, model.getSecretSettings()); - - return new HttpRequest(httpPost, getInferenceEntityId()); - } - - @Override - public URI getURI() { - return this.uri; - } - - @Override - public String getInferenceEntityId() { - return model.getInferenceEntityId(); + private static String createRequestEntity(List input, AzureOpenAiCompletionModel model, boolean stream) { + var user = model.getTaskSettings().user().orElse(null); + return Strings.toString(new AzureOpenAiCompletionRequestEntity(input, user, stream)); } @Override public boolean isStreaming() { return stream; } - - @Override - public Request truncate() { - // No truncation for Azure OpenAI completion - return this; - } - - @Override - public boolean[] getTruncationInfo() { - // No truncation for Azure OpenAI completion - return null; - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/request/AzureOpenAiEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/request/AzureOpenAiEmbeddingsRequest.java index 291b4791bb8c1..930c00f89e0df 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/request/AzureOpenAiEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/request/AzureOpenAiEmbeddingsRequest.java @@ -7,26 +7,19 @@ package org.elasticsearch.xpack.inference.services.azureopenai.request; -import org.apache.http.client.methods.HttpPost; -import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; import org.elasticsearch.inference.InputType; import org.elasticsearch.xpack.inference.common.Truncator; -import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; -import java.net.URI; -import java.nio.charset.StandardCharsets; import java.util.Objects; -public class AzureOpenAiEmbeddingsRequest implements AzureOpenAiRequest { +public class AzureOpenAiEmbeddingsRequest extends AzureOpenAiRequest { private final Truncator truncator; private final Truncator.TruncationResult truncationResult; private final InputType inputType; - private final URI uri; - private final AzureOpenAiEmbeddingsModel model; public AzureOpenAiEmbeddingsRequest( Truncator truncator, @@ -34,48 +27,27 @@ public AzureOpenAiEmbeddingsRequest( InputType inputType, AzureOpenAiEmbeddingsModel model ) { + super(Objects.requireNonNull(model), model.getTaskSettings(), createRequestEntity(input, inputType, model)); this.truncator = Objects.requireNonNull(truncator); this.truncationResult = Objects.requireNonNull(input); this.inputType = inputType; - this.model = Objects.requireNonNull(model); - this.uri = model.getUri(); } - public HttpRequest createHttpRequest() { - HttpPost httpPost = new HttpPost(uri); - - String requestEntity = Strings.toString( + private static String createRequestEntity(Truncator.TruncationResult input, InputType inputType, AzureOpenAiEmbeddingsModel model) { + return Strings.toString( new AzureOpenAiEmbeddingsRequestEntity( - truncationResult.input(), + input.input(), inputType, - model.getTaskSettings().user(), + model.getTaskSettings().user().orElse(null), model.getServiceSettings().dimensions(), model.getServiceSettings().dimensionsSetByUser() ) ); - - ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8)); - httpPost.setEntity(byteEntity); - - AzureOpenAiRequest.decorateWithAuthHeader(httpPost, model.getSecretSettings()); - - return new HttpRequest(httpPost, getInferenceEntityId()); - } - - @Override - public URI getURI() { - return this.uri; - } - - @Override - public String getInferenceEntityId() { - return model.getInferenceEntityId(); } @Override public Request truncate() { var truncatedInput = truncator.truncate(truncationResult.input()); - return new AzureOpenAiEmbeddingsRequest(truncator, truncatedInput, inputType, model); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/request/AzureOpenAiEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/request/AzureOpenAiEmbeddingsRequestEntity.java index 2f75a6da6964b..45aac6e48a299 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/request/AzureOpenAiEmbeddingsRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/request/AzureOpenAiEmbeddingsRequestEntity.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.azureopenai.request; +import org.elasticsearch.common.Strings; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InputType; import org.elasticsearch.xcontent.ToXContentObject; @@ -38,7 +39,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); builder.field(INPUT_FIELD, input); - if (user != null) { + if (Strings.isNullOrEmpty(user) == false) { builder.field(USER_FIELD, user); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/request/AzureOpenAiRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/request/AzureOpenAiRequest.java index 65da3def83d81..bf402337d7a2e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/request/AzureOpenAiRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/request/AzureOpenAiRequest.java @@ -9,23 +9,61 @@ import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; import org.apache.http.message.BasicHeader; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiModel; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings; +import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiTaskSettings; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings.API_KEY; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings.ENTRA_ID; import static org.elasticsearch.xpack.inference.services.azureopenai.request.AzureOpenAiUtils.API_KEY_HEADER; -public interface AzureOpenAiRequest extends Request { +public abstract class AzureOpenAiRequest implements Request { - String MISSING_AUTHENTICATION_ERROR_MESSAGE = + public static final String MISSING_AUTHENTICATION_ERROR_MESSAGE = "The request does not have any authentication methods set. One of [%s] or [%s] is required."; + protected final M model; + private final AzureOpenAiTaskSettings taskSettings; + private final String requestEntity; + + protected AzureOpenAiRequest(M model, AzureOpenAiTaskSettings taskSettings, String requestEntity) { + this.model = Objects.requireNonNull(model); + this.taskSettings = Objects.requireNonNull(taskSettings); + this.requestEntity = Objects.requireNonNull(requestEntity); + } + + @Override + public HttpRequest createHttpRequest() { + var httpPost = new HttpPost(getURI()); + + ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8)); + httpPost.setEntity(byteEntity); + + decorateWithAuthHeader(httpPost, model.getSecretSettings()); + + var headers = taskSettings.headers(); + if (headers.mapValue().isPresent()) { + for (var entry : headers.mapValue().get().entrySet()) { + httpPost.setHeader(entry.getKey(), entry.getValue()); + } + } + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + // Default for testing static void decorateWithAuthHeader(HttpPost httpPost, AzureOpenAiSecretSettings secretSettings) { httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType())); @@ -43,4 +81,26 @@ static void decorateWithAuthHeader(HttpPost httpPost, AzureOpenAiSecretSettings throw validationException; } } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public URI getURI() { + return model.getUri(); + } + + @Override + public Request truncate() { + // Default implementation: no truncation. Subclasses may override to apply truncation if needed. + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // Default implementation: no truncation was applied, so no truncation info is available. + return null; + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/parser/HeadersTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/parser/HeadersTests.java new file mode 100644 index 0000000000000..fb3d5541f6144 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/parser/HeadersTests.java @@ -0,0 +1,267 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.common.parser; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParseException; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Consumer; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +public class HeadersTests extends AbstractBWCWireSerializationTestCase { + + public static Headers createRandom() { + return randomFrom( + Headers.UNDEFINED_INSTANCE, + Headers.NULL_INSTANCE, + new Headers(StatefulValue.of(Map.of(randomAlphaOfLength(15), randomAlphaOfLength(15)))) + ); + } + + public static Headers createRandomNonNull() { + return randomFrom( + Headers.UNDEFINED_INSTANCE, + new Headers(StatefulValue.of(Map.of(randomAlphaOfLength(15), randomAlphaOfLength(15)))) + ); + } + + @Override + protected Writeable.Reader instanceReader() { + return Headers::new; + } + + @Override + protected Headers createTestInstance() { + return createRandom(); + } + + @Override + protected Headers mutateInstance(Headers instance) throws IOException { + return doMutateInstance(instance); + } + + public static Headers doMutateInstance(Headers instance) { + var statefulValue = instance.mapValue(); + if (statefulValue.isPresent()) { + var newHeaders = new HashMap<>(statefulValue.get()); + newHeaders.put(randomAlphaOfLength(15), randomAlphaOfLength(15)); + var withNewKey = new Headers(StatefulValue.of(newHeaders)); + return randomFrom(withNewKey, Headers.NULL_INSTANCE, Headers.UNDEFINED_INSTANCE); + } + if (statefulValue.isNull()) { + var withValue = new Headers(StatefulValue.of(Map.of(randomAlphaOfLength(15), randomAlphaOfLength(15)))); + return randomFrom(withValue, Headers.UNDEFINED_INSTANCE); + } + var withValue = new Headers(StatefulValue.of(Map.of(randomAlphaOfLength(15), randomAlphaOfLength(15)))); + return randomFrom(withValue, Headers.NULL_INSTANCE); + } + + @Override + protected Headers mutateInstanceForVersion(Headers instance, TransportVersion version) { + return instance; + } + + private static String toXContentString(Headers headers) throws IOException { + var builder = XContentFactory.contentBuilder(XContentType.JSON); + builder.startObject(); + headers.toXContent(builder, null); + builder.endObject(); + return Strings.toString(builder); + } + + public void testConstructor_WhenNull_ThrowsNullPointerException() { + expectThrows(NullPointerException.class, () -> new Headers((StatefulValue>) null)); + } + + public void testIsPresent_WhenAbsent() { + assertFalse(Headers.UNDEFINED_INSTANCE.isPresent()); + } + + public void testIsPresent_WhenNull() { + assertFalse(Headers.NULL_INSTANCE.isPresent()); + } + + public void testState_WhenWithValue() { + var headers = new Headers(StatefulValue.of(Map.of("k", "v"))); + assertTrue(headers.isPresent()); + assertFalse(headers.isNull()); + assertFalse(headers.isEmpty()); + } + + public void testIsNull_WhenAbsent() { + assertFalse(Headers.UNDEFINED_INSTANCE.isNull()); + } + + public void testIsNull_WhenNull() { + assertTrue(Headers.NULL_INSTANCE.isNull()); + } + + public void testIsNull_WhenWithValue() { + assertFalse(new Headers(StatefulValue.of(Map.of("k", "v"))).isNull()); + } + + public void testIsEmpty_WhenAbsent() { + assertTrue(Headers.UNDEFINED_INSTANCE.isEmpty()); + } + + public void testIsEmpty_WhenNull() { + assertTrue(Headers.NULL_INSTANCE.isEmpty()); + } + + public void testIsEmpty_WhenPresentWithEmptyMap() { + var headers = new Headers(StatefulValue.of(Map.of())); + assertTrue(headers.isEmpty()); + assertTrue(headers.isPresent()); + } + + public void testIsEmpty_WhenPresentWithEntries() { + assertFalse(new Headers(StatefulValue.of(Map.of("k", "v"))).isEmpty()); + } + + public void testToXContent_WhenEmptyMap() throws IOException { + var headers = new Headers(StatefulValue.of(Map.of())); + assertThat(toXContentString(headers), is(XContentHelper.stripWhitespace(""" + {} + """))); + } + + public void testToXContent_WhenWithEntries() throws IOException { + var headerMap = Map.of("key", "value"); + var headers = new Headers(StatefulValue.of(headerMap)); + assertThat(toXContentString(headers), is(XContentHelper.stripWhitespace(""" + { + "headers": { + "key": "value" + } + } + """))); + } + + public void testParse_WithHeaders() throws IOException { + var json = """ + { + "headers": { + "key": "value" + } + } + """; + parseJson(json, parsed -> { + assertTrue(parsed.mapValue().isPresent()); + assertThat(parsed.mapValue().get(), is(Map.of("key", "value"))); + }); + } + + public void testParse_WhenHeadersMissing_ReturnsNullHeaders() throws IOException { + var json = """ + { + } + """; + parseJson(json, parsed -> assertThat(parsed, sameInstance(Headers.UNDEFINED_INSTANCE))); + } + + public void testParse_WhenHeadersEmptyMap() throws IOException { + var json = """ + { + "headers": {} + } + """; + parseJson(json, parsed -> assertThat(parsed, sameInstance(Headers.NULL_INSTANCE))); + } + + public void testParse_WhenHeadersIsSetToNull() throws IOException { + var json = """ + { + "headers": null + } + """; + parseJson(json, parsed -> assertThat(parsed, sameInstance(Headers.NULL_INSTANCE))); + } + + public void testParse_ThrowsWhenValueNotString() { + var json = """ + { + "headers": { + "key": 1 + } + } + """; + var exception = expectThrows(XContentParseException.class, () -> parseJson(json, parsed -> {})); + assertThat(exception.getMessage(), containsString("[headers_parser] failed to parse field [headers]")); + assertThat( + exception.getCause().getMessage(), + containsString( + "Map field [headers] has an entry that is not valid, [key => 1]. Value type of [Integer] is not one of [String].;" + ) + ); + } + + public void testParse_ThrowsWhenValueIsAnObject() { + var json = """ + { + "headers": { + "key": {} + } + } + """; + var exception = expectThrows(XContentParseException.class, () -> parseJson(json, parsed -> {})); + assertThat(exception.getMessage(), containsString("[headers_parser] failed to parse field [headers]")); + assertThat( + exception.getCause().getMessage(), + containsString("Map field [headers] has an entry that is not valid, [key => {}]. Value type of [Map] is not one of [String].;") + ); + } + + public void testParse_Roundtrip() throws IOException { + // The reason we don't allow null here is that when a Headers::NULL_INSTANCE is serialized to xContent + // it is not written (aka would look like this {}) instead of it being written {"headers": null}. + // This is because it's only used for the update API to indicate that the existing headers should be removed. + var original = createRandomNonNull(); + var json = toXContentString(original); + parseJson(json, parsedHeaders -> assertThat(parsedHeaders, is(original))); + } + + public void testParse_RoundtripNull() throws IOException { + // When a null headers is serialized to xContent, it is not written at all + // (aka would look like this {}) instead of it being written {"headers": null}. This is because it's only used for + // the update API to indicate that the existing headers should be removed. + var json = toXContentString(Headers.NULL_INSTANCE); + parseJson(json, parsedHeaders -> assertThat(parsedHeaders, is(Headers.UNDEFINED_INSTANCE))); + } + + private static void parseJson(String jsonInput, Consumer assertCallback) throws IOException { + ConstructingObjectParser constructingObjectParser = new ConstructingObjectParser<>( + "headers_parser", + false, + args -> Headers.create(args[0], "root") + ); + Headers.initParser(constructingObjectParser); + + try ( + XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, jsonInput) + ) { + parser.nextToken(); + var parsed = constructingObjectParser.parse(parser, null); + assertCallback.accept(parsed); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/parser/StatefulValueTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/parser/StatefulValueTests.java new file mode 100644 index 0000000000000..12e5fe482eea2 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/parser/StatefulValueTests.java @@ -0,0 +1,166 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.common.parser; + +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.NoSuchElementException; + +import static org.elasticsearch.xpack.inference.common.parser.StatefulValue.NO_VALUE_PRESENT; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +public class StatefulValueTests extends ESTestCase { + + private static final String VALUE = "value"; + + public void testUndefined_ReturnsSingleton() { + assertThat(StatefulValue.undefined(), sameInstance(StatefulValue.undefined())); + } + + public void testNullInstance_ReturnsSingleton() { + assertThat(StatefulValue.nullInstance(), sameInstance(StatefulValue.nullInstance())); + } + + public void testOf_ThrowsWhenNull() { + expectThrows(NullPointerException.class, () -> StatefulValue.of((String) null)); + } + + public void testOf_ReturnsNewInstanceWithValue() { + var value = randomAlphaOfLength(10); + var statefulValue = StatefulValue.of(value); + assertTrue(statefulValue.isPresent()); + assertThat(statefulValue.get(), is(value)); + } + + public void testIsUndefined_WhenUndefined() { + assertTrue(StatefulValue.undefined().isUndefined()); + } + + public void testIsUndefined_WhenNull() { + assertFalse(StatefulValue.nullInstance().isUndefined()); + } + + public void testIsUndefined_WhenWithValue() { + assertFalse(StatefulValue.of(VALUE).isUndefined()); + } + + public void testIsNull_WhenUndefined() { + assertFalse(StatefulValue.undefined().isNull()); + } + + public void testIsNull_WhenNull() { + assertTrue(StatefulValue.nullInstance().isNull()); + } + + public void testIsNull_WhenWithValue() { + assertFalse(StatefulValue.of(VALUE).isNull()); + } + + public void testIsPresent_WhenUndefined() { + assertFalse(StatefulValue.undefined().isPresent()); + } + + public void testIsPresent_WhenNull() { + assertFalse(StatefulValue.nullInstance().isPresent()); + } + + public void testIsPresent_WhenWithValue() { + assertTrue(StatefulValue.of(VALUE).isPresent()); + } + + public void testGet_ReturnsValueWhenPresent() { + var value = randomAlphaOfLength(10); + assertThat(StatefulValue.of(value).get(), is(value)); + } + + public void testGet_ThrowsWhenUndefined() { + var e = expectThrows(NoSuchElementException.class, () -> StatefulValue.undefined().get()); + assertThat(e.getMessage(), is(NO_VALUE_PRESENT.getMessage())); + } + + public void testGet_ThrowsWhenNull() { + var e = expectThrows(NoSuchElementException.class, () -> StatefulValue.nullInstance().get()); + assertThat(e.getMessage(), is(NO_VALUE_PRESENT.getMessage())); + } + + public void testOrElse_ReturnsValueWhenPresent() { + var value = randomAlphaOfLength(10); + var other = randomAlphaOfLength(10); + assertThat(StatefulValue.of(value).orElse(other), is(value)); + } + + public void testOrElse_ReturnsOtherWhenUndefined() { + var other = randomAlphaOfLength(10); + assertThat(StatefulValue.undefined().orElse(other), is(other)); + } + + public void testOrElse_ReturnsOtherWhenNull() { + var other = randomAlphaOfLength(10); + assertThat(StatefulValue.nullInstance().orElse(other), is(other)); + } + + public void testEquals_hashCode() { + { + var map = new HashMap<>(Map.of(randomAlphaOfLength(10), randomAlphaOfLength(10))); + var equalMap = new HashMap<>(map); + assertEquals(StatefulValue.of(map), StatefulValue.of(equalMap)); + assertEquals(StatefulValue.of(map).hashCode(), StatefulValue.of(equalMap).hashCode()); + } + { + assertEquals(StatefulValue.undefined().hashCode(), StatefulValue.undefined().hashCode()); + assertEquals(StatefulValue.nullInstance().hashCode(), StatefulValue.nullInstance().hashCode()); + } + } + + public void testEquals_WhenPresentWithDifferentValues() { + assertNotEquals(StatefulValue.of("a"), StatefulValue.of("b")); + } + + public void testEquals_WhenDifferentStates() { + assertNotEquals(StatefulValue.undefined(), StatefulValue.nullInstance()); + assertNotEquals(StatefulValue.undefined(), StatefulValue.of(VALUE)); + assertNotEquals(StatefulValue.nullInstance(), StatefulValue.of(VALUE)); + } + + public void testSerializationRoundtrip_WhenUndefined() throws IOException { + var original = StatefulValue.undefined(); + var copy = roundtrip(original); + assertThat(copy, sameInstance(original)); + } + + public void testSerializationRoundtrip_WhenNull() throws IOException { + var original = StatefulValue.nullInstance(); + var copy = roundtrip(original); + assertThat(copy, sameInstance(original)); + } + + public void testSerializationRoundtrip_WhenPresent() throws IOException { + var value = randomAlphaOfLength(10); + var original = StatefulValue.of(value); + var copy = roundtrip(original); + assertThat(copy, is(original)); + assertTrue(copy.isPresent()); + assertThat(copy.get(), is(value)); + } + + private static StatefulValue roundtrip(StatefulValue original) throws IOException { + try (BytesStreamOutput out = new BytesStreamOutput()) { + StatefulValue.write(out, original, StreamOutput::writeString); + try (StreamInput in = out.bytes().streamInput()) { + return StatefulValue.read(in, StreamInput::readString); + } + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java index 21af8f7199cc9..30f45216b86e0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java @@ -898,7 +898,7 @@ public void testValidateMapValues_ThrowsException_WhenMapContainsInvalidTypes() exception.getMessage(), is( "Validation Failed: 1: Map field [setting] has an entry that is not valid, " - + "[num_key => 1]. Value type of [1] is not one of [String].;" + + "[num_key => 1]. Value type of [Integer] is not one of [String].;" ) ); } @@ -958,7 +958,7 @@ public void testValidateMapStringValues_ThrowsException_WhenMapContainsInvalidTy exception.getMessage(), is( "Validation Failed: 1: Map field [setting] has an entry that is not valid, " - + "[num_key => 1]. Value type of [1] is not one of [String].;" + + "[num_key => 1]. Value type of [Integer] is not one of [String].;" ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index ee7b66ecf7ce9..050088ae4987c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -44,6 +44,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParseException; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; @@ -90,10 +91,10 @@ import static org.elasticsearch.xpack.inference.services.SenderServiceTests.createMockSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettingsTests.getAzureOpenAiSecretSettingsMap; +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiTaskSettingsTests.createRequestTaskSettingsMap; import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createChatCompletionModel; import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettingsTests.getPersistentAzureOpenAiServiceSettingsMap; import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettingsTests.getRequestAzureOpenAiServiceSettingsMap; -import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsTaskSettingsTests.getAzureOpenAiRequestTaskSettingsMap; import static org.elasticsearch.xpack.inference.services.azureopenai.request.AzureOpenAiUtils.API_KEY_HEADER; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.containsString; @@ -147,7 +148,7 @@ public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModel() throws IOExc assertThat(embeddingsModel.getServiceSettings().deploymentId(), is(DEPLOYMENT_ID_VALUE)); assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(API_VERSION_VALUE)); assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY_VALUE)); - assertThat(embeddingsModel.getTaskSettings().user(), is(ROLE_VALUE)); + assertThat(embeddingsModel.getTaskSettings().user().get(), is(ROLE_VALUE)); }, exception -> fail("Unexpected exception: " + exception)); service.parseRequestConfig( @@ -155,7 +156,7 @@ public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModel() throws IOExc TaskType.TEXT_EMBEDDING, getRequestConfigMap( getRequestAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, null, null), - getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE), + createRequestTaskSettingsMap(ROLE_VALUE), getAzureOpenAiSecretSettingsMap(API_KEY_VALUE, null) ), modelVerificationListener @@ -173,7 +174,7 @@ public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSet assertThat(embeddingsModel.getServiceSettings().deploymentId(), is(DEPLOYMENT_ID_VALUE)); assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(API_VERSION_VALUE)); assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY_VALUE)); - assertThat(embeddingsModel.getTaskSettings().user(), is(ROLE_VALUE)); + assertThat(embeddingsModel.getTaskSettings().user().get(), is(ROLE_VALUE)); assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); }, exception -> fail("Unexpected exception: " + exception)); @@ -182,7 +183,7 @@ public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSet TaskType.TEXT_EMBEDDING, getRequestConfigMap( getRequestAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, null, null), - getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE), + createRequestTaskSettingsMap(ROLE_VALUE), createRandomChunkingSettingsMap(), getAzureOpenAiSecretSettingsMap(API_KEY_VALUE, null) ), @@ -201,7 +202,7 @@ public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSet assertThat(embeddingsModel.getServiceSettings().deploymentId(), is(DEPLOYMENT_ID_VALUE)); assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(API_VERSION_VALUE)); assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY_VALUE)); - assertThat(embeddingsModel.getTaskSettings().user(), is(ROLE_VALUE)); + assertThat(embeddingsModel.getTaskSettings().user().get(), is(ROLE_VALUE)); assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); }, exception -> fail("Unexpected exception: " + exception)); @@ -210,7 +211,7 @@ public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingSet TaskType.TEXT_EMBEDDING, getRequestConfigMap( getRequestAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, null, null), - getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE), + createRequestTaskSettingsMap(ROLE_VALUE), getAzureOpenAiSecretSettingsMap(API_KEY_VALUE, null) ), modelVerificationListener @@ -233,7 +234,7 @@ public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOExcepti TaskType.SPARSE_EMBEDDING, getRequestConfigMap( getRequestAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, null, null), - getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE), + createRequestTaskSettingsMap(ROLE_VALUE), getAzureOpenAiSecretSettingsMap(API_KEY_VALUE, null) ), modelVerificationListener @@ -245,7 +246,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I try (var service = createAzureOpenAiService()) { var config = getRequestConfigMap( getRequestAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, null, null), - getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE), + createRequestTaskSettingsMap(ROLE_VALUE), getAzureOpenAiSecretSettingsMap(API_KEY_VALUE, null) ); config.put("extra_key", "value"); @@ -278,7 +279,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa var config = getRequestConfigMap( serviceSettings, - getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE), + createRequestTaskSettingsMap(ROLE_VALUE), getAzureOpenAiSecretSettingsMap(API_KEY_VALUE, null) ); @@ -295,7 +296,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { try (var service = createAzureOpenAiService()) { - var taskSettingsMap = getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE); + var taskSettingsMap = createRequestTaskSettingsMap(ROLE_VALUE); taskSettingsMap.put("extra_key", "value"); var config = getRequestConfigMap( @@ -307,8 +308,8 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() ActionListener modelVerificationListener = ActionListener.wrap((model) -> { fail("Expected exception, but got model: " + model); }, e -> { - assertThat(e, instanceOf(ElasticsearchStatusException.class)); - assertThat(e.getMessage(), is("Configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service")); + assertThat(e, instanceOf(XContentParseException.class)); + assertThat(e.getMessage(), containsString("unknown field [extra_key]")); }); service.parseRequestConfig(INFERENCE_ENTITY_ID_VALUE, TaskType.TEXT_EMBEDDING, config, modelVerificationListener); @@ -322,7 +323,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap var config = getRequestConfigMap( getRequestAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, null, null), - getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE), + createRequestTaskSettingsMap(ROLE_VALUE), secretSettingsMap ); @@ -347,7 +348,7 @@ public void testParseRequestConfig_MovesModel() throws IOException { assertThat(embeddingsModel.getServiceSettings().deploymentId(), is(DEPLOYMENT_ID_VALUE)); assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(API_VERSION_VALUE)); assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY_VALUE)); - assertThat(embeddingsModel.getTaskSettings().user(), is(ROLE_VALUE)); + assertThat(embeddingsModel.getTaskSettings().user().get(), is(ROLE_VALUE)); }, exception -> fail("Unexpected exception: " + exception)); service.parseRequestConfig( @@ -355,7 +356,7 @@ public void testParseRequestConfig_MovesModel() throws IOException { TaskType.TEXT_EMBEDDING, getRequestConfigMap( getRequestAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, null, null), - getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE), + createRequestTaskSettingsMap(ROLE_VALUE), getAzureOpenAiSecretSettingsMap(API_KEY_VALUE, null) ), modelVerificationListener @@ -367,7 +368,7 @@ public void testParsePersistedConfig_WithSecrets_CreatesAnAzureOpenAiEmbeddingsM try (var service = createAzureOpenAiService()) { var persistedConfig = getPersistedConfigMap( getPersistentAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, 100, 512), - getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE), + createRequestTaskSettingsMap(ROLE_VALUE), getAzureOpenAiSecretSettingsMap(API_KEY_VALUE, null) ); @@ -389,7 +390,7 @@ public void testParsePersistedConfig_WithSecrets_CreatesAnAzureOpenAiEmbeddingsM assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(API_VERSION_VALUE)); assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100)); assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); - assertThat(embeddingsModel.getTaskSettings().user(), is(ROLE_VALUE)); + assertThat(embeddingsModel.getTaskSettings().user().get(), is(ROLE_VALUE)); assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY_VALUE)); } } @@ -398,7 +399,7 @@ public void testParsePersistedConfig_WithSecrets_CreatesAnOpenAiEmbeddingsModelW try (var service = createAzureOpenAiService()) { var persistedConfig = getPersistedConfigMap( getPersistentAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, 100, 512), - getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE), + createRequestTaskSettingsMap(ROLE_VALUE), createRandomChunkingSettingsMap(), getAzureOpenAiSecretSettingsMap(API_KEY_VALUE, null) ); @@ -421,7 +422,7 @@ public void testParsePersistedConfig_WithSecrets_CreatesAnOpenAiEmbeddingsModelW assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(API_VERSION_VALUE)); assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100)); assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); - assertThat(embeddingsModel.getTaskSettings().user(), is(ROLE_VALUE)); + assertThat(embeddingsModel.getTaskSettings().user().get(), is(ROLE_VALUE)); assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY_VALUE)); } @@ -431,7 +432,7 @@ public void testParsePersistedConfig_WithSecrets_CreatesAnOpenAiEmbeddingsModelW try (var service = createAzureOpenAiService()) { var persistedConfig = getPersistedConfigMap( getPersistentAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, 100, 512), - getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE), + createRequestTaskSettingsMap(ROLE_VALUE), getAzureOpenAiSecretSettingsMap(API_KEY_VALUE, null) ); @@ -453,7 +454,7 @@ public void testParsePersistedConfig_WithSecrets_CreatesAnOpenAiEmbeddingsModelW assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(API_VERSION_VALUE)); assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100)); assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); - assertThat(embeddingsModel.getTaskSettings().user(), is(ROLE_VALUE)); + assertThat(embeddingsModel.getTaskSettings().user().get(), is(ROLE_VALUE)); assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY_VALUE)); } @@ -463,7 +464,7 @@ public void testParsePersistedConfig_WithSecrets_ThrowsErrorTryingToParseInvalid try (var service = createAzureOpenAiService()) { var persistedConfig = getPersistedConfigMap( getPersistentAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, null, null), - getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE), + createRequestTaskSettingsMap(ROLE_VALUE), getAzureOpenAiSecretSettingsMap(API_KEY_VALUE, null) ); @@ -495,7 +496,7 @@ public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExist try (var service = createAzureOpenAiService()) { var persistedConfig = getPersistedConfigMap( getPersistentAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, 100, 512), - getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE), + createRequestTaskSettingsMap(ROLE_VALUE), getAzureOpenAiSecretSettingsMap(API_KEY_VALUE, null) ); persistedConfig.config().put("extra_key", "value"); @@ -518,7 +519,7 @@ public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExist assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(API_VERSION_VALUE)); assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100)); assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); - assertThat(embeddingsModel.getTaskSettings().user(), is(ROLE_VALUE)); + assertThat(embeddingsModel.getTaskSettings().user().get(), is(ROLE_VALUE)); assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY_VALUE)); } } @@ -530,7 +531,7 @@ public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExist var persistedConfig = getPersistedConfigMap( getPersistentAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, 100, 512), - getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE), + createRequestTaskSettingsMap(ROLE_VALUE), secretSettingsMap ); @@ -552,7 +553,7 @@ public void testParsePersistedConfig_WithSecrets_DoesNotThrowWhenAnExtraKeyExist assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(API_VERSION_VALUE)); assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100)); assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); - assertThat(embeddingsModel.getTaskSettings().user(), is(ROLE_VALUE)); + assertThat(embeddingsModel.getTaskSettings().user().get(), is(ROLE_VALUE)); assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY_VALUE)); } } @@ -561,7 +562,7 @@ public void testParsePersistedConfig_WithSecrets_NotThrowWhenAnExtraKeyExistsInS try (var service = createAzureOpenAiService()) { var persistedConfig = getPersistedConfigMap( getPersistentAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, 100, 512), - getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE), + createRequestTaskSettingsMap(ROLE_VALUE), getAzureOpenAiSecretSettingsMap(API_KEY_VALUE, null) ); persistedConfig.secrets().put("extra_key", "value"); @@ -584,7 +585,7 @@ public void testParsePersistedConfig_WithSecrets_NotThrowWhenAnExtraKeyExistsInS assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(API_VERSION_VALUE)); assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100)); assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); - assertThat(embeddingsModel.getTaskSettings().user(), is(ROLE_VALUE)); + assertThat(embeddingsModel.getTaskSettings().user().get(), is(ROLE_VALUE)); assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY_VALUE)); } } @@ -602,7 +603,7 @@ public void testParsePersistedConfig_WithSecrets_NotThrowWhenAnExtraKeyExistsInS var persistedConfig = getPersistedConfigMap( serviceSettingsMap, - getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE), + createRequestTaskSettingsMap(ROLE_VALUE), getAzureOpenAiSecretSettingsMap(API_KEY_VALUE, null) ); @@ -624,14 +625,14 @@ public void testParsePersistedConfig_WithSecrets_NotThrowWhenAnExtraKeyExistsInS assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(API_VERSION_VALUE)); assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100)); assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); - assertThat(embeddingsModel.getTaskSettings().user(), is(ROLE_VALUE)); + assertThat(embeddingsModel.getTaskSettings().user().get(), is(ROLE_VALUE)); assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY_VALUE)); } } public void testParsePersistedConfig_WithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { try (var service = createAzureOpenAiService()) { - var taskSettingsMap = getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE); + var taskSettingsMap = createRequestTaskSettingsMap(ROLE_VALUE); taskSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap( @@ -658,7 +659,7 @@ public void testParsePersistedConfig_WithSecrets_NotThrowWhenAnExtraKeyExistsInT assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(API_VERSION_VALUE)); assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100)); assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); - assertThat(embeddingsModel.getTaskSettings().user(), is(ROLE_VALUE)); + assertThat(embeddingsModel.getTaskSettings().user().get(), is(ROLE_VALUE)); assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(API_KEY_VALUE)); } } @@ -667,7 +668,7 @@ public void testParsePersistedConfig_CreatesAnAzureOpenAiEmbeddingsModel() throw try (var service = createAzureOpenAiService()) { var persistedConfig = getPersistedConfigMap( getPersistentAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, null, null), - getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE) + createRequestTaskSettingsMap(ROLE_VALUE) ); var model = service.parsePersistedConfig( @@ -686,7 +687,7 @@ public void testParsePersistedConfig_CreatesAnAzureOpenAiEmbeddingsModel() throw assertThat(embeddingsModel.getServiceSettings().resourceName(), is(RESOURCE_NAME_VALUE)); assertThat(embeddingsModel.getServiceSettings().deploymentId(), is(DEPLOYMENT_ID_VALUE)); assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(API_VERSION_VALUE)); - assertThat(embeddingsModel.getTaskSettings().user(), is(ROLE_VALUE)); + assertThat(embeddingsModel.getTaskSettings().user().get(), is(ROLE_VALUE)); assertNull(embeddingsModel.getSecretSettings()); } } @@ -695,7 +696,7 @@ public void testParsePersistedConfig_CreatesAnAzureOpenAiEmbeddingsModelWhenChun try (var service = createAzureOpenAiService()) { var persistedConfig = getPersistedConfigMap( getPersistentAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, null, null), - getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE), + createRequestTaskSettingsMap(ROLE_VALUE), createRandomChunkingSettingsMap() ); @@ -715,7 +716,7 @@ public void testParsePersistedConfig_CreatesAnAzureOpenAiEmbeddingsModelWhenChun assertThat(embeddingsModel.getServiceSettings().resourceName(), is(RESOURCE_NAME_VALUE)); assertThat(embeddingsModel.getServiceSettings().deploymentId(), is(DEPLOYMENT_ID_VALUE)); assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(API_VERSION_VALUE)); - assertThat(embeddingsModel.getTaskSettings().user(), is(ROLE_VALUE)); + assertThat(embeddingsModel.getTaskSettings().user().get(), is(ROLE_VALUE)); assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); assertNull(embeddingsModel.getSecretSettings()); } @@ -725,7 +726,7 @@ public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingS try (var service = createAzureOpenAiService()) { var persistedConfig = getPersistedConfigMap( getPersistentAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, null, null), - getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE) + createRequestTaskSettingsMap(ROLE_VALUE) ); var model = service.parsePersistedConfig( @@ -744,7 +745,7 @@ public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWhenChunkingS assertThat(embeddingsModel.getServiceSettings().resourceName(), is(RESOURCE_NAME_VALUE)); assertThat(embeddingsModel.getServiceSettings().deploymentId(), is(DEPLOYMENT_ID_VALUE)); assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(API_VERSION_VALUE)); - assertThat(embeddingsModel.getTaskSettings().user(), is(ROLE_VALUE)); + assertThat(embeddingsModel.getTaskSettings().user().get(), is(ROLE_VALUE)); assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); assertNull(embeddingsModel.getSecretSettings()); } @@ -754,7 +755,7 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro try (var service = createAzureOpenAiService()) { var persistedConfig = getPersistedConfigMap( getPersistentAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, null, null), - getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE) + createRequestTaskSettingsMap(ROLE_VALUE) ); var thrownException = expectThrows( @@ -785,7 +786,7 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() try (var service = createAzureOpenAiService()) { var persistedConfig = getPersistedConfigMap( getPersistentAzureOpenAiServiceSettingsMap(RESOURCE_NAME_VALUE, DEPLOYMENT_ID_VALUE, API_VERSION_VALUE, null, null), - getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE) + createRequestTaskSettingsMap(ROLE_VALUE) ); persistedConfig.config().put("extra_key", "value"); @@ -805,7 +806,7 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() assertThat(embeddingsModel.getServiceSettings().resourceName(), is(RESOURCE_NAME_VALUE)); assertThat(embeddingsModel.getServiceSettings().deploymentId(), is(DEPLOYMENT_ID_VALUE)); assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(API_VERSION_VALUE)); - assertThat(embeddingsModel.getTaskSettings().user(), is(ROLE_VALUE)); + assertThat(embeddingsModel.getTaskSettings().user().get(), is(ROLE_VALUE)); assertNull(embeddingsModel.getSecretSettings()); } } @@ -821,7 +822,7 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettin ); serviceSettingsMap.put("extra_key", "value"); - var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE)); + var persistedConfig = getPersistedConfigMap(serviceSettingsMap, createRequestTaskSettingsMap(ROLE_VALUE)); var model = service.parsePersistedConfig( new UnparsedModel( @@ -839,14 +840,14 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettin assertThat(embeddingsModel.getServiceSettings().resourceName(), is(RESOURCE_NAME_VALUE)); assertThat(embeddingsModel.getServiceSettings().deploymentId(), is(DEPLOYMENT_ID_VALUE)); assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(API_VERSION_VALUE)); - assertThat(embeddingsModel.getTaskSettings().user(), is(ROLE_VALUE)); + assertThat(embeddingsModel.getTaskSettings().user().get(), is(ROLE_VALUE)); assertNull(embeddingsModel.getSecretSettings()); } } public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { try (var service = createAzureOpenAiService()) { - var taskSettingsMap = getAzureOpenAiRequestTaskSettingsMap(ROLE_VALUE); + var taskSettingsMap = createRequestTaskSettingsMap(ROLE_VALUE); taskSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap( @@ -870,7 +871,7 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings( assertThat(embeddingsModel.getServiceSettings().resourceName(), is(RESOURCE_NAME_VALUE)); assertThat(embeddingsModel.getServiceSettings().deploymentId(), is(DEPLOYMENT_ID_VALUE)); assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(API_VERSION_VALUE)); - assertThat(embeddingsModel.getTaskSettings().user(), is(ROLE_VALUE)); + assertThat(embeddingsModel.getTaskSettings().user().get(), is(ROLE_VALUE)); assertNull(embeddingsModel.getSecretSettings()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiTaskSettingsTests.java new file mode 100644 index 0000000000000..397c4d4c52e09 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiTaskSettingsTests.java @@ -0,0 +1,415 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParseException; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.common.parser.Headers; +import org.elasticsearch.xpack.inference.common.parser.HeadersTests; +import org.elasticsearch.xpack.inference.common.parser.StatefulValue; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiTaskSettings.INFERENCE_AZURE_OPENAI_TASK_SETTINGS_HEADERS; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +public abstract class AzureOpenAiTaskSettingsTests> extends AbstractBWCWireSerializationTestCase { + + private static final String USER = "user"; + private static final StatefulValue STATEFUL_USER = StatefulValue.of(USER); + private static final Map HEADERS_MAP = Map.of("key", "value"); + private static final Headers HEADERS = new Headers(StatefulValue.of(HEADERS_MAP)); + + public T createRandom() { + StatefulValue user = randomFrom( + StatefulValue.undefined(), + StatefulValue.nullInstance(), + StatefulValue.of(randomAlphaOfLength(15)) + ); + var headers = HeadersTests.createRandom(); + return create(user, headers); + } + + public void testIsEmpty() { + var bothNull = create(StatefulValue.nullInstance(), Headers.NULL_INSTANCE); + assertTrue(bothNull.isEmpty()); + + var nullUserEmptyHeaders = create(StatefulValue.nullInstance(), new Headers(StatefulValue.of(Map.of()))); + assertTrue(nullUserEmptyHeaders.isEmpty()); + + var nullHeaders = create(STATEFUL_USER, Headers.NULL_INSTANCE); + assertFalse(nullHeaders.isEmpty()); + + var nullUser = create(StatefulValue.nullInstance(), HEADERS); + assertFalse(nullUser.isEmpty()); + + var neitherNull = create(STATEFUL_USER, HEADERS); + assertFalse(neitherNull.isEmpty()); + + var emptyUserString = create(StatefulValue.of(""), Headers.NULL_INSTANCE); + assertTrue(emptyUserString.isEmpty()); + + var headersNull = create(StatefulValue.nullInstance(), Headers.NULL_INSTANCE); + assertTrue(headersNull.isEmpty()); + + var headersUndefined = create(StatefulValue.nullInstance(), Headers.UNDEFINED_INSTANCE); + assertTrue(headersUndefined.isEmpty()); + } + + public void testUpdatedTaskSettings() { + var initialSettings = createRandom(); + var newSettings = createRandom(); + + Map newSettingsMap = new HashMap<>(); + + if (newSettings.user().isUndefined() == false) { + newSettingsMap.put(AzureOpenAiServiceFields.USER, newSettings.user().orElse(null)); + } + + if (newSettings.headers().mapValue().isUndefined() == false) { + newSettingsMap.put(Headers.HEADERS_FIELD, newSettings.headers().mapValue().orElse(null)); + } + + var updatedSettings = initialSettings.updatedTaskSettings(Collections.unmodifiableMap(newSettingsMap)); + + if (newSettings.user().isPresent()) { + assertEquals(newSettings.user(), updatedSettings.user()); + } else if (newSettings.user().isNull()) { + // When the new settings has a null user, we want to remove the existing user, so the updated settings should now + // have the user as undefined + assertEquals(StatefulValue.undefined(), updatedSettings.user()); + } else { + // If the new settings did not have user, the updated settings should keep the existing user + assertEquals(initialSettings.user(), updatedSettings.user()); + } + + if (newSettings.headers().isPresent()) { + assertEquals(newSettings.headers(), updatedSettings.headers()); + } else if (newSettings.headers().isNull()) { + // When the new settings has a null headers field, we want to remove the existing headers, so the updated settings should now + // have the headers as undefined + assertEquals(Headers.UNDEFINED_INSTANCE, updatedSettings.headers()); + } else { + // If the new settings did not have the headers field, the updated settings should keep the existing headers + assertEquals(initialSettings.headers(), updatedSettings.headers()); + } + } + + public void testUpdatedTaskSettings_ApplyingEmptyHeaders() { + var initialSettings = create(STATEFUL_USER, Headers.NULL_INSTANCE); + Map newSettingsMap = Map.of(Headers.HEADERS_FIELD, Map.of()); + + var updatedSettings = initialSettings.updatedTaskSettings(newSettingsMap); + assertThat(updatedSettings, is(create(STATEFUL_USER, Headers.UNDEFINED_INSTANCE))); + + var initialSettingsDefinedHeaders = create(STATEFUL_USER, HEADERS); + // This will remove the headers because using "headers": {} in the update counts as the user wanting to remove all existing headers + updatedSettings = initialSettingsDefinedHeaders.updatedTaskSettings(newSettingsMap); + assertThat(updatedSettings, is(create(STATEFUL_USER, Headers.UNDEFINED_INSTANCE))); + } + + public void testUpdateTaskSettings_EmptyInstance() { + var initialSettings = create(STATEFUL_USER, HEADERS); + var newSettingsMap = new HashMap(); + newSettingsMap.put(AzureOpenAiServiceFields.USER, null); + newSettingsMap.put(Headers.HEADERS_FIELD, null); + + var updatedSettings = initialSettings.updatedTaskSettings(newSettingsMap); + assertThat(updatedSettings, sameInstance(emptySettings())); + } + + public void testFromMap_WithUserAndHeaders() { + assertThat( + createFromMap( + new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, USER, Headers.HEADERS_FIELD, HEADERS_MAP)), + ConfigurationParseContext.REQUEST + ), + is(create(STATEFUL_USER, HEADERS)) + ); + } + + public void testFromMap_UserIsEmptyString() { + var thrownException = expectThrows( + ValidationException.class, + () -> createFromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "")), ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + is(Strings.format("Validation Failed: 1: [task_settings] Invalid value empty string. [user] must be a non-empty string;")) + ); + } + + public void testFromMap_UserIsEmptyString_DoesNotThrowForPersistentContext() { + var settings = createFromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "")), ConfigurationParseContext.PERSISTENT); + assertTrue(settings.user().isPresent() && settings.user().get().isEmpty()); + } + + public void testFromMap_isEmpty() { + { + var emptyMap = createFromMap(new HashMap<>(Map.of()), randomContext()); + assertTrue(emptyMap.isEmpty()); + } + { + var emptyUserUndefinedHeaders = createFromMap( + new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "")), + ConfigurationParseContext.PERSISTENT + ); + assertTrue(emptyUserUndefinedHeaders.isEmpty()); + } + { + var emptyUserEmptyHeaders = createFromMap( + new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "", Headers.HEADERS_FIELD, Map.of())), + ConfigurationParseContext.PERSISTENT + ); + assertTrue(emptyUserEmptyHeaders.isEmpty()); + } + { + var emptyUserNullHeadersMap = new HashMap(); + emptyUserNullHeadersMap.put(AzureOpenAiServiceFields.USER, ""); + emptyUserNullHeadersMap.put(Headers.HEADERS_FIELD, null); + var emptyUserNullHeaders = createFromMap(emptyUserNullHeadersMap, ConfigurationParseContext.PERSISTENT); + assertTrue(emptyUserNullHeaders.isEmpty()); + } + { + var undefinedUserEmptyHeaders = createFromMap(new HashMap<>(Map.of(Headers.HEADERS_FIELD, Map.of())), randomContext()); + assertTrue(undefinedUserEmptyHeaders.isEmpty()); + } + { + var nullUserNullHeadersMap = new HashMap(); + nullUserNullHeadersMap.put(AzureOpenAiServiceFields.USER, null); + nullUserNullHeadersMap.put(Headers.HEADERS_FIELD, null); + var emptyUserNullHeaders = createFromMap(nullUserNullHeadersMap, randomContext()); + assertTrue(emptyUserNullHeaders.isEmpty()); + } + } + + private static ConfigurationParseContext randomContext() { + return randomBoolean() ? ConfigurationParseContext.REQUEST : ConfigurationParseContext.PERSISTENT; + } + + public void testFromMap_MissingUser_DoesNotThrowException() { + var taskSettings = createFromMap(new HashMap<>(Map.of()), ConfigurationParseContext.REQUEST); + assertTrue(taskSettings.user().isUndefined()); + } + + public void testFromMap_ReturnsEmptySettings_WhenTheMapDoesNotContainTheFields() { + // The HashMap is missing the headers key + var settings = createFromMap(new HashMap<>(HEADERS_MAP), ConfigurationParseContext.PERSISTENT); + assertTrue(settings.user().isUndefined()); + assertThat(settings.headers(), sameInstance(Headers.UNDEFINED_INSTANCE)); + } + + public void testFromMap_ParsesCorrectly_WhenUserIsMissing() { + var settings = createFromMap( + new HashMap<>(Map.of(Headers.HEADERS_FIELD, new HashMap<>(HEADERS_MAP))), + ConfigurationParseContext.REQUEST + ); + + assertTrue(settings.user().isUndefined()); + assertThat(settings.headers(), is(HEADERS)); + } + + public void testFromMap_ParsesCorrectly_WhenHeadersIsMissing() { + var settings = createFromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, USER)), ConfigurationParseContext.REQUEST); + + assertTrue(settings.user().isPresent()); + assertThat(settings.user().get(), is(USER)); + assertThat(settings.headers(), is(Headers.UNDEFINED_INSTANCE)); + } + + public void testFromMap_ParsesCorrectly_WhenHeadersIsEmptyMap() { + var settings = createFromMap( + new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, USER, Headers.HEADERS_FIELD, Map.of())), + ConfigurationParseContext.REQUEST + ); + + assertTrue(settings.user().isPresent()); + assertThat(settings.user().get(), is(USER)); + assertTrue(settings.headers().isEmpty()); + } + + public void testFromMap_ParsesCorrectly_WhenHeadersMapOfNulls() { + var headersMap = new HashMap(); + headersMap.put("key1", null); + headersMap.put("key2", null); + var settings = createFromMap( + new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, USER, Headers.HEADERS_FIELD, headersMap)), + ConfigurationParseContext.REQUEST + ); + + assertTrue(settings.user().isPresent()); + assertThat(settings.user().get(), is(USER)); + assertTrue(settings.headers().isEmpty()); + } + + public void testFromMap_ThrowsException_WhenHeadersContainsAnInteger() { + var exception = expectThrows( + XContentParseException.class, + () -> createFromMap( + new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, USER, Headers.HEADERS_FIELD, new HashMap<>(Map.of("key", 1)))), + ConfigurationParseContext.REQUEST + ) + ); + + assertThat(exception.getMessage(), containsString("failed to parse field [headers]")); + assertThat( + exception.getCause().getMessage(), + containsString( + "Map field [headers] has an entry that is not valid, [key => 1]. Value type of [Integer] is not one of [String].;" + ) + ); + } + + public void testFromMap_ThrowsException_WhenUserIsAnInteger() { + var exception = expectThrows( + XContentParseException.class, + () -> createFromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, 1)), ConfigurationParseContext.REQUEST) + ); + + assertThat( + exception.getMessage(), + containsString("[azure_openai_task_settings_parser] user doesn't support values of type: VALUE_NUMBER") + ); + } + + public void testFromMap_WithUser() { + assertThat( + create(STATEFUL_USER, Headers.UNDEFINED_INSTANCE), + is(createFromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, USER)), ConfigurationParseContext.PERSISTENT)) + ); + } + + public void testFromMap_WithRequestContext_ReturnsEmptySettings_WhenMapIsEmpty() { + var settings = createFromMap(new HashMap<>(Map.of()), ConfigurationParseContext.REQUEST); + assertTrue(settings.isEmpty()); + assertTrue(settings.user().isUndefined()); + assertThat(settings.headers(), sameInstance(Headers.UNDEFINED_INSTANCE)); + assertThat(settings, sameInstance(emptySettings())); + } + + public void testUpdatedTaskSettings_KeepsOriginalValues_WhenOverridesAreEmpty() { + var taskSettings = createFromMap( + new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, USER, Headers.HEADERS_FIELD, HEADERS_MAP)), + ConfigurationParseContext.PERSISTENT + ); + + var overriddenTaskSettings = taskSettings.updatedTaskSettings(Map.of()); + assertThat(overriddenTaskSettings, is(taskSettings)); + } + + public void testToXContent_RoundTrip() throws IOException { + // The reason we don't allow null here is that when a NULL_INSTANCE is serialized to xContent + // it is not written (aka would look like this {}) instead of it being written {"headers": null} or {"user": null}. + // This is because it's only used for the update API to indicate that the existing headers should be removed. + var user = randomFrom(StatefulValue.undefined(), StatefulValue.of(randomAlphaOfLength(15))); + var headers = HeadersTests.createRandomNonNull(); + var original = create(user, headers); + + String json; + try (XContentBuilder builder = XContentBuilder.builder(JsonXContent.jsonXContent)) { + original.toXContent(builder, ToXContent.EMPTY_PARAMS); + json = Strings.toString(builder); + } + var map = XContentHelper.convertToMap(JsonXContent.jsonXContent, json, false); + + var roundTrippedPersistentContext = createFromMap(map, ConfigurationParseContext.PERSISTENT); + assertThat(roundTrippedPersistentContext, is(original)); + + var roundTrippedRequestContext = createFromMap(map, ConfigurationParseContext.REQUEST); + assertThat(roundTrippedRequestContext, is(original)); + } + + public void testFromMap_ThrowsException_WhenMapContainsExtraFields_ForRequestContext() { + var exception = expectThrows( + IllegalArgumentException.class, + () -> createFromMap( + new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, USER, Headers.HEADERS_FIELD, Map.of(), "extra_field", "value")), + ConfigurationParseContext.REQUEST + ) + ); + + assertThat(exception.getMessage(), containsString("[azure_openai_task_settings_parser] unknown field [extra_field]")); + } + + public void testFromMap_DoesNotThrowException_WhenMapContainsExtraFields_ForPersistentContext() { + var settings = createFromMap( + new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, USER, Headers.HEADERS_FIELD, Map.of(), "extra_field", "value")), + ConfigurationParseContext.PERSISTENT + ); + + assertTrue(settings.user().isPresent()); + assertThat(settings.user().get(), is(USER)); + assertTrue(settings.headers().isEmpty()); + } + + public static Map createRequestTaskSettingsMap(@Nullable String user) { + var map = new HashMap(); + + if (user != null) { + map.put(AzureOpenAiServiceFields.USER, user); + } + + return map; + } + + @Override + protected T mutateInstanceForVersion(T instance, TransportVersion version) { + if (version.supports(INFERENCE_AZURE_OPENAI_TASK_SETTINGS_HEADERS)) { + return instance; + } + + var userForCreate = instance.user().isPresent() ? instance.user() : StatefulValue.undefined(); + + return create(userForCreate, Headers.UNDEFINED_INSTANCE); + } + + @Override + protected T mutateInstance(T instance) { + var setNull = randomBoolean(); + var fieldToMutate = randomIntBetween(0, 1); + + return switch (fieldToMutate) { + case 0 -> { + StatefulValue userForCreate; + + if (instance.user().isUndefined()) { + userForCreate = setNull ? StatefulValue.nullInstance() : StatefulValue.of(randomAlphaOfLength(15)); + } else if (instance.user().isNull()) { + userForCreate = randomBoolean() ? StatefulValue.undefined() : StatefulValue.of(randomAlphaOfLength(15)); + } else { + userForCreate = StatefulValue.of(instance.user() + "modified"); + } + yield create(userForCreate, instance.headers()); + } + case 1 -> create(instance.user(), HeadersTests.doMutateInstance(instance.headers())); + default -> throw new IllegalStateException("Unexpected value: " + fieldToMutate); + }; + } + + protected abstract T create(StatefulValue user, @Nullable Headers headers); + + protected abstract T createFromMap(Map map, ConfigurationParseContext context); + + protected abstract T emptySettings(); +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java index 23c773ab0d61b..20b6f1741ae46 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java @@ -50,9 +50,9 @@ import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiTaskSettingsTests.createRequestTaskSettingsMap; import static org.elasticsearch.xpack.inference.services.azureopenai.completion.AzureOpenAiCompletionModelTests.createCompletionModel; import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel; -import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsRequestTaskSettingsTests.createRequestTaskSettingsMap; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModelTests.java index be31721314c53..9c62d57f38cfc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionModelTests.java @@ -11,6 +11,8 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.common.parser.Headers; +import org.elasticsearch.xpack.inference.common.parser.StatefulValue; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields; @@ -165,13 +167,14 @@ private static AzureOpenAiCompletionModel createAzureOpenAiModelWithTaskType( ) { var secureApiKey = apiKey != null ? new SecureString(apiKey.toCharArray()) : null; var secureEntraId = entraId != null ? new SecureString(entraId.toCharArray()) : null; + var userToUse = user == null ? StatefulValue.undefined() : StatefulValue.of(user); return new AzureOpenAiCompletionModel( inferenceEntityId, taskType, "service", new AzureOpenAiCompletionServiceSettings(resourceName, deploymentId, apiVersion, null), - new AzureOpenAiCompletionTaskSettings(user), + new AzureOpenAiCompletionTaskSettings(userToUse, Headers.UNDEFINED_INSTANCE), new AzureOpenAiSecretSettings(secureApiKey, secureEntraId) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionRequestTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionRequestTaskSettingsTests.java deleted file mode 100644 index 51963c275a08a..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionRequestTaskSettingsTests.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.azureopenai.completion; - -import org.elasticsearch.common.ValidationException; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields; - -import java.util.HashMap; -import java.util.Map; - -import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.is; - -public class AzureOpenAiCompletionRequestTaskSettingsTests extends ESTestCase { - - public void testFromMap_ReturnsEmptySettings_WhenMapIsEmpty() { - var settings = AzureOpenAiCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of())); - assertThat(settings, is(AzureOpenAiCompletionRequestTaskSettings.EMPTY_SETTINGS)); - } - - public void testFromMap_ReturnsEmptySettings_WhenMapDoesNotContainKnownFields() { - var settings = AzureOpenAiCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of("key", "model"))); - assertThat(settings, is(AzureOpenAiCompletionRequestTaskSettings.EMPTY_SETTINGS)); - } - - public void testFromMap_ReturnsUser() { - var settings = AzureOpenAiCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "user"))); - assertThat(settings.user(), is("user")); - } - - public void testFromMap_WhenUserIsEmpty_ThrowsValidationException() { - var exception = expectThrows( - ValidationException.class, - () -> AzureOpenAiCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, ""))) - ); - - assertThat(exception.getMessage(), containsString("[user] must be a non-empty string")); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettingsTests.java index 42a5388b66d6b..79fadb90df2ab 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/completion/AzureOpenAiCompletionTaskSettingsTests.java @@ -7,109 +7,38 @@ package org.elasticsearch.xpack.inference.services.azureopenai.completion; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields; -import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsTaskSettings; -import org.hamcrest.MatcherAssert; +import org.elasticsearch.xpack.inference.common.parser.Headers; +import org.elasticsearch.xpack.inference.common.parser.StatefulValue; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiTaskSettingsTests; -import java.io.IOException; -import java.util.HashMap; import java.util.Map; -import static org.hamcrest.Matchers.is; +public class AzureOpenAiCompletionTaskSettingsTests extends AzureOpenAiTaskSettingsTests { -public class AzureOpenAiCompletionTaskSettingsTests extends AbstractWireSerializingTestCase { - - public static AzureOpenAiCompletionTaskSettings createRandomWithUser() { - return new AzureOpenAiCompletionTaskSettings(randomAlphaOfLength(15)); - } - - public static AzureOpenAiCompletionTaskSettings createRandom() { - return new AzureOpenAiCompletionTaskSettings(randomAlphaOfLengthOrNull(15)); - } - - public void testIsEmpty() { - var randomSettings = createRandom(); - var stringRep = Strings.toString(randomSettings); - assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}")); - } - - public void testUpdatedTaskSettings() { - var initialSettings = createRandom(); - var newSettings = createRandom(); - AzureOpenAiCompletionTaskSettings updatedSettings = (AzureOpenAiCompletionTaskSettings) initialSettings.updatedTaskSettings( - newSettings.user() == null ? Map.of() : Map.of(AzureOpenAiServiceFields.USER, newSettings.user()) - ); - - assertEquals(newSettings.user() == null ? initialSettings.user() : newSettings.user(), updatedSettings.user()); - } - - public void testFromMap_WithUser() { - var user = "user"; - - assertThat( - new AzureOpenAiCompletionTaskSettings(user), - is(AzureOpenAiCompletionTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, user)))) - ); - } - - public void testFromMap_UserIsEmptyString() { - var thrownException = expectThrows( - ValidationException.class, - () -> AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, ""))) - ); - - MatcherAssert.assertThat( - thrownException.getMessage(), - is(Strings.format("Validation Failed: 1: [task_settings] Invalid value empty string. [user] must be a non-empty string;")) - ); - } - - public void testFromMap_MissingUser_DoesNotThrowException() { - var taskSettings = AzureOpenAiCompletionTaskSettings.fromMap(new HashMap<>(Map.of())); - assertNull(taskSettings.user()); - } - - public void testOverrideWith_KeepsOriginalValuesWithOverridesAreNull() { - var taskSettings = AzureOpenAiCompletionTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "user"))); - - var overriddenTaskSettings = AzureOpenAiCompletionTaskSettings.of( - taskSettings, - AzureOpenAiCompletionRequestTaskSettings.EMPTY_SETTINGS - ); - assertThat(overriddenTaskSettings, is(taskSettings)); + @Override + protected Writeable.Reader instanceReader() { + return AzureOpenAiCompletionTaskSettings::new; } - public void testOverrideWith_UsesOverriddenSettings() { - var user = "user"; - var userOverride = "user override"; - - var taskSettings = AzureOpenAiCompletionTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, user))); - - var requestTaskSettings = AzureOpenAiCompletionRequestTaskSettings.fromMap( - new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, userOverride)) - ); - - var overriddenTaskSettings = AzureOpenAiCompletionTaskSettings.of(taskSettings, requestTaskSettings); - assertThat(overriddenTaskSettings, is(new AzureOpenAiCompletionTaskSettings(userOverride))); + @Override + protected AzureOpenAiCompletionTaskSettings createTestInstance() { + return createRandom(); } @Override - protected Writeable.Reader instanceReader() { - return AzureOpenAiCompletionTaskSettings::new; + protected AzureOpenAiCompletionTaskSettings create(StatefulValue user, Headers headers) { + return new AzureOpenAiCompletionTaskSettings(user, headers); } @Override - protected AzureOpenAiCompletionTaskSettings createTestInstance() { - return createRandomWithUser(); + protected AzureOpenAiCompletionTaskSettings createFromMap(Map map, ConfigurationParseContext context) { + return AzureOpenAiCompletionTaskSettings.fromMap(map, context); } @Override - protected AzureOpenAiCompletionTaskSettings mutateInstance(AzureOpenAiCompletionTaskSettings instance) throws IOException { - String user = randomValueOtherThan(instance.user(), () -> randomAlphaOfLengthOrNull(15)); - return new AzureOpenAiCompletionTaskSettings(user); + protected AzureOpenAiCompletionTaskSettings emptySettings() { + return AzureOpenAiCompletionTaskSettings.EMPTY; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModelTests.java index 2f6760cb36e9f..332a65d8c122c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModelTests.java @@ -13,12 +13,14 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.common.parser.Headers; +import org.elasticsearch.xpack.inference.common.parser.StatefulValue; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings; import java.net.URISyntaxException; import java.util.Map; -import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsTaskSettingsTests.getAzureOpenAiRequestTaskSettingsMap; +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiTaskSettingsTests.createRequestTaskSettingsMap; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.sameInstance; @@ -26,7 +28,7 @@ public class AzureOpenAiEmbeddingsModelTests extends ESTestCase { public void testOverrideWith_OverridesUser() { var model = createModel("resource", "deployment", "apiversion", null, "api_key", null, "id"); - var requestTaskSettingsMap = getAzureOpenAiRequestTaskSettingsMap("user_override"); + var requestTaskSettingsMap = createRequestTaskSettingsMap("user_override"); var overriddenModel = AzureOpenAiEmbeddingsModel.of(model, requestTaskSettingsMap); @@ -108,12 +110,14 @@ public static AzureOpenAiEmbeddingsModel createModel( ) { var secureApiKey = apiKey != null ? new SecureString(apiKey.toCharArray()) : null; var secureEntraId = entraId != null ? new SecureString(entraId.toCharArray()) : null; + var userToUse = user == null ? StatefulValue.undefined() : StatefulValue.of(user); + return new AzureOpenAiEmbeddingsModel( inferenceEntityId, TaskType.TEXT_EMBEDDING, "service", new AzureOpenAiEmbeddingsServiceSettings(resourceName, deploymentId, apiVersion, null, false, null, null, null), - new AzureOpenAiEmbeddingsTaskSettings(user), + new AzureOpenAiEmbeddingsTaskSettings(userToUse, Headers.UNDEFINED_INSTANCE), chunkingSettings, new AzureOpenAiSecretSettings(secureApiKey, secureEntraId) ); @@ -130,12 +134,14 @@ public static AzureOpenAiEmbeddingsModel createModel( ) { var secureApiKey = apiKey != null ? new SecureString(apiKey.toCharArray()) : null; var secureEntraId = entraId != null ? new SecureString(entraId.toCharArray()) : null; + var userToUse = user == null ? StatefulValue.undefined() : StatefulValue.of(user); + return new AzureOpenAiEmbeddingsModel( inferenceEntityId, TaskType.TEXT_EMBEDDING, "service", new AzureOpenAiEmbeddingsServiceSettings(resourceName, deploymentId, apiVersion, null, false, null, null, null), - new AzureOpenAiEmbeddingsTaskSettings(user), + new AzureOpenAiEmbeddingsTaskSettings(userToUse, Headers.UNDEFINED_INSTANCE), null, new AzureOpenAiSecretSettings(secureApiKey, secureEntraId) ); @@ -156,6 +162,7 @@ public static AzureOpenAiEmbeddingsModel createModel( ) { var secureApiKey = apiKey != null ? new SecureString(apiKey.toCharArray()) : null; var secureEntraId = entraId != null ? new SecureString(entraId.toCharArray()) : null; + var userToUse = user == null ? StatefulValue.undefined() : StatefulValue.of(user); return new AzureOpenAiEmbeddingsModel( inferenceEntityId, @@ -171,7 +178,7 @@ public static AzureOpenAiEmbeddingsModel createModel( similarity, null ), - new AzureOpenAiEmbeddingsTaskSettings(user), + new AzureOpenAiEmbeddingsTaskSettings(userToUse, Headers.UNDEFINED_INSTANCE), null, new AzureOpenAiSecretSettings(secureApiKey, secureEntraId) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTaskSettingsTests.java deleted file mode 100644 index 0aef2a97ee0a1..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTaskSettingsTests.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.azureopenai.embeddings; - -import org.elasticsearch.common.ValidationException; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields; - -import java.util.HashMap; -import java.util.Map; - -import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.is; - -public class AzureOpenAiEmbeddingsRequestTaskSettingsTests extends ESTestCase { - public void testFromMap_ReturnsEmptySettings_WhenTheMapIsEmpty() { - var settings = AzureOpenAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(Map.of())); - assertThat(settings, is(AzureOpenAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS)); - } - - public void testFromMap_ReturnsEmptySettings_WhenTheMapDoesNotContainTheFields() { - var settings = AzureOpenAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(Map.of("key", "model"))); - assertNull(settings.user()); - } - - public void testFromMap_ReturnsUser() { - var settings = AzureOpenAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user"))); - assertThat(settings.user(), is("user")); - } - - public void testFromMap_WhenUserIsEmpty_ThrowsValidationException() { - var exception = expectThrows( - ValidationException.class, - () -> AzureOpenAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, ""))) - ); - - assertThat(exception.getMessage(), containsString("[user] must be a non-empty string")); - } - - public static Map createRequestTaskSettingsMap(@Nullable String user) { - var map = new HashMap(); - - if (user != null) { - map.put(OpenAiServiceFields.USER, user); - } - - return map; - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettingsTests.java index 5e7b40912511b..dddd6319351c7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettingsTests.java @@ -7,95 +7,15 @@ package org.elasticsearch.xpack.inference.services.azureopenai.embeddings; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.hamcrest.MatcherAssert; +import org.elasticsearch.xpack.inference.common.parser.Headers; +import org.elasticsearch.xpack.inference.common.parser.StatefulValue; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiTaskSettingsTests; -import java.io.IOException; -import java.util.HashMap; import java.util.Map; -import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.USER; -import static org.hamcrest.Matchers.is; - -public class AzureOpenAiEmbeddingsTaskSettingsTests extends AbstractWireSerializingTestCase { - - public static AzureOpenAiEmbeddingsTaskSettings createRandomWithUser() { - return new AzureOpenAiEmbeddingsTaskSettings(randomAlphaOfLength(15)); - } - - public void testIsEmpty() { - var randomSettings = createRandom(); - var stringRep = Strings.toString(randomSettings); - assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}")); - } - - /** - * The created settings can have the user set to null. - */ - public static AzureOpenAiEmbeddingsTaskSettings createRandom() { - return new AzureOpenAiEmbeddingsTaskSettings(randomAlphaOfLengthOrNull(15)); - } - - public void testUpdatedTaskSettings() { - var initialSettings = createRandom(); - var newSettings = createRandom(); - AzureOpenAiEmbeddingsTaskSettings updatedSettings = (AzureOpenAiEmbeddingsTaskSettings) initialSettings.updatedTaskSettings( - newSettings.user() == null ? Map.of() : Map.of(USER, newSettings.user()) - ); - - if (newSettings.user() == null) { - assertEquals(initialSettings.user(), updatedSettings.user()); - } else { - assertEquals(newSettings.user(), updatedSettings.user()); - } - } - - public void testFromMap_WithUser() { - assertEquals( - new AzureOpenAiEmbeddingsTaskSettings("user"), - AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(USER, "user"))) - ); - } - - public void testFromMap_UserIsEmptyString() { - var thrownException = expectThrows( - ValidationException.class, - () -> AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(USER, ""))) - ); - - MatcherAssert.assertThat( - thrownException.getMessage(), - is(Strings.format("Validation Failed: 1: [task_settings] Invalid value empty string. [user] must be a non-empty string;")) - ); - } - - public void testFromMap_MissingUser_DoesNotThrowException() { - var taskSettings = AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of())); - assertNull(taskSettings.user()); - } - - public void testOverrideWith_KeepsOriginalValuesWithOverridesAreNull() { - var taskSettings = AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(USER, "user"))); - - var overriddenTaskSettings = AzureOpenAiEmbeddingsTaskSettings.of( - taskSettings, - AzureOpenAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS - ); - MatcherAssert.assertThat(overriddenTaskSettings, is(taskSettings)); - } - - public void testOverrideWith_UsesOverriddenSettings() { - var taskSettings = AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(USER, "user"))); - - var requestTaskSettings = AzureOpenAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(Map.of(USER, "user2"))); - - var overriddenTaskSettings = AzureOpenAiEmbeddingsTaskSettings.of(taskSettings, requestTaskSettings); - MatcherAssert.assertThat(overriddenTaskSettings, is(new AzureOpenAiEmbeddingsTaskSettings("user2"))); - } +public class AzureOpenAiEmbeddingsTaskSettingsTests extends AzureOpenAiTaskSettingsTests { @Override protected Writeable.Reader instanceReader() { @@ -104,22 +24,21 @@ protected Writeable.Reader instanceReader() { @Override protected AzureOpenAiEmbeddingsTaskSettings createTestInstance() { - return createRandomWithUser(); + return createRandom(); } @Override - protected AzureOpenAiEmbeddingsTaskSettings mutateInstance(AzureOpenAiEmbeddingsTaskSettings instance) throws IOException { - String user = randomValueOtherThan(instance.user(), () -> randomAlphaOfLengthOrNull(15)); - return new AzureOpenAiEmbeddingsTaskSettings(user); + protected AzureOpenAiEmbeddingsTaskSettings create(StatefulValue user, Headers headers) { + return new AzureOpenAiEmbeddingsTaskSettings(user, headers); } - public static Map getAzureOpenAiRequestTaskSettingsMap(@Nullable String user) { - var map = new HashMap(); - - if (user != null) { - map.put(USER, user); - } + @Override + protected AzureOpenAiEmbeddingsTaskSettings createFromMap(Map map, ConfigurationParseContext context) { + return AzureOpenAiEmbeddingsTaskSettings.fromMap(map, context); + } - return map; + @Override + protected AzureOpenAiEmbeddingsTaskSettings emptySettings() { + return AzureOpenAiEmbeddingsTaskSettings.EMPTY; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java index 72a31e6a789f1..d53becee3613a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java @@ -537,7 +537,7 @@ public void testFromMap_ReturnsError_IfHeadersContainsNonStringValues() { exception.getMessage(), is( "Validation Failed: 1: Map field [headers] has an entry that is not valid, [key => 1]. " - + "Value type of [1] is not one of [String].;" + + "Value type of [Integer] is not one of [String].;" ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettingsTests.java index 471a66a642ec8..a499f81a564bc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomTaskSettingsTests.java @@ -70,7 +70,7 @@ public void testFromMap_Throws_IfValueIsInvalid() { exception.getMessage(), is( "Validation Failed: 1: Map field [parameters] has an entry that is not valid, [key => {another_key=value}]. " - + "Value type of [{another_key=value}] is not one of [Boolean, Double, Float, Integer, String].;" + + "Value type of [Map] is not one of [Boolean, Double, Float, Integer, String].;" ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiTaskSettingsTests.java index 5de993462b202..153a0c6196309 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiTaskSettingsTests.java @@ -196,7 +196,7 @@ public void testFromMap_ParsesCorrectly_WhenHeadersContainsAnInteger() { exception.getMessage(), is( "Validation Failed: 1: Map field [headers] has an entry that is not valid, " - + "[key => 1]. Value type of [1] is not one of [String].;" + + "[key => 1]. Value type of [Integer] is not one of [String].;" ) ); }