Skip to content

Commit 7eaf380

Browse files
[8.x] [ML] Inference duration and error metrics (#115876) (#118700)
* [ML] Inference duration and error metrics (#115876) Add `es.inference.requests.time` metric around `infer` API. As recommended by OTel spec, errors are determined by the presence or absence of the `error.type` attribute in the metric. "error.type" will be the http status code (as a string) if it is available, otherwise it will be the name of the exception (e.g. NullPointerException). Additional notes: - ApmInferenceStats is merged into InferenceStats. Originally we planned to have multiple implementations, but now we're only using APM. - Request count is now always recorded, even when there are failures loading the endpoint configuration. - Added a hook in streaming for cancel messages, so we can close the metrics when a user cancels the stream. (cherry picked from commit 26870ef) * fixing switch with class issue --------- Co-authored-by: Pat Whelan <[email protected]>
1 parent d2960b4 commit 7eaf380

File tree

11 files changed

+828
-151
lines changed

11 files changed

+828
-151
lines changed

docs/changelog/115876.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 115876
2+
summary: Inference duration and error metrics
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@
105105
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
106106
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
107107
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
108-
import org.elasticsearch.xpack.inference.telemetry.ApmInferenceStats;
109108
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
110109

111110
import java.util.ArrayList;
@@ -240,7 +239,7 @@ public Collection<?> createComponents(PluginServices services) {
240239
shardBulkInferenceActionFilter.set(actionFilter);
241240

242241
var meterRegistry = services.telemetryProvider().getMeterRegistry();
243-
var stats = new PluginComponentBinding<>(InferenceStats.class, ApmInferenceStats.create(meterRegistry));
242+
var stats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));
244243

245244
return List.of(modelRegistry, registry, httpClientManager, stats);
246245
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java

Lines changed: 102 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77

88
package org.elasticsearch.xpack.inference.action;
99

10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
1012
import org.elasticsearch.ElasticsearchStatusException;
1113
import org.elasticsearch.action.ActionListener;
1214
import org.elasticsearch.action.support.ActionFilters;
1315
import org.elasticsearch.action.support.HandledTransportAction;
1416
import org.elasticsearch.common.logging.DeprecationLogger;
1517
import org.elasticsearch.common.util.concurrent.EsExecutors;
1618
import org.elasticsearch.common.xcontent.ChunkedToXContent;
19+
import org.elasticsearch.core.Nullable;
1720
import org.elasticsearch.inference.InferenceService;
1821
import org.elasticsearch.inference.InferenceServiceRegistry;
1922
import org.elasticsearch.inference.InferenceServiceResults;
@@ -26,20 +29,22 @@
2629
import org.elasticsearch.transport.TransportService;
2730
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
2831
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
32+
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
2933
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
3034
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
35+
import org.elasticsearch.xpack.inference.telemetry.InferenceTimer;
3136

32-
import java.util.Set;
3337
import java.util.stream.Collectors;
3438

3539
import static org.elasticsearch.core.Strings.format;
40+
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
41+
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
3642

3743
public class TransportInferenceAction extends HandledTransportAction<InferenceAction.Request, InferenceAction.Response> {
44+
private static final Logger log = LogManager.getLogger(TransportInferenceAction.class);
3845
private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference";
3946
private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]";
4047

41-
private static final Set<Class<? extends InferenceService>> supportsStreaming = Set.of();
42-
4348
private final ModelRegistry modelRegistry;
4449
private final InferenceServiceRegistry serviceRegistry;
4550
private final InferenceStats inferenceStats;
@@ -64,17 +69,22 @@ public TransportInferenceAction(
6469

6570
@Override
6671
protected void doExecute(Task task, InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) {
72+
var timer = InferenceTimer.start();
6773

68-
ActionListener<UnparsedModel> getModelListener = listener.delegateFailureAndWrap((delegate, unparsedModel) -> {
74+
var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> {
6975
var service = serviceRegistry.getService(unparsedModel.service());
7076
if (service.isEmpty()) {
71-
listener.onFailure(unknownServiceException(unparsedModel.service(), request.getInferenceEntityId()));
77+
var e = unknownServiceException(unparsedModel.service(), request.getInferenceEntityId());
78+
recordMetrics(unparsedModel, timer, e);
79+
listener.onFailure(e);
7280
return;
7381
}
7482

7583
if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) {
7684
// not the wildcard task type and not the model task type
77-
listener.onFailure(incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType()));
85+
var e = incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType());
86+
recordMetrics(unparsedModel, timer, e);
87+
listener.onFailure(e);
7888
return;
7989
}
8090

@@ -85,20 +95,69 @@ protected void doExecute(Task task, InferenceAction.Request request, ActionListe
8595
unparsedModel.settings(),
8696
unparsedModel.secrets()
8797
);
88-
inferOnService(model, request, service.get(), delegate);
98+
inferOnServiceWithMetrics(model, request, service.get(), timer, listener);
99+
}, e -> {
100+
try {
101+
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(e));
102+
} catch (Exception metricsException) {
103+
log.atDebug().withThrowable(metricsException).log("Failed to record metrics when the model is missing, dropping metrics");
104+
}
105+
listener.onFailure(e);
89106
});
90107

