|
10 | 10 | import java.util.Objects;
|
11 | 11 | import org.apache.logging.log4j.LogManager;
|
12 | 12 | import org.apache.logging.log4j.Logger;
|
| 13 | +import org.apache.lucene.util.VectorUtil; |
13 | 14 | import org.opensearch.knn.index.KNNVectorScriptDocValues;
|
14 | 15 | import org.opensearch.knn.index.SpaceType;
|
15 | 16 | import org.opensearch.knn.index.VectorDataType;
|
@@ -48,13 +49,7 @@ private static void requireEqualDimension(final float[] queryVector, final float
|
48 | 49 | * @return L2 score
|
49 | 50 | */
|
50 | 51 | public static float l2Squared(float[] queryVector, float[] inputVector) {
|
51 |
| - requireEqualDimension(queryVector, inputVector); |
52 |
| - float squaredDistance = 0; |
53 |
| - for (int i = 0; i < inputVector.length; i++) { |
54 |
| - float diff = queryVector[i] - inputVector[i]; |
55 |
| - squaredDistance += diff * diff; |
56 |
| - } |
57 |
| - return squaredDistance; |
| 52 | + return VectorUtil.squareDistance(queryVector, inputVector); |
58 | 53 | }
|
59 | 54 |
|
60 | 55 | private static float[] toFloat(List<Number> inputVector, VectorDataType vectorDataType) {
|
@@ -148,20 +143,12 @@ public static float cosineSimilarity(List<Number> queryVector, KNNVectorScriptDo
|
148 | 143 | */
|
149 | 144 | public static float cosinesimil(float[] queryVector, float[] inputVector) {
|
150 | 145 | requireEqualDimension(queryVector, inputVector);
|
151 |
| - float dotProduct = 0.0f; |
152 |
| - float normQueryVector = 0.0f; |
153 |
| - float normInputVector = 0.0f; |
154 |
| - for (int i = 0; i < queryVector.length; i++) { |
155 |
| - dotProduct += queryVector[i] * inputVector[i]; |
156 |
| - normQueryVector += queryVector[i] * queryVector[i]; |
157 |
| - normInputVector += inputVector[i] * inputVector[i]; |
158 |
| - } |
159 |
| - float normalizedProduct = normQueryVector * normInputVector; |
160 |
| - if (normalizedProduct == 0) { |
| 146 | + try { |
| 147 | + return VectorUtil.cosine(queryVector, inputVector); |
| 148 | + } catch (IllegalArgumentException | AssertionError e) { |
161 | 149 | logger.debug("Invalid vectors for cosine. Returning minimum score to put this result to end");
|
162 | 150 | return 0.0f;
|
163 | 151 | }
|
164 |
| - return (float) (dotProduct / (Math.sqrt(normalizedProduct))); |
165 | 152 | }
|
166 | 153 |
|
167 | 154 | /**
|
@@ -217,7 +204,6 @@ public static float calculateHammingBit(Long queryLong, Long inputLong) {
|
217 | 204 | * @return L1 score
|
218 | 205 | */
|
219 | 206 | public static float l1Norm(float[] queryVector, float[] inputVector) {
|
220 |
| - requireEqualDimension(queryVector, inputVector); |
221 | 207 | float distance = 0;
|
222 | 208 | for (int i = 0; i < inputVector.length; i++) {
|
223 | 209 | float diff = queryVector[i] - inputVector[i];
|
@@ -255,7 +241,6 @@ public static float l1Norm(List<Number> queryVector, KNNVectorScriptDocValues do
|
255 | 241 | * @return L-inf score
|
256 | 242 | */
|
257 | 243 | public static float lInfNorm(float[] queryVector, float[] inputVector) {
|
258 |
| - requireEqualDimension(queryVector, inputVector); |
259 | 244 | float distance = 0;
|
260 | 245 | for (int i = 0; i < inputVector.length; i++) {
|
261 | 246 | float diff = queryVector[i] - inputVector[i];
|
@@ -293,12 +278,7 @@ public static float lInfNorm(List<Number> queryVector, KNNVectorScriptDocValues
|
293 | 278 | * @return dot product score
|
294 | 279 | */
|
295 | 280 | public static float innerProduct(float[] queryVector, float[] inputVector) {
|
296 |
| - requireEqualDimension(queryVector, inputVector); |
297 |
| - float distance = 0; |
298 |
| - for (int i = 0; i < inputVector.length; i++) { |
299 |
| - distance += queryVector[i] * inputVector[i]; |
300 |
| - } |
301 |
| - return distance; |
| 281 | + return VectorUtil.dotProduct(queryVector, inputVector); |
302 | 282 | }
|
303 | 283 |
|
304 | 284 | /**
|
|
0 commit comments