Skip to content

Commit 3a5fe64

Browse files
authored
Fix few bugs on binary index with Faiss HNSW (#1850)
Signed-off-by: Heemin Kim <[email protected]>
1 parent 881364f commit 3a5fe64

File tree

11 files changed

+169
-10
lines changed

11 files changed

+169
-10
lines changed

src/main/java/org/opensearch/knn/index/IndexUtil.java

+13
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,19 @@ public static ValidationException validateKnnField(
190190
return exception;
191191
}
192192

193+
String vectorDataType = (String) fieldMap.get(VECTOR_DATA_TYPE_FIELD);
194+
if (VectorDataType.BINARY.toString().equalsIgnoreCase(vectorDataType)) {
195+
exception.addValidationError(
196+
String.format(
197+
Locale.ROOT,
198+
"Field \"%s\" is of data type %s. Only FLOAT or BYTE is supported.",
199+
field,
200+
VectorDataType.BINARY
201+
)
202+
);
203+
return exception;
204+
}
205+
193206
// Return if dimension does not need to be checked
194207
if (expectedDimension < 0) {
195208
return null;

src/main/java/org/opensearch/knn/index/KNNMethod.java

+7-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.HashMap;
2323
import java.util.HashSet;
2424
import java.util.List;
25+
import java.util.Locale;
2526
import java.util.Map;
2627
import java.util.Set;
2728

@@ -57,8 +58,10 @@ public ValidationException validate(KNNMethodContext knnMethodContext) {
5758
if (!isSpaceTypeSupported(knnMethodContext.getSpaceType())) {
5859
errorMessages.add(
5960
String.format(
60-
"\"%s\" configuration does not support space type: " + "\"%s\".",
61+
Locale.ROOT,
62+
"\"%s\" with \"%s\" configuration does not support space type: " + "\"%s\".",
6163
this.methodComponent.getName(),
64+
knnMethodContext.getKnnEngine().getName().toLowerCase(Locale.ROOT),
6265
knnMethodContext.getSpaceType().getValue()
6366
)
6467
);
@@ -90,8 +93,10 @@ public ValidationException validateWithData(KNNMethodContext knnMethodContext, V
9093
if (!isSpaceTypeSupported(knnMethodContext.getSpaceType())) {
9194
errorMessages.add(
9295
String.format(
93-
"\"%s\" configuration does not support space type: " + "\"%s\".",
96+
Locale.ROOT,
97+
"\"%s\" with \"%s\" configuration does not support space type: " + "\"%s\".",
9498
this.methodComponent.getName(),
99+
knnMethodContext.getKnnEngine().getName().toLowerCase(Locale.ROOT),
95100
knnMethodContext.getSpaceType().getValue()
96101
)
97102
);

src/main/java/org/opensearch/knn/index/KNNSettings.java

+1
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ public class KNNSettings {
8686
*/
8787
public static final boolean KNN_DEFAULT_FAISS_AVX2_DISABLED_VALUE = false;
8888
public static final String INDEX_KNN_DEFAULT_SPACE_TYPE = "l2";
89+
public static final String INDEX_KNN_DEFAULT_SPACE_TYPE_FOR_BINARY = "hammingbit";
8990
public static final Integer INDEX_KNN_DEFAULT_ALGO_PARAM_M = 16;
9091
public static final Integer INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH = 100;
9192
public static final Integer INDEX_KNN_DEFAULT_ALGO_PARAM_EF_CONSTRUCTION = 100;

src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ public KNNVectorFieldMapper build(BuilderContext context) {
307307

308308
// Build legacy
309309
if (this.spaceType == null) {
310-
this.spaceType = LegacyFieldMapper.getSpaceType(context.indexSettings());
310+
this.spaceType = LegacyFieldMapper.getSpaceType(context.indexSettings(), vectorDataType.getValue());
311311
}
312312

313313
if (this.m == null) {

src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java

+6-3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.opensearch.common.settings.Settings;
1313
import org.opensearch.index.mapper.ParametrizedFieldMapper;
1414
import org.opensearch.knn.index.KNNSettings;
15+
import org.opensearch.knn.index.VectorDataType;
1516
import org.opensearch.knn.index.util.IndexHyperParametersUtil;
1617
import org.opensearch.knn.index.util.KNNEngine;
1718

@@ -78,17 +79,19 @@ public ParametrizedFieldMapper.Builder getMergeBuilder() {
7879
);
7980
}
8081

81-
static String getSpaceType(Settings indexSettings) {
82+
static String getSpaceType(final Settings indexSettings, final VectorDataType vectorDataType) {
8283
String spaceType = indexSettings.get(KNNSettings.INDEX_KNN_SPACE_TYPE.getKey());
8384
if (spaceType == null) {
85+
spaceType = VectorDataType.BINARY == vectorDataType
86+
? KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE_FOR_BINARY
87+
: KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE;
8488
log.info(
8589
String.format(
8690
"[KNN] The setting \"%s\" was not set for the index. Likely caused by recent version upgrade. Setting the setting to the default value=%s",
8791
METHOD_PARAMETER_SPACE_TYPE,
88-
KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE
92+
spaceType
8993
)
9094
);
91-
return KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE;
9295
}
9396
return spaceType;
9497
}

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

+3
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,9 @@ protected Query doToQuery(QueryShardContext context) {
622622
String.format(Locale.ROOT, "Engine [%s] does not support radial search", knnEngine)
623623
);
624624
}
625+
if (vectorDataType == VectorDataType.BINARY) {
626+
throw new UnsupportedOperationException(String.format(Locale.ROOT, "Binary data type does not support radial search"));
627+
}
625628
RNNQueryFactory.CreateQueryRequest createQueryRequest = RNNQueryFactory.CreateQueryRequest.builder()
626629
.knnEngine(knnEngine)
627630
.indexName(indexName)

src/main/java/org/opensearch/knn/index/query/KNNWeight.java

+8-1
Original file line numberDiff line numberDiff line change
@@ -476,8 +476,15 @@ private boolean canDoExactSearch(final int filterIdsCount) {
476476
if (isExactSearchThresholdSettingSet(filterThresholdValue)) {
477477
return filterThresholdValue >= filterIdsCount;
478478
}
479+
479480
// if no setting is set, then use the default max distance computation value to see if we can do exact search.
480-
return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * knnQuery.getQueryVector().length;
481+
/**
482+
* TODO we can have a different MAX_DISTANCE_COMPUTATIONS for binary index as computation cost for binary index
483+
* is cheaper than computation cost for non binary vector
484+
*/
485+
return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * (knnQuery.getVectorDataType() == VectorDataType.FLOAT
486+
? knnQuery.getQueryVector().length
487+
: knnQuery.getByteQueryVector().length);
481488
}
482489

483490
/**

src/test/java/org/opensearch/knn/index/IndexUtilTests.java

+21
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,27 @@ public void testValidateKnnField_EmptyIndexMetadata() {
228228
assert (Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Invalid index. Index does not contain a mapping;"));
229229
}
230230

231+
public void testValidateKnnField_whenBinaryDataType_thenThrowException() {
232+
Map<String, Object> fieldValues = Map.of("type", "knn_vector", "dimension", 8, "data_type", "BINARY");
233+
Map<String, Object> top_level_field = Map.of("top_level_field", fieldValues);
234+
Map<String, Object> properties = Map.of("properties", top_level_field);
235+
String field = "top_level_field";
236+
int dimension = 8;
237+
238+
MappingMetadata mappingMetadata = mock(MappingMetadata.class);
239+
when(mappingMetadata.getSourceAsMap()).thenReturn(properties);
240+
IndexMetadata indexMetadata = mock(IndexMetadata.class);
241+
when(indexMetadata.mapping()).thenReturn(mappingMetadata);
242+
ModelDao modelDao = mock(ModelDao.class);
243+
ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class);
244+
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
245+
when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata);
246+
247+
ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao);
248+
249+
assert (Objects.requireNonNull(e).getMessage().contains("is of data type BINARY. Only FLOAT or BYTE is supported"));
250+
}
251+
231252
public void testIsShareableStateContainedInIndex_whenIndexNotModelBased_thenReturnFalse() {
232253
String modelId = null;
233254
KNNEngine knnEngine = KNNEngine.FAISS;

src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java

+17-3
Original file line numberDiff line numberDiff line change
@@ -192,24 +192,38 @@ public void testBuilder_build_fromLegacy() {
192192
ModelDao modelDao = mock(ModelDao.class);
193193
KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT);
194194

195-
SpaceType spaceType = SpaceType.COSINESIMIL;
196195
int m = 17;
197196
int efConstruction = 17;
198197

199198
// Setup settings
200199
Settings settings = Settings.builder()
201200
.put(settings(CURRENT).build())
202-
.put(KNNSettings.KNN_SPACE_TYPE, spaceType.getValue())
203201
.put(KNNSettings.KNN_ALGO_PARAM_M, m)
204202
.put(KNNSettings.KNN_ALGO_PARAM_EF_CONSTRUCTION, efConstruction)
205203
.build();
206204

207205
Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath());
208206
KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext);
209207
assertTrue(knnVectorFieldMapper instanceof LegacyFieldMapper);
210-
211208
assertNull(knnVectorFieldMapper.modelId);
212209
assertNull(knnVectorFieldMapper.knnMethod);
210+
assertEquals(SpaceType.L2.getValue(), ((LegacyFieldMapper) knnVectorFieldMapper).spaceType);
211+
}
212+
213+
public void testBuilder_whenKnnFalseWithBinary_thenSetHammingAsDefault() {
214+
// Check legacy is picked up if model context and method context are not set
215+
ModelDao modelDao = mock(ModelDao.class);
216+
KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT);
217+
builder.vectorDataType.setValue(VectorDataType.BINARY);
218+
builder.dimension.setValue(8);
219+
220+
// Setup settings
221+
Settings settings = Settings.builder().put(settings(CURRENT).build()).build();
222+
223+
Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath());
224+
KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext);
225+
assertTrue(knnVectorFieldMapper instanceof LegacyFieldMapper);
226+
assertEquals(SpaceType.HAMMING_BIT.getValue(), ((LegacyFieldMapper) knnVectorFieldMapper).spaceType);
213227
}
214228

215229
public void testBuilder_parse_fromKnnMethodContext_luceneEngine() throws IOException {

src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java

+24
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,30 @@ public void testDoToQuery_whenPassNegativeDistance_whenUnSupportedSpaceType_then
718718
expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
719719
}
720720

721+
public void testDoToQuery_whenRadialSearchOnBinaryIndex_thenException() {
722+
float[] queryVector = { 1.0f };
723+
KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder()
724+
.fieldName(FIELD_NAME)
725+
.vector(queryVector)
726+
.maxDistance(MAX_DISTANCE)
727+
.build();
728+
Index dummyIndex = new Index("dummy", "dummy");
729+
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
730+
KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
731+
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
732+
when(mockKNNVectorField.getDimension()).thenReturn(8);
733+
when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BINARY);
734+
when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField);
735+
MethodComponentContext methodComponentContext = new MethodComponentContext(
736+
org.opensearch.knn.common.KNNConstants.METHOD_HNSW,
737+
ImmutableMap.of()
738+
);
739+
KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.HAMMING_BIT, methodComponentContext);
740+
when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext);
741+
Exception e = expectThrows(UnsupportedOperationException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
742+
assertTrue(e.getMessage().contains("Binary data type does not support radial search"));
743+
}
744+
721745
public void testDoToQuery_KnnQueryWithFilter_Lucene() throws Exception {
722746
// Given
723747
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };

src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java

+68
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,74 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSetting_thenSucces
863863
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
864864
}
865865

866+
/**
867+
* This test ensure that we do the exact search when threshold settings are correct and not using filteredIds<=K
868+
* condition to do exact search on binary index
869+
* FilteredIdThreshold: 10
870+
* FilteredIdThresholdPct: 10%
871+
* FilteredIdsCount: 6
872+
* liveDocs : null, as there is no deleted documents
873+
* MaxDoc: 100
874+
* K : 1
875+
*/
876+
@SneakyThrows
877+
public void testANNWithFilterQuery_whenExactSearchViaThresholdSettingOnBinaryIndex_thenSuccess() {
878+
knnSettingsMockedStatic.when(() -> KNNSettings.getFilteredExactSearchThreshold(INDEX_NAME)).thenReturn(10);
879+
byte[] vector = new byte[] { 1, 3 };
880+
int k = 1;
881+
final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 };
882+
883+
final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
884+
final SegmentReader reader = mock(SegmentReader.class);
885+
when(leafReaderContext.reader()).thenReturn(reader);
886+
when(reader.maxDoc()).thenReturn(100);
887+
when(reader.getLiveDocs()).thenReturn(null);
888+
final Weight filterQueryWeight = mock(Weight.class);
889+
final Scorer filterScorer = mock(Scorer.class);
890+
when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer);
891+
892+
when(filterScorer.iterator()).thenReturn(DocIdSetIterator.all(filterDocIds.length));
893+
894+
final KNNQuery query = new KNNQuery(FIELD_NAME, BYTE_QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY, null, VectorDataType.BINARY);
895+
896+
final float boost = (float) randomDoubleBetween(0, 10, true);
897+
final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight);
898+
final Map<String, String> attributesMap = ImmutableMap.of(
899+
KNN_ENGINE,
900+
KNNEngine.FAISS.getName(),
901+
SPACE_TYPE,
902+
SpaceType.HAMMING_BIT.name(),
903+
PARAMETERS,
904+
String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "BHNSW32")
905+
);
906+
final FieldInfos fieldInfos = mock(FieldInfos.class);
907+
final FieldInfo fieldInfo = mock(FieldInfo.class);
908+
final BinaryDocValues binaryDocValues = mock(BinaryDocValues.class);
909+
when(reader.getFieldInfos()).thenReturn(fieldInfos);
910+
when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo);
911+
when(fieldInfo.attributes()).thenReturn(attributesMap);
912+
when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(SpaceType.HAMMING_BIT.getValue());
913+
when(fieldInfo.getName()).thenReturn(FIELD_NAME);
914+
when(reader.getBinaryDocValues(FIELD_NAME)).thenReturn(binaryDocValues);
915+
when(binaryDocValues.advance(0)).thenReturn(0);
916+
BytesRef vectorByteRef = new BytesRef(vector);
917+
when(binaryDocValues.binaryValue()).thenReturn(vectorByteRef);
918+
919+
final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext);
920+
assertNotNull(knnScorer);
921+
final DocIdSetIterator docIdSetIterator = knnScorer.iterator();
922+
assertNotNull(docIdSetIterator);
923+
assertEquals(EXACT_SEARCH_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost());
924+
925+
final List<Integer> actualDocIds = new ArrayList<>();
926+
for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) {
927+
actualDocIds.add(docId);
928+
assertEquals(BINARY_EXACT_SEARCH_DOC_ID_TO_SCORES.get(docId) * boost, knnScorer.score(), 0.01f);
929+
}
930+
assertEquals(docIdSetIterator.cost(), actualDocIds.size());
931+
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
932+
}
933+
866934
@SneakyThrows
867935
public void testANNWithFilterQuery_whenEmptyFilterIds_thenReturnEarly() {
868936
final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);

0 commit comments

Comments
 (0)