diff --git a/docs/changelog/139113.yaml b/docs/changelog/139113.yaml new file mode 100644 index 0000000000000..b7cb437d8a503 --- /dev/null +++ b/docs/changelog/139113.yaml @@ -0,0 +1,5 @@ +pr: 139113 +summary: "[ES|QL]: Update Vector Similarity To Support BFLOAT16" +area: "ES|QL" +type: bug +issues: [] diff --git a/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClusterSpecIT.java b/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClusterSpecIT.java index 9e52513629a94..02b5f6e774ae9 100644 --- a/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClusterSpecIT.java +++ b/x-pack/plugin/esql/qa/server/multi-clusters/src/javaRestTest/java/org/elasticsearch/xpack/esql/ccq/MultiClusterSpecIT.java @@ -414,6 +414,19 @@ protected boolean supportsTDigestField() { } } + @Override + protected boolean supportsBFloat16ElementType() { + try { + return RestEsqlTestCase.hasCapabilities(client(), List.of(EsqlCapabilities.Cap.GENERIC_VECTOR_FORMAT.capabilityName())) + && RestEsqlTestCase.hasCapabilities( + remoteClusterClient(), + List.of(EsqlCapabilities.Cap.GENERIC_VECTOR_FORMAT.capabilityName()) + ); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + /** * Convert index patterns and subqueries in FROM commands to use remote indices for a given test case. */ diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java index cb9477386dbb0..058b73ddd57de 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java @@ -186,7 +186,8 @@ public void setup() { supportsSemanticTextInference(), false, supportsExponentialHistograms(), - supportsTDigestField() + supportsTDigestField(), + supportsBFloat16ElementType() ); return null; }); @@ -318,6 +319,10 @@ protected boolean supportsTDigestField() { return RestEsqlTestCase.hasCapabilities(client(), List.of(EsqlCapabilities.Cap.TDIGEST_FIELD_TYPE_SUPPORT_V3.capabilityName())); } + protected boolean supportsBFloat16ElementType() { + return RestEsqlTestCase.hasCapabilities(client(), List.of(EsqlCapabilities.Cap.GENERIC_VECTOR_FORMAT.capabilityName())); + } + protected void doTest() throws Throwable { doTest(testCase.query); } diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/GenerativeRestTest.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/GenerativeRestTest.java index 88f303144803d..b3a8886a396d9 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/GenerativeRestTest.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/GenerativeRestTest.java @@ -265,7 +265,7 @@ private static List originalTypes(Map x) { } private List availableIndices() throws IOException { - return availableDatasetsForEs(true, supportsSourceFieldMapping(), false, requiresTimeSeries(), false, false).stream() + return availableDatasetsForEs(true, supportsSourceFieldMapping(), false, requiresTimeSeries(), false, false, false).stream() .filter(x -> x.requiresInferenceEndpoint() == false) .map(x -> x.indexName()) .toList(); diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java index dafcb1428a3fb..2233635a85f7c 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java @@ -173,6 +173,7 @@ public class CsvTestsDataLoader { private static final TestDataset DENSE_VECTOR_TEXT = new TestDataset("dense_vector_text"); private static final TestDataset MV_TEXT = new TestDataset("mv_text"); private static final TestDataset DENSE_VECTOR = new TestDataset("dense_vector"); + private static final TestDataset DENSE_VECTOR_BFLOAT16 = new TestDataset("dense_vector_bfloat16"); private static final TestDataset COLORS = new TestDataset("colors"); private static final TestDataset COLORS_CMYK_LOOKUP = new TestDataset("colors_cmyk").withSetting("lookup-settings.json"); private static final TestDataset BASE_CONVERSION = new TestDataset("base_conversion"); @@ -248,6 +249,7 @@ public class CsvTestsDataLoader { Map.entry(DENSE_VECTOR_TEXT.indexName, DENSE_VECTOR_TEXT), Map.entry(MV_TEXT.indexName, MV_TEXT), Map.entry(DENSE_VECTOR.indexName, DENSE_VECTOR), + Map.entry(DENSE_VECTOR_BFLOAT16.indexName, DENSE_VECTOR_BFLOAT16), Map.entry(COLORS.indexName, COLORS), Map.entry(COLORS_CMYK_LOOKUP.indexName, COLORS_CMYK_LOOKUP), Map.entry(BASE_CONVERSION.indexName, BASE_CONVERSION), @@ -348,7 +350,7 @@ public static void main(String[] args) throws IOException { } try (RestClient client = builder.build()) { - loadDataSetIntoEs(client, true, true, false, false, true, true, (restClient, indexName, indexMapping, indexSettings) -> { + loadDataSetIntoEs(client, true, true, false, false, true, true, true, (restClient, indexName, indexMapping, indexSettings) -> { // don't use ESRestTestCase methods here or, if you do, test running the main method before making the change StringBuilder jsonBody = new StringBuilder("{"); if (indexSettings != null && indexSettings.isEmpty() == false) { @@ -373,7 +375,8 @@ public static Set availableDatasetsForEs( boolean inferenceEnabled, boolean requiresTimeSeries, boolean exponentialHistogramFieldSupported, - boolean tDigestFieldSupported + boolean tDigestFieldSupported, + boolean bFloat16ElementTypeSupported ) throws IOException { Set testDataSets = new HashSet<>(); @@ -383,7 +386,8 @@ public static Set availableDatasetsForEs( && (supportsSourceFieldMapping || isSourceMappingDataset(dataset) == false) && (requiresTimeSeries == false || isTimeSeries(dataset)) && (exponentialHistogramFieldSupported || containsExponentialHistogramFields(dataset) == false) - && (tDigestFieldSupported || containsTDigestFields(dataset) == false)) { + && (tDigestFieldSupported || containsTDigestFields(dataset) == false) + && (bFloat16ElementTypeSupported || containsBFloat16ElementType(dataset) == false)) { testDataSets.add(dataset); } } @@ -408,44 +412,33 @@ private static boolean isSourceMappingDataset(TestDataset dataset) throws IOExce } private static boolean containsExponentialHistogramFields(TestDataset dataset) throws IOException { - if (dataset.mappingFileName() == null) { - return false; - } - String mappingJsonText = readTextFile(getResource("/" + dataset.mappingFileName())); - JsonNode mappingNode = new ObjectMapper().readTree(mappingJsonText); - JsonNode properties = mappingNode.get("properties"); - if (properties != null) { - for (var fieldWithValue : properties.properties()) { - JsonNode fieldProperties = fieldWithValue.getValue(); - if (fieldProperties != null) { - JsonNode typeNode = fieldProperties.get("type"); - if (typeNode != null && typeNode.asText().equals("exponential_histogram")) { - return true; - } - } - } - } - return false; + return containsFieldWithProperties(dataset, Map.of("type", "exponential_histogram")); } private static boolean containsTDigestFields(TestDataset dataset) throws IOException { - if (dataset.mappingFileName() == null) { + return containsFieldWithProperties(dataset, Map.of("type", "tdigest")); + } + + private static boolean containsBFloat16ElementType(TestDataset dataset) throws IOException { + return containsFieldWithProperties(dataset, Map.of("element_type", "bfloat16")); + } + + private static boolean containsFieldWithProperties(TestDataset dataset, Map properties) throws IOException { + if (dataset.mappingFileName() == null || properties.isEmpty()) { return false; } + String mappingJsonText = readTextFile(getResource("/" + dataset.mappingFileName())); - JsonNode mappingNode = new ObjectMapper().readTree(mappingJsonText); - JsonNode properties = mappingNode.get("properties"); - if (properties != null) { - for (var fieldWithValue : properties.properties()) { - JsonNode fieldProperties = fieldWithValue.getValue(); - if (fieldProperties != null) { - JsonNode typeNode = fieldProperties.get("type"); - if (typeNode != null && typeNode.asText().equals("tdigest")) { - return true; - } + Map mappingNode = new ObjectMapper().readValue(mappingJsonText, Map.class); + Object mappingProperties = mappingNode.get("properties"); + if (mappingProperties instanceof Map mappingPropertiesMap) { + for (Object field : mappingPropertiesMap.values()) { + if (field instanceof Map fieldMap && fieldMap.entrySet().containsAll(properties.entrySet())) { + return true; } } } + return false; } @@ -461,7 +454,7 @@ public static void loadDataSetIntoEs( boolean supportsSourceFieldMapping, boolean inferenceEnabled ) throws IOException { - loadDataSetIntoEs(client, supportsIndexModeLookup, supportsSourceFieldMapping, inferenceEnabled, false, false, false); + loadDataSetIntoEs(client, supportsIndexModeLookup, supportsSourceFieldMapping, inferenceEnabled, false, false, false, false); } public static void loadDataSetIntoEs( @@ -471,7 +464,8 @@ public static void loadDataSetIntoEs( boolean inferenceEnabled, boolean timeSeriesOnly, boolean exponentialHistogramFieldSupported, - boolean tDigestFieldSupported + boolean tDigestFieldSupported, + boolean bFloat16ElementTypeSupported ) throws IOException { loadDataSetIntoEs( client, @@ -481,6 +475,7 @@ public static void loadDataSetIntoEs( timeSeriesOnly, exponentialHistogramFieldSupported, tDigestFieldSupported, + bFloat16ElementTypeSupported, (restClient, indexName, indexMapping, indexSettings) -> { ESRestTestCase.createIndex(restClient, indexName, indexSettings, indexMapping, null); } @@ -495,6 +490,7 @@ private static void loadDataSetIntoEs( boolean timeSeriesOnly, boolean exponentialHistogramFieldSupported, boolean tDigestFieldSupported, + boolean bFloat16ElementTypeSupported, IndexCreator indexCreator ) throws IOException { Logger logger = LogManager.getLogger(CsvTestsDataLoader.class); @@ -507,7 +503,8 @@ private static void loadDataSetIntoEs( inferenceEnabled, timeSeriesOnly, exponentialHistogramFieldSupported, - tDigestFieldSupported + tDigestFieldSupported, + bFloat16ElementTypeSupported )) { load(client, dataset, logger, indexCreator); loadedDatasets.add(dataset.indexName); diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/data/dense_vector.csv b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/data/dense_vector.csv index 93284fcda9ec5..8c5d5f8dad129 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/data/dense_vector.csv +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/data/dense_vector.csv @@ -1,5 +1,5 @@ id:l, float_vector:dense_vector, byte_vector:dense_vector, bit_vector:dense_vector -0, [1.0, 2.0, 3.0], [10, 20, 30], [13, 112] -1, [4.0, 5.0, 6.0], [40, 50, 60], [45, 9] -2, [9.0, 8.0, 7.0], [90, 80, 70], [127, 0] -3, [0.054, 0.032, 0.012], [100, 110, 120], [88, 53] \ No newline at end of file +0, [1.0, 2.0, 3.0], [10, 20, 30], [13, 112] +1, [4.0, 5.0, 6.0], [40, 50, 60], [45, 9] +2, [9.0, 8.0, 7.0], [90, 80, 70], [127, 0] +3, [0.054, 0.032, 0.012], [100, 110, 120], [88, 53] diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/data/dense_vector_bfloat16.csv b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/data/dense_vector_bfloat16.csv new file mode 100644 index 0000000000000..a0318d2757a65 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/data/dense_vector_bfloat16.csv @@ -0,0 +1,5 @@ +id:l, bfloat16_vector:dense_vector +0, [1.0, 2.0, 3.0] +1, [4.0, 5.0, 6.0] +2, [9.0, 8.0, 7.0] +3, [0.5390625, 0.3203125, 0.01202392578125] diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector-bfloat16.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector-bfloat16.csv-spec new file mode 100644 index 0000000000000..de2957206801b --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/dense_vector-bfloat16.csv-spec @@ -0,0 +1,54 @@ +retrieveDenseBFloat16VectorData +required_capability: dense_vector_field_type_released +required_capability: dense_vector_agg_metric_double_if_version +required_capability: l2_norm_vector_similarity_function +required_capability: generic_vector_format + +FROM dense_vector_bfloat16 +| KEEP id, bfloat16_vector +| SORT id +; + +id:l | bfloat16_vector:dense_vector +0 | [1.0, 2.0, 3.0] +1 | [4.0, 5.0, 6.0] +2 | [9.0, 8.0, 7.0] +3 | [0.5390625, 0.3203125, 0.01202392578125] +; + +denseBFloat16VectorWithEval +required_capability: dense_vector_agg_metric_double_if_version +required_capability: l2_norm_vector_similarity_function +required_capability: generic_vector_format + +FROM dense_vector_bfloat16 +| EVAL v = bfloat16_vector +| KEEP id, v +| SORT id +; + +id:l | v:dense_vector +0 | [1.0, 2.0, 3.0] +1 | [4.0, 5.0, 6.0] +2 | [9.0, 8.0, 7.0] +3 | [0.5390625, 0.3203125, 0.01202392578125] +; + +denseBFloat16VectorWithRenameAndDrop +required_capability: dense_vector_agg_metric_double_if_version +required_capability: l2_norm_vector_similarity_function +required_capability: generic_vector_format + +FROM dense_vector_bfloat16 +| EVAL v = bfloat16_vector +| RENAME v AS new_vector +| DROP bfloat16_vector +| SORT id +; + +id:l | new_vector:dense_vector +0 | [1.0, 2.0, 3.0] +1 | [4.0, 5.0, 6.0] +2 | [9.0, 8.0, 7.0] +3 | [0.5390625, 0.3203125, 0.01202392578125] +; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-dense_vector-all_element_types.json b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-dense_vector-all_element_types.json new file mode 100644 index 0000000000000..71d30b0de1b56 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-dense_vector-all_element_types.json @@ -0,0 +1,48 @@ +{ + "properties": { + "id": { + "type": "long" + }, + "float_vector": { + "type": "dense_vector", + "similarity": "l2_norm", + "index_options": { + "type": "hnsw", + "m": 16, + "ef_construction": 100 + } + }, + "byte_vector": { + "type": "dense_vector", + "similarity": "l2_norm", + "element_type": "byte", + "index_options": { + "type": "hnsw", + "m": 16, + "ef_construction": 100 + } + }, + "bit_vector": { + "type": "dense_vector", + "dims": 16, + "similarity": "l2_norm", + "element_type": "bit", + "index_options": { + "type": "hnsw", + "m": 16, + "ef_construction": 100 + } + }, + "bfloat16_vector": { + "type": "dense_vector", + "dims": 16, + "similarity": "l2_norm", + "element_type": "bfloat16", + "index_options": { + "type": "hnsw", + "m": 16, + "ef_construction": 100 + } + } + } +} diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-dense_vector_bfloat16.json b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-dense_vector_bfloat16.json new file mode 100644 index 0000000000000..2b589d560640a --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-dense_vector_bfloat16.json @@ -0,0 +1,17 @@ +{ + "properties": { + "id": { + "type": "long" + }, + "bfloat16_vector": { + "type": "dense_vector", + "similarity": "l2_norm", + "element_type": "bfloat16", + "index_options": { + "type": "hnsw", + "m": 16, + "ef_construction": 100 + } + } + } +} diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java index e102ad8b1b8c8..cfdfcbd4a4af1 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java @@ -35,7 +35,6 @@ import java.util.Collection; import java.util.List; import java.util.Locale; -import java.util.Set; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.hamcrest.CoreMatchers.containsString; @@ -49,12 +48,12 @@ public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase { public static Iterable parameters() throws Exception { List params = new ArrayList<>(); - for (ElementType elementType : Set.of(ElementType.FLOAT, ElementType.BYTE, ElementType.BIT)) { + for (ElementType elementType : ElementType.values()) { params.add(new Object[] { "v_cosine", CosineSimilarity.SIMILARITY_FUNCTION, elementType }); params.add(new Object[] { "v_dot_product", DotProduct.SIMILARITY_FUNCTION, elementType }); params.add(new Object[] { "v_l1_norm", L1Norm.SIMILARITY_FUNCTION, elementType }); params.add(new Object[] { "v_l2_norm", L2Norm.SIMILARITY_FUNCTION, elementType }); - if (elementType != ElementType.FLOAT) { + if (elementType != ElementType.FLOAT && elementType != ElementType.BFLOAT16) { params.add(new Object[] { "v_hamming", Hamming.EVALUATOR_SIMILARITY_FUNCTION, elementType }); } } @@ -236,7 +235,7 @@ private Double calculateSimilarity( case BYTE, BIT -> { return (double) similarityFunction.calculateSimilarity(asByteArray(randomVector), asByteArray(vector)); } - case FLOAT -> { + case FLOAT, BFLOAT16 -> { return (double) similarityFunction.calculateSimilarity(asFloatArray(randomVector), asFloatArray(vector)); } default -> throw new IllegalArgumentException("Unexpected element type: " + elementType); @@ -335,7 +334,7 @@ private List randomVector(int numDims, boolean allowNull) { List vector = new ArrayList<>(dimensions); for (int j = 0; j < dimensions; j++) { switch (elementType) { - case FLOAT -> { + case FLOAT, BFLOAT16 -> { if (dimensions == 1) { vector.add(randomValueOtherThan(0f, () -> randomFloat())); } else { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index d2e6f9bccdabe..06e39c81e7260 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1418,6 +1418,11 @@ public enum Cap { */ DENSE_VECTOR_FIELD_TYPE_BIT_ELEMENTS, + /** + * Support directIO rescoring and `bfloat16` for `bbq_hnsw` and `bbq_disk`, and `bfloat16` for `hnsw` ans `bbq_flat` index types. + */ + GENERIC_VECTOR_FORMAT, + /** * Support null elements on vector similarity functions */ diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java index 42a7cddff68fe..de798cd5502c8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java @@ -233,7 +233,9 @@ public final PushedBlockLoaderExpression tryPushToFieldLoading(SearchStats stats if (fieldType instanceof DenseVectorFieldMapper.DenseVectorFieldType) { elementType = ((DenseVectorFieldMapper.DenseVectorFieldType) fieldType).getElementType(); } - if (elementType == null || elementType == DenseVectorFieldMapper.ElementType.FLOAT) { + if (elementType == null + || elementType == DenseVectorFieldMapper.ElementType.FLOAT + || elementType == DenseVectorFieldMapper.ElementType.BFLOAT16) { float[] floatVector = new float[vectorList.size()]; for (int i = 0; i < vectorList.size(); i++) { floatVector[i] = ((Number) vectorList.get(i)).floatValue(); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index 13ff3e7f18ff1..bd43ff784438b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -195,6 +195,8 @@ public class AnalyzerTests extends ESTestCase { Settings.EMPTY ); + private static final String DENSE_VECTOR_MAPPING_FILE = "mapping-dense_vector-all_element_types.json"; + public void testIndexResolution() { EsIndex idx = EsIndexGenerator.esIndex("idx"); Analyzer analyzer = analyzer(IndexResolution.valid(idx)); @@ -2393,12 +2395,14 @@ public void testDenseVectorImplicitCastingKnn() { checkDenseVectorCastingKnn("bit_vector"); checkDenseVectorCastingHexKnn("bit_vector"); checkDenseVectorEvalCastingKnn("bit_vector"); + checkDenseVectorEvalCastingKnn("bfloat16_vector"); + checkDenseVectorCastingHexKnn("bfloat16_vector"); } private static void checkDenseVectorCastingKnn(String fieldName) { var plan = analyze(String.format(Locale.ROOT, """ from test | where knn(%s, [0, 1, 2]) - """, fieldName), "mapping-dense_vector.json"); + """, fieldName), DENSE_VECTOR_MAPPING_FILE); var limit = as(plan, Limit.class); var filter = as(limit.child(), Filter.class); @@ -2411,7 +2415,7 @@ private static void checkDenseVectorCastingKnn(String fieldName) { private static void checkDenseVectorCastingHexKnn(String fieldName) { var plan = analyze(String.format(Locale.ROOT, """ from test | where knn(%s, "000102") - """, fieldName), "mapping-dense_vector.json"); + """, fieldName), DENSE_VECTOR_MAPPING_FILE); var limit = as(plan, Limit.class); var filter = as(limit.child(), Filter.class); @@ -2424,7 +2428,7 @@ private static void checkDenseVectorCastingHexKnn(String fieldName) { private static void checkDenseVectorEvalCastingKnn(String fieldName) { var plan = analyze(String.format(Locale.ROOT, """ from test | eval query = to_dense_vector([0, 1, 2]) | where knn(%s, query) - """, fieldName), "mapping-dense_vector.json"); + """, fieldName), DENSE_VECTOR_MAPPING_FILE); var limit = as(plan, Limit.class); var filter = as(limit.child(), Filter.class); @@ -2438,12 +2442,13 @@ public void testDenseVectorImplicitCastingKnnQueryParams() { checkDenseVectorCastingKnnQueryParams("float_vector"); checkDenseVectorCastingKnnQueryParams("byte_vector"); checkDenseVectorCastingKnnQueryParams("bit_vector"); + checkDenseVectorCastingKnnQueryParams("bfloat16_vector"); } private void checkDenseVectorCastingKnnQueryParams(String fieldName) { var plan = analyze(String.format(Locale.ROOT, """ from test | where knn(%s, ?query_vector) - """, fieldName), "mapping-dense_vector.json", new QueryParams(List.of(paramAsConstant("query_vector", List.of(0, 1, 2))))); + """, fieldName), DENSE_VECTOR_MAPPING_FILE, new QueryParams(List.of(paramAsConstant("query_vector", List.of(0, 1, 2))))); var limit = as(plan, Limit.class); var filter = as(limit.child(), Filter.class); @@ -2456,17 +2461,21 @@ private void checkDenseVectorCastingKnnQueryParams(String fieldName) { public void testDenseVectorImplicitCastingSimilarityFunctions() { checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(float_vector, [0.342, 0.164, 0.234])", List.of(0.342, 0.164, 0.234)); checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(byte_vector, [1, 2, 3])", List.of(1, 2, 3)); + checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(bfloat16_vector, [1, 2, 3])", List.of(1, 2, 3)); checkDenseVectorImplicitCastingSimilarityFunction( "v_dot_product(float_vector, [0.342, 0.164, 0.234])", List.of(0.342, 0.164, 0.234) ); checkDenseVectorImplicitCastingSimilarityFunction("v_dot_product(byte_vector, [1, 2, 3])", List.of(1, 2, 3)); + checkDenseVectorImplicitCastingSimilarityFunction("v_dot_product(bfloat16_vector, [1, 2, 3])", List.of(1, 2, 3)); checkDenseVectorImplicitCastingSimilarityFunction("v_l1_norm(float_vector, [0.342, 0.164, 0.234])", List.of(0.342, 0.164, 0.234)); checkDenseVectorImplicitCastingSimilarityFunction("v_l1_norm(byte_vector, [1, 2, 3])", List.of(1, 2, 3)); + checkDenseVectorImplicitCastingSimilarityFunction("v_l1_norm(bfloat16_vector, [1, 2, 3])", List.of(1, 2, 3)); checkDenseVectorImplicitCastingSimilarityFunction("v_l2_norm(float_vector, [0.342, 0.164, 0.234])", List.of(0.342, 0.164, 0.234)); checkDenseVectorImplicitCastingSimilarityFunction("v_l2_norm(float_vector, [1, 2, 3])", List.of(1, 2, 3)); checkDenseVectorImplicitCastingSimilarityFunction("v_l2_norm(byte_vector, [1, 2, 3])", List.of(1, 2, 3)); checkDenseVectorImplicitCastingSimilarityFunction("v_l2_norm(bit_vector, [1, 2])", List.of(1, 2)); + checkDenseVectorImplicitCastingSimilarityFunction("v_l2_norm(bfloat16_vector, [1, 2, 3])", List.of(1, 2, 3)); checkDenseVectorImplicitCastingSimilarityFunction("v_hamming(byte_vector, [0.342, 0.164, 0.234])", List.of(0.342, 0.164, 0.234)); checkDenseVectorImplicitCastingSimilarityFunction("v_hamming(byte_vector, [1, 2, 3])", List.of(1, 2, 3)); checkDenseVectorImplicitCastingSimilarityFunction("v_hamming(bit_vector, [1, 2])", List.of(1, 2)); @@ -2475,7 +2484,7 @@ public void testDenseVectorImplicitCastingSimilarityFunctions() { private void checkDenseVectorImplicitCastingSimilarityFunction(String similarityFunction, List expectedElems) { var plan = analyze(String.format(Locale.ROOT, """ from test | eval similarity = %s - """, similarityFunction), "mapping-dense_vector.json"); + """, similarityFunction), DENSE_VECTOR_MAPPING_FILE); var limit = as(plan, Limit.class); var eval = as(limit.child(), Eval.class); @@ -2483,7 +2492,7 @@ private void checkDenseVectorImplicitCastingSimilarityFunction(String similarity assertEquals("similarity", alias.name()); var similarity = as(alias.child(), VectorSimilarityFunction.class); var left = as(similarity.left(), FieldAttribute.class); - assertThat(List.of("float_vector", "byte_vector", "bit_vector"), hasItem(left.name())); + assertThat(List.of("float_vector", "byte_vector", "bit_vector", "bfloat16_vector"), hasItem(left.name())); var right = as(similarity.right(), ToDenseVector.class); var literal = as(right.field(), Literal.class); assertThat(literal.value(), equalTo(expectedElems)); @@ -2492,20 +2501,25 @@ private void checkDenseVectorImplicitCastingSimilarityFunction(String similarity public void testDenseVectorEvalCastingSimilarityFunctions() { checkDenseVectorEvalCastingSimilarityFunction("v_cosine(float_vector, query)"); checkDenseVectorEvalCastingSimilarityFunction("v_cosine(byte_vector, query)"); + checkDenseVectorEvalCastingSimilarityFunction("v_cosine(bfloat16_vector, query)"); checkDenseVectorEvalCastingSimilarityFunction("v_dot_product(float_vector, query)"); checkDenseVectorEvalCastingSimilarityFunction("v_dot_product(byte_vector, query)"); + checkDenseVectorEvalCastingSimilarityFunction("v_dot_product(bfloat16_vector, query)"); checkDenseVectorEvalCastingSimilarityFunction("v_l1_norm(float_vector, query)"); checkDenseVectorEvalCastingSimilarityFunction("v_l1_norm(byte_vector, query)"); + checkDenseVectorEvalCastingSimilarityFunction("v_l1_norm(bfloat16_vector, query)"); checkDenseVectorEvalCastingSimilarityFunction("v_l2_norm(float_vector, query)"); - checkDenseVectorEvalCastingSimilarityFunction("v_l2_norm(float_vector, query)"); - checkDenseVectorEvalCastingSimilarityFunction("v_hamming(byte_vector, query)"); + checkDenseVectorEvalCastingSimilarityFunction("v_l2_norm(byte_vector, query)"); + checkDenseVectorEvalCastingSimilarityFunction("v_l2_norm(bit_vector, query)"); + checkDenseVectorEvalCastingSimilarityFunction("v_l2_norm(bfloat16_vector, query)"); checkDenseVectorEvalCastingSimilarityFunction("v_hamming(byte_vector, query)"); + checkDenseVectorEvalCastingSimilarityFunction("v_hamming(bit_vector, query)"); } private void checkDenseVectorEvalCastingSimilarityFunction(String similarityFunction) { var plan = analyze(String.format(Locale.ROOT, """ from test | eval query = to_dense_vector([0.342, 0.164, 0.234]) | eval similarity = %s - """, similarityFunction), "mapping-dense_vector.json"); + """, similarityFunction), DENSE_VECTOR_MAPPING_FILE); var limit = as(plan, Limit.class); var eval = as(limit.child(), Eval.class); @@ -2513,7 +2527,7 @@ private void checkDenseVectorEvalCastingSimilarityFunction(String similarityFunc assertEquals("similarity", alias.name()); var similarity = as(alias.child(), VectorSimilarityFunction.class); var left = as(similarity.left(), FieldAttribute.class); - assertThat(List.of("float_vector", "byte_vector"), hasItem(left.name())); + assertThat(List.of("float_vector", "byte_vector", "bit_vector", "bfloat16_vector"), hasItem(left.name())); var right = as(similarity.right(), ReferenceAttribute.class); assertThat(right.dataType(), is(DENSE_VECTOR)); assertThat(right.name(), is("query")); @@ -2529,7 +2543,7 @@ public void testVectorFunctionHexImplicitCastingError() { private void checkVectorFunctionHexImplicitCastingError(String clause) { var query = "from test | " + clause; - VerificationException error = expectThrows(VerificationException.class, () -> analyze(query, "mapping-dense_vector.json")); + VerificationException error = expectThrows(VerificationException.class, () -> analyze(query, DENSE_VECTOR_MAPPING_FILE)); assertThat( error.getMessage(), containsString( @@ -2544,7 +2558,7 @@ public void testMagnitudePlanWithDenseVectorImplicitCasting() { var plan = analyze(String.format(Locale.ROOT, """ from test | eval scalar = v_magnitude([1, 2, 3]) - """), "mapping-dense_vector.json"); + """), DENSE_VECTOR_MAPPING_FILE); var limit = as(plan, Limit.class); var eval = as(limit.child(), Eval.class); @@ -3896,7 +3910,7 @@ public void testKnnFunctionWithTextEmbedding() { LogicalPlan plan = analyze( String.format(Locale.ROOT, """ from test | where KNN(float_vector, TEXT_EMBEDDING("italian food recipe", "%s"))""", TEXT_EMBEDDING_INFERENCE_ID), - "mapping-dense_vector.json" + DENSE_VECTOR_MAPPING_FILE ); Limit limit = as(plan, Limit.class);