Skip to content

Commit

Permalink
Embeddings index memory performance improvements, closes #3
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Aug 4, 2020
1 parent cec0652 commit f9084ac
Showing 1 changed file with 30 additions and 32 deletions.
62 changes: 30 additions & 32 deletions src/python/paperai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,8 @@ def index(self, documents):
documents: list of documents
"""

ids = []
embeddings = []
stream = None
# Initialize local variables
ids, embeddings, dimensions, stream = [], None, None, None

# Shared objects with Pool
args = (self.config, self.scoring)
Expand All @@ -140,26 +139,26 @@ def index(self, documents):
with tempfile.NamedTemporaryFile(mode="wb", suffix=".npy", delete=False) as output:
stream = output.name
for uid, embedding in pool.imap(transform, documents):
if not dimensions:
# Set number of dimensions for embeddings
dimensions = embedding.shape[0]

ids.append(uid)
pickle.dump(embedding, output)

# Load streamed embeddings
# Load streamed embeddings back to memory
embeddings = np.empty((len(ids), dimensions), dtype=np.float32)
with open(stream, "rb") as stream:
while True:
try:
embeddings.append(pickle.load(stream))
except EOFError:
break

# Convert embeddings into a numpy array
embeddings = np.array(embeddings)
for x in range(embeddings.shape[0]):
embeddings[x] = pickle.load(stream)

# Build LSA model (if enabled). Remove principal components from embeddings.
self.lsa = self.buildLSA(embeddings, self.config["pca"]) if self.config["pca"] else None
embeddings = self.removePC(embeddings) if self.lsa else embeddings
if self.config["pca"]:
self.lsa = self.buildLSA(embeddings, self.config["pca"])
self.removePC(embeddings)

# Normalize embeddings
embeddings = self.normalize(embeddings)
self.normalize(embeddings)

# Create embeddings index. Inner product is equal to cosine similarity on normalized vectors.
# pylint: disable=E1136
Expand Down Expand Up @@ -189,40 +188,35 @@ def buildLSA(self, embeddings, components):

def removePC(self, embeddings):
"""
Applies a LSA model to embeddings, removed the top n principal components.
Applies a LSA model to embeddings, removed the top n principal components. Operation applied
directly on array.
Args:
embeddings: input embeddings matrix
Returns:
embeddings with the principal component(s) removed
"""

pc = self.lsa.components_

# Apply LSA model
# Calculation is different if n_components = 1
if pc.shape[0] == 1:
return embeddings - embeddings.dot(pc.transpose()) * pc

# Apply LSA model
return embeddings - embeddings.dot(pc.transpose()).dot(pc)
embeddings -= embeddings.dot(pc.transpose()) * pc
else:
embeddings -= embeddings.dot(pc.transpose()).dot(pc)

def normalize(self, embeddings):
"""
Normalizes embeddings using L2 normalization.
Normalizes embeddings using L2 normalization. Operation applied directly on array.
Args:
embeddings: input embeddings matrix
Returns:
normalized embeddings
"""

# Calculation is different for matrices vs vectors
if len(embeddings.shape) > 1:
return embeddings / np.linalg.norm(embeddings, axis=1).reshape(-1, 1)

return embeddings / np.linalg.norm(embeddings)
embeddings /= np.linalg.norm(embeddings, axis=1)[:, np.newaxis]
else:
embeddings /= np.linalg.norm(embeddings)

def transform(self, document):
"""
Expand All @@ -248,10 +242,14 @@ def transform(self, document):

# Reduce the dimensionality of the embeddings. Scale the embeddings using this
# model to reduce the noise of common but less relevant terms.
embedding = self.removePC(embedding) if self.lsa else embedding
if self.lsa:
self.removePC(embedding)

# Normalize vector if embeddings index exists, normalization is skipped during index builds
return self.normalize(embedding) if self.embeddings else embedding
if self.embeddings:
self.normalize(embedding)

return embedding

def lookup(self, tokens):
"""
Expand Down

0 comments on commit f9084ac

Please sign in to comment.