7
7
8
8
package org .elasticsearch .xpack .inference .action ;
9
9
10
+ import org .apache .logging .log4j .LogManager ;
11
+ import org .apache .logging .log4j .Logger ;
10
12
import org .elasticsearch .ElasticsearchStatusException ;
11
13
import org .elasticsearch .action .ActionListener ;
12
14
import org .elasticsearch .action .support .ActionFilters ;
13
15
import org .elasticsearch .action .support .HandledTransportAction ;
14
16
import org .elasticsearch .common .util .concurrent .EsExecutors ;
15
17
import org .elasticsearch .common .xcontent .ChunkedToXContent ;
18
+ import org .elasticsearch .core .Nullable ;
16
19
import org .elasticsearch .inference .InferenceService ;
17
20
import org .elasticsearch .inference .InferenceServiceRegistry ;
18
21
import org .elasticsearch .inference .InferenceServiceResults ;
25
28
import org .elasticsearch .transport .TransportService ;
26
29
import org .elasticsearch .xpack .core .inference .action .InferenceAction ;
27
30
import org .elasticsearch .xpack .inference .action .task .StreamingTaskManager ;
31
+ import org .elasticsearch .xpack .inference .common .DelegatingProcessor ;
28
32
import org .elasticsearch .xpack .inference .registry .ModelRegistry ;
29
33
import org .elasticsearch .xpack .inference .telemetry .InferenceStats ;
34
+ import org .elasticsearch .xpack .inference .telemetry .InferenceTimer ;
30
35
31
- import java .util .Set ;
32
36
import java .util .stream .Collectors ;
33
37
34
38
import static org .elasticsearch .core .Strings .format ;
39
+ import static org .elasticsearch .xpack .inference .telemetry .InferenceStats .modelAttributes ;
40
+ import static org .elasticsearch .xpack .inference .telemetry .InferenceStats .responseAttributes ;
35
41
36
42
public class TransportInferenceAction extends HandledTransportAction <InferenceAction .Request , InferenceAction .Response > {
43
+ private static final Logger log = LogManager .getLogger (TransportInferenceAction .class );
37
44
private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference" ;
38
45
private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]" ;
39
46
40
- private static final Set <Class <? extends InferenceService >> supportsStreaming = Set .of ();
41
-
42
47
private final ModelRegistry modelRegistry ;
43
48
private final InferenceServiceRegistry serviceRegistry ;
44
49
private final InferenceStats inferenceStats ;
@@ -62,17 +67,22 @@ public TransportInferenceAction(
62
67
63
68
@ Override
64
69
protected void doExecute (Task task , InferenceAction .Request request , ActionListener <InferenceAction .Response > listener ) {
70
+ var timer = InferenceTimer .start ();
65
71
66
- ActionListener < UnparsedModel > getModelListener = listener . delegateFailureAndWrap (( delegate , unparsedModel ) -> {
72
+ var getModelListener = ActionListener . wrap (( UnparsedModel unparsedModel ) -> {
67
73
var service = serviceRegistry .getService (unparsedModel .service ());
68
74
if (service .isEmpty ()) {
69
- listener .onFailure (unknownServiceException (unparsedModel .service (), request .getInferenceEntityId ()));
75
+ var e = unknownServiceException (unparsedModel .service (), request .getInferenceEntityId ());
76
+ recordMetrics (unparsedModel , timer , e );
77
+ listener .onFailure (e );
70
78
return ;
71
79
}
72
80
73
81
if (request .getTaskType ().isAnyOrSame (unparsedModel .taskType ()) == false ) {
74
82
// not the wildcard task type and not the model task type
75
- listener .onFailure (incompatibleTaskTypeException (request .getTaskType (), unparsedModel .taskType ()));
83
+ var e = incompatibleTaskTypeException (request .getTaskType (), unparsedModel .taskType ());
84
+ recordMetrics (unparsedModel , timer , e );
85
+ listener .onFailure (e );
76
86
return ;
77
87
}
78
88
@@ -83,20 +93,69 @@ protected void doExecute(Task task, InferenceAction.Request request, ActionListe
83
93
unparsedModel .settings (),
84
94
unparsedModel .secrets ()
85
95
);
86
- inferOnService (model , request , service .get (), delegate );
96
+ inferOnServiceWithMetrics (model , request , service .get (), timer , listener );
97
+ }, e -> {
98
+ try {
99
+ inferenceStats .inferenceDuration ().record (timer .elapsedMillis (), responseAttributes (e ));
100
+ } catch (Exception metricsException ) {
101
+ log .atDebug ().withThrowable (metricsException ).log ("Failed to record metrics when the model is missing, dropping metrics" );
102
+ }
103
+ listener .onFailure (e );
87
104
});
88
105
89
106
modelRegistry .getModelWithSecrets (request .getInferenceEntityId (), getModelListener );
90
107
}
91
108
92
- private void inferOnService (
109
+ private void recordMetrics (UnparsedModel model , InferenceTimer timer , @ Nullable Throwable t ) {
110
+ try {
111
+ inferenceStats .inferenceDuration ().record (timer .elapsedMillis (), responseAttributes (model , t ));
112
+ } catch (Exception e ) {
113
+ log .atDebug ().withThrowable (e ).log ("Failed to record metrics with an unparsed model, dropping metrics" );
114
+ }
115
+ }
116
+
117
+ private void inferOnServiceWithMetrics (
93
118
Model model ,
94
119
InferenceAction .Request request ,
95
120
InferenceService service ,
121
+ InferenceTimer timer ,
96
122
ActionListener <InferenceAction .Response > listener
123
+ ) {
124
+ inferenceStats .requestCount ().incrementBy (1 , modelAttributes (model ));
125
+ inferOnService (model , request , service , ActionListener .wrap (inferenceResults -> {
126
+ if (request .isStreaming ()) {
127
+ var taskProcessor = streamingTaskManager .<ChunkedToXContent >create (STREAMING_INFERENCE_TASK_TYPE , STREAMING_TASK_ACTION );
128
+ inferenceResults .publisher ().subscribe (taskProcessor );
129
+
130
+ var instrumentedStream = new PublisherWithMetrics (timer , model );
131
+ taskProcessor .subscribe (instrumentedStream );
132
+
133
+ listener .onResponse (new InferenceAction .Response (inferenceResults , instrumentedStream ));
134
+ } else {
135
+ recordMetrics (model , timer , null );
136
+ listener .onResponse (new InferenceAction .Response (inferenceResults ));
137
+ }
138
+ }, e -> {
139
+ recordMetrics (model , timer , e );
140
+ listener .onFailure (e );
141
+ }));
142
+ }
143
+
144
+ private void recordMetrics (Model model , InferenceTimer timer , @ Nullable Throwable t ) {
145
+ try {
146
+ inferenceStats .inferenceDuration ().record (timer .elapsedMillis (), responseAttributes (model , t ));
147
+ } catch (Exception e ) {
148
+ log .atDebug ().withThrowable (e ).log ("Failed to record metrics with a parsed model, dropping metrics" );
149
+ }
150
+ }
151
+
152
+ private void inferOnService (
153
+ Model model ,
154
+ InferenceAction .Request request ,
155
+ InferenceService service ,
156
+ ActionListener <InferenceServiceResults > listener
97
157
) {
98
158
if (request .isStreaming () == false || service .canStream (request .getTaskType ())) {
99
- inferenceStats .incrementRequestCount (model );
100
159
service .infer (
101
160
model ,
102
161
request .getQuery (),
@@ -105,7 +164,7 @@ private void inferOnService(
105
164
request .getTaskSettings (),
106
165
request .getInputType (),
107
166
request .getInferenceTimeout (),
108
- createListener ( request , listener )
167
+ listener
109
168
);
110
169
} else {
111
170
listener .onFailure (unsupportedStreamingTaskException (request , service ));
@@ -133,20 +192,6 @@ private ElasticsearchStatusException unsupportedStreamingTaskException(Inference
133
192
}
134
193
}
135
194
136
- private ActionListener <InferenceServiceResults > createListener (
137
- InferenceAction .Request request ,
138
- ActionListener <InferenceAction .Response > listener
139
- ) {
140
- if (request .isStreaming ()) {
141
- return listener .delegateFailureAndWrap ((l , inferenceResults ) -> {
142
- var taskProcessor = streamingTaskManager .<ChunkedToXContent >create (STREAMING_INFERENCE_TASK_TYPE , STREAMING_TASK_ACTION );
143
- inferenceResults .publisher ().subscribe (taskProcessor );
144
- l .onResponse (new InferenceAction .Response (inferenceResults , taskProcessor ));
145
- });
146
- }
147
- return listener .delegateFailureAndWrap ((l , inferenceResults ) -> l .onResponse (new InferenceAction .Response (inferenceResults )));
148
- }
149
-
150
195
private static ElasticsearchStatusException unknownServiceException (String service , String inferenceId ) {
151
196
return new ElasticsearchStatusException ("Unknown service [{}] for model [{}]. " , RestStatus .BAD_REQUEST , service , inferenceId );
152
197
}
@@ -160,4 +205,37 @@ private static ElasticsearchStatusException incompatibleTaskTypeException(TaskTy
160
205
);
161
206
}
162
207
208
+ private class PublisherWithMetrics extends DelegatingProcessor <ChunkedToXContent , ChunkedToXContent > {
209
+ private final InferenceTimer timer ;
210
+ private final Model model ;
211
+
212
+ private PublisherWithMetrics (InferenceTimer timer , Model model ) {
213
+ this .timer = timer ;
214
+ this .model = model ;
215
+ }
216
+
217
+ @ Override
218
+ protected void next (ChunkedToXContent item ) {
219
+ downstream ().onNext (item );
220
+ }
221
+
222
+ @ Override
223
+ public void onError (Throwable throwable ) {
224
+ recordMetrics (model , timer , throwable );
225
+ super .onError (throwable );
226
+ }
227
+
228
+ @ Override
229
+ protected void onCancel () {
230
+ recordMetrics (model , timer , null );
231
+ super .onCancel ();
232
+ }
233
+
234
+ @ Override
235
+ public void onComplete () {
236
+ recordMetrics (model , timer , null );
237
+ super .onComplete ();
238
+ }
239
+ }
240
+
163
241
}
0 commit comments