Skip to content

Commit 26870ef

Browse files
authored
[ML] Inference duration and error metrics (elastic#115876)
Add `es.inference.requests.time` metric around `infer` API. As recommended by OTel spec, errors are determined by the presence or absence of the `error.type` attribute in the metric. "error.type" will be the http status code (as a string) if it is available, otherwise it will be the name of the exception (e.g. NullPointerException). Additional notes: - ApmInferenceStats is merged into InferenceStats. Originally we planned to have multiple implementations, but now we're only using APM. - Request count is now always recorded, even when there are failures loading the endpoint configuration. - Added a hook in streaming for cancel messages, so we can close the metrics when a user cancels the stream.
1 parent 38c7ddd commit 26870ef

File tree

11 files changed

+826
-151
lines changed

11 files changed

+826
-151
lines changed

docs/changelog/115876.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 115876
2+
summary: Inference duration and error metrics
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@
101101
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
102102
import org.elasticsearch.xpack.inference.services.mistral.MistralService;
103103
import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
104-
import org.elasticsearch.xpack.inference.telemetry.ApmInferenceStats;
105104
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
106105

107106
import java.util.ArrayList;
@@ -239,7 +238,7 @@ public Collection<?> createComponents(PluginServices services) {
239238
shardBulkInferenceActionFilter.set(actionFilter);
240239

241240
var meterRegistry = services.telemetryProvider().getMeterRegistry();
242-
var stats = new PluginComponentBinding<>(InferenceStats.class, ApmInferenceStats.create(meterRegistry));
241+
var stats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));
243242

244243
return List.of(modelRegistry, registry, httpClientManager, stats);
245244
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java

Lines changed: 102 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77

88
package org.elasticsearch.xpack.inference.action;
99

10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
1012
import org.elasticsearch.ElasticsearchStatusException;
1113
import org.elasticsearch.action.ActionListener;
1214
import org.elasticsearch.action.support.ActionFilters;
1315
import org.elasticsearch.action.support.HandledTransportAction;
1416
import org.elasticsearch.common.util.concurrent.EsExecutors;
1517
import org.elasticsearch.common.xcontent.ChunkedToXContent;
18+
import org.elasticsearch.core.Nullable;
1619
import org.elasticsearch.inference.InferenceService;
1720
import org.elasticsearch.inference.InferenceServiceRegistry;
1821
import org.elasticsearch.inference.InferenceServiceResults;
@@ -25,20 +28,22 @@
2528
import org.elasticsearch.transport.TransportService;
2629
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
2730
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
31+
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
2832
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
2933
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
34+
import org.elasticsearch.xpack.inference.telemetry.InferenceTimer;
3035

31-
import java.util.Set;
3236
import java.util.stream.Collectors;
3337

3438
import static org.elasticsearch.core.Strings.format;
39+
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
40+
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
3541

3642
public class TransportInferenceAction extends HandledTransportAction<InferenceAction.Request, InferenceAction.Response> {
43+
private static final Logger log = LogManager.getLogger(TransportInferenceAction.class);
3744
private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference";
3845
private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]";
3946

40-
private static final Set<Class<? extends InferenceService>> supportsStreaming = Set.of();
41-
4247
private final ModelRegistry modelRegistry;
4348
private final InferenceServiceRegistry serviceRegistry;
4449
private final InferenceStats inferenceStats;
@@ -62,17 +67,22 @@ public TransportInferenceAction(
6267

6368
@Override
6469
protected void doExecute(Task task, InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) {
70+
var timer = InferenceTimer.start();
6571

66-
ActionListener<UnparsedModel> getModelListener = listener.delegateFailureAndWrap((delegate, unparsedModel) -> {
72+
var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> {
6773
var service = serviceRegistry.getService(unparsedModel.service());
6874
if (service.isEmpty()) {
69-
listener.onFailure(unknownServiceException(unparsedModel.service(), request.getInferenceEntityId()));
75+
var e = unknownServiceException(unparsedModel.service(), request.getInferenceEntityId());
76+
recordMetrics(unparsedModel, timer, e);
77+
listener.onFailure(e);
7078
return;
7179
}
7280

7381
if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) {
7482
// not the wildcard task type and not the model task type
75-
listener.onFailure(incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType()));
83+
var e = incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType());
84+
recordMetrics(unparsedModel, timer, e);
85+
listener.onFailure(e);
7686
return;
7787
}
7888

@@ -83,20 +93,69 @@ protected void doExecute(Task task, InferenceAction.Request request, ActionListe
8393
unparsedModel.settings(),
8494
unparsedModel.secrets()
8595
);
86-
inferOnService(model, request, service.get(), delegate);
96+
inferOnServiceWithMetrics(model, request, service.get(), timer, listener);
97+
}, e -> {
98+
try {
99+
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(e));
100+
} catch (Exception metricsException) {
101+
log.atDebug().withThrowable(metricsException).log("Failed to record metrics when the model is missing, dropping metrics");
102+
}
103+
listener.onFailure(e);
87104
});
88105

