Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/117840.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 117840
summary: Fix timeout ingesting an empty string into a `semantic_text` field
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ public List<ChunkOffset> chunk(String input, ChunkingSettings chunkingSettings)
*
* @param input Text to chunk
* @param maxNumberWordsPerChunk Maximum size of the chunk
* @return The input text chunked
* @param includePrecedingSentence Include the previous sentence
* @return The input text offsets
*/
public List<ChunkOffset> chunk(String input, int maxNumberWordsPerChunk, boolean includePrecedingSentence) {
var chunks = new ArrayList<ChunkOffset>();
Expand Down Expand Up @@ -158,6 +159,11 @@ public List<ChunkOffset> chunk(String input, int maxNumberWordsPerChunk, boolean
chunks.add(new ChunkOffset(chunkStart, input.length()));
}

if (chunks.isEmpty()) {
// The input did not chunk, return the entire input
chunks.add(new ChunkOffset(0, input.length()));
}

return chunks;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,6 @@ List<ChunkPosition> chunkPositions(String input, int chunkSize, int overlap) {
throw new IllegalArgumentException("Invalid chunking parameters, overlap [" + overlap + "] must be >= 0");
}

if (input.isEmpty()) {
return List.of();
}

var chunkPositions = new ArrayList<ChunkPosition>();

// This position in the chunk is where the next overlapping chunk will start
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
import org.hamcrest.Matchers;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -31,16 +32,62 @@

public class EmbeddingRequestChunkerTests extends ESTestCase {

public void testEmptyInput() {
public void testEmptyInput_WordChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
assertThat(batches, empty());
}

public void testBlankInput() {
public void testEmptyInput_SentenceChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
.batchRequestsWithListeners(testListener());
assertThat(batches, empty());
}

public void testWhitespaceInput_SentenceChunker() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a test for whitespace input for the word chunker?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Covered by an existing test

var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(" "), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
.batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(" "));
}

public void testBlankInput_WordChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(""), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(""));
}

public void testBlankInput_SentenceChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(""), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
.batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(""));
}

public void testInputThatDoesNotChunk_WordChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 100, 100, 10, embeddingType).batchRequestsWithListeners(
testListener()
);
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA"));
}

public void testInputThatDoesNotChunk_SentenceChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
.batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA"));
}

public void testShortInputsAreSingleBatch() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,41 @@ private List<String> textChunks(
return chunkPositions.stream().map(offset -> input.substring(offset.start(), offset.end())).collect(Collectors.toList());
}

public void testEmptyString() {
var chunks = textChunks(new SentenceBoundaryChunker(), "", 100, randomBoolean());
assertThat(chunks, hasSize(1));
assertThat(chunks.get(0), Matchers.is(""));
}

public void testBlankString() {
var chunks = textChunks(new SentenceBoundaryChunker(), " ", 100, randomBoolean());
assertThat(chunks, hasSize(1));
assertThat(chunks.get(0), Matchers.is(" "));
}

public void testSingleChar() {
var chunks = textChunks(new SentenceBoundaryChunker(), " b", 100, randomBoolean());
assertThat(chunks, Matchers.contains(" b"));

chunks = textChunks(new SentenceBoundaryChunker(), "b", 100, randomBoolean());
assertThat(chunks, Matchers.contains("b"));

chunks = textChunks(new SentenceBoundaryChunker(), ". ", 100, randomBoolean());
assertThat(chunks, Matchers.contains(". "));

chunks = textChunks(new SentenceBoundaryChunker(), " , ", 100, randomBoolean());
assertThat(chunks, Matchers.contains(" , "));

chunks = textChunks(new SentenceBoundaryChunker(), " ,", 100, randomBoolean());
assertThat(chunks, Matchers.contains(" ,"));
}

public void testSingleCharRepeated() {
var input = "a".repeat(32_000);
var chunks = textChunks(new SentenceBoundaryChunker(), input, 100, randomBoolean());
assertThat(chunks, Matchers.contains(input));
}

public void testChunkSplitLargeChunkSizes() {
for (int maxWordsPerChunk : new int[] { 100, 200 }) {
var chunker = new SentenceBoundaryChunker();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.test.ESTestCase;
import org.hamcrest.Matchers;

import java.util.List;
import java.util.Locale;
Expand Down Expand Up @@ -71,10 +72,6 @@ public class WordBoundaryChunkerTests extends ESTestCase {
* Use the chunk functions that return offsets where possible
*/
List<String> textChunks(WordBoundaryChunker chunker, String input, int chunkSize, int overlap) {
if (input.isEmpty()) {
return List.of("");
}

var chunkPositions = chunker.chunk(input, chunkSize, overlap);
return chunkPositions.stream().map(p -> input.substring(p.start(), p.end())).collect(Collectors.toList());
}
Expand Down Expand Up @@ -240,6 +237,35 @@ public void testWhitespace() {
assertThat(chunks, contains(" "));
}

public void testBlankString() {
var chunks = textChunks(new WordBoundaryChunker(), " ", 100, 10);
assertThat(chunks, hasSize(1));
assertThat(chunks.get(0), Matchers.is(" "));
}

public void testSingleChar() {
var chunks = textChunks(new WordBoundaryChunker(), " b", 100, 10);
assertThat(chunks, Matchers.contains(" b"));

chunks = textChunks(new WordBoundaryChunker(), "b", 100, 10);
assertThat(chunks, Matchers.contains("b"));

chunks = textChunks(new WordBoundaryChunker(), ". ", 100, 10);
assertThat(chunks, Matchers.contains(". "));

chunks = textChunks(new WordBoundaryChunker(), " , ", 100, 10);
assertThat(chunks, Matchers.contains(" , "));

chunks = textChunks(new WordBoundaryChunker(), " ,", 100, 10);
assertThat(chunks, Matchers.contains(" ,"));
}

public void testSingleCharRepeated() {
var input = "a".repeat(32_000);
var chunks = textChunks(new WordBoundaryChunker(), input, 100, 10);
assertThat(chunks, Matchers.contains(input));
}

public void testPunctuation() {
int chunkSize = 1;
var chunks = textChunks(new WordBoundaryChunker(), "Comma, separated", chunkSize, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,22 @@ public void testDeploymentThreadsIncludedInUsage() throws IOException {
}
}

public void testInferEmptyInput() throws IOException {
String modelId = "empty_input";
createPassThroughModel(modelId);
putModelDefinition(modelId);
putVocabulary(List.of("these", "are", "my", "words"), modelId);
startDeployment(modelId);

Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/_infer?timeout=30s");
request.setJsonEntity("""
{ "docs": [] }
""");

var inferenceResponse = client().performRequest(request);
assertThat(EntityUtils.toString(inferenceResponse.getEntity()), equalTo("{\"inference_results\":[]}"));
}

private void putModelDefinition(String modelId) throws IOException {
putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
Response.Builder responseBuilder = Response.builder();
TaskId parentTaskId = new TaskId(clusterService.localNode().getId(), task.getId());

if (request.numberOfDocuments() == 0) {
listener.onResponse(responseBuilder.setId(request.getId()).build());
return;
}

if (MachineLearning.INFERENCE_AGG_FEATURE.check(licenseState)) {
responseBuilder.setLicensed(true);
doInfer(task, request, responseBuilder, parentTaskId, listener);
Expand Down