Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ public static class Request extends MasterNodeRequest<Request> implements ToXCon
public static final ParseField WAIT_FOR = new ParseField("wait_for");
public static final ParseField INFERENCE_THREADS = TaskParams.INFERENCE_THREADS;
public static final ParseField MODEL_THREADS = TaskParams.MODEL_THREADS;
public static final ParseField QUEUE_CAPACITY = TaskParams.QUEUE_CAPACITY;

public static final ObjectParser<Request, Void> PARSER = new ObjectParser<>(NAME, Request::new);

Expand All @@ -70,7 +69,6 @@ public static class Request extends MasterNodeRequest<Request> implements ToXCon
PARSER.declareString((request, waitFor) -> request.setWaitForState(AllocationStatus.State.fromString(waitFor)), WAIT_FOR);
PARSER.declareInt(Request::setInferenceThreads, INFERENCE_THREADS);
PARSER.declareInt(Request::setModelThreads, MODEL_THREADS);
PARSER.declareInt(Request::setQueueCapacity, QUEUE_CAPACITY);
}

public static Request parseRequest(String modelId, XContentParser parser) {
Expand All @@ -89,7 +87,6 @@ public static Request parseRequest(String modelId, XContentParser parser) {
private AllocationStatus.State waitForState = AllocationStatus.State.STARTED;
private int modelThreads = 1;
private int inferenceThreads = 1;
private int queueCapacity = 1024;

private Request() {}

Expand All @@ -104,7 +101,6 @@ public Request(StreamInput in) throws IOException {
waitForState = in.readEnum(AllocationStatus.State.class);
modelThreads = in.readVInt();
inferenceThreads = in.readVInt();
queueCapacity = in.readVInt();
}

public final void setModelId(String modelId) {
Expand Down Expand Up @@ -148,14 +144,6 @@ public void setInferenceThreads(int inferenceThreads) {
this.inferenceThreads = inferenceThreads;
}

public int getQueueCapacity() {
return queueCapacity;
}

public void setQueueCapacity(int queueCapacity) {
this.queueCapacity = queueCapacity;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
Expand All @@ -164,7 +152,6 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeEnum(waitForState);
out.writeVInt(modelThreads);
out.writeVInt(inferenceThreads);
out.writeVInt(queueCapacity);
}

@Override
Expand All @@ -175,7 +162,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(WAIT_FOR.getPreferredName(), waitForState);
builder.field(MODEL_THREADS.getPreferredName(), modelThreads);
builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads);
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
builder.endObject();
return builder;
}
Expand All @@ -197,15 +183,12 @@ public ActionRequestValidationException validate() {
if (inferenceThreads < 1) {
validationException.addValidationError("[" + INFERENCE_THREADS + "] must be a positive integer");
}
if (queueCapacity < 1 || queueCapacity > 10000) {
validationException.addValidationError("[" + QUEUE_CAPACITY + "] must be in [1, 10000]");
}
return validationException.validationErrors().isEmpty() ? null : validationException;
}

@Override
public int hashCode() {
return Objects.hash(modelId, timeout, waitForState, modelThreads, inferenceThreads, queueCapacity);
return Objects.hash(modelId, timeout, waitForState, modelThreads, inferenceThreads);
}

@Override
Expand All @@ -221,8 +204,7 @@ public boolean equals(Object obj) {
&& Objects.equals(timeout, other.timeout)
&& Objects.equals(waitForState, other.waitForState)
&& modelThreads == other.modelThreads
&& inferenceThreads == other.inferenceThreads
&& queueCapacity == other.queueCapacity;
&& inferenceThreads == other.inferenceThreads;
}

@Override
Expand All @@ -244,20 +226,16 @@ public static boolean mayAllocateToNode(DiscoveryNode node) {
private static final ParseField MODEL_BYTES = new ParseField("model_bytes");
public static final ParseField MODEL_THREADS = new ParseField("model_threads");
public static final ParseField INFERENCE_THREADS = new ParseField("inference_threads");
public static final ParseField QUEUE_CAPACITY = new ParseField("queue_capacity");

private static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
"trained_model_deployment_params",
true,
a -> new TaskParams((String)a[0], (Long)a[1], (int) a[2], (int) a[3], (int) a[4])
a -> new TaskParams((String)a[0], (Long)a[1], (int) a[2], (int) a[3])
);

static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID);
PARSER.declareLong(ConstructingObjectParser.constructorArg(), MODEL_BYTES);
PARSER.declareInt(ConstructingObjectParser.constructorArg(), INFERENCE_THREADS);
PARSER.declareInt(ConstructingObjectParser.constructorArg(), MODEL_THREADS);
PARSER.declareInt(ConstructingObjectParser.constructorArg(), QUEUE_CAPACITY);
}

