diff --git a/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/DenseVectorRollingUpgradeIT.java b/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/DenseVectorRollingUpgradeIT.java index 8fccb916396cb..c17c249054838 100644 --- a/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/DenseVectorRollingUpgradeIT.java +++ b/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/DenseVectorRollingUpgradeIT.java @@ -26,6 +26,7 @@ import java.util.Collection; import java.util.List; import java.util.OptionalInt; +import java.util.stream.Collectors; import java.util.stream.IntStream; import static org.elasticsearch.rest.action.search.RestSearchAction.TOTAL_HITS_AS_INT_PARAM; @@ -33,7 +34,7 @@ public class DenseVectorRollingUpgradeIT extends AbstractRollingUpgradeTestCase { - private static String generateBulkData(int upgradedNodes, int dimensions) { + private static String generateBulkData(ElementType elementType, int upgradedNodes, int dimensions) { StringBuilder sb = new StringBuilder(); int[] vals = new int[dimensions]; @@ -44,10 +45,16 @@ private static String generateBulkData(int upgradedNodes, int dimensions) { for (var it = docs.iterator(); it.hasNext();) { vals[upgradedNodes]++; + // ensure float values are added as floats + // so ES doesn't get confused in the mapping + String vector = switch (elementType) { + case BYTE, BIT -> Arrays.toString(vals); + case FLOAT, BFLOAT16 -> Arrays.stream(vals).mapToObj(i -> i + ".0").collect(Collectors.joining(",", "[", "]")); + }; sb.append("{\"index\": {\"_id\": \"").append(it.nextInt()).append("\"}}"); sb.append(System.lineSeparator()); - sb.append("{\"embedding\": ").append(Arrays.toString(vals)).append("}"); + sb.append("{\"embedding\": ").append(vector).append("}"); sb.append(System.lineSeparator()); } @@ -95,7 +102,7 @@ public void testDenseVectorMappingUpdateOnOldCluster() throws IOException { client().performRequest(createIndex); Request index = new Request("POST", "/" + indexName + "/_bulk/"); index.addParameter("refresh", "true"); - index.setJsonEntity(generateBulkData(upgradedNodes, 8)); + index.setJsonEntity(generateBulkData(ElementType.FLOAT, upgradedNodes, 8)); client().performRequest(index); } @@ -125,7 +132,7 @@ public void testDenseVectorMappingUpdateOnOldCluster() throws IOException { assertOK(client().performRequest(updateMapping)); Request index = new Request("POST", "/" + indexName + "/_bulk/"); index.addParameter("refresh", "true"); - index.setJsonEntity(generateBulkData(upgradedNodes, 8)); + index.setJsonEntity(generateBulkData(ElementType.FLOAT, upgradedNodes, 8)); assertOK(client().performRequest(index)); expectedCount = 20; assertCount(indexName, expectedCount); @@ -214,7 +221,7 @@ public void testDenseVectorIndexOverUpgrade() throws IOException { Request index = new Request("POST", "/" + indexName + "/_bulk/"); index.addParameter("refresh", "true"); - index.setJsonEntity(generateBulkData(upgradedNodes, dims.getAsInt())); + index.setJsonEntity(generateBulkData(elementType, upgradedNodes, dims.getAsInt())); assertOK(client().performRequest(index)); int count = existingCount + 10;