Skip to content

Commit 40b8053

Browse files
authored
Revert "[ML] Add queue_capacity setting to start deployment API (#79369)"
This reverts commit 637a299.
1 parent 637a299 commit 40b8053

File tree

12 files changed

+27
-108
lines changed

12 files changed

+27
-108
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java

Lines changed: 6 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ public static class Request extends MasterNodeRequest<Request> implements ToXCon
6060
public static final ParseField WAIT_FOR = new ParseField("wait_for");
6161
public static final ParseField INFERENCE_THREADS = TaskParams.INFERENCE_THREADS;
6262
public static final ParseField MODEL_THREADS = TaskParams.MODEL_THREADS;
63-
public static final ParseField QUEUE_CAPACITY = TaskParams.QUEUE_CAPACITY;
6463

6564
public static final ObjectParser<Request, Void> PARSER = new ObjectParser<>(NAME, Request::new);
6665

@@ -70,7 +69,6 @@ public static class Request extends MasterNodeRequest<Request> implements ToXCon
7069
PARSER.declareString((request, waitFor) -> request.setWaitForState(AllocationStatus.State.fromString(waitFor)), WAIT_FOR);
7170
PARSER.declareInt(Request::setInferenceThreads, INFERENCE_THREADS);
7271
PARSER.declareInt(Request::setModelThreads, MODEL_THREADS);
73-
PARSER.declareInt(Request::setQueueCapacity, QUEUE_CAPACITY);
7472
}
7573