91108
modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener);
92109
}
93110

94-
private void inferOnService(
111+
private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) {
112+
try {
113+
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t));
114+
} catch (Exception e) {
115+
log.atDebug().withThrowable(e).log("Failed to record metrics with an unparsed model, dropping metrics");
116+
}
117+
}
118+
119+
private void inferOnServiceWithMetrics(
95120
Model model,
96121
InferenceAction.Request request,
97122
InferenceService service,
123+
InferenceTimer timer,
98124
ActionListener<InferenceAction.Response> listener
125+
) {
126+
inferenceStats.requestCount().incrementBy(1, modelAttributes(model));
127+
inferOnService(model, request, service, ActionListener.wrap(inferenceResults -> {
128+
if (request.isStreaming()) {
129+
var taskProcessor = streamingTaskManager.<ChunkedToXContent>create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION);
130+
inferenceResults.publisher().subscribe(taskProcessor);
131+
132+
var instrumentedStream = new PublisherWithMetrics(timer, model);
133+
taskProcessor.subscribe(instrumentedStream);
134+
135+
listener.onResponse(new InferenceAction.Response(inferenceResults, instrumentedStream));
136+
} else {
137+
recordMetrics(model, timer, null);
138+
listener.onResponse(new InferenceAction.Response(inferenceResults));
139+
}
140+
}, e -> {
141+
recordMetrics(model, timer, e);
142+
listener.onFailure(e);
143+
}));
144+
}
145+
146+
private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) {
147+
try {
148+
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t));
149+
} catch (Exception e) {
150+
log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics");
151+
}
152+
}
153+
154+
private void inferOnService(
155+
Model model,
156+
InferenceAction.Request request,
157+
InferenceService service,
158+
ActionListener<InferenceServiceResults> listener
99159
) {
100160
if (request.isStreaming() == false || service.canStream(request.getTaskType())) {
101-
inferenceStats.incrementRequestCount(model);
102161
service.infer(
103162
model,
104163
request.getQuery(),
@@ -107,7 +166,7 @@ private void inferOnService(
107166
request.getTaskSettings(),
108167
request.getInputType(),
109168
request.getInferenceTimeout(),
110-
createListener(request, listener)
169+
listener
111170
);
112171
} else {
113172
listener.onFailure(unsupportedStreamingTaskException(request, service));
@@ -135,20 +194,6 @@ private ElasticsearchStatusException unsupportedStreamingTaskException(Inference
135194
}
136195
}
137196

138-
private ActionListener<InferenceServiceResults> createListener(
139-
InferenceAction.Request request,
140-
ActionListener<InferenceAction.Response> listener
141-
) {
142-
if (request.isStreaming()) {
143-
return listener.delegateFailureAndWrap((l, inferenceResults) -> {
144-
var taskProcessor = streamingTaskManager.<ChunkedToXContent>create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION);
145-
inferenceResults.publisher().subscribe(taskProcessor);
146-
l.onResponse(new InferenceAction.Response(inferenceResults, taskProcessor));
147-
});
148-
}
149-
return listener.delegateFailureAndWrap((l, inferenceResults) -> l.onResponse(new InferenceAction.Response(inferenceResults)));
150-
}
151-
152197
private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) {
153198
return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId);
154199
}
@@ -162,4 +207,37 @@ private static ElasticsearchStatusException incompatibleTaskTypeException(TaskTy
162207
);
163208
}
164209

210+
private class PublisherWithMetrics extends DelegatingProcessor<ChunkedToXContent, ChunkedToXContent> {
211+
private final InferenceTimer timer;
212+
private final Model model;
213+
214+
private PublisherWithMetrics(InferenceTimer timer, Model model) {
215+
this.timer = timer;
216+
this.model = model;
217+
}
218+
219+
@Override
220+
protected void next(ChunkedToXContent item) {
221+
downstream().onNext(item);
222+
}
223+
224+
@Override
225+
public void onError(Throwable throwable) {
226+
recordMetrics(model, timer, throwable);
227+
super.onError(throwable);
228+
}
229+
230+
@Override
231+
protected void onCancel() {
232+
recordMetrics(model, timer, null);
233+
super.onCancel();
234+
}
235+
236+
@Override
237+
public void onComplete() {
238+
recordMetrics(model, timer, null);
239+
super.onComplete();
240+
}
241+
}
242+
165243
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,14 @@ public void request(long n) {
6161
public void cancel() {
6262
if (isClosed.compareAndSet(false, true) && upstream != null) {
6363
upstream.cancel();
64+
onCancel();
6465
}
6566
}
6667
};
6768
}
6869