89106
modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener);
90107
}
91108

92-
private void inferOnService(
109+
private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) {
110+
try {
111+
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t));
112+
} catch (Exception e) {
113+
log.atDebug().withThrowable(e).log("Failed to record metrics with an unparsed model, dropping metrics");
114+
}
115+
}
116+
117+
private void inferOnServiceWithMetrics(
93118
Model model,
94119
InferenceAction.Request request,
95120
InferenceService service,
121+
InferenceTimer timer,
96122
ActionListener<InferenceAction.Response> listener
123+
) {
124+
inferenceStats.requestCount().incrementBy(1, modelAttributes(model));
125+
inferOnService(model, request, service, ActionListener.wrap(inferenceResults -> {
126+
if (request.isStreaming()) {
127+
var taskProcessor = streamingTaskManager.<ChunkedToXContent>create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION);
128+
inferenceResults.publisher().subscribe(taskProcessor);
129+
130+
var instrumentedStream = new PublisherWithMetrics(timer, model);
131+
taskProcessor.subscribe(instrumentedStream);
132+
133+
listener.onResponse(new InferenceAction.Response(inferenceResults, instrumentedStream));
134+
} else {
135+
recordMetrics(model, timer, null);
136+
listener.onResponse(new InferenceAction.Response(inferenceResults));
137+
}
138+
}, e -> {
139+
recordMetrics(model, timer, e);
140+
listener.onFailure(e);
141+
}));
142+
}
143+
144+
private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) {
145+
try {
146+
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t));
147+
} catch (Exception e) {
148+
log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics");
149+
}
150+
}
151+
152+
private void inferOnService(
153+
Model model,
154+
InferenceAction.Request request,
155+
InferenceService service,
156+
ActionListener<InferenceServiceResults> listener
97157
) {
98158
if (request.isStreaming() == false || service.canStream(request.getTaskType())) {
99-
inferenceStats.incrementRequestCount(model);
100159
service.infer(
101160
model,
102161
request.getQuery(),
@@ -105,7 +164,7 @@ private void inferOnService(
105164
request.getTaskSettings(),
106165
request.getInputType(),
107166
request.getInferenceTimeout(),
108-
createListener(request, listener)
167+
listener
109168
);
110169
} else {
111170
listener.onFailure(unsupportedStreamingTaskException(request, service));
@@ -133,20 +192,6 @@ private ElasticsearchStatusException unsupportedStreamingTaskException(Inference
133192
}
134193
}
135194

136-
private ActionListener<InferenceServiceResults> createListener(
137-
InferenceAction.Request request,
138-
ActionListener<InferenceAction.Response> listener
139-
) {
140-
if (request.isStreaming()) {
141-
return listener.delegateFailureAndWrap((l, inferenceResults) -> {
142-
var taskProcessor = streamingTaskManager.<ChunkedToXContent>create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION);
143-
inferenceResults.publisher().subscribe(taskProcessor);
144-
l.onResponse(new InferenceAction.Response(inferenceResults, taskProcessor));
145-
});
146-
}
147-
return listener.delegateFailureAndWrap((l, inferenceResults) -> l.onResponse(new InferenceAction.Response(inferenceResults)));
148-
}
149-
150195
private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) {
151196
return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId);
152197
}
@@ -160,4 +205,37 @@ private static ElasticsearchStatusException incompatibleTaskTypeException(TaskTy
160205
);
161206
}
162207

208+
private class PublisherWithMetrics extends DelegatingProcessor<ChunkedToXContent, ChunkedToXContent> {
209+
private final InferenceTimer timer;
210+
private final Model model;
211+
212+
private PublisherWithMetrics(InferenceTimer timer, Model model) {
213+
this.timer = timer;
214+
this.model = model;
215+
}
216+
217+
@Override
218+
protected void next(ChunkedToXContent item) {
219+
downstream().onNext(item);
220+
}
221+
222+
@Override
223+
public void onError(Throwable throwable) {
224+
recordMetrics(model, timer, throwable);
225+
super.onError(throwable);
226+
}
227+
228+
@Override
229+
protected void onCancel() {
230+
recordMetrics(model, timer, null);
231+
super.onCancel();
232+
}
233+
234+
@Override
235+
public void onComplete() {
236+
recordMetrics(model, timer, null);
237+
super.onComplete();
238+
}
239+
}
240+
163241
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,14 @@ public void request(long n) {
6161
public void cancel() {
6262
if (isClosed.compareAndSet(false, true) && upstream != null) {
6363
upstream.cancel();
64+
onCancel();
6465
}
6566
}
6667
};
6768
}
6869

