Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/88945.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 88945
summary: Address potential bug where trained models get stuck in starting after being
allocated to node
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,13 @@ void loadQueuedModels() {
} catch (Exception ex) {
logger.warn(() -> "[" + modelId + "] Start deployment failed", ex);
if (ExceptionsHelper.unwrapCause(ex) instanceof ResourceNotFoundException) {
logger.warn(() -> "[" + modelId + "] Start deployment failed", ex);
logger.debug(() -> "[" + modelId + "] Start deployment failed as model was not found", ex);
handleLoadFailure(loadingTask, ExceptionsHelper.missingTrainedModel(modelId, ex));
} else if (ExceptionsHelper.unwrapCause(ex) instanceof SearchPhaseExecutionException) {
logger.trace(() -> "[" + modelId + "] Start deployment failed, will retry", ex);
logger.debug(() -> "[" + modelId + "] Start deployment failed, will retry", ex);
// A search phase execution failure should be retried, push task back to the queue
loadingToRetry.add(loadingTask);
} else {
logger.warn(() -> "[" + modelId + "] Start deployment failed", ex);
handleLoadFailure(loadingTask, ex);
}
}
Expand Down Expand Up @@ -413,7 +412,7 @@ private void updateNumberOfAllocations(TrainedModelAssignmentMetadata assignment
for (TrainedModelAssignment assignment : modelsToUpdate) {
TrainedModelDeploymentTask task = modelIdToTask.get(assignment.getModelId());
if (task == null) {
logger.debug(() -> format("[%s] task was removed whilst updating number of allocations", task.getModelId()));
logger.debug(() -> format("[%s] task was removed whilst updating number of allocations", assignment.getModelId()));
continue;
}
RoutingInfo routingInfo = assignment.getNodeRoutingTable().get(nodeId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;

Expand Down Expand Up @@ -149,34 +150,46 @@ private void doStartDeployment(TrainedModelDeploymentTask task, ActionListener<T
TrainedModelConfig modelConfig = getModelResponse.getResources().results().get(0);
processContext.modelInput.set(modelConfig.getInput());

assert modelConfig.getInferenceConfig() instanceof NlpConfig;
NlpConfig nlpConfig = (NlpConfig) modelConfig.getInferenceConfig();
task.init(nlpConfig);

SearchRequest searchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig(), modelConfig.getModelId());
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(searchVocabResponse -> {
if (searchVocabResponse.getHits().getHits().length == 0) {
listener.onFailure(
new ResourceNotFoundException(
Messages.getMessage(
Messages.VOCABULARY_NOT_FOUND,
task.getModelId(),
VocabularyConfig.docId(modelConfig.getModelId())
if (modelConfig.getInferenceConfig()instanceof NlpConfig nlpConfig) {
task.init(nlpConfig);

SearchRequest searchRequest = vocabSearchRequest(nlpConfig.getVocabularyConfig(), modelConfig.getModelId());
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(searchVocabResponse -> {
if (searchVocabResponse.getHits().getHits().length == 0) {
listener.onFailure(
new ResourceNotFoundException(
Messages.getMessage(
Messages.VOCABULARY_NOT_FOUND,
task.getModelId(),
VocabularyConfig.docId(modelConfig.getModelId())
)
)
)
);
return;
}

Vocabulary vocabulary = parseVocabularyDocLeniently(searchVocabResponse.getHits().getAt(0));
NlpTask nlpTask = new NlpTask(nlpConfig, vocabulary);
NlpTask.Processor processor = nlpTask.createProcessor();
processContext.nlpTaskProcessor.set(processor);
// here, we are being called back on the searching thread, which MAY be a network thread
// `startAndLoad` creates named pipes, blocking the calling thread, better to execute that in our utility
// executor.
executorServiceForDeployment.execute(
() -> startAndLoad(processContext, modelConfig.getLocation(), modelLoadedListener)
);
return;
}

Vocabulary vocabulary = parseVocabularyDocLeniently(searchVocabResponse.getHits().getAt(0));
NlpTask nlpTask = new NlpTask(nlpConfig, vocabulary);
NlpTask.Processor processor = nlpTask.createProcessor();
processContext.nlpTaskProcessor.set(processor);
// here, we are being called back on the searching thread, which MAY be a network thread
// `startAndLoad` creates named pipes, blocking the calling thread, better to execute that in our utility
// executor.
executorServiceForDeployment.execute(() -> startAndLoad(processContext, modelConfig.getLocation(), modelLoadedListener));
}, listener::onFailure));
}, listener::onFailure));
} else {
listener.onFailure(
new IllegalArgumentException(
format(
"[%s] must be an pytorch model found inference config of kind [%s]",
modelConfig.getModelId(),
modelConfig.getInferenceConfig().getWriteableName()
)
)
);
}
}, listener::onFailure);

executeAsyncWithOrigin(
Expand Down Expand Up @@ -404,10 +417,12 @@ private Consumer<String> onProcessCrash() {
}

void loadModel(TrainedModelLocation modelLocation, ActionListener<Boolean> listener) {
if (modelLocation instanceof IndexLocation) {
process.get().loadModel(task.getModelId(), ((IndexLocation) modelLocation).getIndexName(), stateStreamer, listener);
if (modelLocation instanceof IndexLocation indexLocation) {
process.get().loadModel(task.getModelId(), indexLocation.getIndexName(), stateStreamer, listener);
} else {
throw new IllegalStateException("unsupported trained model location [" + modelLocation.getClass().getSimpleName() + "]");
listener.onFailure(
new IllegalStateException("unsupported trained model location [" + modelLocation.getClass().getSimpleName() + "]")
);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.core.CheckedFunction;
Expand All @@ -38,8 +38,10 @@
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
import static org.elasticsearch.xpack.ml.MachineLearning.NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME;
import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;

/**
* Searches for and emits {@link TrainedModelDefinitionDoc}s in
Expand Down Expand Up @@ -71,7 +73,7 @@ public ChunkedTrainedModelRestorer(
ExecutorService executorService,
NamedXContentRegistry xContentRegistry
) {
this.client = client;
this.client = new OriginSettingClient(client, ML_ORIGIN);
this.executorService = executorService;
this.xContentRegistry = xContentRegistry;
this.modelId = modelId;
Expand Down Expand Up @@ -122,7 +124,6 @@ public void restoreModelDefinition(

logger.debug("[{}] restoring model", modelId);
SearchRequest searchRequest = buildSearch(client, modelId, index, searchSize, null);

executorService.execute(() -> doSearch(searchRequest, modelConsumer, successConsumer, errorConsumer));
}

Expand All @@ -132,8 +133,16 @@ private void doSearch(
Consumer<Boolean> successConsumer,
Consumer<Exception> errorConsumer
) {

executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(searchResponse -> {
try {
assert Thread.currentThread().getName().contains(NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME)
|| Thread.currentThread().getName().contains(UTILITY_THREAD_POOL_NAME)
: format(
"Must execute from [%s] or [%s] but thread is [%s]",
NATIVE_INFERENCE_COMMS_THREAD_POOL_NAME,
UTILITY_THREAD_POOL_NAME,
Thread.currentThread().getName()
);
SearchResponse searchResponse = client.search(searchRequest).actionGet();
if (searchResponse.getHits().getHits().length == 0) {
errorConsumer.accept(new ResourceNotFoundException(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
return;
Expand Down Expand Up @@ -182,13 +191,13 @@ private void doSearch(
searchRequestBuilder.searchAfter(new Object[] { lastHit.getIndex(), lastNum });
executorService.execute(() -> doSearch(searchRequestBuilder.request(), modelConsumer, successConsumer, errorConsumer));
}
}, e -> {
} catch (Exception e) {
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
errorConsumer.accept(new ResourceNotFoundException(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
} else {
errorConsumer.accept(e);
}
}));
}
}

private static SearchRequestBuilder buildSearchBuilder(Client client, String modelId, String index, int searchSize) {
Expand Down