Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validation for pq m parameter before training starts #1713

Merged
merged 15 commits into from
May 30, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Add KnnCircuitBreakerException and modify exception message [#1688](https://github.com/opensearch-project/k-NN/pull/1688)
* Add stats for radial search [#1684](https://github.com/opensearch-project/k-NN/pull/1684)
* Support script score when doc value is disabled and fix misusing DISI [#1696](https://github.com/opensearch-project/k-NN/pull/1696)
* Add validation for pq m parameter before training starts [#1713](https://github.com/opensearch-project/k-NN/pull/1713)
### Bug Fixes
* Block commas in model description [#1692](https://github.com/opensearch-project/k-NN/pull/1692)
### Infrastructure
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ public class KNNConstants {
public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT;

public static final String RADIAL_SEARCH_KEY = "radial_search";
public static final String INVALID_CODE_COUNT_ERROR_MESSAGE =
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
"The dimension of the vector is not a multiple of the number of subquantizers (m)";

// Lucene specific constants
public static final String LUCENE_NAME = "lucene";
Expand Down
13 changes: 13 additions & 0 deletions src/main/java/org/opensearch/knn/training/TrainingJob.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.common.UUIDs;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
Expand Down Expand Up @@ -115,6 +116,18 @@ public void run() {
NativeMemoryAllocation modelAnonymousAllocation = null;
ModelMetadata modelMetadata = model.getModelMetadata();

Map<String, Object> parameters = modelMetadata.getMethodComponentContext().getParameters();
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
if (parameters.get("encoder") instanceof MethodComponentContext) {
MethodComponentContext encoder = (MethodComponentContext) parameters.get("encoder");
Map<String, Object> encoderParameters = encoder.getParameters();
if (encoderParameters.get("m") instanceof Integer) {
int codeCount = (int) encoderParameters.get("m");
if (modelMetadata.getDimension() % codeCount != 0) {
throw new IllegalArgumentException(KNNConstants.INVALID_CODE_COUNT_ERROR_MESSAGE);
}
}
}

try {
// Get training data
trainingDataAllocation = nativeMemoryCacheManager.get(trainingDataEntryContext, false);
Expand Down
12 changes: 12 additions & 0 deletions src/main/java/org/opensearch/knn/training/TrainingJobRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.indices.Model;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
Expand Down Expand Up @@ -139,6 +140,17 @@ private void train(TrainingJob trainingJob) {
} catch (Exception e) {
logger.error("Unable to complete training for \"" + trainingJob.getModelId() + "\": " + e.getMessage());
KNNCounter.TRAINING_ERRORS.increment();
if (e.getMessage().equals(KNNConstants.INVALID_CODE_COUNT_ERROR_MESSAGE)) {
ModelMetadata modelMetadata = trainingJob.getModel().getModelMetadata();
modelMetadata.setState(ModelState.FAILED);
modelMetadata.setError(KNNConstants.INVALID_CODE_COUNT_ERROR_MESSAGE);

try {
serializeModel(trainingJob, loggingListener, true);
} catch (IOException | ExecutionException | InterruptedException ex) {
logger.error("Unable to serialize the failure for model \"{}\": ", trainingJob.getModelId(), ex);
}
}
} finally {
jobCount.decrementAndGet();
semaphore.release();
Expand Down
55 changes: 53 additions & 2 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.opensearch.client.Response;
import org.opensearch.common.settings.Settings;
import org.opensearch.client.ResponseException;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
Expand Down Expand Up @@ -66,6 +67,7 @@
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.MIN_SCORE;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.MODEL_ERROR;
import static org.opensearch.knn.common.KNNConstants.NAME;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;

Expand Down Expand Up @@ -1311,8 +1313,8 @@ public void testSharedIndexState_whenOneIndexDeleted_thenSecondIndexIsStillSearc
.startObject(METHOD_ENCODER_PARAMETER)
.field(NAME, ENCODER_PQ)
.startObject(PARAMETERS)
.field(ENCODER_PARAMETER_PQ_M, pqCodeSize)
.field(ENCODER_PARAMETER_PQ_CODE_SIZE, pqM)
.field(ENCODER_PARAMETER_PQ_M, pqM)
.field(ENCODER_PARAMETER_PQ_CODE_SIZE, pqCodeSize)
.endObject()
.endObject()
.endObject()
Expand Down Expand Up @@ -1648,6 +1650,55 @@ public void testFiltering_whenUsingFaissExactSearchWithIP_thenMatchExpectedScore
}
}

@SneakyThrows
public void testInvalidPQM_thenFail() {
String trainingIndexName = "training-index";
String trainingFieldName = "training-field";

String modelId = "test-model";
String modelDescription = "test model";

List<Integer> mValues = ImmutableList.of(16, 32, 64, 128);
int invalidPQM = 3;

// training data needs to be at least equal to the number of centroids for PQ
// which is 2^8 = 256. 8 because thats the only valid code_size for HNSWPQ
int trainingDataCount = 256;

SpaceType spaceType = SpaceType.L2;

Integer dimension = testData.indexData.vectors[0].length;

XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.field(NAME, METHOD_HNSW)
.field(KNN_ENGINE, FAISS_NAME)
.startObject(PARAMETERS)
.field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size())))
.startObject(METHOD_ENCODER_PARAMETER)
.field(NAME, ENCODER_PQ)
.startObject(PARAMETERS)
.field(ENCODER_PARAMETER_PQ_M, invalidPQM)
.endObject()
.endObject()
.endObject()
.endObject();
Map<String, Object> in = xContentBuilderToMap(xContentBuilder);

createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension);
ingestDataAndTrainModel(modelId, trainingIndexName, trainingFieldName, dimension, modelDescription, in, trainingDataCount);
assertTrainingFails(modelId, 360, 1000);

Response response = getModel(modelId, null);

Map<String, Object> responseMap = createParser(
MediaTypeRegistry.getDefaultMediaType().xContent(),
EntityUtils.toString(response.getEntity())
).map();

assertEquals(KNNConstants.INVALID_CODE_COUNT_ERROR_MESSAGE, responseMap.get(MODEL_ERROR));
}

protected void setupKNNIndexForFilterQuery() throws Exception {
// Create Mappings
XContentBuilder builder = XContentFactory.jsonBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@

import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.knn.common.KNNConstants.INVALID_CODE_COUNT_ERROR_MESSAGE;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME;
import static org.opensearch.knn.common.KNNConstants.TRAIN_THREAD_POOL;

Expand Down Expand Up @@ -146,4 +148,62 @@ public void testExecute_failure_rejected() throws IOException, InterruptedExcept
executorService.shutdown();
executorService.awaitTermination(10, TimeUnit.SECONDS);
}

@SuppressWarnings("unchecked")
public void testExecute_failure_invalidPQM() throws IOException, InterruptedException, ExecutionException {
// This test makes sure we handle the exception thrown by TrainingJob when an invalid m parameter is provided
// and update the model accordingly.

ExecutorService executorService = Executors.newSingleThreadExecutor();

TrainingJobRunner trainingJobRunner = TrainingJobRunner.getInstance();

ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.executor(TRAIN_THREAD_POOL)).thenReturn(executorService);

String modelId = "test-model-id";
Model model = mock(Model.class);
ModelMetadata modelMetadata = mock(ModelMetadata.class);
when(modelMetadata.getState()).thenReturn(ModelState.TRAINING);
when(model.getModelMetadata()).thenReturn(modelMetadata);
TrainingJob trainingJob = mock(TrainingJob.class);
when(trainingJob.getModelId()).thenReturn(modelId);
when(trainingJob.getModel()).thenReturn(model);
doThrow(new IllegalArgumentException(INVALID_CODE_COUNT_ERROR_MESSAGE)).when(trainingJob).run();

// This gets called right after the initial put, before training begins. Just check that the model id is
// equal
ActionListener<IndexResponse> responseListener = ActionListener.wrap(
indexResponse -> assertEquals(modelId, indexResponse.getId()),
e -> fail("Failure should not have occurred")
);

// After put finishes, it should call the onResponse function that will call responseListener and then kickoff
// training.
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.get(modelId)).thenReturn(model);
doAnswer(invocationOnMock -> {
assertEquals(1, trainingJobRunner.getJobCount()); // Make sure job count is correct
IndexResponse indexResponse = new IndexResponse(new ShardId(MODEL_INDEX_NAME, "uuid", 0), modelId, 0, 0, 0, true);
((ActionListener<IndexResponse>) invocationOnMock.getArguments()[1]).onResponse(indexResponse);
return null;
}).when(modelDao).put(any(Model.class), any(ActionListener.class));

// Function finishes when update is called
doAnswer(invocationOnMock -> null).when(modelDao).update(any(Model.class), any(ActionListener.class));

// Finally, initialize the singleton runner, execute the job.
TrainingJobRunner.initialize(threadPool, modelDao);
trainingJobRunner.execute(trainingJob, responseListener);

// Immediately, we shutdown the executor and await its termination.
executorService.shutdown();
executorService.awaitTermination(10, TimeUnit.SECONDS);

// Make sure these methods get called once
verify(trainingJob, times(1)).run();
verify(modelDao, times(1)).put(any(Model.class), any(ActionListener.class));
verify(modelDao, times(1)).update(any(Model.class), any(ActionListener.class));
verify(trainingJob, times(3)).getModel();
}
}
80 changes: 80 additions & 0 deletions src/test/java/org/opensearch/knn/training/TrainingJobTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PQ;
import static org.opensearch.knn.common.KNNConstants.INDEX_THREAD_QTY;
import static org.opensearch.knn.common.KNNConstants.INVALID_CODE_COUNT_ERROR_MESSAGE;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_IVF;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;

