diff --git a/docs/painless/painless-api-reference/painless-api-reference-score/index.asciidoc b/docs/painless/painless-api-reference/painless-api-reference-score/index.asciidoc index d355a495e0625..ea272a3e392c5 100644 --- a/docs/painless/painless-api-reference/painless-api-reference-score/index.asciidoc +++ b/docs/painless/painless-api-reference/painless-api-reference-score/index.asciidoc @@ -10,8 +10,8 @@ The following specialized API is available in the Score context. ==== Static Methods The following methods are directly callable without a class/instance qualifier. Note parameters denoted by a (*) are treated as read-only values. -* double cosineSimilarity(List *, VectorScriptDocValues.DenseVectorScriptDocValues) -* double cosineSimilaritySparse(Map *, VectorScriptDocValues.SparseVectorScriptDocValues) +* double cosineSimilarity(List *, String) +* double cosineSimilaritySparse(Map *, String) * double decayDateExp(String *, String *, String *, double *, JodaCompatibleZonedDateTime) * double decayDateGauss(String *, String *, String *, double *, JodaCompatibleZonedDateTime) * double decayDateLinear(String *, String *, String *, double *, JodaCompatibleZonedDateTime) @@ -21,8 +21,8 @@ The following methods are directly callable without a class/instance qualifier. * double decayNumericExp(double *, double *, double *, double *, double) * double decayNumericGauss(double *, double *, double *, double *, double) * double decayNumericLinear(double *, double *, double *, double *, double) -* double dotProduct(List, VectorScriptDocValues.DenseVectorScriptDocValues) -* double dotProductSparse(Map *, VectorScriptDocValues.SparseVectorScriptDocValues) +* double dotProduct(List, String) +* double dotProductSparse(Map *, String) * double randomScore(int *) * double randomScore(int *, String *) * double saturation(double, double) diff --git a/docs/reference/vectors/vector-functions.asciidoc b/docs/reference/vectors/vector-functions.asciidoc index 4a23703b7ae6c..9db4757f03579 100644 --- a/docs/reference/vectors/vector-functions.asciidoc +++ b/docs/reference/vectors/vector-functions.asciidoc @@ -68,7 +68,7 @@ GET my_index/_search } }, "script": { - "source": "cosineSimilarity(params.query_vector, doc['my_dense_vector']) + 1.0", <2> + "source": "cosineSimilarity(params.query_vector, 'my_dense_vector') + 1.0", <2> "params": { "query_vector": [4, 3.4, -0.2] <3> } @@ -105,7 +105,7 @@ GET my_index/_search }, "script": { "source": """ - double value = dotProduct(params.query_vector, doc['my_dense_vector']); + double value = dotProduct(params.query_vector, 'my_dense_vector'); return sigmoid(1, Math.E, -value); <1> """, "params": { @@ -139,7 +139,7 @@ GET my_index/_search } }, "script": { - "source": "1 / (1 + l1norm(params.queryVector, doc['my_dense_vector']))", <1> + "source": "1 / (1 + l1norm(params.queryVector, 'my_dense_vector'))", <1> "params": { "queryVector": [4, 3.4, -0.2] } @@ -178,7 +178,7 @@ GET my_index/_search } }, "script": { - "source": "1 / (1 + l2norm(params.queryVector, doc['my_dense_vector']))", + "source": "1 / (1 + l2norm(params.queryVector, 'my_dense_vector'))", "params": { "queryVector": [4, 3.4, -0.2] } @@ -196,7 +196,7 @@ You can check if a document has a value for the field `my_vector` by [source,js] -------------------------------------------------- -"source": "doc['my_vector'].size() == 0 ? 0 : cosineSimilarity(params.queryVector, doc['my_vector'])" +"source": "doc['my_vector'].size() == 0 ? 0 : cosineSimilarity(params.queryVector, 'my_vector')" -------------------------------------------------- // NOTCONSOLE @@ -262,7 +262,7 @@ GET my_sparse_index/_search } }, "script": { - "source": "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector']) + 1.0", + "source": "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector') + 1.0", "params": { "query_vector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0} } @@ -294,7 +294,7 @@ GET my_sparse_index/_search }, "script": { "source": """ - double value = dotProductSparse(params.query_vector, doc['my_sparse_vector']); + double value = dotProductSparse(params.query_vector, 'my_sparse_vector'); return sigmoid(1, Math.E, -value); """, "params": { @@ -327,7 +327,7 @@ GET my_sparse_index/_search } }, "script": { - "source": "1 / (1 + l1normSparse(params.queryVector, doc['my_sparse_vector']))", + "source": "1 / (1 + l1normSparse(params.queryVector, 'my_sparse_vector'))", "params": { "queryVector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0} } @@ -358,7 +358,7 @@ GET my_sparse_index/_search } }, "script": { - "source": "1 / (1 + l2normSparse(params.queryVector, doc['my_sparse_vector']))", + "source": "1 / (1 + l2normSparse(params.queryVector, 'my_sparse_vector'))", "params": { "queryVector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0} } diff --git a/server/src/main/java/org/elasticsearch/script/ScoreScript.java b/server/src/main/java/org/elasticsearch/script/ScoreScript.java index faad66fc1479b..7c2c09d17afe7 100644 --- a/server/src/main/java/org/elasticsearch/script/ScoreScript.java +++ b/server/src/main/java/org/elasticsearch/script/ScoreScript.java @@ -109,7 +109,7 @@ public Map getParams() { } /** The doc lookup for the Lucene segment this script was created for. */ - public final Map> getDoc() { + public Map> getDoc() { return leafLookup.doc(); } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/10_dense_vector_basic.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/10_dense_vector_basic.yml index 903b9dc3de3b0..0d46b46163855 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/10_dense_vector_basic.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/10_dense_vector_basic.yml @@ -1,6 +1,6 @@ setup: - skip: - features: headers + features: [headers, warnings] version: " - 7.2.99" reason: "dense_vector functions were added from 7.3" @@ -52,7 +52,7 @@ setup: script_score: query: {match_all: {} } script: - source: "dotProduct(params.query_vector, doc['my_dense_vector'])" + source: "dotProduct(params.query_vector, 'my_dense_vector')" params: query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] @@ -82,7 +82,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + source: "cosineSimilarity(params.query_vector, 'my_dense_vector')" params: query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] @@ -99,3 +99,26 @@ setup: - match: {hits.hits.2._id: "1"} - gte: {hits.hits.2._score: 0.78} - lte: {hits.hits.2._score: 0.791} + +--- +"Deprecated function signature": + - do: + headers: + Content-Type: application/json + warnings: + - The vector functions of the form function(query, doc['field']) are deprecated, and the form function(query, 'field')` should be used instead. For example, cosineSimilarity(query, doc['field'] is replaced by cosineSimilarity(query, 'field'). + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + params: + query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] + + - match: {hits.total: 3} + - match: {hits.hits.0._id: "3"} + - match: {hits.hits.1._id: "2"} + - match: {hits.hits.2._id: "1"} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/15_dense_vector_l1l2.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/15_dense_vector_l1l2.yml index dbb274d077645..882d11566dfaa 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/15_dense_vector_l1l2.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/15_dense_vector_l1l2.yml @@ -53,7 +53,7 @@ setup: script_score: query: {match_all: {} } script: - source: "l1norm(params.query_vector, doc['my_dense_vector'])" + source: "l1norm(params.query_vector, 'my_dense_vector')" params: query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] @@ -83,7 +83,7 @@ setup: script_score: query: {match_all: {} } script: - source: "l2norm(params.query_vector, doc['my_dense_vector'])" + source: "l2norm(params.query_vector, 'my_dense_vector')" params: query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/20_dense_vector_special_cases.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/20_dense_vector_special_cases.yml index 98a68cab9ca0a..cfec55095ad9d 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/20_dense_vector_special_cases.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/20_dense_vector_special_cases.yml @@ -62,7 +62,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + source: "cosineSimilarity(params.query_vector, 'my_dense_vector')" params: query_vector: [10, 10, 10] @@ -81,7 +81,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + source: "cosineSimilarity(params.query_vector, 'my_dense_vector')" params: query_vector: [10.0, 10.0, 10.0] @@ -111,7 +111,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + source: "cosineSimilarity(params.query_vector, 'my_dense_vector')" params: query_vector: [1, 2, 3, 4] - match: { error.root_cause.0.type: "script_exception" } @@ -125,7 +125,7 @@ setup: script_score: query: {match_all: {} } script: - source: "dotProduct(params.query_vector, doc['my_dense_vector'])" + source: "dotProduct(params.query_vector, 'my_dense_vector')" params: query_vector: [1, 2, 3, 4] - match: { error.root_cause.0.type: "script_exception" } @@ -161,7 +161,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + source: "cosineSimilarity(params.query_vector, 'my_dense_vector')" params: query_vector: [10.0, 10.0, 10.0] - match: { error.root_cause.0.type: "script_exception" } @@ -177,7 +177,7 @@ setup: script_score: query: {match_all: {} } script: - source: "doc['my_dense_vector'].size() == 0 ? 0 : cosineSimilarity(params.query_vector, doc['my_dense_vector'])" + source: "doc['my_dense_vector'].size() == 0 ? 0 : cosineSimilarity(params.query_vector, 'my_dense_vector')" params: query_vector: [10.0, 10.0, 10.0] @@ -209,7 +209,7 @@ setup: script_score: query: {match_all: {} } script: - source: "dotProductSparse(params.query_vector, doc['my_dense_vector'])" + source: "dotProductSparse(params.query_vector, 'my_dense_vector')" params: query_vector: {"2": 0.5, "10" : 111.3, "3": 44} - match: { error.root_cause.0.type: "script_exception" } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/30_sparse_vector_basic.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/30_sparse_vector_basic.yml index e184fd0ce9333..52e937547ff37 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/30_sparse_vector_basic.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/30_sparse_vector_basic.yml @@ -55,7 +55,7 @@ setup: script_score: query: {match_all: {} } script: - source: "dotProductSparse(params.query_vector, doc['my_sparse_vector'])" + source: "dotProductSparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} @@ -87,7 +87,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"2": -0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} @@ -104,3 +104,27 @@ setup: - match: {hits.hits.2._id: "1"} - gte: {hits.hits.2._score: 0.78} - lte: {hits.hits.2._score: 0.791} + +--- +"Deprecated function signature": + - do: + headers: + Content-Type: application/json + warnings: + - The [sparse_vector] field type is deprecated and will be removed in 8.0. + - The vector functions of the form function(query, doc['field']) are deprecated, and the form function(query, 'field')` should be used instead. For example, cosineSimilarity(query, doc['field'] is replaced by cosineSimilarity(query, 'field'). + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: {match_all: {} } + script: + source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + params: + query_vector: {"2": -0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} + + - match: {hits.total: 3} + - match: {hits.hits.0._id: "3"} + - match: {hits.hits.1._id: "2"} + - match: {hits.hits.2._id: "1"} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/35_sparse_vector_l1l2.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/35_sparse_vector_l1l2.yml index 3a6ed9fd561e9..8a1ec0d3cdde3 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/35_sparse_vector_l1l2.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/35_sparse_vector_l1l2.yml @@ -55,7 +55,7 @@ setup: script_score: query: {match_all: {} } script: - source: "l1normSparse(params.query_vector, doc['my_sparse_vector'])" + source: "l1normSparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} @@ -88,7 +88,7 @@ setup: script_score: query: {match_all: {} } script: - source: "l2normSparse(params.query_vector, doc['my_sparse_vector'])" + source: "l2normSparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/40_sparse_vector_special_cases.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/40_sparse_vector_special_cases.yml index 90a28eeb1eeae..c49413097807c 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/40_sparse_vector_special_cases.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/40_sparse_vector_special_cases.yml @@ -61,7 +61,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"1": 10} @@ -83,7 +83,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"1": 10.0} @@ -127,7 +127,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"1": 10.0} - match: { error.root_cause.0.type: "script_exception" } @@ -145,7 +145,7 @@ setup: script_score: query: {match_all: {} } script: - source: "doc['my_sparse_vector'].size() == 0 ? 0 : cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + source: "doc['my_sparse_vector'].size() == 0 ? 0 : cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"1": 10.0} @@ -194,7 +194,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"100": -200.0, "11" : 300.33, "12": -34.8988, "2": 230.0, "30": 15.555} @@ -230,7 +230,7 @@ setup: script_score: query: {match_all: {} } script: - source: "dotProduct(params.query_vector, doc['my_sparse_vector'])" + source: "dotProduct(params.query_vector, 'my_sparse_vector')" params: query_vector: [0.5, 111] - match: { error.root_cause.0.type: "script_exception" } @@ -273,7 +273,7 @@ setup: script_score: query: {match_all: {} } script: - source: "dotProductSparse(params.query_vector, doc['my_sparse_vector'])" + source: "dotProductSparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"1": 10, "5": 5} @@ -304,7 +304,7 @@ setup: script_score: query: {match_all: {} } script: - source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" + source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"1": 10, "5" : 5} @@ -334,7 +334,7 @@ setup: script_score: query: {match_all: {} } script: - source: "l1normSparse(params.query_vector, doc['my_sparse_vector'])" + source: "l1normSparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"1": 10, "5": 5} @@ -361,7 +361,7 @@ setup: script_score: query: {match_all: {} } script: - source: "l2normSparse(params.query_vector, doc['my_sparse_vector'])" + source: "l2normSparse(params.query_vector, 'my_sparse_vector')" params: query_vector: {"1": 10, "5": 5} diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java index 91f2fc343b113..35e6b78598dc3 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java @@ -9,12 +9,16 @@ import org.apache.logging.log4j.LogManager; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.Version; import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.script.ScoreScript; import org.elasticsearch.xpack.vectors.mapper.SparseVectorFieldMapper; import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder; +import org.elasticsearch.xpack.vectors.query.VectorScriptDocValues.DenseVectorScriptDocValues; +import org.elasticsearch.xpack.vectors.query.VectorScriptDocValues.SparseVectorScriptDocValues; +import java.io.IOException; import java.nio.ByteBuffer; import java.util.List; import java.util.Map; @@ -22,6 +26,10 @@ import static org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder.sortSparseDimsFloatValues; public class ScoreScriptUtils { + private static final DeprecationLogger deprecationLogger = new DeprecationLogger(LogManager.getLogger(ScoreScriptUtils.class)); + static final String DEPRECATION_MESSAGE = "The vector functions of the form function(query, doc['field']) are deprecated, and " + + "the form function(query, 'field')` should be used instead. For example, cosineSimilarity(query, doc['field'] is replaced by " + + "cosineSimilarity(query, 'field')."; //**************FUNCTIONS FOR DENSE VECTORS // Functions are implemented as classes to accept a hidden parameter scoreScript that contains some index settings. @@ -31,9 +39,12 @@ public class ScoreScriptUtils { public static class DenseVectorFunction { final ScoreScript scoreScript; final float[] queryVector; + final VectorScriptDocValues.DenseVectorScriptDocValues docValues; - public DenseVectorFunction(ScoreScript scoreScript, List queryVector) { - this(scoreScript, queryVector, false); + public DenseVectorFunction(ScoreScript scoreScript, + List queryVector, + Object field) { + this(scoreScript, queryVector, field, false); } /** @@ -45,6 +56,7 @@ public DenseVectorFunction(ScoreScript scoreScript, List queryVector) { */ public DenseVectorFunction(ScoreScript scoreScript, List queryVector, + Object field, boolean normalizeQuery) { this.scoreScript = scoreScript; @@ -62,9 +74,28 @@ public DenseVectorFunction(ScoreScript scoreScript, this.queryVector[dim] /= queryMagnitude; } } + + if (field instanceof String) { + String fieldName = (String) field; + docValues = (DenseVectorScriptDocValues) scoreScript.getDoc().get(fieldName); + } else if (field instanceof DenseVectorScriptDocValues) { + docValues = (DenseVectorScriptDocValues) field; + deprecationLogger.deprecatedAndMaybeLog("vector_function_signature", DEPRECATION_MESSAGE); + } else { + throw new IllegalArgumentException("For vector functions, the 'field' argument must be of type String or " + + "VectorScriptDocValues"); + } } - public void validateDocVector(BytesRef vector) { + BytesRef getEncodedVector() { + try { + docValues.setNextDocId(scoreScript._getDocId()); + } catch (IOException e) { + throw ExceptionsHelper.convertToElastic(e); + } + + // Validate the encoded vector's length. + BytesRef vector = docValues.getEncodedValue(); if (vector == null) { throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); } @@ -74,20 +105,21 @@ public void validateDocVector(BytesRef vector) { throw new IllegalArgumentException("The query vector has a different number of dimensions [" + queryVector.length + "] than the document vectors [" + vectorLength + "]."); } + return vector; } } // Calculate l1 norm (Manhattan distance) between a query's dense vector and documents' dense vectors public static final class L1Norm extends DenseVectorFunction { - public L1Norm(ScoreScript scoreScript, List queryVector) { - super(scoreScript, queryVector); + public L1Norm(ScoreScript scoreScript, List queryVector, Object field) { + super(scoreScript, queryVector, field); } - public double l1norm(VectorScriptDocValues.DenseVectorScriptDocValues dvs) { - BytesRef vector = dvs.getEncodedValue(); - validateDocVector(vector); + public double l1norm() { + BytesRef vector = getEncodedVector(); ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); + double l1norm = 0; for (float queryValue : queryVector) { @@ -100,13 +132,12 @@ public double l1norm(VectorScriptDocValues.DenseVectorScriptDocValues dvs) { // Calculate l2 norm (Euclidean distance) between a query's dense vector and documents' dense vectors public static final class L2Norm extends DenseVectorFunction { - public L2Norm(ScoreScript scoreScript, List queryVector) { - super(scoreScript, queryVector); + public L2Norm(ScoreScript scoreScript, List queryVector, Object field) { + super(scoreScript, queryVector, field); } - public double l2norm(VectorScriptDocValues.DenseVectorScriptDocValues dvs) { - BytesRef vector = dvs.getEncodedValue(); - validateDocVector(vector); + public double l2norm() { + BytesRef vector = getEncodedVector(); ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); double l2norm = 0; @@ -121,13 +152,12 @@ public double l2norm(VectorScriptDocValues.DenseVectorScriptDocValues dvs) { // Calculate a dot product between a query's dense vector and documents' dense vectors public static final class DotProduct extends DenseVectorFunction { - public DotProduct(ScoreScript scoreScript, List queryVector) { - super(scoreScript, queryVector); + public DotProduct(ScoreScript scoreScript, List queryVector, Object field) { + super(scoreScript, queryVector, field); } - public double dotProduct(VectorScriptDocValues.DenseVectorScriptDocValues dvs){ - BytesRef vector = dvs.getEncodedValue(); - validateDocVector(vector); + public double dotProduct() { + BytesRef vector = getEncodedVector(); ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); double dotProduct = 0; @@ -141,14 +171,12 @@ public double dotProduct(VectorScriptDocValues.DenseVectorScriptDocValues dvs){ // Calculate cosine similarity between a query's dense vector and documents' dense vectors public static final class CosineSimilarity extends DenseVectorFunction { - public CosineSimilarity(ScoreScript scoreScript, List queryVector) { - super(scoreScript, queryVector, true); + public CosineSimilarity(ScoreScript scoreScript, List queryVector, Object field) { + super(scoreScript, queryVector, field, true); } - public double cosineSimilarity(VectorScriptDocValues.DenseVectorScriptDocValues dvs) { - BytesRef vector = dvs.getEncodedValue(); - validateDocVector(vector); - + public double cosineSimilarity() { + BytesRef vector = getEncodedVector(); ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); double dotProduct = 0.0; @@ -176,15 +204,17 @@ public double cosineSimilarity(VectorScriptDocValues.DenseVectorScriptDocValues // per script execution for all documents. public static class SparseVectorFunction { - static final DeprecationLogger deprecationLogger = new DeprecationLogger(LogManager.getLogger(SparseVectorFunction.class)); - final ScoreScript scoreScript; final float[] queryValues; final int[] queryDims; + final VectorScriptDocValues.SparseVectorScriptDocValues docValues; + // prepare queryVector once per script execution // queryVector represents a map of dimensions to values - public SparseVectorFunction(ScoreScript scoreScript, Map queryVector) { + public SparseVectorFunction(ScoreScript scoreScript, + Map queryVector, + Object field) { this.scoreScript = scoreScript; //break vector into two arrays dims and values int n = queryVector.size(); @@ -203,28 +233,46 @@ public SparseVectorFunction(ScoreScript scoreScript, Map queryVe // Sort dimensions in the ascending order and sort values in the same order as their corresponding dimensions sortSparseDimsFloatValues(queryDims, queryValues, n); + if (field instanceof String) { + String fieldName = (String) field; + docValues = (SparseVectorScriptDocValues) scoreScript.getDoc().get(fieldName); + } else if (field instanceof SparseVectorScriptDocValues) { + docValues = (SparseVectorScriptDocValues) field; + deprecationLogger.deprecatedAndMaybeLog("vector_function_signature", DEPRECATION_MESSAGE); + } else { + throw new IllegalArgumentException("For vector functions, the 'field' argument must be of type String or " + + "VectorScriptDocValues"); + } + deprecationLogger.deprecatedAndMaybeLog("sparse_vector_function", SparseVectorFieldMapper.DEPRECATION_MESSAGE); } - public void validateDocVector(BytesRef vector) { + BytesRef getEncodedVector() { + try { + docValues.setNextDocId(scoreScript._getDocId()); + } catch (IOException e) { + throw ExceptionsHelper.convertToElastic(e); + } + + BytesRef vector = docValues.getEncodedValue(); if (vector == null) { throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); } + return vector; } } // Calculate l1 norm (Manhattan distance) between a query's sparse vector and documents' sparse vectors public static final class L1NormSparse extends SparseVectorFunction { - public L1NormSparse(ScoreScript scoreScript,Map queryVector) { - super(scoreScript, queryVector); + public L1NormSparse(ScoreScript scoreScript,Map queryVector, Object docVector) { + super(scoreScript, queryVector, docVector); } - public double l1normSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) { - BytesRef vector = dvs.getEncodedValue(); - validateDocVector(vector); - + public double l1normSparse() { + BytesRef vector = getEncodedVector(); int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector); float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector); + int queryIndex = 0; int docIndex = 0; double l1norm = 0; @@ -255,16 +303,15 @@ public double l1normSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs // Calculate l2 norm (Euclidean distance) between a query's sparse vector and documents' sparse vectors public static final class L2NormSparse extends SparseVectorFunction { - public L2NormSparse(ScoreScript scoreScript, Map queryVector) { - super(scoreScript, queryVector); + public L2NormSparse(ScoreScript scoreScript, Map queryVector, Object docVector) { + super(scoreScript, queryVector, docVector); } - public double l2normSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) { - BytesRef vector = dvs.getEncodedValue(); - validateDocVector(vector); - + public double l2normSparse() { + BytesRef vector = getEncodedVector(); int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector); float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector); + int queryIndex = 0; int docIndex = 0; double l2norm = 0; @@ -298,16 +345,15 @@ public double l2normSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs // Calculate a dot product between a query's sparse vector and documents' sparse vectors public static final class DotProductSparse extends SparseVectorFunction { - public DotProductSparse(ScoreScript scoreScript, Map queryVector) { - super(scoreScript, queryVector); + public DotProductSparse(ScoreScript scoreScript, Map queryVector, Object docVector) { + super(scoreScript, queryVector, docVector); } - public double dotProductSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) { - BytesRef vector = dvs.getEncodedValue(); - validateDocVector(vector); - + public double dotProductSparse() { + BytesRef vector = getEncodedVector(); int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector); float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector); + return intDotProductSparse(queryValues, queryDims, docValues, docDims); } } @@ -316,8 +362,8 @@ public double dotProductSparse(VectorScriptDocValues.SparseVectorScriptDocValues public static final class CosineSimilaritySparse extends SparseVectorFunction { final double queryVectorMagnitude; - public CosineSimilaritySparse(ScoreScript scoreScript, Map queryVector) { - super(scoreScript, queryVector); + public CosineSimilaritySparse(ScoreScript scoreScript, Map queryVector, Object docVector) { + super(scoreScript, queryVector, docVector); double dotProduct = 0; for (int i = 0; i< queryDims.length; i++) { dotProduct += queryValues[i] * queryValues[i]; @@ -325,10 +371,8 @@ public CosineSimilaritySparse(ScoreScript scoreScript, Map query this.queryVectorMagnitude = Math.sqrt(dotProduct); } - public double cosineSimilaritySparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) { - BytesRef vector = dvs.getEncodedValue(); - validateDocVector(vector); - + public double cosineSimilaritySparse() { + BytesRef vector = getEncodedVector(); int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector); float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector); diff --git a/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt b/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt index 42d6e6d0b0f7a..73155bf1333b6 100644 --- a/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt +++ b/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt @@ -13,12 +13,12 @@ class org.elasticsearch.script.ScoreScript @no_import { } static_import { - double l1norm(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1Norm - double l2norm(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2Norm - double cosineSimilarity(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilarity - double dotProduct(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProduct - double l1normSparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1NormSparse - double l2normSparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2NormSparse - double dotProductSparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProductSparse - double cosineSimilaritySparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilaritySparse + double l1norm(org.elasticsearch.script.ScoreScript, List, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1Norm + double l2norm(org.elasticsearch.script.ScoreScript, List, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2Norm + double cosineSimilarity(org.elasticsearch.script.ScoreScript, List, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilarity + double dotProduct(org.elasticsearch.script.ScoreScript, List, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProduct + double l1normSparse(org.elasticsearch.script.ScoreScript, Map, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1NormSparse + double l2normSparse(org.elasticsearch.script.ScoreScript, Map, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2NormSparse + double dotProductSparse(org.elasticsearch.script.ScoreScript, Map, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProductSparse + double cosineSimilaritySparse(org.elasticsearch.script.ScoreScript, Map, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilaritySparse } \ No newline at end of file diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java index bff87a5ac472c..9aff40b359aad 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L2NormSparse; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -39,6 +40,7 @@ public void testDenseVectorFunctions() { } private void testDenseVectorFunctions(Version indexVersion) { + String field = "vector"; float[] docVector = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f}; BytesRef encodedDocVector = mockEncodeDenseVector(docVector, indexVersion); VectorScriptDocValues.DenseVectorScriptDocValues dvs = mock(VectorScriptDocValues.DenseVectorScriptDocValues.class); @@ -46,65 +48,136 @@ private void testDenseVectorFunctions(Version indexVersion) { ScoreScript scoreScript = mock(ScoreScript.class); when(scoreScript._getIndexVersion()).thenReturn(indexVersion); + when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs)); List queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f); // test dotProduct - DotProduct dotProduct = new DotProduct(scoreScript, queryVector); - double result = dotProduct.dotProduct(dvs); + DotProduct dotProduct = new DotProduct(scoreScript, queryVector, field); + double result = dotProduct.dotProduct(); assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001); // test cosineSimilarity - CosineSimilarity cosineSimilarity = new CosineSimilarity(scoreScript, queryVector); - double result2 = cosineSimilarity.cosineSimilarity(dvs); + CosineSimilarity cosineSimilarity = new CosineSimilarity(scoreScript, queryVector, field); + double result2 = cosineSimilarity.cosineSimilarity(); assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, result2, 0.001); // test l1Norm - L1Norm l1norm = new L1Norm(scoreScript, queryVector); - double result3 = l1norm.l1norm(dvs); + L1Norm l1norm = new L1Norm(scoreScript, queryVector, field); + double result3 = l1norm.l1norm(); assertEquals("l1norm result is not equal to the expected value!", 485.184, result3, 0.001); // test l2norm - L2Norm l2norm = new L2Norm(scoreScript, queryVector); - double result4 = l2norm.l2norm(dvs); + L2Norm l2norm = new L2Norm(scoreScript, queryVector, field); + double result4 = l2norm.l2norm(); assertEquals("l2norm result is not equal to the expected value!", 301.361, result4, 0.001); // test dotProduct fails when queryVector has wrong number of dims List invalidQueryVector = Arrays.asList(0.5, 111.3); - DotProduct dotProduct2 = new DotProduct(scoreScript, invalidQueryVector); - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> dotProduct2.dotProduct(dvs)); + DotProduct dotProduct2 = new DotProduct(scoreScript, invalidQueryVector, field); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, dotProduct2::dotProduct); assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); // test cosineSimilarity fails when queryVector has wrong number of dims - CosineSimilarity cosineSimilarity2 = new CosineSimilarity(scoreScript, invalidQueryVector); - e = expectThrows(IllegalArgumentException.class, () -> cosineSimilarity2.cosineSimilarity(dvs)); + CosineSimilarity cosineSimilarity2 = new CosineSimilarity(scoreScript, invalidQueryVector, field); + e = expectThrows(IllegalArgumentException.class, cosineSimilarity2::cosineSimilarity); assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); // test l1norm fails when queryVector has wrong number of dims - L1Norm l1norm2 = new L1Norm(scoreScript, invalidQueryVector); - e = expectThrows(IllegalArgumentException.class, () -> l1norm2.l1norm(dvs)); + L1Norm l1norm2 = new L1Norm(scoreScript, invalidQueryVector, field); + e = expectThrows(IllegalArgumentException.class, l1norm2::l1norm); assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); // test l2norm fails when queryVector has wrong number of dims - L2Norm l2norm2 = new L2Norm(scoreScript, invalidQueryVector); - e = expectThrows(IllegalArgumentException.class, () -> l2norm2.l2norm(dvs)); + L2Norm l2norm2 = new L2Norm(scoreScript, invalidQueryVector, field); + e = expectThrows(IllegalArgumentException.class, l2norm2::l2norm); assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); } + public void testDeprecatedDenseVectorFunctions() { + testDeprecatedDenseVectorFunctions(Version.V_7_4_0); + testDeprecatedDenseVectorFunctions(Version.CURRENT); + } + + private void testDeprecatedDenseVectorFunctions(Version indexVersion) { + float[] docVector = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f}; + BytesRef encodedDocVector = mockEncodeDenseVector(docVector, indexVersion); + VectorScriptDocValues.DenseVectorScriptDocValues dvs = mock(VectorScriptDocValues.DenseVectorScriptDocValues.class); + when(dvs.getEncodedValue()).thenReturn(encodedDocVector); + + ScoreScript scoreScript = mock(ScoreScript.class); + when(scoreScript._getIndexVersion()).thenReturn(indexVersion); + + List queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f); + + // test dotProduct + DotProduct dotProduct = new DotProduct(scoreScript, queryVector, dvs); + double result = dotProduct.dotProduct(); + assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001); + assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE); + + // test cosineSimilarity + CosineSimilarity cosineSimilarity = new CosineSimilarity(scoreScript, queryVector, dvs); + double result2 = cosineSimilarity.cosineSimilarity(); + assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, result2, 0.001); + assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE); + + // test l1Norm + L1Norm l1norm = new L1Norm(scoreScript, queryVector, dvs); + double result3 = l1norm.l1norm(); + assertEquals("l1norm result is not equal to the expected value!", 485.184, result3, 0.001); + assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE); + + // test l2norm + L2Norm l2norm = new L2Norm(scoreScript, queryVector, dvs); + double result4 = l2norm.l2norm(); + assertEquals("l2norm result is not equal to the expected value!", 301.361, result4, 0.001); + assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE); + + // test dotProduct fails when queryVector has wrong number of dims + List invalidQueryVector = Arrays.asList(0.5, 111.3); + DotProduct dotProduct2 = new DotProduct(scoreScript, invalidQueryVector, dvs); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, dotProduct2::dotProduct); + assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); + assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE); + + // test cosineSimilarity fails when queryVector has wrong number of dims + CosineSimilarity cosineSimilarity2 = new CosineSimilarity(scoreScript, invalidQueryVector, dvs); + e = expectThrows(IllegalArgumentException.class, cosineSimilarity2::cosineSimilarity); + assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); + assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE); + + // test l1norm fails when queryVector has wrong number of dims + L1Norm l1norm2 = new L1Norm(scoreScript, invalidQueryVector, dvs); + e = expectThrows(IllegalArgumentException.class, l1norm2::l1norm); + assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); + assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE); + + // test l2norm fails when queryVector has wrong number of dims + L2Norm l2norm2 = new L2Norm(scoreScript, invalidQueryVector, dvs); + e = expectThrows(IllegalArgumentException.class, l2norm2::l2norm); + assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); + assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE); + } + public void testSparseVectorFunctions() { testSparseVectorFunctions(Version.V_7_4_0); testSparseVectorFunctions(Version.CURRENT); } private void testSparseVectorFunctions(Version indexVersion) { + String field = "vector"; + int[] docVectorDims = {2, 10, 50, 113, 4545}; float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f}; BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector( indexVersion, docVectorDims, docVectorValues, docVectorDims.length); VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class); when(dvs.getEncodedValue()).thenReturn(encodedDocVector); + ScoreScript scoreScript = mock(ScoreScript.class); when(scoreScript._getIndexVersion()).thenReturn(indexVersion); + when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs)); Map queryVector = new HashMap() {{ put("2", 0.5); @@ -115,29 +188,79 @@ private void testSparseVectorFunctions(Version indexVersion) { }}; // test dotProduct - DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector); - double result = docProductSparse.dotProductSparse(dvs); + DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, field); + double result = docProductSparse.dotProductSparse(); assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001); // test cosineSimilarity - CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector); - double result2 = cosineSimilaritySparse.cosineSimilaritySparse(dvs); + CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, field); + double result2 = cosineSimilaritySparse.cosineSimilaritySparse(); assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.790, result2, 0.001); // test l1norm - L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector); - double result3 = l1Norm.l1normSparse(dvs); + L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, field); + double result3 = l1Norm.l1normSparse(); assertEquals("l1normSparse result is not equal to the expected value!", 485.184, result3, 0.001); // test l2norm - L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector); - double result4 = l2Norm.l2normSparse(dvs); + L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, field); + double result4 = l2Norm.l2normSparse(); assertEquals("l2normSparse result is not equal to the expected value!", 301.361, result4, 0.001); - assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE); } + public void testDeprecatedSparseVectorFunctions() { + testDeprecatedSparseVectorFunctions(Version.V_7_4_0); + testDeprecatedSparseVectorFunctions(Version.CURRENT); + } + + private void testDeprecatedSparseVectorFunctions(Version indexVersion) { + int[] docVectorDims = {2, 10, 50, 113, 4545}; + float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f}; + BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector( + indexVersion, docVectorDims, docVectorValues, docVectorDims.length); + VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class); + when(dvs.getEncodedValue()).thenReturn(encodedDocVector); + + ScoreScript scoreScript = mock(ScoreScript.class); + when(scoreScript._getIndexVersion()).thenReturn(indexVersion); + + Map queryVector = new HashMap() {{ + put("2", 0.5); + put("10", 111.3); + put("50", -13.0); + put("113", 14.8); + put("4545", -156.0); + }}; + + // test dotProduct + DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, dvs); + double result = docProductSparse.dotProductSparse(); + assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001); + assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE); + + // test cosineSimilarity + CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, dvs); + double result2 = cosineSimilaritySparse.cosineSimilaritySparse(); + assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.790, result2, 0.001); + assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE); + + // test l1norm + L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, dvs); + double result3 = l1Norm.l1normSparse(); + assertEquals("l1normSparse result is not equal to the expected value!", 485.184, result3, 0.001); + assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE); + + // test l2norm + L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, dvs); + double result4 = l2Norm.l2normSparse(); + assertEquals("l2normSparse result is not equal to the expected value!", 301.361, result4, 0.001); + assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE); + } + public void testSparseVectorMissingDimensions1() { + String field = "vector"; + // Document vector's biggest dimension > query vector's biggest dimension int[] docVectorDims = {2, 10, 50, 113, 4545, 4546}; float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f, 11.5f}; @@ -145,8 +268,11 @@ public void testSparseVectorMissingDimensions1() { Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length); VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class); when(dvs.getEncodedValue()).thenReturn(encodedDocVector); + ScoreScript scoreScript = mock(ScoreScript.class); when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT); + when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs)); + Map queryVector = new HashMap() {{ put("2", 0.5); put("10", 111.3); @@ -157,29 +283,33 @@ public void testSparseVectorMissingDimensions1() { }}; // test dotProduct - DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector); - double result = docProductSparse.dotProductSparse(dvs); + DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, field); + double result = docProductSparse.dotProductSparse(); assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001); + assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE); // test cosineSimilarity - CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector); - double result2 = cosineSimilaritySparse.cosineSimilaritySparse(dvs); + CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, field); + double result2 = cosineSimilaritySparse.cosineSimilaritySparse(); assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.786, result2, 0.001); + assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE); // test l1norm - L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector); - double result3 = l1Norm.l1normSparse(dvs); + L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, field); + double result3 = l1Norm.l1normSparse(); assertEquals("l1normSparse result is not equal to the expected value!", 517.184, result3, 0.001); + assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE); // test l2norm - L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector); - double result4 = l2Norm.l2normSparse(dvs); + L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, field); + double result4 = l2Norm.l2normSparse(); assertEquals("l2normSparse result is not equal to the expected value!", 302.277, result4, 0.001); - assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE); } public void testSparseVectorMissingDimensions2() { + String field = "vector"; + // Document vector's biggest dimension < query vector's biggest dimension int[] docVectorDims = {2, 10, 50, 113, 4545, 4546}; float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f, 11.5f}; @@ -187,8 +317,11 @@ public void testSparseVectorMissingDimensions2() { Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length); VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class); when(dvs.getEncodedValue()).thenReturn(encodedDocVector); + ScoreScript scoreScript = mock(ScoreScript.class); when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT); + when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs)); + Map queryVector = new HashMap() {{ put("2", 0.5); put("10", 111.3); @@ -199,25 +332,27 @@ public void testSparseVectorMissingDimensions2() { }}; // test dotProduct - DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector); - double result = docProductSparse.dotProductSparse(dvs); + DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, field); + double result = docProductSparse.dotProductSparse(); assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001); + assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE); // test cosineSimilarity - CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector); - double result2 = cosineSimilaritySparse.cosineSimilaritySparse(dvs); + CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, field); + double result2 = cosineSimilaritySparse.cosineSimilaritySparse(); assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.786, result2, 0.001); + assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE); // test l1norm - L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector); - double result3 = l1Norm.l1normSparse(dvs); + L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, field); + double result3 = l1Norm.l1normSparse(); assertEquals("l1normSparse result is not equal to the expected value!", 517.184, result3, 0.001); + assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE); // test l2norm - L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector); - double result4 = l2Norm.l2normSparse(dvs); + L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, field); + double result4 = l2Norm.l2normSparse(); assertEquals("l2normSparse result is not equal to the expected value!", 302.277, result4, 0.001); - assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE); } }