Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/117840.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 117840
summary: Fix timeout ingesting an empty string into a `semantic_text` field
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ public List<String> chunk(String input, ChunkingSettings chunkingSettings) {
* @return The input text chunked
*/
public List<String> chunk(String input, int maxNumberWordsPerChunk, boolean includePrecedingSentence) {
if (input.isEmpty()) {
return List.of("");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to trim the input before checking if it's empty? How does the chunker handle input that is only whitespace?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch @Mikep86 thanks.

There's a whole class of bugs for "things that don't chunk". Pure whitespace, a single character or the same character repeated thousands of times don't chunk. I've added test cases for these situations and the solution I've implemented here in this is to return the original input if it did not chunk. This applies to both the Word and Sentence chunker.

This makes me wonder if there should be an upper limit on the chunk size in terms of number of characters. A badly formed input contain latin characters but no whitespace would result in a single large chunk, think a binary file base64 encoded.


var chunks = new ArrayList<String>();

sentenceIterator.setText(input);
Expand Down Expand Up @@ -154,6 +158,11 @@ public List<String> chunk(String input, int maxNumberWordsPerChunk, boolean incl
chunks.add(input.substring(chunkStart));
}

if (chunks.isEmpty()) {
// The input did not chunk, return the entire input
chunks.add(input);
}

return chunks;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
import org.hamcrest.Matchers;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -31,16 +32,62 @@

public class EmbeddingRequestChunkerTests extends ESTestCase {

public void testEmptyInput() {
public void testEmptyInput_WordChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
assertThat(batches, empty());
}

public void testBlankInput() {
public void testEmptyInput_SentenceChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
.batchRequestsWithListeners(testListener());
assertThat(batches, empty());
}

public void testWhitespaceInput_SentenceChunker() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a test for whitespace input for the word chunker?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Covered by an existing test

var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(" "), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
.batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(" "));
}

public void testBlankInput_WordChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(""), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(""));
}

public void testBlankInput_SentenceChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of(""), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
.batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(""));
}

public void testInputThatDoesNotChunk_WordChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 100, 100, 10, embeddingType).batchRequestsWithListeners(
testListener()
);
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA"));
}

public void testInputThatDoesNotChunk_SentenceChunker() {
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
.batchRequestsWithListeners(testListener());
assertThat(batches, hasSize(1));
assertThat(batches.get(0).batch().inputs(), hasSize(1));
assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA"));
}

public void testShortInputsAreSingleBatch() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,41 @@

public class SentenceBoundaryChunkerTests extends ESTestCase {

public void testEmptyString() {
var chunks = new SentenceBoundaryChunker().chunk("", 100, randomBoolean());
assertThat(chunks, hasSize(1));
assertThat(chunks.get(0), Matchers.is(""));
}

public void testBlankString() {
var chunks = new SentenceBoundaryChunker().chunk(" ", 100, randomBoolean());
assertThat(chunks, hasSize(1));
assertThat(chunks.get(0), Matchers.is(" "));
}

public void testSingleChar() {
var chunks = new SentenceBoundaryChunker().chunk(" b", 100, randomBoolean());
assertThat(chunks, Matchers.contains(" b"));

chunks = new SentenceBoundaryChunker().chunk("b", 100, randomBoolean());
assertThat(chunks, Matchers.contains("b"));

chunks = new SentenceBoundaryChunker().chunk(". ", 100, randomBoolean());
assertThat(chunks, Matchers.contains(". "));

chunks = new SentenceBoundaryChunker().chunk(" , ", 100, randomBoolean());
assertThat(chunks, Matchers.contains(" , "));

chunks = new SentenceBoundaryChunker().chunk(" ,", 100, randomBoolean());
assertThat(chunks, Matchers.contains(" ,"));
}

public void testSingleCharRepeated() {
var input = "a".repeat(32_000);
var chunks = new SentenceBoundaryChunker().chunk(input, 100, randomBoolean());
assertThat(chunks, Matchers.contains(input));
}

public void testChunkSplitLargeChunkSizes() {
for (int maxWordsPerChunk : new int[] { 100, 200 }) {
var chunker = new SentenceBoundaryChunker();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.test.ESTestCase;
import org.hamcrest.Matchers;

import java.util.List;
import java.util.Locale;
Expand Down Expand Up @@ -226,6 +227,35 @@ public void testWhitespace() {
assertThat(chunks, contains(" "));
}

public void testBlankString() {
var chunks = new WordBoundaryChunker().chunk(" ", 100, 10);
assertThat(chunks, hasSize(1));
assertThat(chunks.get(0), Matchers.is(" "));
}

public void testSingleChar() {
var chunks = new WordBoundaryChunker().chunk(" b", 100, 10);
assertThat(chunks, Matchers.contains(" b"));

chunks = new WordBoundaryChunker().chunk("b", 100, 10);
assertThat(chunks, Matchers.contains("b"));

chunks = new WordBoundaryChunker().chunk(". ", 100, 10);
assertThat(chunks, Matchers.contains(". "));

chunks = new WordBoundaryChunker().chunk(" , ", 100, 10);
assertThat(chunks, Matchers.contains(" , "));

chunks = new WordBoundaryChunker().chunk(" ,", 100, 10);
assertThat(chunks, Matchers.contains(" ,"));
}

public void testSingleCharRepeated() {
var input = "a".repeat(32_000);
var chunks = new WordBoundaryChunker().chunk(input, 100, 10);
assertThat(chunks, Matchers.contains(input));
}

public void testPunctuation() {
int chunkSize = 1;
var chunks = new WordBoundaryChunker().chunk("Comma, separated", chunkSize, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,22 @@ public void testDeploymentThreadsIncludedInUsage() throws IOException {
}
}

public void testInferEmptyInput() throws IOException {
String modelId = "empty_input";
createPassThroughModel(modelId);
putModelDefinition(modelId);
putVocabulary(List.of("these", "are", "my", "words"), modelId);
startDeployment(modelId);

Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/_infer?timeout=30s");
request.setJsonEntity("""
{ "docs": [] }
""");

var inferenceResponse = client().performRequest(request);
assertThat(EntityUtils.toString(inferenceResponse.getEntity()), equalTo("{\"inference_results\":[]}"));
}

private void putModelDefinition(String modelId) throws IOException {
putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
Response.Builder responseBuilder = Response.builder();
TaskId parentTaskId = new TaskId(clusterService.localNode().getId(), task.getId());

if (request.numberOfDocuments() == 0) {
listener.onResponse(responseBuilder.setId(request.getId()).build());
return;
}

if (MachineLearning.INFERENCE_AGG_FEATURE.check(licenseState)) {
responseBuilder.setLicensed(true);
doInfer(task, request, responseBuilder, parentTaskId, listener);
Expand Down