Skip to content

Commit bd0eb7a

Browse files
ryanbogangithub-actions[bot]
authored andcommitted
Add validation for pq m parameter before training starts (#1713)
* Add validation for pq code count before training starts Signed-off-by: Ryan Bogan <[email protected]> * Add integration test Signed-off-by: Ryan Bogan <[email protected]> * Add unit tests Signed-off-by: Ryan Bogan <[email protected]> * Clean up code Signed-off-by: Ryan Bogan <[email protected]> * Remove unnecessary lines Signed-off-by: Ryan Bogan <[email protected]> * Add changelog entry Signed-off-by: Ryan Bogan <[email protected]> * Change framework to add validation with data Signed-off-by: Ryan Bogan <[email protected]> * Remove unused error message Signed-off-by: Ryan Bogan <[email protected]> * Add unit tests Signed-off-by: Ryan Bogan <[email protected]> * Change space type check name for readability Signed-off-by: Ryan Bogan <[email protected]> * Add javadocs Signed-off-by: Ryan Bogan <[email protected]> * Modify validation error wording and add json structure to tests Signed-off-by: Ryan Bogan <[email protected]> * Change TrainingDataSpec to VectorSpaceInfo Signed-off-by: Ryan Bogan <[email protected]> * Add unit tests Signed-off-by: Ryan Bogan <[email protected]> --------- Signed-off-by: Ryan Bogan <[email protected]> (cherry picked from commit 3701d19)
1 parent 1ef0974 commit bd0eb7a

File tree

16 files changed

+599
-15
lines changed

16 files changed

+599
-15
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1919
* Add KnnCircuitBreakerException and modify exception message [#1688](https://github.com/opensearch-project/k-NN/pull/1688)
2020
* Add stats for radial search [#1684](https://github.com/opensearch-project/k-NN/pull/1684)
2121
* Support script score when doc value is disabled and fix misusing DISI [#1696](https://github.com/opensearch-project/k-NN/pull/1696)
22+
* Add validation for pq m parameter before training starts [#1713](https://github.com/opensearch-project/k-NN/pull/1713)
2223
### Bug Fixes
2324
* Block commas in model description [#1692](https://github.com/opensearch-project/k-NN/pull/1692)
2425
* Update threshold value after new result is added [#1715](https://github.com/opensearch-project/k-NN/pull/1715)

src/main/java/org/opensearch/knn/index/KNNMethod.java

+39-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import lombok.Getter;
1616
import org.opensearch.common.ValidationException;
1717
import org.opensearch.knn.common.KNNConstants;
18+
import org.opensearch.knn.training.VectorSpaceInfo;
1819

1920
import java.util.ArrayList;
2021
import java.util.Arrays;
@@ -41,7 +42,7 @@ public class KNNMethod {
4142
* @param space to be checked
4243
* @return true if the space is supported; false otherwise
4344
*/
44-
public boolean containsSpace(SpaceType space) {
45+
public boolean isSpaceTypeSupported(SpaceType space) {
4546
return spaces.contains(space);
4647
}
4748

@@ -53,7 +54,7 @@ public boolean containsSpace(SpaceType space) {
5354
*/
5455
public ValidationException validate(KNNMethodContext knnMethodContext) {
5556
List<String> errorMessages = new ArrayList<>();
56-
if (!containsSpace(knnMethodContext.getSpaceType())) {
57+
if (!isSpaceTypeSupported(knnMethodContext.getSpaceType())) {
5758
errorMessages.add(
5859
String.format(
5960
"\"%s\" configuration does not support space type: " + "\"%s\".",
@@ -77,6 +78,42 @@ public ValidationException validate(KNNMethodContext knnMethodContext) {
7778
return validationException;
7879
}
7980

81+
/**
82+
* Validate that the configured KNNMethodContext is valid for this method, using additional data not present in the method context
83+
*
84+
* @param knnMethodContext to be validated
85+
* @param vectorSpaceInfo additional data not present in the method context
86+
* @return ValidationException produced by validation errors; null if no validations errors.
87+
*/
88+
public ValidationException validateWithData(KNNMethodContext knnMethodContext, VectorSpaceInfo vectorSpaceInfo) {
89+
List<String> errorMessages = new ArrayList<>();
90+
if (!isSpaceTypeSupported(knnMethodContext.getSpaceType())) {
91+
errorMessages.add(
92+
String.format(
93+
"\"%s\" configuration does not support space type: " + "\"%s\".",
94+
this.methodComponent.getName(),
95+
knnMethodContext.getSpaceType().getValue()
96+
)
97+
);
98+
}
99+
100+
ValidationException methodValidation = methodComponent.validateWithData(
101+
knnMethodContext.getMethodComponentContext(),
102+
vectorSpaceInfo
103+
);
104+
if (methodValidation != null) {
105+
errorMessages.addAll(methodValidation.validationErrors());
106+
}
107+
108+
if (errorMessages.isEmpty()) {
109+
return null;
110+
}
111+
112+
ValidationException validationException = new ValidationException();
113+
validationException.addValidationErrors(errorMessages);
114+
return validationException;
115+
}
116+
80117
/**
81118
* returns whether training is required or not
82119
*

src/main/java/org/opensearch/knn/index/KNNMethodContext.java

+11
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import java.util.stream.Collectors;
3131
import org.apache.commons.lang.builder.EqualsBuilder;
3232
import org.apache.commons.lang.builder.HashCodeBuilder;
33+
import org.opensearch.knn.training.VectorSpaceInfo;
3334

3435
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
3536
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
@@ -86,6 +87,16 @@ public ValidationException validate() {
8687
return knnEngine.validateMethod(this);
8788
}
8889

90+
/**
91+
* This method uses the knnEngine to validate that the method is compatible with the engine, using additional data not present in the method context
92+
*
93+
* @param vectorSpaceInfo additional data not present in the method context
94+
* @return ValidationException produced by validation errors; null if no validations errors.
95+
*/
96+
public ValidationException validateWithData(VectorSpaceInfo vectorSpaceInfo) {
97+
return knnEngine.validateMethodWithData(this, vectorSpaceInfo);
98+
}
99+
89100
/**
90101
* This method returns whether training is requires or not from knnEngine
91102
*

src/main/java/org/opensearch/knn/index/MethodComponent.java

+38
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.opensearch.common.ValidationException;
1818
import org.opensearch.knn.common.KNNConstants;
1919
import org.opensearch.knn.index.util.IndexHyperParametersUtil;
20+
import org.opensearch.knn.training.VectorSpaceInfo;
2021

2122
import java.util.ArrayList;
2223
import java.util.HashMap;
@@ -102,6 +103,43 @@ public ValidationException validate(MethodComponentContext methodComponentContex
102103
return validationException;
103104
}
104105

106+
/**
107+
* Validate that the methodComponentContext is a valid configuration for this methodComponent, using additional data not present in the method component context
108+
*
109+
* @param methodComponentContext to be validated
110+
* @param vectorSpaceInfo additional data not present in the method component context
111+
* @return ValidationException produced by validation errors; null if no validations errors.
112+
*/
113+
public ValidationException validateWithData(MethodComponentContext methodComponentContext, VectorSpaceInfo vectorSpaceInfo) {
114+
Map<String, Object> providedParameters = methodComponentContext.getParameters();
115+
List<String> errorMessages = new ArrayList<>();
116+
117+
if (providedParameters == null) {
118+
return null;
119+
}
120+
121+
ValidationException parameterValidation;
122+
for (Map.Entry<String, Object> parameter : providedParameters.entrySet()) {
123+
if (!parameters.containsKey(parameter.getKey())) {
124+
errorMessages.add(String.format("Invalid parameter for method \"%s\".", getName()));
125+
continue;
126+
}
127+
128+
parameterValidation = parameters.get(parameter.getKey()).validateWithData(parameter.getValue(), vectorSpaceInfo);
129+
if (parameterValidation != null) {
130+
errorMessages.addAll(parameterValidation.validationErrors());
131+
}
132+
}
133+
134+
if (errorMessages.isEmpty()) {
135+
return null;
136+
}
137+
138+
ValidationException validationException = new ValidationException();
139+
validationException.addValidationErrors(errorMessages);
140+
return validationException;
141+
}
142+
105143
/**
106144
* gets requiresTraining value
107145
*

src/main/java/org/opensearch/knn/index/Parameter.java

+148-1
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
package org.opensearch.knn.index;
1313

1414
import org.opensearch.common.ValidationException;
15+
import org.opensearch.knn.training.VectorSpaceInfo;
1516

1617
import java.util.Map;
18+
import java.util.function.BiFunction;
1719
import java.util.function.Predicate;
1820

1921
/**
@@ -26,6 +28,7 @@ public abstract class Parameter<T> {
2628
private String name;
2729
private T defaultValue;
2830
protected Predicate<T> validator;
31+
protected BiFunction<T, VectorSpaceInfo, Boolean> validatorWithData;
2932

3033
/**
3134
* Constructor
@@ -38,6 +41,14 @@ public Parameter(String name, T defaultValue, Predicate<T> validator) {
3841
this.name = name;
3942
this.defaultValue = defaultValue;
4043
this.validator = validator;
44+
this.validatorWithData = null;
45+
}
46+
47+
public Parameter(String name, T defaultValue, Predicate<T> validator, BiFunction<T, VectorSpaceInfo, Boolean> validatorWithData) {
48+
this.name = name;
49+
this.defaultValue = defaultValue;
50+
this.validator = validator;
51+
this.validatorWithData = validatorWithData;
4152
}
4253

4354
/**
@@ -66,6 +77,15 @@ public T getDefaultValue() {
6677
*/
6778
public abstract ValidationException validate(Object value);
6879

80+
/**
81+
* Check if the value passed in is valid, using additional data not present in the value
82+
*
83+
* @param value to be checked
84+
* @param vectorSpaceInfo additional data not present in the value
85+
* @return ValidationException produced by validation errors; null if no validations errors.
86+
*/
87+
public abstract ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo);
88+
6989
/**
7090
* Boolean method parameter
7191
*/
@@ -74,12 +94,23 @@ public BooleanParameter(String name, Boolean defaultValue, Predicate<Boolean> va
7494
super(name, defaultValue, validator);
7595
}
7696

97+
public BooleanParameter(
98+
String name,
99+
Boolean defaultValue,
100+
Predicate<Boolean> validator,
101+
BiFunction<Boolean, VectorSpaceInfo, Boolean> validatorWithData
102+
) {
103+
super(name, defaultValue, validator, validatorWithData);
104+
}
105+
77106
@Override
78107
public ValidationException validate(Object value) {
79108
ValidationException validationException = null;
80109
if (!(value instanceof Boolean)) {
81110
validationException = new ValidationException();
82-
validationException.addValidationError(String.format("value not of type Boolean for Boolean parameter [%s].", getName()));
111+
validationException.addValidationError(
112+
String.format("value is not an instance of Boolean for Boolean parameter [%s].", getName())
113+
);
83114
return validationException;
84115
}
85116

@@ -89,6 +120,27 @@ public ValidationException validate(Object value) {
89120
}
90121
return validationException;
91122
}
123+
124+
@Override
125+
public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) {
126+
ValidationException validationException = null;
127+
if (!(value instanceof Boolean)) {
128+
validationException = new ValidationException();
129+
validationException.addValidationError(String.format("value not of type Boolean for Boolean parameter [%s].", getName()));
130+
return validationException;
131+
}
132+
133+
if (validatorWithData == null) {
134+
return null;
135+
}
136+
137+
if (!validatorWithData.apply((Boolean) value, vectorSpaceInfo)) {
138+
validationException = new ValidationException();
139+
validationException.addValidationError(String.format("parameter validation failed for Boolean parameter [%s].", getName()));
140+
}
141+
142+
return validationException;
143+
}
92144
}
93145

94146
/**
@@ -99,6 +151,15 @@ public IntegerParameter(String name, Integer defaultValue, Predicate<Integer> va
99151
super(name, defaultValue, validator);
100152
}
101153

154+
public IntegerParameter(
155+
String name,
156+
Integer defaultValue,
157+
Predicate<Integer> validator,
158+
BiFunction<Integer, VectorSpaceInfo, Boolean> validatorWithData
159+
) {
160+
super(name, defaultValue, validator, validatorWithData);
161+
}
162+
102163
@Override
103164
public ValidationException validate(Object value) {
104165
ValidationException validationException = null;
@@ -118,6 +179,29 @@ public ValidationException validate(Object value) {
118179
}
119180
return validationException;
120181
}
182+
183+
@Override
184+
public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) {
185+
ValidationException validationException = null;
186+
if (!(value instanceof Integer)) {
187+
validationException = new ValidationException();
188+
validationException.addValidationError(
189+
String.format("value is not an instance of Integer for Integer parameter [%s].", getName())
190+
);
191+
return validationException;
192+
}
193+
194+
if (validatorWithData == null) {
195+
return null;
196+
}
197+
198+
if (!validatorWithData.apply((Integer) value, vectorSpaceInfo)) {
199+
validationException = new ValidationException();
200+
validationException.addValidationError(String.format("parameter validation failed for Integer parameter [%s].", getName()));
201+
}
202+
203+
return validationException;
204+
}
121205
}
122206

123207
/**
@@ -136,6 +220,15 @@ public StringParameter(String name, String defaultValue, Predicate<String> valid
136220
super(name, defaultValue, validator);
137221
}
138222

223+
public StringParameter(
224+
String name,
225+
String defaultValue,
226+
Predicate<String> validator,
227+
BiFunction<String, VectorSpaceInfo, Boolean> validatorWithData
228+
) {
229+
super(name, defaultValue, validator, validatorWithData);
230+
}
231+
139232
/**
140233
* Check if the value passed in is valid
141234
*
@@ -161,6 +254,29 @@ public ValidationException validate(Object value) {
161254
}
162255
return validationException;
163256
}
257+
258+
@Override
259+
public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) {
260+
ValidationException validationException = null;
261+
if (!(value instanceof String)) {
262+
validationException = new ValidationException();
263+
validationException.addValidationError(
264+
String.format("value is not an instance of String for String parameter [%s].", getName())
265+
);
266+
return validationException;
267+
}
268+
269+
if (validatorWithData == null) {
270+
return null;
271+
}
272+
273+
if (!validatorWithData.apply((String) value, vectorSpaceInfo)) {
274+
validationException = new ValidationException();
275+
validationException.addValidationError(String.format("parameter validation failed for String parameter [%s].", getName()));
276+
}
277+
278+
return validationException;
279+
}
164280
}
165281

166282
/**
@@ -190,6 +306,12 @@ public MethodComponentContextParameter(
190306
}
191307

192308
return methodComponents.get(methodComponentContext.getName()).validate(methodComponentContext) == null;
309+
}, (methodComponentContext, vectorSpaceInfo) -> {
310+
if (!methodComponents.containsKey(methodComponentContext.getName())) {
311+
return false;
312+
}
313+
return methodComponents.get(methodComponentContext.getName())
314+
.validateWithData(methodComponentContext, vectorSpaceInfo) == null;
193315
});
194316
this.methodComponents = methodComponents;
195317
}
@@ -216,6 +338,31 @@ public ValidationException validate(Object value) {
216338
return validationException;
217339
}
218340

341+
@Override
342+
public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) {
343+
ValidationException validationException = null;
344+
if (!(value instanceof MethodComponentContext)) {
345+
validationException = new ValidationException();
346+
validationException.addValidationError(
347+
String.format("value is not an instance of for MethodComponentContext parameter [%s].", getName())
348+
);
349+
return validationException;
350+
}
351+
352+
if (validatorWithData == null) {
353+
return null;
354+
}
355+
356+
if (!validatorWithData.apply((MethodComponentContext) value, vectorSpaceInfo)) {
357+
validationException = new ValidationException();
358+
validationException.addValidationError(
359+
String.format("parameter validation failed for MethodComponentContext parameter [%s].", getName())
360+
);
361+
}
362+
363+
return validationException;
364+
}
365+
219366
/**
220367
* Get method component by name
221368
*

0 commit comments

Comments
 (0)