Skip to content

Commit 7a88f40

Browse files
authored
Use the Lucene Distance Calculation Function in Script Scoring for doing exact search (#1699)
* Use the Lucene Distance Calculation Function in Script Scoring for doing exact search Signed-off-by: Ryan Bogan <[email protected]> * Add Changelog entry Signed-off-by: Ryan Bogan <[email protected]> * Fix failing test Signed-off-by: Ryan Bogan <[email protected]> * fix test Signed-off-by: Ryan Bogan <[email protected]> * Fix test bug and remove unnecessary validation Signed-off-by: Ryan Bogan <[email protected]> * Remove cosineSimilOptimized Signed-off-by: Ryan Bogan <[email protected]> * Revert "Remove cosineSimilOptimized" This reverts commit f872d83. Signed-off-by: Ryan Bogan <[email protected]> --------- Signed-off-by: Ryan Bogan <[email protected]>
1 parent 9023604 commit 7a88f40

File tree

3 files changed

+9
-28
lines changed

3 files changed

+9
-28
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1414

1515
## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.14...2.x)
1616
### Features
17+
* Use the Lucene Distance Calculation Function in Script Scoring for doing exact search [#1699](https://github.com/opensearch-project/k-NN/pull/1699)
1718
### Enhancements
1819
* Add KnnCircuitBreakerException and modify exception message [#1688](https://github.com/opensearch-project/k-NN/pull/1688)
1920
* Add stats for radial search [#1684](https://github.com/opensearch-project/k-NN/pull/1684)

src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java

+6-26
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import java.util.Objects;
1111
import org.apache.logging.log4j.LogManager;
1212
import org.apache.logging.log4j.Logger;
13+
import org.apache.lucene.util.VectorUtil;
1314
import org.opensearch.knn.index.KNNVectorScriptDocValues;
1415
import org.opensearch.knn.index.SpaceType;
1516
import org.opensearch.knn.index.VectorDataType;
@@ -48,13 +49,7 @@ private static void requireEqualDimension(final float[] queryVector, final float
4849
* @return L2 score
4950
*/
5051
public static float l2Squared(float[] queryVector, float[] inputVector) {
51-
requireEqualDimension(queryVector, inputVector);
52-
float squaredDistance = 0;
53-
for (int i = 0; i < inputVector.length; i++) {
54-
float diff = queryVector[i] - inputVector[i];
55-
squaredDistance += diff * diff;
56-
}
57-
return squaredDistance;
52+
return VectorUtil.squareDistance(queryVector, inputVector);
5853
}
5954

6055
private static float[] toFloat(List<Number> inputVector, VectorDataType vectorDataType) {
@@ -148,20 +143,12 @@ public static float cosineSimilarity(List<Number> queryVector, KNNVectorScriptDo
148143
*/
149144
public static float cosinesimil(float[] queryVector, float[] inputVector) {
150145
requireEqualDimension(queryVector, inputVector);
151-
float dotProduct = 0.0f;
152-
float normQueryVector = 0.0f;
153-
float normInputVector = 0.0f;
154-
for (int i = 0; i < queryVector.length; i++) {
155-
dotProduct += queryVector[i] * inputVector[i];
156-
normQueryVector += queryVector[i] * queryVector[i];
157-
normInputVector += inputVector[i] * inputVector[i];
158-
}
159-
float normalizedProduct = normQueryVector * normInputVector;
160-
if (normalizedProduct == 0) {
146+
try {
147+
return VectorUtil.cosine(queryVector, inputVector);
148+
} catch (IllegalArgumentException | AssertionError e) {
161149
logger.debug("Invalid vectors for cosine. Returning minimum score to put this result to end");
162150
return 0.0f;
163151
}
164-
return (float) (dotProduct / (Math.sqrt(normalizedProduct)));
165152
}
166153

167154
/**
@@ -217,7 +204,6 @@ public static float calculateHammingBit(Long queryLong, Long inputLong) {
217204
* @return L1 score
218205
*/
219206
public static float l1Norm(float[] queryVector, float[] inputVector) {
220-
requireEqualDimension(queryVector, inputVector);
221207
float distance = 0;
222208
for (int i = 0; i < inputVector.length; i++) {
223209
float diff = queryVector[i] - inputVector[i];
@@ -255,7 +241,6 @@ public static float l1Norm(List<Number> queryVector, KNNVectorScriptDocValues do
255241
* @return L-inf score
256242
*/
257243
public static float lInfNorm(float[] queryVector, float[] inputVector) {
258-
requireEqualDimension(queryVector, inputVector);
259244
float distance = 0;
260245
for (int i = 0; i < inputVector.length; i++) {
261246
float diff = queryVector[i] - inputVector[i];
@@ -293,12 +278,7 @@ public static float lInfNorm(List<Number> queryVector, KNNVectorScriptDocValues
293278
* @return dot product score
294279
*/
295280
public static float innerProduct(float[] queryVector, float[] inputVector) {
296-
requireEqualDimension(queryVector, inputVector);
297-
float distance = 0;
298-
for (int i = 0; i < inputVector.length; i++) {
299-
distance += queryVector[i] * inputVector[i];
300-
}
301-
return distance;
281+
return VectorUtil.dotProduct(queryVector, inputVector);
302282
}
303283

304284
/**

src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ public void testL2() {
4747

4848
public void testCosineSimilarity() {
4949
float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f };
50-
List<Double> arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0));
50+
List<Double> arrayListQueryObject = new ArrayList<>(Arrays.asList(2.0, 4.0, 6.0));
5151
float[] arrayFloat2 = new float[] { 2.0f, 4.0f, 6.0f };
5252
KNNMethodContext knnMethodContext = KNNMethodContext.getDefault();
5353

@@ -59,7 +59,7 @@ public void testCosineSimilarity() {
5959
);
6060
KNNScoringSpace.CosineSimilarity cosineSimilarity = new KNNScoringSpace.CosineSimilarity(arrayListQueryObject, fieldType);
6161

62-
assertEquals(3F, cosineSimilarity.scoringMethod.apply(arrayFloat2, arrayFloat), 0.1F);
62+
assertEquals(2F, cosineSimilarity.scoringMethod.apply(arrayFloat2, arrayFloat), 0.1F);
6363

6464
// invalid zero vector
6565
final List<Float> queryZeroVector = List.of(0.0f, 0.0f, 0.0f);

0 commit comments

Comments
 (0)