7
7
8
8
package org .elasticsearch .xpack .inference .action ;
9
9
10
+ import org .apache .logging .log4j .LogManager ;
11
+ import org .apache .logging .log4j .Logger ;
10
12
import org .elasticsearch .ElasticsearchStatusException ;
11
13
import org .elasticsearch .action .ActionListener ;
12
14
import org .elasticsearch .action .support .ActionFilters ;
13
15
import org .elasticsearch .action .support .HandledTransportAction ;
14
16
import org .elasticsearch .common .logging .DeprecationLogger ;
15
17
import org .elasticsearch .common .util .concurrent .EsExecutors ;
16
18
import org .elasticsearch .common .xcontent .ChunkedToXContent ;
19
+ import org .elasticsearch .core .Nullable ;
17
20
import org .elasticsearch .inference .InferenceService ;
18
21
import org .elasticsearch .inference .InferenceServiceRegistry ;
19
22
import org .elasticsearch .inference .InferenceServiceResults ;
26
29
import org .elasticsearch .transport .TransportService ;
27
30
import org .elasticsearch .xpack .core .inference .action .InferenceAction ;
28
31
import org .elasticsearch .xpack .inference .action .task .StreamingTaskManager ;
32
+ import org .elasticsearch .xpack .inference .common .DelegatingProcessor ;
29
33
import org .elasticsearch .xpack .inference .registry .ModelRegistry ;
30
34
import org .elasticsearch .xpack .inference .telemetry .InferenceStats ;
35
+ import org .elasticsearch .xpack .inference .telemetry .InferenceTimer ;
31
36
32
- import java .util .Set ;
33
37
import java .util .stream .Collectors ;
34
38
35
39
import static org .elasticsearch .core .Strings .format ;
40
+ import static org .elasticsearch .xpack .inference .telemetry .InferenceStats .modelAttributes ;
41
+ import static org .elasticsearch .xpack .inference .telemetry .InferenceStats .responseAttributes ;
36
42
37
43
public class TransportInferenceAction extends HandledTransportAction <InferenceAction .Request , InferenceAction .Response > {
44
+ private static final Logger log = LogManager .getLogger (TransportInferenceAction .class );
38
45
private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference" ;
39
46
private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]" ;
40
47
41
- private static final Set <Class <? extends InferenceService >> supportsStreaming = Set .of ();
42
-
43
48
private final ModelRegistry modelRegistry ;
44
49
private final InferenceServiceRegistry serviceRegistry ;
45
50
private final InferenceStats inferenceStats ;
@@ -64,17 +69,22 @@ public TransportInferenceAction(
64
69
65
70
@ Override
66
71
protected void doExecute (Task task , InferenceAction .Request request , ActionListener <InferenceAction .Response > listener ) {
72
+ var timer = InferenceTimer .start ();
67
73
68
- ActionListener < UnparsedModel > getModelListener = listener . delegateFailureAndWrap (( delegate , unparsedModel ) -> {
74
+ var getModelListener = ActionListener . wrap (( UnparsedModel unparsedModel ) -> {
69
75
var service = serviceRegistry .getService (unparsedModel .service ());
70
76
if (service .isEmpty ()) {
71
- listener .onFailure (unknownServiceException (unparsedModel .service (), request .getInferenceEntityId ()));
77
+ var e = unknownServiceException (unparsedModel .service (), request .getInferenceEntityId ());
78
+ recordMetrics (unparsedModel , timer , e );
79
+ listener .onFailure (e );
72
80
return ;
73
81
}
74
82
75
83
if (request .getTaskType ().isAnyOrSame (unparsedModel .taskType ()) == false ) {
76
84
// not the wildcard task type and not the model task type
77
- listener .onFailure (incompatibleTaskTypeException (request .getTaskType (), unparsedModel .taskType ()));
85
+ var e = incompatibleTaskTypeException (request .getTaskType (), unparsedModel .taskType ());
86
+ recordMetrics (unparsedModel , timer , e );
87
+ listener .onFailure (e );
78
88
return ;
79
89
}
80
90
@@ -85,20 +95,69 @@ protected void doExecute(Task task, InferenceAction.Request request, ActionListe
85
95
unparsedModel .settings (),
86
96
unparsedModel .secrets ()
87
97
);
88
- inferOnService (model , request , service .get (), delegate );
98
+ inferOnServiceWithMetrics (model , request , service .get (), timer , listener );
99
+ }, e -> {
100
+ try {
101
+ inferenceStats .inferenceDuration ().record (timer .elapsedMillis (), responseAttributes (e ));
102
+ } catch (Exception metricsException ) {
103
+ log .atDebug ().withThrowable (metricsException ).log ("Failed to record metrics when the model is missing, dropping metrics" );
104
+ }
105
+ listener .onFailure (e );
89
106
});
90
107
91
108
modelRegistry .getModelWithSecrets (request .getInferenceEntityId (), getModelListener );
92
109
}
93
110
94
- private void inferOnService (
111
+ private void recordMetrics (UnparsedModel model , InferenceTimer timer , @ Nullable Throwable t ) {
112
+ try {
113
+ inferenceStats .inferenceDuration ().record (timer .elapsedMillis (), responseAttributes (model , t ));
114
+ } catch (Exception e ) {
115
+ log .atDebug ().withThrowable (e ).log ("Failed to record metrics with an unparsed model, dropping metrics" );
116
+ }
117
+ }
118
+
119
+ private void inferOnServiceWithMetrics (
95
120
Model model ,
96
121
InferenceAction .Request request ,
97
122
InferenceService service ,
123
+ InferenceTimer timer ,
98
124
ActionListener <InferenceAction .Response > listener
125
+ ) {
126
+ inferenceStats .requestCount ().incrementBy (1 , modelAttributes (model ));
127
+ inferOnService (model , request , service , ActionListener .wrap (inferenceResults -> {
128
+ if (request .isStreaming ()) {
129
+ var taskProcessor = streamingTaskManager .<ChunkedToXContent >create (STREAMING_INFERENCE_TASK_TYPE , STREAMING_TASK_ACTION );
130
+ inferenceResults .publisher ().subscribe (taskProcessor );
131
+
132
+ var instrumentedStream = new PublisherWithMetrics (timer , model );
133
+ taskProcessor .subscribe (instrumentedStream );
134
+
135
+ listener .onResponse (new InferenceAction .Response (inferenceResults , instrumentedStream ));
136
+ } else {
137
+ recordMetrics (model , timer , null );
138
+ listener .onResponse (new InferenceAction .Response (inferenceResults ));
139
+ }
140
+ }, e -> {
141
+ recordMetrics (model , timer , e );
142
+ listener .onFailure (e );
143
+ }));
144
+ }
145
+
146
+ private void recordMetrics (Model model , InferenceTimer timer , @ Nullable Throwable t ) {
147
+ try {
148
+ inferenceStats .inferenceDuration ().record (timer .elapsedMillis (), responseAttributes (model , t ));
149
+ } catch (Exception e ) {
150
+ log .atDebug ().withThrowable (e ).log ("Failed to record metrics with a parsed model, dropping metrics" );
151
+ }
152
+ }
153
+
154
+ private void inferOnService (
155
+ Model model ,
156
+ InferenceAction .Request request ,
157
+ InferenceService service ,
158
+ ActionListener <InferenceServiceResults > listener
99
159
) {
100
160
if (request .isStreaming () == false || service .canStream (request .getTaskType ())) {
101
- inferenceStats .incrementRequestCount (model );
102
161
service .infer (
103
162
model ,
104
163
request .getQuery (),
@@ -107,7 +166,7 @@ private void inferOnService(
107
166
request .getTaskSettings (),
108
167
request .getInputType (),
109
168
request .getInferenceTimeout (),
110
- createListener ( request , listener )
169
+ listener
111
170
);
112
171
} else {
113
172
listener .onFailure (unsupportedStreamingTaskException (request , service ));
@@ -135,20 +194,6 @@ private ElasticsearchStatusException unsupportedStreamingTaskException(Inference
135
194
}
136
195
}
137
196
138
- private ActionListener <InferenceServiceResults > createListener (
139
- InferenceAction .Request request ,
140
- ActionListener <InferenceAction .Response > listener
141
- ) {
142
- if (request .isStreaming ()) {
143
- return listener .delegateFailureAndWrap ((l , inferenceResults ) -> {
144
- var taskProcessor = streamingTaskManager .<ChunkedToXContent >create (STREAMING_INFERENCE_TASK_TYPE , STREAMING_TASK_ACTION );
145
- inferenceResults .publisher ().subscribe (taskProcessor );
146
- l .onResponse (new InferenceAction .Response (inferenceResults , taskProcessor ));
147
- });
148
- }
149
- return listener .delegateFailureAndWrap ((l , inferenceResults ) -> l .onResponse (new InferenceAction .Response (inferenceResults )));
150
- }
151
-
152
197
private static ElasticsearchStatusException unknownServiceException (String service , String inferenceId ) {
153
198
return new ElasticsearchStatusException ("Unknown service [{}] for model [{}]. " , RestStatus .BAD_REQUEST , service , inferenceId );
154
199
}
@@ -162,4 +207,37 @@ private static ElasticsearchStatusException incompatibleTaskTypeException(TaskTy
162
207
);
163
208
}
164
209
210
+ private class PublisherWithMetrics extends DelegatingProcessor <ChunkedToXContent , ChunkedToXContent > {
211
+ private final InferenceTimer timer ;
212
+ private final Model model ;
213
+
214
+ private PublisherWithMetrics (InferenceTimer timer , Model model ) {
215
+ this .timer = timer ;
216
+ this .model = model ;
217
+ }
218
+
219
+ @ Override
220
+ protected void next (ChunkedToXContent item ) {
221
+ downstream ().onNext (item );
222
+ }
223
+
224
+ @ Override
225
+ public void onError (Throwable throwable ) {
226
+ recordMetrics (model , timer , throwable );
227
+ super .onError (throwable );
228
+ }
229
+
230
+ @ Override
231
+ protected void onCancel () {
232
+ recordMetrics (model , timer , null );
233
+ super .onCancel ();
234
+ }
235
+
236
+ @ Override
237
+ public void onComplete () {
238
+ recordMetrics (model , timer , null );
239
+ super .onComplete ();
240
+ }
241
+ }
242
+
165
243
}
0 commit comments