70+
protected void onCancel() {}
71+
6972
@Override
7073
public void onSubscribe(Flow.Subscription subscription) {
7174
if (upstream != null) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/ApmInferenceStats.java

Lines changed: 0 additions & 49 deletions
This file was deleted.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java

Lines changed: 81 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,89 @@
77

88
package org.elasticsearch.xpack.inference.telemetry;
99

10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.core.Nullable;
1012
import org.elasticsearch.inference.Model;
13+
import org.elasticsearch.inference.UnparsedModel;
14+
import org.elasticsearch.telemetry.metric.LongCounter;
15+
import org.elasticsearch.telemetry.metric.LongHistogram;
16+
import org.elasticsearch.telemetry.metric.MeterRegistry;
1117

12-
public interface InferenceStats {
18+
import java.util.Map;
19+
import java.util.Objects;
20+
import java.util.stream.Collectors;
21+
import java.util.stream.Stream;
1322

14-
/**
15-
* Increment the counter for a particular value in a thread safe manner.
16-
* @param model the model to increment request count for
17-
*/
18-
void incrementRequestCount(Model model);
23+
import static java.util.Map.entry;
24+
import static java.util.stream.Stream.concat;
1925

20-
InferenceStats NOOP = model -> {};
26+
public record InferenceStats(LongCounter requestCount, LongHistogram inferenceDuration) {
27+
28+
public InferenceStats {
29+
Objects.requireNonNull(requestCount);
30+
Objects.requireNonNull(inferenceDuration);
31+
}
32+
33+
public static InferenceStats create(MeterRegistry meterRegistry) {
34+
return new InferenceStats(
35+
meterRegistry.registerLongCounter(
36+
"es.inference.requests.count.total",
37+
"Inference API request counts for a particular service, task type, model ID",
38+
"operations"
39+
),
40+
meterRegistry.registerLongHistogram(
41+
"es.inference.requests.time",
42+
"Inference API request counts for a particular service, task type, model ID",
43+
"ms"
44+
)
45+
);
46+
}
47+
48+
public static Map<String, Object> modelAttributes(Model model) {
49+
return toMap(modelAttributeEntries(model));
50+
}
51+
52+
private static Stream<Map.Entry<String, Object>> modelAttributeEntries(Model model) {
53+
var stream = Stream.<Map.Entry<String, Object>>builder()
54+
.add(entry("service", model.getConfigurations().getService()))
55+
.add(entry("task_type", model.getTaskType().toString()));
56+
if (model.getServiceSettings().modelId() != null) {
57+
stream.add(entry("model_id", model.getServiceSettings().modelId()));
58+
}
59+
return stream.build();
60+
}
61+
62+
private static Map<String, Object> toMap(Stream<Map.Entry<String, Object>> stream) {
63+
return stream.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
64+
}
65+
66+
public static Map<String, Object> responseAttributes(Model model, @Nullable Throwable t) {
67+
return toMap(concat(modelAttributeEntries(model), errorAttributes(t)));
68+
}
69+
70+
public static Map<String, Object> responseAttributes(UnparsedModel model, @Nullable Throwable t) {
71+
var unknownModelAttributes = Stream.<Map.Entry<String, Object>>builder()
72+
.add(entry("service", model.service()))
73+
.add(entry("task_type", model.taskType().toString()))
74+
.build();
75+
76+
return toMap(concat(unknownModelAttributes, errorAttributes(t)));
77+
}
78+
79+
public static Map<String, Object> responseAttributes(@Nullable Throwable t) {
80+
return toMap(errorAttributes(t));
81+
}
82+
83+
private static Stream<Map.Entry<String, Object>> errorAttributes(@Nullable Throwable t) {
84+
if (t == null) {
85+
return Stream.of(entry("status_code", 200));
86+
} else if (t instanceof ElasticsearchStatusException ese) {
87+
return Stream.<Map.Entry<String, Object>>builder()
88+
.add(entry("status_code", ese.status().getStatus()))
89+
.add(entry("error.type", String.valueOf(ese.status().getStatus())))
90+
.build();
91+
} else {
92+
return Stream.of(entry("error.type", t.getClass().getSimpleName()));
93+
}
94+
}
2195
}

0 commit comments

Comments
 (0)