Skip to content

Commit 540782c

Browse files
committed
Fix test bug and remove unnecessary validation
Signed-off-by: Ryan Bogan <[email protected]>
1 parent e1ec3b9 commit 540782c

File tree

2 files changed

+8
-21
lines changed

2 files changed

+8
-21
lines changed

src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java

+6-19
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,19 @@ public static float l2Squared(List<Number> queryVector, KNNVectorScriptDocValues
9696
* @return cosine score
9797
*/
9898
public static float cosinesimilOptimized(float[] queryVector, float[] inputVector, float normQueryVector) {
99+
requireEqualDimension(queryVector, inputVector);
100+
float dotProduct = 0.0f;
99101
float normInputVector = 0.0f;
100102
for (int i = 0; i < queryVector.length; i++) {
103+
dotProduct += queryVector[i] * inputVector[i];
101104
normInputVector += inputVector[i] * inputVector[i];
102105
}
103106
float normalizedProduct = normQueryVector * normInputVector;
104107
if (normalizedProduct == 0) {
105108
logger.debug("Invalid vectors for cosine. Returning minimum score to put this result to end");
106109
return 0.0f;
107110
}
108-
return (float) (VectorUtil.dotProduct(queryVector, inputVector) / (Math.sqrt(normalizedProduct)));
111+
return (float) (dotProduct / (Math.sqrt(normalizedProduct)));
109112
}
110113

111114
/**
@@ -140,28 +143,12 @@ public static float cosineSimilarity(List<Number> queryVector, KNNVectorScriptDo
140143
*/
141144
public static float cosinesimil(float[] queryVector, float[] inputVector) {
142145
requireEqualDimension(queryVector, inputVector);
143-
int numZeroInInput = 0;
144-
int numZeroInQuery = 0;
145-
float cosine = 0.0f;
146-
for (int i = 0; i < inputVector.length; i++) {
147-
if (inputVector[i] == 0) {
148-
numZeroInInput++;
149-
}
150-
151-
if (queryVector[i] == 0) {
152-
numZeroInQuery++;
153-
}
154-
}
155-
if (numZeroInInput == inputVector.length || numZeroInQuery == queryVector.length) {
156-
return cosine;
157-
}
158146
try {
159-
cosine = VectorUtil.cosine(queryVector, inputVector);
160-
} catch (IllegalArgumentException e) {
147+
return VectorUtil.cosine(queryVector, inputVector);
148+
} catch (IllegalArgumentException | AssertionError e) {
161149
logger.debug("Invalid vectors for cosine. Returning minimum score to put this result to end");
162150
return 0.0f;
163151
}
164-
return cosine;
165152
}
166153

167154
/**

src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ public void testL2() {
4747

4848
public void testCosineSimilarity() {
4949
float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f };
50-
List<Double> arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0));
50+
List<Double> arrayListQueryObject = new ArrayList<>(Arrays.asList(2.0, 4.0, 6.0));
5151
float[] arrayFloat2 = new float[] { 2.0f, 4.0f, 6.0f };
5252
KNNMethodContext knnMethodContext = KNNMethodContext.getDefault();
5353

@@ -59,7 +59,7 @@ public void testCosineSimilarity() {
5959
);
6060
KNNScoringSpace.CosineSimilarity cosineSimilarity = new KNNScoringSpace.CosineSimilarity(arrayListQueryObject, fieldType);
6161

62-
assertEquals(3F, cosineSimilarity.scoringMethod.apply(arrayFloat2, arrayFloat), 0.1F);
62+
assertEquals(2F, cosineSimilarity.scoringMethod.apply(arrayFloat2, arrayFloat), 0.1F);
6363

6464
// invalid zero vector
6565
final List<Float> queryZeroVector = List.of(0.0f, 0.0f, 0.0f);

0 commit comments

Comments
 (0)