Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.Locale;
import java.util.Map;

import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilterTestUtil.randomInputCasesForSemanticText;
import static org.hamcrest.Matchers.equalTo;

public class ShardBulkInferenceActionFilterIT extends ESIntegTestCase {
Expand Down Expand Up @@ -93,8 +94,8 @@ public void testBulkOperations() throws Exception {
String id = Long.toString(totalDocs);
boolean isIndexRequest = randomBoolean();
Map<String, Object> source = new HashMap<>();
source.put("sparse_field", isIndexRequest && rarely() ? null : randomAlphaOfLengthBetween(0, 1000));
source.put("dense_field", isIndexRequest && rarely() ? null : randomAlphaOfLengthBetween(0, 1000));
source.put("sparse_field", isIndexRequest && rarely() ? null : randomInputCasesForSemanticText());
source.put("dense_field", isIndexRequest && rarely() ? null : randomInputCasesForSemanticText());
if (isIndexRequest) {
bulkReqBuilder.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(id).setSource(source));
totalDocs++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,12 +492,16 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
* If {@code valueObj} is not a string or a collection of strings, it throws an ElasticsearchStatusException.
*/
private static List<String> nodeStringValues(String field, Object valueObj) {
if (valueObj instanceof String value) {
if (valueObj instanceof Number || valueObj instanceof Boolean) {
return List.of(valueObj.toString());
} else if (valueObj instanceof String value) {
return List.of(value);
} else if (valueObj instanceof Collection<?> values) {
List<String> valuesString = new ArrayList<>();
for (var v : values) {
if (v instanceof String value) {
if (v instanceof Number || v instanceof Boolean) {
valuesString.add(v.toString());
} else if (v instanceof String value) {
valuesString.add(value);
} else {
throw new ElasticsearchStatusException(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.action.filter;

import static org.elasticsearch.test.ESTestCase.randomAlphaOfLengthBetween;
import static org.elasticsearch.test.ESTestCase.randomBoolean;
import static org.elasticsearch.test.ESTestCase.randomDouble;
import static org.elasticsearch.test.ESTestCase.randomFloat;
import static org.elasticsearch.test.ESTestCase.randomInt;
import static org.elasticsearch.test.ESTestCase.randomIntBetween;
import static org.elasticsearch.test.ESTestCase.randomLong;

public class ShardBulkInferenceActionFilterTestUtil {

/**
* Returns a randomly generated object for Semantic Text tests purpose.
*/
public static Object randomInputCasesForSemanticText() {
int randomInt = randomIntBetween(0, 4);
return switch (randomInt) {
case 0 -> randomAlphaOfLengthBetween(10, 20);
case 1 -> randomInt();
case 2 -> randomLong();
case 3 -> randomFloat();
case 4 -> randomBoolean();
default -> randomDouble();
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.model.TestModel;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.junit.After;
Expand All @@ -55,8 +56,10 @@
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilterTestUtil.randomInputCasesForSemanticText;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSparseEmbeddings;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.semanticTextFieldFromChunkedInferenceResults;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.toChunkedResult;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -331,16 +334,31 @@ private static BulkItemRequest[] randomBulkItemRequest(
for (var entry : fieldInferenceMap.values()) {
String field = entry.getName();
var model = modelMap.get(entry.getInferenceId());
String text = randomAlphaOfLengthBetween(10, 20);
docMap.put(field, text);
expectedDocMap.put(field, text);
Object inputObject = randomInputCasesForSemanticText();
String inputText = inputObject.toString();
docMap.put(field, inputObject);
expectedDocMap.put(field, inputText);
if (model == null) {
// ignore results, the doc should fail with a resource not found exception
continue;
}
var result = randomSemanticText(field, model, List.of(text), requestContentType);
model.putResult(text, toChunkedResult(result));
expectedDocMap.put(field, result);

SemanticTextField semanticTextField;
if (model.hasResult(inputText)) {
ChunkedInferenceServiceResults results = model.getResults(inputText);
semanticTextField = semanticTextFieldFromChunkedInferenceResults(
field,
model,
List.of(inputText),
results,
requestContentType
);
} else {
semanticTextField = randomSemanticText(field, model, List.of(inputText), requestContentType);
model.putResult(inputText, toChunkedResult(semanticTextField));
}
Comment on lines +350 to +362
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@carlosdelest We had to make this change because the inference result cache in model is not field-aware. Now that our input can be of many data types (including Boolean, with only two values), we are nearly guaranteed to hit value collisions across 100+ bulk requests. This caused test failures with the previous logic because different random embeddings would be generated every time we saw the value "true" (for example).

This updated logic checks if the inference result cache already has results for the value, and uses them if it does.

Copy link
Member

@carlosdelest carlosdelest Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - we could maybe have used ESTestCase.randomValueOtherThanMany()to a similar effect. That would ensure that the random value is not in the model, and not just trying twice - AFAIU we should loop until we find a value that is not on the results?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm misunderstanding your comment, but I don't think ESTestCase.randomValueOtherThanMany() would help here. The issue is that the previous logic always generated a new embedding for the input, regardless of whether the model already had a cached value for that input. This caused test failures. Consider the following case:

  • We generate a random embedding for the input true
  • We write that embedding to the expected doc map and cache it in model
  • In a later bulk request, the input true is randomly generated again
  • We generate a different random embedding for the input true
  • We overwrite the cached embedding in model with the new embedding
  • After all requests are generated, we assert that the embedding in the expected doc map matches that in the model cache. This fails because the embedding in the model cache was overwritten.

This new logic fixes the problem by first checking if model already has a cached embedding for the input. If it does, we use it. If it doesn't, we generate a new random embedding and add it to the model cache.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is, wouldn't it be simpler not to generate the duplicate input values, and thus avoid managing the results as it happens?


expectedDocMap.put(field, semanticTextField);
}

int requestId = randomIntBetween(0, Integer.MAX_VALUE);
Expand Down Expand Up @@ -383,5 +401,9 @@ ChunkedInferenceServiceResults getResults(String text) {
void putResult(String text, ChunkedInferenceServiceResults result) {
resultMap.put(text, result);
}

boolean hasResult(String text) {
return resultMap.containsKey(text);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,16 @@ public static SemanticTextField randomSemanticText(String fieldName, Model model
case SPARSE_EMBEDDING -> randomSparseEmbeddings(inputs);
default -> throw new AssertionError("invalid task type: " + model.getTaskType().name());
};
return semanticTextFieldFromChunkedInferenceResults(fieldName, model, inputs, results, contentType);
}

public static SemanticTextField semanticTextFieldFromChunkedInferenceResults(
String fieldName,
Model model,
List<String> inputs,
ChunkedInferenceServiceResults results,
XContentType contentType
) {
return new SemanticTextField(
fieldName,
inputs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,93 @@ setup:
- match: { _source.dense_field.inference.chunks.0.text: "another inference test" }
- match: { _source.non_inference_field: "non inference test" }

---
"Calculates text expansion and embedding results for new documents with integer value":
- do:
index:
index: test-index
id: doc_1
body:
sparse_field: 75
dense_field: 100

- do:
get:
index: test-index
id: doc_1

- match: { _source.sparse_field.text: "75" }
- exists: _source.sparse_field.inference.chunks.0.embeddings
- match: { _source.sparse_field.inference.chunks.0.text: "75" }
- match: { _source.dense_field.text: "100" }
- exists: _source.dense_field.inference.chunks.0.embeddings
- match: { _source.dense_field.inference.chunks.0.text: "100" }

---
"Calculates text expansion and embedding results for new documents with boolean value":
- do:
index:
index: test-index
id: doc_1
body:
sparse_field: true
dense_field: false

- do:
get:
index: test-index
id: doc_1

- match: { _source.sparse_field.text: "true" }
- exists: _source.sparse_field.inference.chunks.0.embeddings
- match: { _source.sparse_field.inference.chunks.0.text: "true" }
- match: { _source.dense_field.text: "false" }
- exists: _source.dense_field.inference.chunks.0.embeddings
- match: { _source.dense_field.inference.chunks.0.text: "false" }

---
"Calculates text expansion and embedding results for new documents with collection":
- do:
index:
index: test-index
id: doc_1
body:
sparse_field: [false, 75, "inference test", 13.49]
dense_field: [true, 49.99, "another inference test", 5654]

- do:
get:
index: test-index
id: doc_1

- length: { _source.sparse_field.text: 4 }
- match: { _source.sparse_field.text.0: "false" }
- match: { _source.sparse_field.text.1: "75" }
- match: { _source.sparse_field.text.2: "inference test" }
- match: { _source.sparse_field.text.3: "13.49" }
- exists: _source.sparse_field.inference.chunks.0.embeddings
- exists: _source.sparse_field.inference.chunks.1.embeddings
- exists: _source.sparse_field.inference.chunks.2.embeddings
- exists: _source.sparse_field.inference.chunks.3.embeddings
- match: { _source.sparse_field.inference.chunks.0.text: "false" }
- match: { _source.sparse_field.inference.chunks.1.text: "75" }
- match: { _source.sparse_field.inference.chunks.2.text: "inference test" }
- match: { _source.sparse_field.inference.chunks.3.text: "13.49" }

- length: { _source.dense_field.text: 4 }
- match: { _source.dense_field.text.0: "true" }
- match: { _source.dense_field.text.1: "49.99" }
- match: { _source.dense_field.text.2: "another inference test" }
- match: { _source.dense_field.text.3: "5654" }
- exists: _source.dense_field.inference.chunks.0.embeddings
- exists: _source.dense_field.inference.chunks.1.embeddings
- exists: _source.dense_field.inference.chunks.2.embeddings
- exists: _source.dense_field.inference.chunks.3.embeddings
- match: { _source.dense_field.inference.chunks.0.text: "true" }
- match: { _source.dense_field.inference.chunks.1.text: "49.99" }
- match: { _source.dense_field.inference.chunks.2.text: "another inference test" }
- match: { _source.dense_field.inference.chunks.3.text: "5654" }

---
"Inference fields do not create new mappings":
- do:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,64 @@ setup:
- close_to: { hits.hits.0._score: { value: 3.7837332e17, error: 1e10 } }
- length: { hits.hits.0._source.inference_field.inference.chunks: 2 }

---
"Numeric query using a sparse embedding model":
- skip:
features: [ "headers", "close_to" ]

- do:
index:
index: test-sparse-index
id: doc_1
body:
inference_field: [40, 49.678]
refresh: true

- do:
headers:
# Force JSON content type so that we use a parser that interprets the floating-point score as a double
Content-Type: application/json
search:
index: test-sparse-index
body:
query:
semantic:
field: "inference_field"
query: "40"

- match: { hits.total.value: 1 }
- match: { hits.hits.0._id: "doc_1" }
- length: { hits.hits.0._source.inference_field.inference.chunks: 2 }

---
"Boolean query using a sparse embedding model":
- skip:
features: [ "headers", "close_to" ]

- do:
index:
index: test-sparse-index
id: doc_1
body:
inference_field: true
refresh: true

- do:
headers:
# Force JSON content type so that we use a parser that interprets the floating-point score as a double
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is content type needed here, as we're using booleans?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a copy-paste from other YAML tests so that we can compare scores using the YAML assertions. Fun fact: If the test uses the SMILE format (which it will randomly do, unless you force JSON like is done here), then scores in search responses will be parsed as float, breaking the ability to check them using YAML assertions (which take double values).

We don't compare scores in this particular test, but it should be harmless to leave this so as to not create a landmine if we add score comparison in the future.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it - thought it was referring to the actual values we were indexing instead of the score, as I saw no scores involved in the test.

I'd say if it's not needed, don't add it - it confused me and probably will confuse others 🤷

Content-Type: application/json
search:
index: test-sparse-index
body:
query:
semantic:
field: "inference_field"
query: "true"

- match: { hits.total.value: 1 }
- match: { hits.hits.0._id: "doc_1" }
- length: { hits.hits.0._source.inference_field.inference.chunks: 1 }

---
"Query using a dense embedding model":
- skip:
Expand Down Expand Up @@ -121,6 +179,64 @@ setup:
- close_to: { hits.hits.0._score: { value: 1.0, error: 0.0001 } }
- length: { hits.hits.0._source.inference_field.inference.chunks: 2 }

---
"Numeric query using a dense embedding model":
- skip:
features: [ "headers", "close_to" ]

- do:
index:
index: test-dense-index
id: doc_1
body:
inference_field: [45.1, 100]
refresh: true

- do:
headers:
# Force JSON content type so that we use a parser that interprets the floating-point score as a double
Content-Type: application/json
search:
index: test-dense-index
body:
query:
semantic:
field: "inference_field"
query: "45.1"

- match: { hits.total.value: 1 }
- match: { hits.hits.0._id: "doc_1" }
- length: { hits.hits.0._source.inference_field.inference.chunks: 2 }

---
"Boolean query using a dense embedding model":
- skip:
features: [ "headers", "close_to" ]

- do:
index:
index: test-dense-index
id: doc_1
body:
inference_field: false
refresh: true

- do:
headers:
# Force JSON content type so that we use a parser that interprets the floating-point score as a double
Content-Type: application/json
search:
index: test-dense-index
body:
query:
semantic:
field: "inference_field"
query: "false"

- match: { hits.total.value: 1 }
- match: { hits.hits.0._id: "doc_1" }
- length: { hits.hits.0._source.inference_field.inference.chunks: 1 }

---
"Query using a dense embedding model that uses byte embeddings":
- skip:
Expand Down