-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[Inference API] Add custom headers for Azure OpenAI Service #142969
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f2a21e9
ebdb4a6
bb0907e
8cedcd0
c4b846d
e3d466f
bc2384e
e7a8fc2
56bf6dc
1f64d15
4741068
f1e92f7
e55d554
2eb28ce
838ede1
6242017
bfd1953
3313304
e13ba77
4ae979a
569673e
f03baf4
2d9f471
efb0956
7cb22c3
0b1cff0
80a2272
2b600d2
8ebbd3b
710f8aa
83df8d5
179d7b2
30d256d
fe979eb
c75b7f4
e24b927
3fe6c9e
69b4732
ef2b448
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| area: Inference | ||
| issues: [] | ||
| pr: 142969 | ||
| summary: "[Inference API] Add custom headers for Azure OpenAI Service" | ||
| type: enhancement |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| 9304000 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1 @@ | ||
| query_dsl_boxplot_exponential_histogram_support,9303000 | ||
| inference_azure_openai_task_settings_headers,9304000 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,133 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.common.parser; | ||
|
|
||
| import org.elasticsearch.common.ValidationException; | ||
| import org.elasticsearch.common.io.stream.StreamInput; | ||
| import org.elasticsearch.common.io.stream.StreamOutput; | ||
| import org.elasticsearch.common.io.stream.Writeable; | ||
| import org.elasticsearch.xcontent.ConstructingObjectParser; | ||
| import org.elasticsearch.xcontent.ParseField; | ||
| import org.elasticsearch.xcontent.ToXContentFragment; | ||
| import org.elasticsearch.xcontent.XContentBuilder; | ||
|
|
||
| import java.io.IOException; | ||
| import java.util.HashMap; | ||
| import java.util.Map; | ||
| import java.util.Objects; | ||
|
|
||
| import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; | ||
| import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeNullValues; | ||
| import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues; | ||
|
|
||
| public record Headers(StatefulValue<Map<String, String>> mapValue) implements ToXContentFragment, Writeable { | ||
|
|
||
| // public for testing | ||
| public static final String HEADERS_FIELD = "headers"; | ||
| // public for testing | ||
| public static final Headers UNDEFINED_INSTANCE = new Headers(StatefulValue.undefined()); | ||
| public static final Headers NULL_INSTANCE = new Headers(StatefulValue.nullInstance()); | ||
|
|
||
| /** | ||
| * Sentinel passed by the parser when the headers field is present with value null. | ||
| */ | ||
| public static final Object PARSER_NULL_SENTINEL = new HashMap<>(); | ||
|
|
||
| private static final ParseField HEADERS = new ParseField(HEADERS_FIELD); | ||
|
|
||
| public static <Value, Context> void initParser(ConstructingObjectParser<Value, Context> parser) { | ||
| parser.declareObjectOrNull(optionalConstructorArg(), (p, c) -> { | ||
| var parsedMap = p.map(); | ||
| if (parsedMap == null || parsedMap == PARSER_NULL_SENTINEL) { | ||
| return parsedMap; | ||
| } | ||
|
|
||
| var validationException = new ValidationException(); | ||
|
|
||
| return doValidation(parsedMap, validationException); | ||
| }, PARSER_NULL_SENTINEL, HEADERS); | ||
| } | ||
|
|
||
| private static Map<String, String> doValidation(Map<String, Object> map, ValidationException validationException) { | ||
| removeNullValues(map); | ||
|
|
||
| var stringHeaders = validateMapStringValues(map, HEADERS.getPreferredName(), validationException, false, Map.of()); | ||
|
|
||
| validationException.throwIfValidationErrorsExist(); | ||
|
|
||
| return stringHeaders; | ||
| } | ||
|
|
||
| @SuppressWarnings("unchecked") | ||
| public static Headers create(Object arg, String path) { | ||
| // We will get null here if the headers field was not present in the json | ||
| if (arg == null) { | ||
| return UNDEFINED_INSTANCE; | ||
| } | ||
|
Comment on lines
+69
to
+75
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there an argument for returning EMPTY_INSTANCE here? I think it would allow us to make the headers field on AzureOpenAiTaskSettings not @nullable, so we'd have to check .isEmpty() instead of null in a few places, but it would mean that we wouldn't end up potentially creating a AzureOpenAiTaskSettings with a null user and empty headers, which could happen at the moment. Alternately, we could check if
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, now that I'm thinking about this more I wonder how a user would use the If we removed I think we can use I agree though, it'd be nice to have a single state with only empty headers. Let me know if you can think of a way to handle that. I suppose we could use an enum as well 🤔
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok give this another look, I think it's in a better state 😅 |
||
|
|
||
| if (arg == PARSER_NULL_SENTINEL) { | ||
| return NULL_INSTANCE; | ||
| } | ||
|
|
||
| var validationException = new ValidationException(); | ||
|
|
||
| if (arg instanceof Map == false) { | ||
| validationException.addValidationError(ObjectParserUtils.invalidTypeErrorMsg(HEADERS_FIELD, path, arg, "Map")); | ||
| throw validationException; | ||
| } | ||
|
|
||
| // It's not likely that this create method would be called with invalid values since they should be validated during parsing but | ||
| // we'll do it just in case this method is used elsewhere | ||
| var stringsMap = doValidation((Map<String, Object>) arg, validationException); | ||
|
|
||
| if (stringsMap.isEmpty()) { | ||
| // If a user specifies "headers": {} we'll assume they don't want any headers. If this in the context of an update API, | ||
| // this is the same as if they did "headers": null which means to remove all existing headers. | ||
| return NULL_INSTANCE; | ||
| } | ||
|
|
||
| return new Headers(StatefulValue.of(stringsMap)); | ||
| } | ||
|
|
||
| public Headers { | ||
| Objects.requireNonNull(mapValue); | ||
| } | ||
|
|
||
| public Headers(StreamInput in) throws IOException { | ||
| this(StatefulValue.read(in, input -> input.readImmutableMap(StreamInput::readString, StreamInput::readString))); | ||
| } | ||
|
|
||
| public boolean isEmpty() { | ||
| return mapValue.isPresent() == false || mapValue.get().isEmpty(); | ||
| } | ||
|
|
||
| public boolean isPresent() { | ||
| return mapValue.isPresent(); | ||
| } | ||
|
|
||
| public boolean isNull() { | ||
| return mapValue.isNull(); | ||
| } | ||
|
|
||
| @Override | ||
| public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
| if (isEmpty() == false) { | ||
| builder.field(HEADERS.getPreferredName(), mapValue.get()); | ||
| } | ||
| return builder; | ||
| } | ||
|
|
||
| @Override | ||
| public void writeTo(StreamOutput out) throws IOException { | ||
| StatefulValue.write( | ||
| out, | ||
| mapValue, | ||
| (streamOutput, v) -> streamOutput.writeMap(v, StreamOutput::writeString, StreamOutput::writeString) | ||
| ); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,126 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| package org.elasticsearch.xpack.inference.common.parser; | ||
|
|
||
| import org.elasticsearch.common.io.stream.StreamInput; | ||
| import org.elasticsearch.common.io.stream.StreamOutput; | ||
| import org.elasticsearch.common.io.stream.Writeable; | ||
|
|
||
| import java.io.IOException; | ||
| import java.util.NoSuchElementException; | ||
| import java.util.Objects; | ||
|
|
||
| /** | ||
| * This class holds a value of type {@param T} that can be in one of three states: undefined, null, or defined with a non-null value. | ||
| * It provides methods to check the state and retrieve the value if present. | ||
| * <p> | ||
| * Undefined means that the value is not defined aka it was absent in the input | ||
| * Null means that the value is defined but explicitly set to null | ||
| * Present means that the value is defined and not null | ||
| * @param <T> the type of the value | ||
| */ | ||
| public final class StatefulValue<T> { | ||
|
|
||
| static final NoSuchElementException NO_VALUE_PRESENT = new NoSuchElementException("No value present"); | ||
|
|
||
| private static final StatefulValue<?> UNDEFINED_INSTANCE = new StatefulValue<>(null, false); | ||
| private static final StatefulValue<?> NULL_INSTANCE = new StatefulValue<>(null, true); | ||
|
|
||
| public static <T> StatefulValue<T> undefined() { | ||
| @SuppressWarnings("unchecked") | ||
| var absent = (StatefulValue<T>) UNDEFINED_INSTANCE; | ||
| return absent; | ||
| } | ||
|
|
||
| public static <T> StatefulValue<T> nullInstance() { | ||
| @SuppressWarnings("unchecked") | ||
| var nullInstance = (StatefulValue<T>) NULL_INSTANCE; | ||
| return nullInstance; | ||
| } | ||
|
|
||
| public static <T> StatefulValue<T> of(T value) { | ||
| return new StatefulValue<>(Objects.requireNonNull(value), true); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very minor nitpick, but if we want to more closely align with the behaviour of |
||
| } | ||
|
|
||
| public static <T> StatefulValue<T> read(StreamInput in, Writeable.Reader<T> reader) throws IOException { | ||
| var isDefined = in.readBoolean(); | ||
| if (isDefined == false) { | ||
| return undefined(); | ||
| } | ||
|
|
||
| var isNull = in.readBoolean(); | ||
| if (isNull) { | ||
| return nullInstance(); | ||
| } | ||
|
|
||
| var value = reader.read(in); | ||
| return of(value); | ||
| } | ||
|
|
||
| public static <T> void write(StreamOutput out, StatefulValue<T> statefulValue, Writeable.Writer<T> writer) throws IOException { | ||
| out.writeBoolean(statefulValue.isDefined); | ||
| if (statefulValue.isDefined) { | ||
| out.writeBoolean(statefulValue.isNull()); | ||
| if (statefulValue.isPresent()) { | ||
| writer.write(out, statefulValue.value); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private final T value; | ||
| private final boolean isDefined; | ||
|
|
||
| private StatefulValue(T value, boolean isDefined) { | ||
| this.value = value; | ||
| this.isDefined = isDefined; | ||
| } | ||
|
|
||
| /** | ||
| * Returns true if the value is not defined, meaning it is absent. | ||
| */ | ||
| public boolean isUndefined() { | ||
| return isDefined == false; | ||
| } | ||
|
|
||
| /** | ||
| * Returns true if the value is defined and explicitly set to null. | ||
| */ | ||
| public boolean isNull() { | ||
| return isDefined && value == null; | ||
| } | ||
|
|
||
| /** | ||
| * Returns true if the value is defined and not null. | ||
| */ | ||
| public boolean isPresent() { | ||
| return isDefined && value != null; | ||
| } | ||
|
|
||
| public T get() { | ||
| if (isPresent() == false) { | ||
| throw NO_VALUE_PRESENT; | ||
| } | ||
| return value; | ||
| } | ||
|
|
||
| public T orElse(T other) { | ||
| return isPresent() ? value : other; | ||
| } | ||
|
|
||
| @Override | ||
| public boolean equals(Object o) { | ||
| if (o == null || getClass() != o.getClass()) return false; | ||
| StatefulValue<?> statefulValue = (StatefulValue<?>) o; | ||
| return Objects.equals(value, statefulValue.value) && isDefined == statefulValue.isDefined; | ||
| } | ||
|
|
||
| @Override | ||
| public int hashCode() { | ||
| return Objects.hash(value, isDefined); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,7 +10,6 @@ | |
| import org.elasticsearch.common.ValidationException; | ||
| import org.elasticsearch.core.Nullable; | ||
| import org.elasticsearch.inference.ModelConfigurations; | ||
| import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsRequestTaskSettings; | ||
|
|
||
| import java.util.Map; | ||
|
|
||
|
|
@@ -31,7 +30,7 @@ public record AzureAiStudioEmbeddingsRequestTaskSettings(@Nullable String user) | |
| * does not throw an error. | ||
| * | ||
| * @param map the settings received from a request | ||
| * @return a {@link AzureOpenAiEmbeddingsRequestTaskSettings} | ||
| * @return a {@link AzureAiStudioEmbeddingsRequestTaskSettings} | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was referencing the wrong class |
||
| */ | ||
| public static AzureAiStudioEmbeddingsRequestTaskSettings fromMap(Map<String, Object> map) { | ||
| if (map.isEmpty()) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we always cast
argtoMap<String, Object>if it's non-null, would it make sense to have the method just take aMap<String, Object>instead ofObject?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keeping it as an object here helps with encapsulation I think. That way only
Headersknows the expected type. If we havecreate()take a Map, then the caller needs to do the cast and have the uncheck cast suppression. It also goes fromMap<String, Object>toMap<String, String>after the validation check.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough. Would it be worth adding some error handling for if we get a
ClassCastExceptionhere, then, or is it fine to just throw the exception and let somewhere higher up deal with it?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, I'll add an instanceof check.