public static TaskParams fromXContent(XContentParser parser) {
Expand All @@ -275,9 +253,8 @@ public static TaskParams fromXContent(XContentParser parser) {
private final long modelBytes;
private final int inferenceThreads;
private final int modelThreads;
private final int queueCapacity;

public TaskParams(String modelId, long modelBytes, int inferenceThreads, int modelThreads, int queueCapacity) {
public TaskParams(String modelId, long modelBytes, int inferenceThreads, int modelThreads) {
this.modelId = Objects.requireNonNull(modelId);
this.modelBytes = modelBytes;
if (modelBytes < 0) {
Expand All @@ -291,18 +268,13 @@ public TaskParams(String modelId, long modelBytes, int inferenceThreads, int mod
if (modelThreads < 1) {
throw new IllegalArgumentException(MODEL_THREADS + " must be positive");
}
this.queueCapacity = queueCapacity;
if (queueCapacity < 1 || queueCapacity > 10000) {
throw new IllegalArgumentException(QUEUE_CAPACITY + " must be in [1, 10000]");
}
}

public TaskParams(StreamInput in) throws IOException {
this.modelId = in.readString();
this.modelBytes = in.readVLong();
this.inferenceThreads = in.readVInt();
this.modelThreads = in.readVInt();
this.queueCapacity = in.readVInt();
}

public String getModelId() {
Expand All @@ -324,7 +296,6 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeVLong(modelBytes);
out.writeVInt(inferenceThreads);
out.writeVInt(modelThreads);
out.writeVInt(queueCapacity);
}

@Override
Expand All @@ -334,14 +305,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(MODEL_BYTES.getPreferredName(), modelBytes);
builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads);
builder.field(MODEL_THREADS.getPreferredName(), modelThreads);
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
builder.endObject();
return builder;
}

@Override
public int hashCode() {
return Objects.hash(modelId, modelBytes, inferenceThreads, modelThreads, queueCapacity);
return Objects.hash(modelId, modelBytes, inferenceThreads, modelThreads);
}

@Override
Expand All @@ -353,8 +323,7 @@ public boolean equals(Object o) {
return Objects.equals(modelId, other.modelId)
&& modelBytes == other.modelBytes
&& inferenceThreads == other.inferenceThreads
&& modelThreads == other.modelThreads
&& queueCapacity == other.queueCapacity;
&& modelThreads == other.modelThreads;
}

@Override
Expand All @@ -373,10 +342,6 @@ public int getInferenceThreads() {
public int getModelThreads() {
return modelThreads;
}

public int getQueueCapacity() {
return queueCapacity;
}
}

public interface TaskMatcher {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@ public class CreateTrainedModelAllocationActionRequestTests extends AbstractWire

@Override
protected Request createTestInstance() {
return new Request(StartTrainedModelDeploymentTaskParamsTests.createRandom());
return new Request(
new StartTrainedModelDeploymentAction.TaskParams(
randomAlphaOfLength(10),
randomNonNegativeLong(),
randomIntBetween(1, 8),
randomIntBetween(1, 8)
)
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import java.io.IOException;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
Expand Down Expand Up @@ -54,9 +53,6 @@ public static Request createRandom() {
if (randomBoolean()) {
request.setModelThreads(randomIntBetween(1, 8));
}
if (randomBoolean()) {
request.setQueueCapacity(randomIntBetween(1, 10000));
}
return request;
}

Expand Down Expand Up @@ -99,43 +95,4 @@ public void testValidate_GivenModelThreadsIsNegative() {
assertThat(e, is(not(nullValue())));
assertThat(e.getMessage(), containsString("[model_threads] must be a positive integer"));
}

public void testValidate_GivenQueueCapacityIsZero() {
Request request = createRandom();
request.setQueueCapacity(0);

ActionRequestValidationException e = request.validate();

assertThat(e, is(not(nullValue())));
assertThat(e.getMessage(), containsString("[queue_capacity] must be in [1, 10000]"));
}

public void testValidate_GivenQueueCapacityIsNegative() {
Request request = createRandom();
request.setQueueCapacity(randomIntBetween(Integer.MIN_VALUE, -1));

ActionRequestValidationException e = request.validate();

assertThat(e, is(not(nullValue())));
assertThat(e.getMessage(), containsString("[queue_capacity] must be in [1, 10000]"));
}

public void testValidate_GivenQueueCapacityIsGreaterThan10000() {
Request request = createRandom();
request.setQueueCapacity(randomIntBetween(10001, Integer.MAX_VALUE));

ActionRequestValidationException e = request.validate();

assertThat(e, is(not(nullValue())));
assertThat(e.getMessage(), containsString("[queue_capacity] must be in [1, 10000]"));
}

public void testDefaults() {
Request request = new Request(randomAlphaOfLength(10));
assertThat(request.getTimeout(), equalTo(TimeValue.timeValueSeconds(20)));
assertThat(request.getWaitForState(), equalTo(AllocationStatus.State.STARTED));
assertThat(request.getInferenceThreads(), equalTo(1));
assertThat(request.getModelThreads(), equalTo(1));
assertThat(request.getQueueCapacity(), equalTo(1024));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ public static StartTrainedModelDeploymentAction.TaskParams createRandom() {
randomAlphaOfLength(10),
randomNonNegativeLong(),
randomIntBetween(1, 8),
randomIntBetween(1, 8),
randomIntBetween(1, 10000)
randomIntBetween(1, 8)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentTaskParamsTests;

import java.io.IOException;
import java.util.List;
Expand All @@ -32,7 +31,9 @@
public class TrainedModelAllocationTests extends AbstractSerializingTestCase<TrainedModelAllocation> {

public static TrainedModelAllocation randomInstance() {
TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty(randomParams());
TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty(
new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 1, 1)
);
List<String> nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).collect(Collectors.toList());
for (String node : nodes) {
if (randomBoolean()) {
Expand Down Expand Up @@ -248,7 +249,7 @@ private static DiscoveryNode buildNode() {
}

private static StartTrainedModelDeploymentAction.TaskParams randomParams() {
return StartTrainedModelDeploymentTaskParamsTests.createRandom();
return new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 1, 1);
}

private static void assertUnchanged(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
Expand All @@ -34,7 +35,6 @@
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction;
Expand Down Expand Up @@ -161,8 +161,7 @@ protected void masterOperation(Task task, StartTrainedModelDeploymentAction.Requ
trainedModelConfig.getModelId(),
modelBytes,
request.getInferenceThreads(),
request.getModelThreads(),
request.getQueueCapacity()
request.getModelThreads()
);
PersistentTasksCustomMetadata persistentTasks = clusterService.state().getMetadata().custom(
PersistentTasksCustomMetadata.TYPE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,6 @@ public void onFailure(Exception e) {

@Override
protected void doRun() throws Exception {
logger.info("Request [{}] running", requestId);
final String requestIdStr = String.valueOf(requestId);
try {
// The request builder expect a list of inputs which are then batched.
Expand Down Expand Up @@ -393,11 +392,7 @@ class ProcessContext {
this.task = Objects.requireNonNull(task);
resultProcessor = new PyTorchResultProcessor(task.getModelId());
this.stateStreamer = new PyTorchStateStreamer(client, executorService, xContentRegistry);
this.executorService = new ProcessWorkerExecutorService(
threadPool.getThreadContext(),
"pytorch_inference",
task.getParams().getQueueCapacity()
);
this.executorService = new ProcessWorkerExecutorService(threadPool.getThreadContext(), "pytorch_inference", 1024);
}

PyTorchResultProcessor getResultProcessor() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import static org.elasticsearch.rest.RestRequest.Method.POST;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.INFERENCE_THREADS;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.MODEL_THREADS;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.QUEUE_CAPACITY;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.TIMEOUT;
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.WAIT_FOR;

Expand Down Expand Up @@ -60,7 +59,6 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
));
request.setInferenceThreads(restRequest.paramAsInt(INFERENCE_THREADS.getPreferredName(), request.getInferenceThreads()));
request.setModelThreads(restRequest.paramAsInt(MODEL_THREADS.getPreferredName(), request.getModelThreads()));
request.setQueueCapacity(restRequest.paramAsInt(QUEUE_CAPACITY.getPreferredName(), request.getQueueCapacity()));
}

return channel -> client.execute(StartTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,7 @@ private static DiscoveryNode buildOldNode(String name, boolean isML, long native
}

private static StartTrainedModelDeploymentAction.TaskParams newParams(String modelId, long modelSize) {
return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, 1, 1, 1024);
return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, 1, 1);
}

private static void assertNodeState(TrainedModelAllocationMetadata metadata, String modelId, String nodeId, RoutingState routingState) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ private static StartTrainedModelDeploymentAction.TaskParams randomParams(String
modelId,
randomNonNegativeLong(),
randomIntBetween(1, 8),
randomIntBetween(1, 8),
randomIntBetween(1, 10000)
randomIntBetween(1, 8)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ private void withSearchingLoadFailure(String modelId) {
}

private static StartTrainedModelDeploymentAction.TaskParams newParams(String modelId) {
return new StartTrainedModelDeploymentAction.TaskParams(modelId, randomNonNegativeLong(), 1, 1, 1024);
return new StartTrainedModelDeploymentAction.TaskParams(modelId, randomNonNegativeLong(), 1, 1);
}

private TrainedModelAllocationNodeService createService() {
Expand Down
Loading