Skip to content

Commit 011775a

Browse files
authored
Block commas in model description (#1692)
* Block commas in model description Signed-off-by: Ryan Bogan <[email protected]> * Add changelog entry Signed-off-by: Ryan Bogan <[email protected]> * Add check in rest handler Signed-off-by: Ryan Bogan <[email protected]> * Extract if statement into ModelUtil method Signed-off-by: Ryan Bogan <[email protected]> * Remove ingestion from integ test Signed-off-by: Ryan Bogan <[email protected]> --------- Signed-off-by: Ryan Bogan <[email protected]>
1 parent 73d5425 commit 011775a

File tree

6 files changed

+118
-14
lines changed

6 files changed

+118
-14
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1616
### Features
1717
### Enhancements
1818
### Bug Fixes
19+
* Block commas in model description [#1692](https://github.com/opensearch-project/k-NN/pull/1692)
1920
### Infrastructure
2021
### Documentation
2122
### Maintenance

src/main/java/org/opensearch/knn/indices/ModelMetadata.java

+2
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ public ModelMetadata(StreamInput in) throws IOException {
6767
// Description and error may be empty. However, reading the string will work as long as they are not null
6868
// which is checked in constructor and setters
6969
this.description = in.readString();
70+
ModelUtil.blockCommasInModelDescription(this.description);
7071
this.error = in.readString();
7172

7273
if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), IndexUtil.MODEL_NODE_ASSIGNMENT_KEY)) {
@@ -123,6 +124,7 @@ public ModelMetadata(
123124
this.state = new AtomicReference<>(Objects.requireNonNull(modelState, "modelState must not be null"));
124125
this.timestamp = Objects.requireNonNull(timestamp, "timestamp must not be null");
125126
this.description = Objects.requireNonNull(description, "description must not be null");
127+
ModelUtil.blockCommasInModelDescription(this.description);
126128
this.error = Objects.requireNonNull(error, "error must not be null");
127129
this.trainingNodeAssignment = Objects.requireNonNull(trainingNodeAssignment, "node assignment must not be null");
128130
this.methodComponentContext = Objects.requireNonNull(methodComponentContext, "method context must not be null");

src/main/java/org/opensearch/knn/indices/ModelUtil.java

+6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616
*/
1717
public class ModelUtil {
1818

19+
public static void blockCommasInModelDescription(String description) {
20+
if (description.contains(",")) {
21+
throw new IllegalArgumentException("Model description cannot contain any commas: ','");
22+
}
23+
}
24+
1925
public static boolean isModelPresent(ModelMetadata modelMetadata) {
2026
return modelMetadata != null;
2127
}

src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.opensearch.core.xcontent.XContentParser;
1717
import org.opensearch.index.mapper.NumberFieldMapper;
1818
import org.opensearch.knn.index.KNNMethodContext;
19+
import org.opensearch.knn.indices.ModelUtil;
1920
import org.opensearch.knn.plugin.KNNPlugin;
2021
import org.opensearch.knn.plugin.transport.TrainingJobRouterAction;
2122
import org.opensearch.knn.plugin.transport.TrainingModelRequest;
@@ -104,6 +105,7 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr
104105
searchSize = (Integer) NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false);
105106
} else if (MODEL_DESCRIPTION.equals(fieldName) && ensureNotSet(fieldName, description)) {
106107
description = parser.textOrNull();
108+
ModelUtil.blockCommasInModelDescription(description);
107109
} else {
108110
throw new IllegalArgumentException("Unable to parse token. \"" + fieldName + "\" is not a valid " + "parameter.");
109111
}

src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java

+46-14
Original file line numberDiff line numberDiff line change
@@ -608,20 +608,7 @@ public void testFromResponseMap() throws IOException {
608608
String description = "test-description";
609609
String error = "test-error";
610610
String nodeAssignment = "test-node";
611-
Map<String, Object> nestedParameters = new HashMap<String, Object>() {
612-
{
613-
put("testNestedKey1", "testNestedString");
614-
put("testNestedKey2", 1);
615-
}
616-
};
617-
Map<String, Object> parameters = new HashMap<>() {
618-
{
619-
put("testKey1", "testString");
620-
put("testKey2", 0);
621-
put("testKey3", new MethodComponentContext("ivf", nestedParameters));
622-
}
623-
};
624-
MethodComponentContext methodComponentContext = new MethodComponentContext("hnsw", parameters);
611+
MethodComponentContext methodComponentContext = getMethodComponentContext();
625612
MethodComponentContext emptyMethodComponentContext = MethodComponentContext.EMPTY;
626613

627614
ModelMetadata expected = new ModelMetadata(
@@ -667,6 +654,51 @@ public void testFromResponseMap() throws IOException {
667654
metadataAsMap.put(KNNConstants.MODEL_NODE_ASSIGNMENT, null);
668655
metadataAsMap.put(KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT, null);
669656
assertEquals(expected2, fromMap);
657+
}
658+
659+
public void testBlockCommasInDescription() {
660+
KNNEngine knnEngine = KNNEngine.DEFAULT;
661+
SpaceType spaceType = SpaceType.L2;
662+
int dimension = 128;
663+
ModelState modelState = ModelState.TRAINING;
664+
String timestamp = ZonedDateTime.now(ZoneOffset.UTC).toString();
665+
String description = "Test, comma, description";
666+
String error = "test-error";
667+
String nodeAssignment = "test-node";
668+
MethodComponentContext methodComponentContext = getMethodComponentContext();
669+
670+
Exception e = expectThrows(
671+
IllegalArgumentException.class,
672+
() -> new ModelMetadata(
673+
knnEngine,
674+
spaceType,
675+
dimension,
676+
modelState,
677+
timestamp,
678+
description,
679+
error,
680+
nodeAssignment,
681+
methodComponentContext
682+
)
683+
);
684+
assertEquals("Model description cannot contain any commas: ','", e.getMessage());
685+
}
670686

687+
private static MethodComponentContext getMethodComponentContext() {
688+
Map<String, Object> nestedParameters = new HashMap<String, Object>() {
689+
{
690+
put("testNestedKey1", "testNestedString");
691+
put("testNestedKey2", 1);
692+
}
693+
};
694+
Map<String, Object> parameters = new HashMap<>() {
695+
{
696+
put("testKey1", "testString");
697+
put("testKey2", 0);
698+
put("testKey3", new MethodComponentContext("ivf", nestedParameters));
699+
}
700+
};
701+
MethodComponentContext methodComponentContext = new MethodComponentContext("hnsw", parameters);
702+
return methodComponentContext;
671703
}
672704
}

src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java

+61
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import org.apache.hc.core5.http.io.entity.EntityUtils;
1515
import org.opensearch.client.Response;
16+
import org.opensearch.client.ResponseException;
1617
import org.opensearch.core.xcontent.XContentBuilder;
1718
import org.opensearch.common.xcontent.XContentFactory;
1819
import org.opensearch.core.xcontent.MediaTypeRegistry;
@@ -192,6 +193,66 @@ public void testTrainModel_fail_tooMuchData() throws Exception {
192193
assertTrainingFails(modelId, 30, 1000);
193194
}
194195

196+
public void testTrainModel_fail_commaInDescription() throws Exception {
197+
// Test checks that training when passing in an id succeeds
198+
199+
String modelId = "test-model-id";
200+
String trainingIndexName = "train-index";
201+
String trainingFieldName = "train-field";
202+
int dimension = 8;
203+
204+
// Create a training index and randomly ingest data into it
205+
createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension);
206+
207+
// Call the train API with this definition:
208+
/*
209+
{
210+
"training_index": "train_index",
211+
"training_field": "train_field",
212+
"dimension": 8,
213+
"description": "this should be allowed to be null",
214+
"method": {
215+
"name":"ivf",
216+
"engine":"faiss",
217+
"space_type": "l2",
218+
"parameters":{
219+
"nlist":1,
220+
"encoder":{
221+
"name":"pq",
222+
"parameters":{
223+
"code_size":2,
224+
"m": 2
225+
}
226+
}
227+
}
228+
}
229+
}
230+
*/
231+
XContentBuilder builder = XContentFactory.jsonBuilder()
232+
.startObject()
233+
.field(NAME, "ivf")
234+
.field(KNN_ENGINE, "faiss")
235+
.field(METHOD_PARAMETER_SPACE_TYPE, "l2")
236+
.startObject(PARAMETERS)
237+
.field(METHOD_PARAMETER_NLIST, 1)
238+
.startObject(METHOD_ENCODER_PARAMETER)
239+
.field(NAME, "pq")
240+
.startObject(PARAMETERS)
241+
.field(ENCODER_PARAMETER_PQ_CODE_SIZE, 2)
242+
.field(ENCODER_PARAMETER_PQ_M, 2)
243+
.endObject()
244+
.endObject()
245+
.endObject()
246+
.endObject();
247+
Map<String, Object> method = xContentBuilderToMap(builder);
248+
249+
Exception e = expectThrows(
250+
ResponseException.class,
251+
() -> trainModel(modelId, trainingIndexName, trainingFieldName, dimension, method, "dummy description, with comma")
252+
);
253+
assertTrue(e.getMessage().contains("Model description cannot contain any commas: ','"));
254+
}
255+
195256
public void testTrainModel_success_withId() throws Exception {
196257
// Test checks that training when passing in an id succeeds
197258

0 commit comments

Comments
 (0)