diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py index c10ffc6040bb0..3d642f3000e1c 100644 --- a/langchain/embeddings/openai.py +++ b/langchain/embeddings/openai.py @@ -153,73 +153,64 @@ def _get_len_safe_embeddings( embeddings: List[List[float]] = [[] for _ in range(len(texts))] try: import tiktoken - - tokens = [] - indices = [] - encoding = tiktoken.model.encoding_for_model(self.model) - for i, text in enumerate(texts): - # replace newlines, which can negatively affect performance. - text = text.replace("\n", " ") - token = encoding.encode( - text, - allowed_special=self.allowed_special, - disallowed_special=self.disallowed_special, - ) - for j in range(0, len(token), self.embedding_ctx_length): - tokens += [token[j : j + self.embedding_ctx_length]] - indices += [i] - - batched_embeddings = [] - _chunk_size = chunk_size or self.chunk_size - for i in range(0, len(tokens), _chunk_size): - response = embed_with_retry( - self, - input=tokens[i : i + _chunk_size], - engine=self.deployment, - ) - batched_embeddings += [r["embedding"] for r in response["data"]] - - results: List[List[List[float]]] = [[] for _ in range(len(texts))] - num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))] - for i in range(len(indices)): - results[indices[i]].append(batched_embeddings[i]) - num_tokens_in_batch[indices[i]].append(len(tokens[i])) - - for i in range(len(texts)): - _result = results[i] - if len(_result) == 0: - average = embed_with_retry(self, input="", engine=self.deployment)[ - "data" - ][0]["embedding"] - else: - average = np.average( - _result, axis=0, weights=num_tokens_in_batch[i] - ) - embeddings[i] = (average / np.linalg.norm(average)).tolist() - - return embeddings - except ImportError: raise ValueError( "Could not import tiktoken python package. " - "This is needed in order to for OpenAIEmbeddings. " + "This is needed for OpenAIEmbeddings. " "Please install it with `pip install tiktoken`." ) - def _embedding_func(self, text: str, *, engine: str) -> List[float]: - """Call out to OpenAI's embedding endpoint.""" - # handle large input text - if len(text) > self.embedding_ctx_length: - return self._get_len_safe_embeddings([text], engine=engine)[0] - else: + tokens = [] + indices = [] + encoding = tiktoken.encoding_for_model(self.model) + for i, text in enumerate(texts): # replace newlines, which can negatively affect performance. text = text.replace("\n", " ") - return embed_with_retry(self, input=[text], engine=engine)["data"][0][ - "embedding" - ] + token = encoding.encode( + text, + allowed_special=self.allowed_special, + disallowed_special=self.disallowed_special, + ) + for j in range(0, len(token), self.embedding_ctx_length): + tokens += [token[j : j + self.embedding_ctx_length]] + indices += [i] + + batched_embeddings = [] + _chunk_size = chunk_size or self.chunk_size + for i in range(0, len(tokens), _chunk_size): + response = embed_with_retry( + self, + input=tokens[i : i + _chunk_size], + engine=self.deployment, + ) + batched_embeddings += [r["embedding"] for r in response["data"]] + + results: List[List[List[float]]] = [[] for _ in range(len(texts))] + num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))] + for i in range(len(indices)): + results[indices[i]].append(batched_embeddings[i]) + num_tokens_in_batch[indices[i]].append(len(tokens[i])) + + for i in range(len(texts)): + _result = results[i] + if len(_result) == 0: + average = embed_with_retry(self, input="", engine=self.deployment)[ + "data" + ][0]["embedding"] + else: + average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) + embeddings[i] = (average / np.linalg.norm(average)).tolist() + + return embeddings + + def _embedding_func(self, text: str, *, engine: str) -> List[float]: + """Call out to OpenAI's embedding endpoint.""" + # NOTE: to keep things simple, we assume the list may contain texts longer + # than the maximum context and use length-safe embedding function. + return self._get_len_safe_embeddings([text], engine=engine)[0] def embed_documents( - self, texts: List[str], chunk_size: Optional[int] = 0 + self, texts: List[str], chunk_size: Optional[int] = None ) -> List[List[float]]: """Call out to OpenAI's embedding endpoint for embedding search docs.