7674
public static Request parseRequest(String modelId, XContentParser parser) {
@@ -89,7 +87,6 @@ public static Request parseRequest(String modelId, XContentParser parser) {
8987
private AllocationStatus.State waitForState = AllocationStatus.State.STARTED;
9088
private int modelThreads = 1;
9189
private int inferenceThreads = 1;
92-
private int queueCapacity = 1024;
9390

9491
private Request() {}
9592

@@ -104,7 +101,6 @@ public Request(StreamInput in) throws IOException {
104101
waitForState = in.readEnum(AllocationStatus.State.class);
105102
modelThreads = in.readVInt();
106103
inferenceThreads = in.readVInt();
107-
queueCapacity = in.readVInt();
108104
}
109105

110106
public final void setModelId(String modelId) {
@@ -148,14 +144,6 @@ public void setInferenceThreads(int inferenceThreads) {
148144
this.inferenceThreads = inferenceThreads;
149145
}
150146

151-
public int getQueueCapacity() {
152-
return queueCapacity;
153-
}
154-
155-
public void setQueueCapacity(int queueCapacity) {
156-
this.queueCapacity = queueCapacity;
157-
}
158-
159147
@Override
160148
public void writeTo(StreamOutput out) throws IOException {
161149
super.writeTo(out);
@@ -164,7 +152,6 @@ public void writeTo(StreamOutput out) throws IOException {
164152
out.writeEnum(waitForState);
165153
out.writeVInt(modelThreads);
166154
out.writeVInt(inferenceThreads);
167-
out.writeVInt(queueCapacity);
168155
}
169156

170157
@Override
@@ -175,7 +162,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
175162
builder.field(WAIT_FOR.getPreferredName(), waitForState);
176163
builder.field(MODEL_THREADS.getPreferredName(), modelThreads);
177164
builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads);
178-
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
179165
builder.endObject();
180166
return builder;
181167
}
@@ -197,15 +183,12 @@ public ActionRequestValidationException validate() {
197183
if (inferenceThreads < 1) {
198184
validationException.addValidationError("[" + INFERENCE_THREADS + "] must be a positive integer");
199185
}
200-
if (queueCapacity < 1 || queueCapacity > 10000) {
201-
validationException.addValidationError("[" + QUEUE_CAPACITY + "] must be in [1, 10000]");
202-
}
203186
return validationException.validationErrors().isEmpty() ? null : validationException;
204187
}
205188

206189
@Override
207190
public int hashCode() {
208-
return Objects.hash(modelId, timeout, waitForState, modelThreads, inferenceThreads, queueCapacity);
191+
return Objects.hash(modelId, timeout, waitForState, modelThreads, inferenceThreads);
209192
}
210193

211194
@Override
@@ -221,8 +204,7 @@ public boolean equals(Object obj) {
221204
&& Objects.equals(timeout, other.timeout)
222205
&& Objects.equals(waitForState, other.waitForState)
223206
&& modelThreads == other.modelThreads
224-
&& inferenceThreads == other.inferenceThreads
225-
&& queueCapacity == other.queueCapacity;
207+
&& inferenceThreads == other.inferenceThreads;
226208
}
227209

228210
@Override
@@ -244,20 +226,16 @@ public static boolean mayAllocateToNode(DiscoveryNode node) {
244226
private static final ParseField MODEL_BYTES = new ParseField("model_bytes");
245227
public static final ParseField MODEL_THREADS = new ParseField("model_threads");
246228
public static final ParseField INFERENCE_THREADS = new ParseField("inference_threads");
247-
public static final ParseField QUEUE_CAPACITY = new ParseField("queue_capacity");
248-
249229
private static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
250230
"trained_model_deployment_params",
251231
true,
252-
a -> new TaskParams((String)a[0], (Long)a[1], (int) a[2], (int) a[3], (int) a[4])
232+
a -> new TaskParams((String)a[0], (Long)a[1], (int) a[2], (int) a[3])
253233
);
254-
255234
static {
256235
PARSER.declareString(ConstructingObjectParser.constructorArg(), TrainedModelConfig.MODEL_ID);
257236
PARSER.declareLong(ConstructingObjectParser.constructorArg(), MODEL_BYTES);
258237
PARSER.declareInt(ConstructingObjectParser.constructorArg(), INFERENCE_THREADS);
259238
PARSER.declareInt(ConstructingObjectParser.constructorArg(), MODEL_THREADS);
260-
PARSER.declareInt(ConstructingObjectParser.constructorArg(), QUEUE_CAPACITY);
261239
}
262240

263241
public static TaskParams fromXContent(XContentParser parser) {
@@ -275,9 +253,8 @@ public static TaskParams fromXContent(XContentParser parser) {
275253
private final long modelBytes;
276254
private final int inferenceThreads;
277255
private final int modelThreads;
278-
private final int queueCapacity;
279256

280-
public TaskParams(String modelId, long modelBytes, int inferenceThreads, int modelThreads, int queueCapacity) {
257+
public TaskParams(String modelId, long modelBytes, int inferenceThreads, int modelThreads) {
281258
this.modelId = Objects.requireNonNull(modelId);
282259
this.modelBytes = modelBytes;
283260
if (modelBytes < 0) {
@@ -291,18 +268,13 @@ public TaskParams(String modelId, long modelBytes, int inferenceThreads, int mod
291268
if (modelThreads < 1) {
292269
throw new IllegalArgumentException(MODEL_THREADS + " must be positive");
293270
}
294-
this.queueCapacity = queueCapacity;
295-
if (queueCapacity < 1 || queueCapacity > 10000) {
296-
throw new IllegalArgumentException(QUEUE_CAPACITY + " must be in [1, 10000]");
297-
}
298271
}
299272

300273
public TaskParams(StreamInput in) throws IOException {
301274
this.modelId = in.readString();
302275
this.modelBytes = in.readVLong();
303276
this.inferenceThreads = in.readVInt();
304277
this.modelThreads = in.readVInt();
305-
this.queueCapacity = in.readVInt();
306278
}
307279

308280
public String getModelId() {
@@ -324,7 +296,6 @@ public void writeTo(StreamOutput out) throws IOException {
324296
out.writeVLong(modelBytes);
325297
out.writeVInt(inferenceThreads);
326298
out.writeVInt(modelThreads);
327-
out.writeVInt(queueCapacity);
328299
}
329300

330301
@Override
@@ -334,14 +305,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
334305
builder.field(MODEL_BYTES.getPreferredName(), modelBytes);
335306
builder.field(INFERENCE_THREADS.getPreferredName(), inferenceThreads);
336307
builder.field(MODEL_THREADS.getPreferredName(), modelThreads);
337-
builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity);
338308
builder.endObject();
339309
return builder;
340310
}
341311

342312
@Override
343313
public int hashCode() {
344-
return Objects.hash(modelId, modelBytes, inferenceThreads, modelThreads, queueCapacity);
314+
return Objects.hash(modelId, modelBytes, inferenceThreads, modelThreads);
345315
}
346316

347317
@Override
@@ -353,8 +323,7 @@ public boolean equals(Object o) {
353323
return Objects.equals(modelId, other.modelId)
354324
&& modelBytes == other.modelBytes
355325
&& inferenceThreads == other.inferenceThreads
356-
&& modelThreads == other.modelThreads
357-
&& queueCapacity == other.queueCapacity;
326+
&& modelThreads == other.modelThreads;
358327
}
359328

360329
@Override
@@ -373,10 +342,6 @@ public int getInferenceThreads() {
373342
public int getModelThreads() {
374343
return modelThreads;
375344
}
376-
377-
public int getQueueCapacity() {
378-
return queueCapacity;
379-
}
380345
}
381346

382347
public interface TaskMatcher {

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAllocationActionRequestTests.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,14 @@ public class CreateTrainedModelAllocationActionRequestTests extends AbstractWire
1414

1515
@Override
1616
protected Request createTestInstance() {
17-
return new Request(StartTrainedModelDeploymentTaskParamsTests.createRandom());
17+
return new Request(
18+
new StartTrainedModelDeploymentAction.TaskParams(
19+
randomAlphaOfLength(10),
20+
randomNonNegativeLong(),
21+
randomIntBetween(1, 8),
22+
randomIntBetween(1, 8)
23+
)
24+
);
1825
}
1926

2027
@Override

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import java.io.IOException;
1919

2020
import static org.hamcrest.Matchers.containsString;
21-
import static org.hamcrest.Matchers.equalTo;
2221
import static org.hamcrest.Matchers.is;
2322
import static org.hamcrest.Matchers.not;
2423
import static org.hamcrest.Matchers.nullValue;
@@ -54,9 +53,6 @@ public static Request createRandom() {
5453
if (randomBoolean()) {
5554
request.setModelThreads(randomIntBetween(1, 8));
5655
}
57-
if (randomBoolean()) {
58-
request.setQueueCapacity(randomIntBetween(1, 10000));
59-
}
6056
return request;
6157
}
6258

@@ -99,43 +95,4 @@ public void testValidate_GivenModelThreadsIsNegative() {
9995
assertThat(e, is(not(nullValue())));
10096
assertThat(e.getMessage(), containsString("[model_threads] must be a positive integer"));
10197
}
102-
103-
public void testValidate_GivenQueueCapacityIsZero() {
104-
Request request = createRandom();
105-
request.setQueueCapacity(0);
106-
107-
ActionRequestValidationException e = request.validate();
108-
109-
assertThat(e, is(not(nullValue())));
110-
assertThat(e.getMessage(), containsString("[queue_capacity] must be in [1, 10000]"));
111-
}
112-
113-
public void testValidate_GivenQueueCapacityIsNegative() {
114-
Request request = createRandom();
115-
request.setQueueCapacity(randomIntBetween(Integer.MIN_VALUE, -1));
116-
117-
ActionRequestValidationException e = request.validate();
118-
119-
assertThat(e, is(not(nullValue())));
120-
assertThat(e.getMessage(), containsString("[queue_capacity] must be in [1, 10000]"));
121-
}
122-
123-
public void testValidate_GivenQueueCapacityIsGreaterThan10000() {
124-
Request request = createRandom();
125-
request.setQueueCapacity(randomIntBetween(10001, Integer.MAX_VALUE));
126-
127-
ActionRequestValidationException e = request.validate();
128-
129-
assertThat(e, is(not(nullValue())));
130-
assertThat(e.getMessage(), containsString("[queue_capacity] must be in [1, 10000]"));
131-
}
132-
133-
public void testDefaults() {
134-
Request request = new Request(randomAlphaOfLength(10));
135-
assertThat(request.getTimeout(), equalTo(TimeValue.timeValueSeconds(20)));
136-
assertThat(request.getWaitForState(), equalTo(AllocationStatus.State.STARTED));
137-
assertThat(request.getInferenceThreads(), equalTo(1));
138-
assertThat(request.getModelThreads(), equalTo(1));
139-
assertThat(request.getQueueCapacity(), equalTo(1024));
140-
}
14198
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ public static StartTrainedModelDeploymentAction.TaskParams createRandom() {
3636
randomAlphaOfLength(10),
3737
randomNonNegativeLong(),
3838
randomIntBetween(1, 8),
39-
randomIntBetween(1, 8),
40-
randomIntBetween(1, 10000)
39+
randomIntBetween(1, 8)
4140
);
4241
}
4342
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/TrainedModelAllocationTests.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@
1313
import org.elasticsearch.cluster.node.DiscoveryNode;
1414
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
1515
import org.elasticsearch.common.io.stream.Writeable;
16-
import org.elasticsearch.test.AbstractSerializingTestCase;
1716
import org.elasticsearch.xcontent.XContentParser;
17+
import org.elasticsearch.test.AbstractSerializingTestCase;
1818
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
19-
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentTaskParamsTests;
2019

2120
import java.io.IOException;
2221
import java.util.List;
@@ -32,7 +31,9 @@
3231
public class TrainedModelAllocationTests extends AbstractSerializingTestCase<TrainedModelAllocation> {
3332

3433
public static TrainedModelAllocation randomInstance() {
35-
TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty(randomParams());
34+
TrainedModelAllocation.Builder builder = TrainedModelAllocation.Builder.empty(
35+
new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 1, 1)
36+
);
3637
List<String> nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).collect(Collectors.toList());
3738
for (String node : nodes) {
3839
if (randomBoolean()) {
@@ -248,7 +249,7 @@ private static DiscoveryNode buildNode() {
248249
}
249250

250251
private static StartTrainedModelDeploymentAction.TaskParams randomParams() {
251-
return StartTrainedModelDeploymentTaskParamsTests.createRandom();
252+
return new StartTrainedModelDeploymentAction.TaskParams(randomAlphaOfLength(10), randomNonNegativeLong(), 1, 1);
252253
}
253254

254255
private static void assertUnchanged(

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.elasticsearch.cluster.service.ClusterService;
2727
import org.elasticsearch.common.inject.Inject;
2828
import org.elasticsearch.common.settings.Settings;
29+
import org.elasticsearch.xcontent.NamedXContentRegistry;
2930
import org.elasticsearch.core.TimeValue;
3031
import org.elasticsearch.license.LicenseUtils;
3132
import org.elasticsearch.license.XPackLicenseState;
@@ -34,7 +35,6 @@
3435
import org.elasticsearch.tasks.Task;
3536
import org.elasticsearch.threadpool.ThreadPool;
3637
import org.elasticsearch.transport.TransportService;
37-
import org.elasticsearch.xcontent.NamedXContentRegistry;
3838
import org.elasticsearch.xpack.core.XPackField;
3939
import org.elasticsearch.xpack.core.ml.MachineLearningField;
4040
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction;
@@ -161,8 +161,7 @@ protected void masterOperation(Task task, StartTrainedModelDeploymentAction.Requ
161161
trainedModelConfig.getModelId(),
162162
modelBytes,
163163
request.getInferenceThreads(),
164-
request.getModelThreads(),
165-
request.getQueueCapacity()
164+
request.getModelThreads()
166165
);
167166
PersistentTasksCustomMetadata persistentTasks = clusterService.state().getMetadata().custom(
168167
PersistentTasksCustomMetadata.TYPE);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,6 @@ public void onFailure(Exception e) {
307307

308308
@Override
309309
protected void doRun() throws Exception {
310-
logger.info("Request [{}] running", requestId);
311310
final String requestIdStr = String.valueOf(requestId);
312311
try {
313312
// The request builder expect a list of inputs which are then batched.
@@ -393,11 +392,7 @@ class ProcessContext {
393392
this.task = Objects.requireNonNull(task);
394393
resultProcessor = new PyTorchResultProcessor(task.getModelId());
395394
this.stateStreamer = new PyTorchStateStreamer(client, executorService, xContentRegistry);
396-
this.executorService = new ProcessWorkerExecutorService(
397-
threadPool.getThreadContext(),
398-
"pytorch_inference",
399-
task.getParams().getQueueCapacity()
400-
);
395+
this.executorService = new ProcessWorkerExecutorService(threadPool.getThreadContext(), "pytorch_inference", 1024);
401396
}
402397

403398
PyTorchResultProcessor getResultProcessor() {

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import static org.elasticsearch.rest.RestRequest.Method.POST;
2424
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.INFERENCE_THREADS;
2525
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.MODEL_THREADS;
26-
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.QUEUE_CAPACITY;
2726
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.TIMEOUT;
2827
import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.WAIT_FOR;
2928

@@ -60,7 +59,6 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
6059
));
6160
request.setInferenceThreads(restRequest.paramAsInt(INFERENCE_THREADS.getPreferredName(), request.getInferenceThreads()));
6261
request.setModelThreads(restRequest.paramAsInt(MODEL_THREADS.getPreferredName(), request.getModelThreads()));
63-
request.setQueueCapacity(restRequest.paramAsInt(QUEUE_CAPACITY.getPreferredName(), request.getQueueCapacity()));
6462
}
6563

6664
return channel -> client.execute(StartTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel));

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterServiceTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -940,7 +940,7 @@ private static DiscoveryNode buildOldNode(String name, boolean isML, long native
940940
}
941941

942942
private static StartTrainedModelDeploymentAction.TaskParams newParams(String modelId, long modelSize) {
943-
return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, 1, 1, 1024);
943+
return new StartTrainedModelDeploymentAction.TaskParams(modelId, modelSize, 1, 1);
944944
}
945945

946946
private static void assertNodeState(TrainedModelAllocationMetadata metadata, String modelId, String nodeId, RoutingState routingState) {

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationMetadataTests.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,7 @@ private static StartTrainedModelDeploymentAction.TaskParams randomParams(String
9999
modelId,
100100
randomNonNegativeLong(),
101101
randomIntBetween(1, 8),
102-
randomIntBetween(1, 8),
103-
randomIntBetween(1, 10000)
102+
randomIntBetween(1, 8)
104103
);
105104
}
106105

0 commit comments

Comments
 (0)