diff --git a/src/python/paperai/embeddings.py b/src/python/paperai/embeddings.py index 727db1b..7a81f8e 100644 --- a/src/python/paperai/embeddings.py +++ b/src/python/paperai/embeddings.py @@ -128,9 +128,8 @@ def index(self, documents): documents: list of documents """ - ids = [] - embeddings = [] - stream = None + # Initialize local variables + ids, embeddings, dimensions, stream = [], None, None, None # Shared objects with Pool args = (self.config, self.scoring) @@ -140,26 +139,26 @@ def index(self, documents): with tempfile.NamedTemporaryFile(mode="wb", suffix=".npy", delete=False) as output: stream = output.name for uid, embedding in pool.imap(transform, documents): + if not dimensions: + # Set number of dimensions for embeddings + dimensions = embedding.shape[0] + ids.append(uid) pickle.dump(embedding, output) - # Load streamed embeddings + # Load streamed embeddings back to memory + embeddings = np.empty((len(ids), dimensions), dtype=np.float32) with open(stream, "rb") as stream: - while True: - try: - embeddings.append(pickle.load(stream)) - except EOFError: - break - - # Convert embeddings into a numpy array - embeddings = np.array(embeddings) + for x in range(embeddings.shape[0]): + embeddings[x] = pickle.load(stream) # Build LSA model (if enabled). Remove principal components from embeddings. - self.lsa = self.buildLSA(embeddings, self.config["pca"]) if self.config["pca"] else None - embeddings = self.removePC(embeddings) if self.lsa else embeddings + if self.config["pca"]: + self.lsa = self.buildLSA(embeddings, self.config["pca"]) + self.removePC(embeddings) # Normalize embeddings - embeddings = self.normalize(embeddings) + self.normalize(embeddings) # Create embeddings index. Inner product is equal to cosine similarity on normalized vectors. # pylint: disable=E1136 @@ -189,40 +188,35 @@ def buildLSA(self, embeddings, components): def removePC(self, embeddings): """ - Applies a LSA model to embeddings, removed the top n principal components. + Applies a LSA model to embeddings, removed the top n principal components. Operation applied + directly on array. Args: embeddings: input embeddings matrix - - Returns: - embeddings with the principal component(s) removed """ pc = self.lsa.components_ + # Apply LSA model # Calculation is different if n_components = 1 if pc.shape[0] == 1: - return embeddings - embeddings.dot(pc.transpose()) * pc - - # Apply LSA model - return embeddings - embeddings.dot(pc.transpose()).dot(pc) + embeddings -= embeddings.dot(pc.transpose()) * pc + else: + embeddings -= embeddings.dot(pc.transpose()).dot(pc) def normalize(self, embeddings): """ - Normalizes embeddings using L2 normalization. + Normalizes embeddings using L2 normalization. Operation applied directly on array. Args: embeddings: input embeddings matrix - - Returns: - normalized embeddings """ # Calculation is different for matrices vs vectors if len(embeddings.shape) > 1: - return embeddings / np.linalg.norm(embeddings, axis=1).reshape(-1, 1) - - return embeddings / np.linalg.norm(embeddings) + embeddings /= np.linalg.norm(embeddings, axis=1)[:, np.newaxis] + else: + embeddings /= np.linalg.norm(embeddings) def transform(self, document): """ @@ -248,10 +242,14 @@ def transform(self, document): # Reduce the dimensionality of the embeddings. Scale the embeddings using this # model to reduce the noise of common but less relevant terms. - embedding = self.removePC(embedding) if self.lsa else embedding + if self.lsa: + self.removePC(embedding) # Normalize vector if embeddings index exists, normalization is skipped during index builds - return self.normalize(embedding) if self.embeddings else embedding + if self.embeddings: + self.normalize(embedding) + + return embedding def lookup(self, tokens): """