88import org .apache .logging .log4j .LogManager ;
99import org .apache .logging .log4j .Logger ;
1010import org .apache .logging .log4j .message .ParameterizedMessage ;
11- import org .elasticsearch .Version ;
12- import org .elasticsearch .action .ActionListener ;
13- import org .elasticsearch .action .LatchedActionListener ;
1411import org .elasticsearch .common .Nullable ;
15- import org .elasticsearch .common .xcontent .XContentHelper ;
16- import org .elasticsearch .common .xcontent .json .JsonXContent ;
17- import org .elasticsearch .license .License ;
1812import org .elasticsearch .xpack .core .ml .dataframe .DataFrameAnalyticsConfig ;
19- import org .elasticsearch .xpack .core .ml .dataframe .analyses .Classification ;
20- import org .elasticsearch .xpack .core .ml .dataframe .analyses .Regression ;
2113import org .elasticsearch .xpack .core .ml .dataframe .stats .classification .ClassificationStats ;
2214import org .elasticsearch .xpack .core .ml .dataframe .stats .common .MemoryUsage ;
2315import org .elasticsearch .xpack .core .ml .dataframe .stats .outlierdetection .OutlierDetectionStats ;
2416import org .elasticsearch .xpack .core .ml .dataframe .stats .regression .RegressionStats ;
25- import org .elasticsearch .xpack .core .ml .inference .TrainedModelConfig ;
26- import org .elasticsearch .xpack .core .ml .inference .TrainedModelDefinition ;
27- import org .elasticsearch .xpack .core .ml .inference .TrainedModelInput ;
28- import org .elasticsearch .xpack .core .ml .inference .trainedmodel .ClassificationConfig ;
29- import org .elasticsearch .xpack .core .ml .inference .trainedmodel .InferenceConfig ;
30- import org .elasticsearch .xpack .core .ml .inference .trainedmodel .PredictionFieldType ;
31- import org .elasticsearch .xpack .core .ml .inference .trainedmodel .RegressionConfig ;
32- import org .elasticsearch .xpack .core .ml .inference .trainedmodel .TargetType ;
3317import org .elasticsearch .xpack .core .ml .job .messages .Messages ;
3418import org .elasticsearch .xpack .core .ml .utils .ExceptionsHelper ;
3519import org .elasticsearch .xpack .core .ml .utils .PhaseProgress ;
36- import org .elasticsearch .xpack .core .security .user .XPackUser ;
3720import org .elasticsearch .xpack .ml .dataframe .process .results .AnalyticsResult ;
3821import org .elasticsearch .xpack .ml .dataframe .process .results .RowResults ;
22+ import org .elasticsearch .xpack .ml .dataframe .process .results .TrainedModelDefinitionChunk ;
3923import org .elasticsearch .xpack .ml .dataframe .stats .StatsHolder ;
4024import org .elasticsearch .xpack .ml .dataframe .stats .StatsPersister ;
4125import org .elasticsearch .xpack .ml .extractor .ExtractedField ;
42- import org .elasticsearch .xpack .ml .extractor . MultiField ;
26+ import org .elasticsearch .xpack .ml .inference . modelsize . ModelSizeInfo ;
4327import org .elasticsearch .xpack .ml .inference .persistence .TrainedModelProvider ;
4428import org .elasticsearch .xpack .ml .notifications .DataFrameAnalyticsAuditor ;
4529
46- import java .time .Instant ;
4730import java .util .Collections ;
4831import java .util .Iterator ;
4932import java .util .List ;
50- import java .util .Map ;
5133import java .util .Objects ;
52- import java .util .Optional ;
5334import java .util .concurrent .CountDownLatch ;
54- import java .util .concurrent .TimeUnit ;
55- import java .util .stream .Collectors ;
5635
57- import static java .util .stream .Collectors .toList ;
5836
5937public class AnalyticsResultProcessor {
6038
@@ -80,6 +58,7 @@ public class AnalyticsResultProcessor {
8058 private final StatsPersister statsPersister ;
8159 private final List <ExtractedField > fieldNames ;
8260 private final CountDownLatch completionLatch = new CountDownLatch (1 );
61+ private final ChunkedTrainedModelPersister chunkedTrainedModelPersister ;
8362 private volatile String failure ;
8463 private volatile boolean isCancelled ;
8564
@@ -93,6 +72,13 @@ public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRow
9372 this .auditor = Objects .requireNonNull (auditor );
9473 this .statsPersister = Objects .requireNonNull (statsPersister );
9574 this .fieldNames = Collections .unmodifiableList (Objects .requireNonNull (fieldNames ));
75+ this .chunkedTrainedModelPersister = new ChunkedTrainedModelPersister (
76+ trainedModelProvider ,
77+ analytics ,
78+ auditor ,
79+ this ::setAndReportFailure ,
80+ fieldNames
81+ );
9682 }
9783
9884 @ Nullable
@@ -171,9 +157,13 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo
171157 phaseProgress .getProgressPercent ());
172158 statsHolder .getProgressTracker ().updatePhase (phaseProgress );
173159 }
174- TrainedModelDefinition .Builder inferenceModelBuilder = result .getInferenceModelBuilder ();
175- if (inferenceModelBuilder != null ) {
176- createAndIndexInferenceModel (inferenceModelBuilder );
160+ ModelSizeInfo modelSize = result .getModelSizeInfo ();
161+ if (modelSize != null ) {
162+ chunkedTrainedModelPersister .createAndIndexInferenceModelMetadata (modelSize );
163+ }
164+ TrainedModelDefinitionChunk trainedModelDefinitionChunk = result .getTrainedModelDefinitionChunk ();
165+ if (trainedModelDefinitionChunk != null ) {
166+ chunkedTrainedModelPersister .createAndIndexInferenceModelDoc (trainedModelDefinitionChunk );
177167 }
178168 MemoryUsage memoryUsage = result .getMemoryUsage ();
179169 if (memoryUsage != null ) {
@@ -197,117 +187,6 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo
197187 }
198188 }
199189
200- private void createAndIndexInferenceModel (TrainedModelDefinition .Builder inferenceModel ) {
201- TrainedModelConfig trainedModelConfig = createTrainedModelConfig (inferenceModel );
202- CountDownLatch latch = storeTrainedModel (trainedModelConfig );
203-
204- try {
205- if (latch .await (30 , TimeUnit .SECONDS ) == false ) {
206- LOGGER .error ("[{}] Timed out (30s) waiting for inference model to be stored" , analytics .getId ());
207- }
208- } catch (InterruptedException e ) {
209- Thread .currentThread ().interrupt ();
210- setAndReportFailure (ExceptionsHelper .serverError ("interrupted waiting for inference model to be stored" ));
211- }
212- }
213-
214- private TrainedModelConfig createTrainedModelConfig (TrainedModelDefinition .Builder inferenceModel ) {
215- Instant createTime = Instant .now ();
216- String modelId = analytics .getId () + "-" + createTime .toEpochMilli ();
217- TrainedModelDefinition definition = inferenceModel .build ();
218- String dependentVariable = getDependentVariable ();
219- List <String > fieldNamesWithoutDependentVariable = fieldNames .stream ()
220- .map (ExtractedField ::getName )
221- .filter (f -> f .equals (dependentVariable ) == false )
222- .collect (toList ());
223- Map <String , String > defaultFieldMapping = fieldNames .stream ()
224- .filter (ef -> ef instanceof MultiField && (ef .getName ().equals (dependentVariable ) == false ))
225- .collect (Collectors .toMap (ExtractedField ::getParentField , ExtractedField ::getName ));
226- return TrainedModelConfig .builder ()
227- .setModelId (modelId )
228- .setCreatedBy (XPackUser .NAME )
229- .setVersion (Version .CURRENT )
230- .setCreateTime (createTime )
231- // NOTE: GET _cat/ml/trained_models relies on the creating analytics ID being in the tags
232- .setTags (Collections .singletonList (analytics .getId ()))
233- .setDescription (analytics .getDescription ())
234- .setMetadata (Collections .singletonMap ("analytics_config" ,
235- XContentHelper .convertToMap (JsonXContent .jsonXContent , analytics .toString (), true )))
236- .setEstimatedHeapMemory (definition .ramBytesUsed ())
237- .setEstimatedOperations (definition .getTrainedModel ().estimatedNumOperations ())
238- .setParsedDefinition (inferenceModel )
239- .setInput (new TrainedModelInput (fieldNamesWithoutDependentVariable ))
240- .setLicenseLevel (License .OperationMode .PLATINUM .description ())
241- .setDefaultFieldMap (defaultFieldMapping )
242- .setInferenceConfig (buildInferenceConfig (definition .getTrainedModel ().targetType ()))
243- .build ();
244- }
245-
246- private InferenceConfig buildInferenceConfig (TargetType targetType ) {
247- switch (targetType ) {
248- case CLASSIFICATION :
249- assert analytics .getAnalysis () instanceof Classification ;
250- Classification classification = ((Classification )analytics .getAnalysis ());
251- PredictionFieldType predictionFieldType = getPredictionFieldType (classification );
252- return ClassificationConfig .builder ()
253- .setNumTopClasses (classification .getNumTopClasses ())
254- .setNumTopFeatureImportanceValues (classification .getBoostedTreeParams ().getNumTopFeatureImportanceValues ())
255- .setPredictionFieldType (predictionFieldType )
256- .build ();
257- case REGRESSION :
258- assert analytics .getAnalysis () instanceof Regression ;
259- Regression regression = ((Regression )analytics .getAnalysis ());
260- return RegressionConfig .builder ()
261- .setNumTopFeatureImportanceValues (regression .getBoostedTreeParams ().getNumTopFeatureImportanceValues ())
262- .build ();
263- default :
264- throw ExceptionsHelper .serverError (
265- "process created a model with an unsupported target type [{}]" ,
266- null ,
267- targetType );
268- }
269- }
270-
271- PredictionFieldType getPredictionFieldType (Classification classification ) {
272- String dependentVariable = classification .getDependentVariable ();
273- Optional <ExtractedField > extractedField = fieldNames .stream ()
274- .filter (f -> f .getName ().equals (dependentVariable ))
275- .findAny ();
276- PredictionFieldType predictionFieldType = Classification .getPredictionFieldType (
277- extractedField .isPresent () ? extractedField .get ().getTypes () : null
278- );
279- return predictionFieldType == null ? PredictionFieldType .STRING : predictionFieldType ;
280- }
281-
282- private String getDependentVariable () {
283- if (analytics .getAnalysis () instanceof Classification ) {
284- return ((Classification )analytics .getAnalysis ()).getDependentVariable ();
285- }
286- if (analytics .getAnalysis () instanceof Regression ) {
287- return ((Regression )analytics .getAnalysis ()).getDependentVariable ();
288- }
289- return null ;
290- }
291-
292- private CountDownLatch storeTrainedModel (TrainedModelConfig trainedModelConfig ) {
293- CountDownLatch latch = new CountDownLatch (1 );
294- ActionListener <Boolean > storeListener = ActionListener .wrap (
295- aBoolean -> {
296- if (aBoolean == false ) {
297- LOGGER .error ("[{}] Storing trained model responded false" , analytics .getId ());
298- setAndReportFailure (ExceptionsHelper .serverError ("storing trained model responded false" ));
299- } else {
300- LOGGER .info ("[{}] Stored trained model with id [{}]" , analytics .getId (), trainedModelConfig .getModelId ());
301- auditor .info (analytics .getId (), "Stored trained model with id [" + trainedModelConfig .getModelId () + "]" );
302- }
303- },
304- e -> setAndReportFailure (ExceptionsHelper .serverError ("error storing trained model with id [{}]" , e ,
305- trainedModelConfig .getModelId ()))
306- );
307- trainedModelProvider .storeTrainedModel (trainedModelConfig , new LatchedActionListener <>(storeListener , latch ));
308- return latch ;
309- }
310-
311190 private void setAndReportFailure (Exception e ) {
312191 LOGGER .error (new ParameterizedMessage ("[{}] Error processing results; " , analytics .getId ()), e );
313192 failure = "error processing results; " + e .getMessage ();
0 commit comments