Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/137883.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 137883
summary: Addressing vector similarity concurrency issue with byte vectors
area: ES|QL
type: bug
issues:
- 137625
23 changes: 0 additions & 23 deletions muted-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -417,10 +417,6 @@ tests:
- class: org.elasticsearch.xpack.ilm.CCRIndexLifecycleIT
method: testTsdbLeaderIndexRolloverAndSyncAfterWaitUntilEndTime {targetCluster=FOLLOWER}
issue: https://github.com/elastic/elasticsearch/issues/137565
- class: org.elasticsearch.xpack.esql.vector.VectorSimilarityFunctionsIT
method: testTopNSimilarityBetweenConstantVectorAndField {functionName=v_l2_norm
similarityFunction=org.elasticsearch.xpack.esql.expression.function.vector.L2Norm$1@5b068087 elementType=byte}
issue: https://github.com/elastic/elasticsearch/issues/137625
- class: org.elasticsearch.xpack.esql.qa.single_node.GenerativeMetricsIT
method: test
issue: https://github.com/elastic/elasticsearch/issues/137655
Expand All @@ -436,14 +432,6 @@ tests:
- class: org.elasticsearch.indices.mapping.UpdateMappingIntegrationIT
method: testUpdateMappingConcurrently
issue: https://github.com/elastic/elasticsearch/issues/137758
- class: org.elasticsearch.xpack.esql.vector.VectorSimilarityFunctionsIT
method: testSimilarityWithOneDimVector {functionName=v_cosine
similarityFunction=org.elasticsearch.xpack.esql.expression.function.vector.CosineSimilarity$1@70ab2d48 elementType=byte}
issue: https://github.com/elastic/elasticsearch/issues/137774
- class: org.elasticsearch.xpack.esql.vector.VectorSimilarityFunctionsIT
method: testSimilarityWithOneDimVector {functionName=v_cosine
similarityFunction=org.elasticsearch.xpack.esql.expression.function.vector.CosineSimilarity$1@5b068087 elementType=byte}
issue: https://github.com/elastic/elasticsearch/issues/137778
- class: org.elasticsearch.xpack.inference.integration.CCMPersistentStorageServiceIT
method: testDelete_RemovesCCMConfiguration
issue: https://github.com/elastic/elasticsearch/issues/137786
Expand All @@ -453,10 +441,6 @@ tests:
- class: org.elasticsearch.xpack.inference.integration.CCMServiceIT
method: testIsEnabled_ReturnsTrue_WhenCCMConfigurationIsPresent
issue: https://github.com/elastic/elasticsearch/issues/137798
- class: org.elasticsearch.xpack.esql.vector.VectorSimilarityFunctionsIT
method: testSimilarityWithOneDimVector {functionName=v_cosine
similarityFunction=org.elasticsearch.xpack.esql.expression.function.vector.CosineSimilarity$1@3300f4fd elementType=byte}
issue: https://github.com/elastic/elasticsearch/issues/137812
- class: org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceTests
method: testChangingCapacity_DoesNotRejectsOverflowTasks_BecauseOfQueueFull
issue: https://github.com/elastic/elasticsearch/issues/137823
Expand All @@ -466,16 +450,9 @@ tests:
- class: org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorMultipleNodesIT
method: testAuthorizationTaskGetsRelocatedToAnotherNode_WhenTheNodeThatIsRunningItShutsDown
issue: https://github.com/elastic/elasticsearch/issues/137911
- class: org.elasticsearch.xpack.esql.vector.VectorSimilarityFunctionsIT
method: testSimilarityWithOneDimVector {functionName=v_cosine
similarityFunction=org.elasticsearch.xpack.esql.expression.function.vector.CosineSimilarity$1@ebb6851 elementType=byte}
issue: https://github.com/elastic/elasticsearch/issues/137915
- class: org.elasticsearch.xpack.esql.qa.multi_node.GenerativeIT
method: test
issue: https://github.com/elastic/elasticsearch/issues/137909
- class: org.elasticsearch.xpack.esql.vector.VectorSimilarityFunctionsIT
method: testSimilarityWithOneDimVector {functionName=v_cosine similarityFunction=V_COSINE elementType=byte}
issue: https://github.com/elastic/elasticsearch/issues/137975
- class: org.elasticsearch.search.TelemetryMetrics.ShardSearchPhaseAPMMetricsTests
method: testTimeRangeFilterAllResults
issue: https://github.com/elastic/elasticsearch/issues/137979
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2937,9 +2937,6 @@ public BlockLoader blockLoader(MappedFieldType.BlockLoaderContext blContext) {
return switch (cfg.function()) {
case V_COSINE, V_DOT_PRODUCT, V_HAMMING, V_L1NORM, V_L2NORM -> {
VectorSimilarityFunctionConfig similarityConfig = (VectorSimilarityFunctionConfig) cfg;
if (getElementType() == ElementType.BYTE || getElementType() == ElementType.BIT) {
similarityConfig = similarityConfig.forByteVector();
}
yield new DenseVectorBlockLoader<>(
name(),
dims,
Expand Down Expand Up @@ -3474,22 +3471,18 @@ public static class VectorSimilarityFunctionConfig implements BlockLoaderFunctio

private final SimilarityFunction similarityFunction;
private final float[] vector;
private byte[] vectorAsBytes;
private final byte[] vectorAsBytes;

public VectorSimilarityFunctionConfig(SimilarityFunction similarityFunction, float[] vector) {
this.similarityFunction = similarityFunction;
this.vector = vector;
this.vectorAsBytes = null;
}

/**
* Call before calculating byte vector similarities
*/
public VectorSimilarityFunctionConfig forByteVector() {
vectorAsBytes = new byte[vector.length];
for (int i = 0; i < vector.length; i++) {
vectorAsBytes[i] = (byte) vector[i];
}
return this;
public VectorSimilarityFunctionConfig(SimilarityFunction similarityFunction, byte[] vectorAsBytes) {
this.similarityFunction = similarityFunction;
this.vector = null;
this.vectorAsBytes = vectorAsBytes;
}

@Override
Expand All @@ -3498,11 +3491,12 @@ public Function function() {
}

public byte[] vectorAsBytes() {
assert vectorAsBytes != null : "vectorAsBytes is null, call forByteVector() first";
assert vectorAsBytes != null : "vectorAsBytes is null, maybe incorrect element type during construction?";
return vectorAsBytes;
}

public float[] vector() {
assert vector != null : "vector is null, maybe incorrect element type during construction?";
return vector;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase {
public static Iterable<Object[]> parameters() throws Exception {
List<Object[]> params = new ArrayList<>();

for (ElementType elementType : Set.of(ElementType.FLOAT, ElementType.BYTE)) {
for (ElementType elementType : Set.of(ElementType.FLOAT, ElementType.BYTE, ElementType.BIT)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bit requires dims % 8 == 0, idk if that will cause flakiness here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When setting up the docs we have:

        for (int j = 0; j < numDims; j++) {
            switch (elementType) {
                case FLOAT -> vector.add(randomFloat());
                case BYTE, BIT -> vector.add((byte) randomIntBetween(-128, 127));

So, I think that while dims might be a bit misleading (as we're adding a byte for every dim) and also given that we don't specify dims in the mapping, we should probably be ok. Will update though to make it a bit more clear.

if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
params.add(new Object[] { "v_cosine", CosineSimilarity.SIMILARITY_FUNCTION, elementType });
}
Expand Down Expand Up @@ -154,7 +154,7 @@ public void testSimilarityBetweenConstantVectorAndField() {

@SuppressWarnings("unchecked")
public void testSimilarityWithOneDimVector() {
var randomVector = randomVector(1);
var randomVector = randomVector(elementType == ElementType.BIT ? Byte.SIZE : 1);
var query = String.format(Locale.ROOT, """
FROM test
| EVAL similarity = %s(one_dim_vector, %s)
Expand Down Expand Up @@ -256,7 +256,7 @@ public void testDifferentDimensions() {
// edge case where this might not throw is if all `left_vector` are null, but the chance is (hopefully!) low enough to ignore
var randomVector = randomValueOtherThan(
null,
() -> randomVector(randomValueOtherThan(numDims, () -> randomIntBetween(32, 64) * 2))
() -> randomVector(randomValueOtherThan(numDims, () -> randomIntBetween(32, 64) * (elementType == ElementType.BIT ? 8 : 2)))
);
var query = String.format(Locale.ROOT, """
FROM test
Expand Down Expand Up @@ -306,16 +306,16 @@ private static float[] readVector(List<Number> leftVector) {

@Before
public void setup() throws IOException {
numDims = randomIntBetween(10, 20) * (elementType == ElementType.BIT ? 8 : 2);
createIndexWithDenseVector("test");

numDims = randomIntBetween(10, 20) * 2; // even number
int numDocs = randomIntBetween(10, 100);
this.leftVectors = new ArrayList<>();
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
for (int i = 0; i < numDocs; i++) {
List<Number> leftVector = randomVector();
List<Number> rightVector = randomVector();
List<Number> oneDimVector = randomVector(1);
List<Number> oneDimVector = randomVector(elementType == ElementType.BIT ? Byte.SIZE : 1);
docs[i] = prepareIndex("test").setId("" + i)
.setSource("id", String.valueOf(i), "left_vector", leftVector, "right_vector", rightVector, "one_dim_vector", oneDimVector);
leftVectors.add(leftVector);
Expand All @@ -333,11 +333,28 @@ private List<Number> randomVector(int numDims) {
if (rarely()) {
return null;
}
List<Number> vector = new ArrayList<>(numDims);
for (int j = 0; j < numDims; j++) {
int dimensions = numDims;
if (elementType == ElementType.BIT) {
assert dimensions % 8 == 0 : "dimensions must be multiple of 8 for BIT element type but was " + dimensions;
dimensions = dimensions / 8;
}
List<Number> vector = new ArrayList<>(dimensions);
for (int j = 0; j < dimensions; j++) {
switch (elementType) {
case FLOAT -> vector.add(randomFloat());
case BYTE, BIT -> vector.add((byte) randomIntBetween(-128, 127));
case FLOAT -> {
if (dimensions == 1) {
vector.add(randomValueOtherThan(0f, () -> randomFloat()));
} else {
vector.add(randomFloat());
}
}
case BYTE, BIT -> {
if (dimensions == 1) {
vector.add(randomValueOtherThan((byte) 0, () -> (byte) randomIntBetween(-128, 127)));
} else {
vector.add((byte) randomIntBetween(-128, 127));
}
}
default -> throw new IllegalArgumentException("Unexpected element type: " + elementType);
}
}
Expand All @@ -352,9 +369,9 @@ private void createIndexWithDenseVector(String indexName) throws IOException {
.startObject("id")
.field("type", "integer")
.endObject();
createDenseVectorField(mapping, "left_vector");
createDenseVectorField(mapping, "right_vector");
createDenseVectorField(mapping, "one_dim_vector");
createDenseVectorField(mapping, "left_vector", elementType, numDims);
createDenseVectorField(mapping, "right_vector", elementType, numDims);
createDenseVectorField(mapping, "one_dim_vector", elementType, elementType == ElementType.BIT ? Byte.SIZE : 1);
mapping.endObject().endObject();
Settings.Builder settingsBuilder = Settings.builder()
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
Expand All @@ -364,9 +381,11 @@ private void createIndexWithDenseVector(String indexName) throws IOException {
assertAcked(CreateRequest);
}

private void createDenseVectorField(XContentBuilder mapping, String fieldName) throws IOException {
private static void createDenseVectorField(XContentBuilder mapping, String fieldName, ElementType elementType, int dims)
throws IOException {
mapping.startObject(fieldName)
.field("type", "dense_vector")
.field("dims", dims)
.field("similarity", "l2_norm")
.field("element_type", elementType.toString().toLowerCase(Locale.ROOT))
.startObject("index_options")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,15 +228,30 @@ public final PushedBlockLoaderExpression tryPushToFieldLoading(SearchStats stats
}

List<?> vectorList = (List<?>) literal.value();
float[] vectorArray = new float[vectorList.size()];
for (int i = 0; i < vectorList.size(); i++) {
vectorArray[i] = ((Number) vectorList.get(i)).floatValue();
DenseVectorFieldMapper.ElementType elementType = null;
var fieldType = stats.fieldType(field.fieldName());
if (fieldType instanceof DenseVectorFieldMapper.DenseVectorFieldType) {
elementType = ((DenseVectorFieldMapper.DenseVectorFieldType) fieldType).getElementType();
}
if (elementType == null || elementType == DenseVectorFieldMapper.ElementType.FLOAT) {
float[] floatVector = new float[vectorList.size()];
for (int i = 0; i < vectorList.size(); i++) {
floatVector[i] = ((Number) vectorList.get(i)).floatValue();
}
return new PushedBlockLoaderExpression(
field,
new DenseVectorFieldMapper.VectorSimilarityFunctionConfig(getSimilarityFunction(), floatVector)
);
} else {
byte[] byteVector = new byte[vectorList.size()];
for (int i = 0; i < vectorList.size(); i++) {
byteVector[i] = ((Number) vectorList.get(i)).byteValue();
}
return new PushedBlockLoaderExpression(
field,
new DenseVectorFieldMapper.VectorSimilarityFunctionConfig(getSimilarityFunction(), byteVector)
);
}

return new PushedBlockLoaderExpression(
field,
new DenseVectorFieldMapper.VectorSimilarityFunctionConfig(getSimilarityFunction(), vectorArray)
);
}

interface VectorValueProvider extends Releasable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,11 @@ public String constantValue(FieldAttribute.FieldName name) {
return val;
}

@Override
public MappedFieldType fieldType(FieldName field) {
return cache.computeIfAbsent(field.string(), this::makeFieldStats).config.fieldType;
}

private interface DocCountTester {
Boolean test(LeafReader leafReader) throws IOException;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ default String constantValue(FieldName name) {
return null;
}

/**
* Returns the mapped field type for the given field name, or null if the field is not found.
*/
default MappedFieldType fieldType(FieldName name) {
return null;
}

/**
* When there are no search stats available, for example when there are no search contexts, we have static results.
*/
Expand Down