Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@
import java.util.Collection;
import java.util.List;
import java.util.OptionalInt;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.elasticsearch.rest.action.search.RestSearchAction.TOTAL_HITS_AS_INT_PARAM;
import static org.hamcrest.Matchers.equalTo;

public class DenseVectorRollingUpgradeIT extends AbstractRollingUpgradeTestCase {

private static String generateBulkData(int upgradedNodes, int dimensions) {
private static String generateBulkData(ElementType elementType, int upgradedNodes, int dimensions) {
StringBuilder sb = new StringBuilder();

int[] vals = new int[dimensions];
Expand All @@ -44,10 +45,16 @@ private static String generateBulkData(int upgradedNodes, int dimensions) {

for (var it = docs.iterator(); it.hasNext();) {
vals[upgradedNodes]++;
// ensure float values are added as floats
// so ES doesn't get confused in the mapping
String vector = switch (elementType) {
case BYTE, BIT -> Arrays.toString(vals);
case FLOAT, BFLOAT16 -> Arrays.stream(vals).mapToObj(i -> i + ".0").collect(Collectors.joining(",", "[", "]"));
};

sb.append("{\"index\": {\"_id\": \"").append(it.nextInt()).append("\"}}");
sb.append(System.lineSeparator());
sb.append("{\"embedding\": ").append(Arrays.toString(vals)).append("}");
sb.append("{\"embedding\": ").append(vector).append("}");
sb.append(System.lineSeparator());
}

Expand Down Expand Up @@ -95,7 +102,7 @@ public void testDenseVectorMappingUpdateOnOldCluster() throws IOException {
client().performRequest(createIndex);
Request index = new Request("POST", "/" + indexName + "/_bulk/");
index.addParameter("refresh", "true");
index.setJsonEntity(generateBulkData(upgradedNodes, 8));
index.setJsonEntity(generateBulkData(ElementType.FLOAT, upgradedNodes, 8));
client().performRequest(index);
}

Expand Down Expand Up @@ -125,7 +132,7 @@ public void testDenseVectorMappingUpdateOnOldCluster() throws IOException {
assertOK(client().performRequest(updateMapping));
Request index = new Request("POST", "/" + indexName + "/_bulk/");
index.addParameter("refresh", "true");
index.setJsonEntity(generateBulkData(upgradedNodes, 8));
index.setJsonEntity(generateBulkData(ElementType.FLOAT, upgradedNodes, 8));
assertOK(client().performRequest(index));
expectedCount = 20;
assertCount(indexName, expectedCount);
Expand Down Expand Up @@ -214,7 +221,7 @@ public void testDenseVectorIndexOverUpgrade() throws IOException {

Request index = new Request("POST", "/" + indexName + "/_bulk/");
index.addParameter("refresh", "true");
index.setJsonEntity(generateBulkData(upgradedNodes, dims.getAsInt()));
index.setJsonEntity(generateBulkData(elementType, upgradedNodes, dims.getAsInt()));
assertOK(client().performRequest(index));

int count = existingCount + 10;
Expand Down
Loading