Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Java: Integrate EmbeddingVector into Embedding #2328

Merged
merged 3 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.microsoft.semantickernel.ai.EmbeddingVector;
import com.microsoft.semantickernel.ai.embeddings.Embedding;
import com.microsoft.semantickernel.memory.MemoryRecord;
import com.microsoft.semantickernel.memory.MemoryStore;
Expand Down Expand Up @@ -131,51 +130,49 @@ public Mono<List<String>> getCollectionsAsync()
}

@Override
public Mono<Tuple2<MemoryRecord, ? extends Number>> getNearestMatchAsync(String collectionName, Embedding embedding, double minRelevanceScore,
public Mono<Tuple2<MemoryRecord, Float>> getNearestMatchAsync(String collectionName, Embedding embedding, double minRelevanceScore,
boolean withEmbedding)
{
// Note: with this simple implementation, the MemoryRecord will always contain the embedding.
final EmbeddingVector embeddingVector = new EmbeddingVector(embedding.getVector());
return Mono.justOrEmpty(
Arrays.stream(this._memoryRecords)
.map(it -> {
EmbeddingVector memoryRecordEmbeddingVector =
new EmbeddingVector(it.getEmbedding().getVector());
double cosineSimilarity = -1d;
Embedding memoryRecordEmbedding =
it.getEmbedding();
float cosineSimilarity = -1f;
try {
cosineSimilarity = memoryRecordEmbeddingVector.cosineSimilarity(embeddingVector);
cosineSimilarity = embedding.cosineSimilarity(memoryRecordEmbedding);
} catch (IllegalArgumentException e) {
// Vectors cannot have zero norm
}
return Tuples.of(it, cosineSimilarity);
})
.filter(it -> it.getT2() >= minRelevanceScore)
.max(Comparator.comparing(Tuple2::getT2, Double::compare))
.max(Comparator.comparing(Tuple2::getT2, Float::compare))
);
}

@Override
public Mono<Collection<Tuple2<MemoryRecord, Number>>> getNearestMatchesAsync(String collectionName, Embedding embedding, int limit,
public Mono<Collection<Tuple2<MemoryRecord, Float>>> getNearestMatchesAsync(String collectionName, Embedding embedding, int limit,
double minRelevanceScore, boolean withEmbeddings)
{
// Note: with this simple implementation, the MemoryRecord will always contain the embedding.
final EmbeddingVector embeddingVector = new EmbeddingVector(embedding.getVector());
return Mono.justOrEmpty(
Arrays.stream(this._memoryRecords)
.map(it -> {
EmbeddingVector memoryRecordEmbeddingVector =
new EmbeddingVector(it.getEmbedding().getVector());
double cosineSimilarity = -1d;
Embedding memoryRecordEmbedding =
it.getEmbedding();
float cosineSimilarity = -1f;
try {
cosineSimilarity = memoryRecordEmbeddingVector.cosineSimilarity(embeddingVector);
cosineSimilarity = embedding.cosineSimilarity(memoryRecordEmbedding);
} catch (IllegalArgumentException e) {
// Vectors cannot have zero norm
}
return Tuples.of(it, (Number)cosineSimilarity);
return Tuples.of(it, cosineSimilarity);
})
.filter(it -> it.getT2().doubleValue() >= minRelevanceScore)
.filter(it -> it.getT2() >= (float)minRelevanceScore)
// sort by similarity score, descending
.sorted(Comparator.comparing(it -> it.getT2().doubleValue(), (a,b) -> Double.compare(b, a)))
.sorted(Comparator.comparing(Tuple2::getT2, (a,b) -> Float.compare(b, a)))
.limit(limit)
.collect(Collectors.toList())
);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,67 +1,103 @@
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.ai.embeddings;

import java.util.ArrayList;
import com.microsoft.semantickernel.ai.embeddings.VectorOperations;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.Objects;
import javax.annotation.Nonnull;

/** Represents a strongly typed vector of numeric data. */
public class Embedding {

// vector is immutable!
private final List<Float> vector;

public List<Float> getVector() {
return Collections.unmodifiableList(vector);
}
private static final Embedding EMPTY = new Embedding();

private final List<Float> vector;
public static Embedding empty() {
return EMPTY;
}

private static final Embedding EMPTY =
new Embedding(Collections.unmodifiableList(new ArrayList<>()));
/** Initializes a new instance of the Embedding class. */
public Embedding() {
this.vector = Collections.emptyList();
}

public static Embedding empty() {
return EMPTY;
}
/**
* Initializes a new instance of the Embedding class that contains numeric elements copied from
* the specified collection
*
* @param vector The collection whose elements are copied to the new Embedding
*/
public Embedding(@Nonnull List<Float> vector) {
Objects.requireNonNull(vector);
this.vector = Collections.unmodifiableList(vector);
}

/** Initializes a new instance of the Embedding class. */
public Embedding() {
this.vector = Collections.emptyList();
}
/**
* Return the embedding vector as a read-only list.
* @return The embedding vector as a read-only list.
*/
public List<Float> getVector() {
return vector;
}

/**
* Initializes a new instance of the Embedding class that contains numeric elements copied from
* the specified collection
*
* @param vector The collection whose elements are copied to the new Embedding
*/
public Embedding(List<Float> vector) {
// Verify.NotNull(vector, nameof(vector));
this.vector =
vector != null ? Collections.unmodifiableList(vector) : Collections.emptyList();
}
/**
* Calculates the dot product of this {@code Embedding} with another.
*
* @param other The other {@code Embedding} to compute the dot product with
* @return The dot product between the {@code Embedding} vectors
*/
public float dot(@Nonnull Embedding other) {
Objects.requireNonNull(other);
return VectorOperations.dot(this.vector, other.getVector());
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof Embedding)) return false;
/**
* Calculates the Euclidean length of this vector.
*
* @return Euclidean length
*/
public float euclideanLength() {
return VectorOperations.euclideanLength(this.vector);
}

Embedding embedding = (Embedding) o;
/**
* Calculates the cosine similarity of this vector with another.
*
* @param other The other vector to compute cosine similarity with.
* @return Cosine similarity between vectors
*/
public float cosineSimilarity(@Nonnull Embedding other) {
Objects.requireNonNull(other);
return VectorOperations.cosineSimilarity(this.vector, other.getVector());
}

return vector.equals(embedding.vector);
}
/**
* Multiply the {@code Embedding} vector by a multiplier.
* @param multiplier The multiplier to multiply the {@code Embedding} vector by
* @return A new {@code Embedding} with the vector multiplied by the multiplier
*/
public Embedding multiply(float multiplier) {
return new Embedding(VectorOperations.multiply(this.vector, multiplier));
}

@Override
public int hashCode() {
return vector.hashCode();
}
/**
* Divide the {@code Embedding} vector by a divisor.
* @param divisor The divisor to divide the {@code Embedding} vector by
* @return A new {@code Embedding} with the vector divided by the divisor
*/
public Embedding divide(float divisor) {
return new Embedding(VectorOperations.divide(this.vector, divisor));
}

@Override
public String toString() {
return "Embedding{"
+ "vector="
+ vector.stream()
.limit(3)
.map(String::valueOf)
.collect(Collectors.joining(", ", "[", vector.size() > 3 ? "...]" : "]"))
+ '}';
}
/**
* Normalizes the underlying vector, such that the Euclidean length is 1.
*
* @return A new {@code Embedding} with the normalized vector
*/
public Embedding normalize() {
return new Embedding(VectorOperations.normalize(this.vector));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package com.microsoft.semantickernel.ai.embeddings;

import javax.annotation.Nonnull;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

final class VectorOperations {

/**
* Calculates the cosine similarity of two vectors. The vectors must be equal
* in length and have non-zero norm.
*
* @param x First vector, which is not modified
* @param y Second vector, which is not modified
* @return The cosine similarity of the two vectors
*/
static float cosineSimilarity(@Nonnull List<Float> x, @Nonnull List<Float> y) {
Objects.requireNonNull(x);
Objects.requireNonNull(y);

if (x.size() != y.size()) {
throw new IllegalArgumentException("Vectors lengths must be equal");
}

float dotProduct = dot(x,y);
float normX = dot(x,x);
float normY = dot(y,y);

if (normX == 0 || normY == 0) {
throw new IllegalArgumentException("Vectors cannot have zero norm");
}

return dotProduct / (float) (Math.sqrt(normX) * Math.sqrt(normY));
}

/**
* Divides the elements of the vector by the divisor.
* @param vector Vector to divide, which is not modified
* @param divisor Divisor to apply to each element of the vector
* @return A new vector with the elements divided by the divisor
*/
static List<Float> divide(@Nonnull List<Float> vector, float divisor) {
Objects.requireNonNull(vector);
if (Float.isNaN(divisor)) {
throw new IllegalArgumentException("Divisor cannot be NaN");
}
if (divisor == 0f) {
throw new IllegalArgumentException("Divisor cannot be zero");
}

return vector.stream()
.map(x -> x / divisor)
.collect(Collectors.toList());
}

static float dot(@Nonnull List<Float> x, @Nonnull List<Float> y) {
Objects.requireNonNull(x);
Objects.requireNonNull(y);

if (x.size() != y.size()) {
throw new IllegalArgumentException("Vectors lengths must be equal");
}

float result = 0;
for (int i = 0; i < x.size(); ++i) {
result += x.get(i) * y.get(i);
}

return result;
}

/**
* Calculates the Euclidean length of a vector.
* @param vector Vector to calculate the length of, which is not modified
* @return The Euclidean length of the vector
*/
static float euclideanLength(@Nonnull List<Float> vector) {
Objects.requireNonNull(vector);
return (float) Math.sqrt(dot(vector, vector));
}

/**
* Multiplies the elements of the vector by the multiplier.
* @param vector Vector to multiply, which is not modified
* @param multiplier Multiplier to apply to each element of the vector
* @return A new vector with the elements multiplied by the multiplier
*/
static List<Float> multiply(@Nonnull List<Float> vector, float multiplier) {
Objects.requireNonNull(vector);
if (Float.isNaN(multiplier)) {
throw new IllegalArgumentException("Multiplier cannot be NaN");
}
if (Float.isInfinite(multiplier)) {
throw new IllegalArgumentException("Multiplier cannot be infinite");
}

return vector.stream()
.map(x -> x * multiplier)
.collect(Collectors.toList());
}

/**
* Normalizes the vector such that the Euclidean length is 1.
*
* @param vector Vector to normalize, which is not modified
* @return A new, normalized vector
*/
static List<Float> normalize(@Nonnull List<Float> vector) {
Objects.requireNonNull(vector);
return divide(vector, euclideanLength(vector));
}

}
Loading
Loading