Expand Down Expand Up @@ -481,6 +485,82 @@ public void testRun_failure_notEnoughTrainingData() throws ExecutionException {
assertFalse(model.getModelMetadata().getError().isEmpty());
}

public void testRun_failure_invalidPQM() throws ExecutionException {
// In this test case, we ensure that failure happens gracefully when there isnt enough training data
String modelId = "test-model-id";

// Define the method setup for method that requires training
int dimension = 16;
int invalidPQM = 3;
KNNEngine knnEngine = KNNEngine.FAISS;

KNNMethodContext knnMethodContext = new KNNMethodContext(
knnEngine,
SpaceType.INNER_PRODUCT,
new MethodComponentContext(
METHOD_IVF,
ImmutableMap.of(
METHOD_ENCODER_PARAMETER,
new MethodComponentContext(ENCODER_PQ, ImmutableMap.of(ENCODER_PARAMETER_PQ_M, invalidPQM))
)
)
);
// Set up training data
int tdataPoints = 2;
float[][] trainingData = new float[tdataPoints][dimension];
fillFloatArrayRandomly(trainingData);
long memoryAddress = JNIService.transferVectors(0, trainingData);

// Setup model manager
NativeMemoryCacheManager nativeMemoryCacheManager = mock(NativeMemoryCacheManager.class);

// Setup mock allocation for model
NativeMemoryAllocation modelAllocation = mock(NativeMemoryAllocation.class);
doAnswer(invocationOnMock -> null).when(modelAllocation).readLock();
doAnswer(invocationOnMock -> null).when(modelAllocation).readUnlock();
when(modelAllocation.isClosed()).thenReturn(false);

String modelKey = "model-test-key";
NativeMemoryEntryContext.AnonymousEntryContext modelContext = mock(NativeMemoryEntryContext.AnonymousEntryContext.class);
when(modelContext.getKey()).thenReturn(modelKey);

when(nativeMemoryCacheManager.get(modelContext, false)).thenReturn(modelAllocation);
doAnswer(invocationOnMock -> null).when(nativeMemoryCacheManager).invalidate(modelKey);

// Setup mock allocation
NativeMemoryAllocation nativeMemoryAllocation = mock(NativeMemoryAllocation.class);
doAnswer(invocationOnMock -> null).when(nativeMemoryAllocation).readLock();
doAnswer(invocationOnMock -> null).when(nativeMemoryAllocation).readUnlock();
when(nativeMemoryAllocation.isClosed()).thenReturn(false);
when(nativeMemoryAllocation.getMemoryAddress()).thenReturn(memoryAddress);

String tdataKey = "t-data-key";
NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = mock(
NativeMemoryEntryContext.TrainingDataEntryContext.class
);
when(trainingDataEntryContext.getKey()).thenReturn(tdataKey);

when(nativeMemoryCacheManager.get(trainingDataEntryContext, false)).thenReturn(nativeMemoryAllocation);
doAnswer(invocationOnMock -> {
JNICommons.freeVectorData(memoryAddress);
return null;
}).when(nativeMemoryCacheManager).invalidate(tdataKey);

TrainingJob trainingJob = new TrainingJob(
modelId,
knnMethodContext,
nativeMemoryCacheManager,
trainingDataEntryContext,
modelContext,
dimension,
"",
"test-node"
);

IllegalArgumentException e = expectThrows(IllegalArgumentException.class, trainingJob::run);
assertEquals(INVALID_CODE_COUNT_ERROR_MESSAGE, e.getMessage());
}

private void fillFloatArrayRandomly(float[][] vectors) {
for (int i = 0; i < vectors.length; i++) {
for (int j = 0; j < vectors[i].length; j++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1430,7 +1430,7 @@ public void assertTrainingFails(String modelId, int attempts, int delayInMillis)
assertNotEquals(ModelState.CREATED, modelState);
}

fail("Training did not succeed after " + attempts + " attempts with a delay of " + delayInMillis + " ms.");
fail("Training did not fail after " + attempts + " attempts with a delay of " + delayInMillis + " ms.");
}

protected boolean systemIndexExists(final String indexName) throws IOException {
Expand Down
Loading