70+
protected void onCancel() {}
71+
6972
@Override
7073
public void onSubscribe(Flow.Subscription subscription) {
7174
if (upstream != null) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/ApmInferenceStats.java

Lines changed: 0 additions & 49 deletions
This file was deleted.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,87 @@
77

88
package org.elasticsearch.xpack.inference.telemetry;
99

10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.core.Nullable;
1012
import org.elasticsearch.inference.Model;
13+
import org.elasticsearch.inference.UnparsedModel;
14+
import org.elasticsearch.telemetry.metric.LongCounter;
15+
import org.elasticsearch.telemetry.metric.LongHistogram;
16+
import org.elasticsearch.telemetry.metric.MeterRegistry;
1117

12-
public interface InferenceStats {
18+
import java.util.Map;
19+
import java.util.Objects;
20+
import java.util.stream.Collectors;
21+
import java.util.stream.Stream;
1322

14-
/**
15-
* Increment the counter for a particular value in a thread safe manner.
16-
* @param model the model to increment request count for
17-
*/
18-
void incrementRequestCount(Model model);
23+
import static java.util.Map.entry;
24+
import static java.util.stream.Stream.concat;
1925

20-
InferenceStats NOOP = model -> {};
26+
public record InferenceStats(LongCounter requestCount, LongHistogram inferenceDuration) {
27+
28+
public InferenceStats {
29+
Objects.requireNonNull(requestCount);
30+
Objects.requireNonNull(inferenceDuration);
31+
}
32+
33+
public static InferenceStats create(MeterRegistry meterRegistry) {
34+
return new InferenceStats(
35+
meterRegistry.registerLongCounter(
36+
"es.inference.requests.count.total",
37+
"Inference API request counts for a particular service, task type, model ID",
38+
"operations"
39+
),
40+
meterRegistry.registerLongHistogram(
41+
"es.inference.requests.time",
42+
"Inference API request counts for a particular service, task type, model ID",
43+
"ms"
44+
)
45+
);
46+
}
47+
48+
public static Map<String, Object> modelAttributes(Model model) {
49+
return toMap(modelAttributeEntries(model));
50+
}
51+
52+
private static Stream<Map.Entry<String, Object>> modelAttributeEntries(Model model) {
53+
var stream = Stream.<Map.Entry<String, Object>>builder()
54+
.add(entry("service", model.getConfigurations().getService()))
55+
.add(entry("task_type", model.getTaskType().toString()));
56+
if (model.getServiceSettings().modelId() != null) {
57+
stream.add(entry("model_id", model.getServiceSettings().modelId()));
58+
}
59+
return stream.build();
60+
}
61+
62+
private static Map<String, Object> toMap(Stream<Map.Entry<String, Object>> stream) {
63+
return stream.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
64+
}
65+
66+
public static Map<String, Object> responseAttributes(Model model, @Nullable Throwable t) {
67+
return toMap(concat(modelAttributeEntries(model), errorAttributes(t)));
68+
}
69+
70+
public static Map<String, Object> responseAttributes(UnparsedModel model, @Nullable Throwable t) {
71+
var unknownModelAttributes = Stream.<Map.Entry<String, Object>>builder()
72+
.add(entry("service", model.service()))
73+
.add(entry("task_type", model.taskType().toString()))
74+
.build();
75+
76+
return toMap(concat(unknownModelAttributes, errorAttributes(t)));
77+
}
78+
79+
public static Map<String, Object> responseAttributes(@Nullable Throwable t) {
80+
return toMap(errorAttributes(t));
81+
}
82+
83+
private static Stream<Map.Entry<String, Object>> errorAttributes(@Nullable Throwable t) {
84+
return switch (t) {
85+
case null -> Stream.of(entry("status_code", 200));
86+
case ElasticsearchStatusException ese -> Stream.<Map.Entry<String, Object>>builder()
87+
.add(entry("status_code", ese.status().getStatus()))
88+
.add(entry("error.type", String.valueOf(ese.status().getStatus())))
89+
.build();
90+
default -> Stream.of(entry("error.type", t.getClass().getSimpleName()));
91+
};
92+
}
2193
}

0 commit comments

Comments
 (0)