Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/java-rest/high-level/ml/put-trained-model.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ include::../execution.asciidoc[]
==== Response

The returned +{response}+ contains the newly created trained model.
The +{response}+ will omit the model definition as a precaution against
streaming large model definitions back to the client.

["source","java",subs="attributes,callouts,macros"]
--------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
// We don't store the definition in the same document as the configuration
if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) {
if (params.paramAsBoolean(DECOMPRESS_DEFINITION, true)) {
if (params.paramAsBoolean(DECOMPRESS_DEFINITION, false)) {
builder.field(DEFINITION.getPreferredName(), definition);
} else {
builder.field(COMPRESSED_DEFINITION.getPreferredName(), definition.getCompressedString());
Expand Down Expand Up @@ -371,6 +371,9 @@ public Builder(TrainedModelConfig config) {
this.tags = config.getTags();
this.metadata = config.getMetadata();
this.input = config.getInput();
this.estimatedOperations = config.estimatedOperations;
this.estimatedHeapMemory = config.estimatedHeapMemory;
this.licenseLevel = config.licenseLevel.description();
}

public Builder setModelId(String modelId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ public void testGetTrainedModels() throws IOException {
assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\""));
assertThat(response, containsString("\"estimated_heap_memory_usage\""));
assertThat(response, containsString("\"definition\""));
assertThat(response, not(containsString("\"compressed_definition\"")));
assertThat(response, containsString("\"count\":1"));

getModel = client().performRequest(new Request("GET",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ protected void masterOperation(Task task,

ActionListener<Void> tagsModelIdCheckListener = ActionListener.wrap(
r -> trainedModelProvider.storeTrainedModel(trainedModelConfig, ActionListener.wrap(
storedConfig -> listener.onResponse(new PutTrainedModelAction.Response(trainedModelConfig)),
bool -> {
TrainedModelConfig configToReturn = new TrainedModelConfig.Builder(trainedModelConfig).clearDefinition().build();
listener.onResponse(new PutTrainedModelAction.Response(configToReturn));
},
listener::onFailure
)),
listener::onFailure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.BytesRestResponse;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.action.RestToXContentListener;
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
Expand All @@ -18,7 +24,9 @@

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static java.util.Arrays.asList;
Expand All @@ -34,6 +42,8 @@ public List<Route> routes() {
new Route(GET, MachineLearning.BASE_PATH + "inference"));
}

private static final Map<String, String> DEFAULT_TO_XCONTENT_VALUES =
Collections.singletonMap(TrainedModelConfig.DECOMPRESS_DEFINITION, Boolean.toString(true));
@Override
public String getName() {
return "ml_get_trained_models_action";
Expand All @@ -56,12 +66,33 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));
}
request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources()));
return channel -> client.execute(GetTrainedModelsAction.INSTANCE, request, new RestToXContentListener<>(channel));
return channel -> client.execute(GetTrainedModelsAction.INSTANCE,
request,
new RestToXContentListenerWithDefaultValues<>(channel, DEFAULT_TO_XCONTENT_VALUES));
}

@Override
protected Set<String> responseParams() {
return Collections.singleton(TrainedModelConfig.DECOMPRESS_DEFINITION);
}

private static class RestToXContentListenerWithDefaultValues<T extends ToXContentObject> extends RestToXContentListener<T> {
private final Map<String, String> defaultToXContentParamValues;

private RestToXContentListenerWithDefaultValues(RestChannel channel, Map<String, String> defaultToXContentParamValues) {
super(channel);
this.defaultToXContentParamValues = defaultToXContentParamValues;
}

@Override
public RestResponse buildResponse(T response, XContentBuilder builder) throws Exception {
assert response.isFragment() == false; //would be nice if we could make default methods final
Map<String, String> params = new HashMap<>(channel.request().params());
defaultToXContentParamValues.forEach((k, v) ->
params.computeIfAbsent(k, defaultToXContentParamValues::get)
);
response.toXContent(builder, new ToXContent.MapParams(params));
return new BytesRestResponse(getStatus(response), builder);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -460,3 +460,53 @@ setup:
}
}
}
---
"Test put model":
- do:
ml.put_trained_model:
model_id: my-regression-model
body: >
{
"description": "model for tests",
"input": {"field_names": ["field1", "field2"]},
"definition": {
"preprocessors": [],
"trained_model": {
"ensemble": {
"target_type": "regression",
"trained_models": [
{
"tree": {
"feature_names": ["field1", "field2"],
"tree_structure": [
{"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2},
{"node_index": 1, "leaf_value": 0},
{"node_index": 2, "leaf_value": 1}
],
"target_type": "regression"
}
},
{
"tree": {
"feature_names": ["field1", "field2"],
"tree_structure": [
{"node_index": 0, "threshold": 2, "left_child": 1, "right_child": 2},
{"node_index": 1, "leaf_value": 0},
{"node_index": 2, "leaf_value": 1}
],
"target_type": "regression"
}
}
]
}
}
}
}
- match: { model_id: my-regression-model }
- match: { estimated_operations: 6 }
- is_false: definition
- is_false: compressed_definition
- is_true: license_level
- is_true: create_time
- is_true: version
- is_true: estimated_heap_memory_usage_bytes