diff --git a/openai/embeddings_utils.py b/openai/embeddings_utils.py index 08fa94c2ea..1b65e7c8e9 100644 --- a/openai/embeddings_utils.py +++ b/openai/embeddings_utils.py @@ -15,51 +15,51 @@ @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) -def get_embedding(text: str, engine="text-similarity-davinci-001") -> List[float]: +def get_embedding(text: str, engine="text-similarity-davinci-001", **kwargs) -> List[float]: # replace newlines, which can negatively affect performance. text = text.replace("\n", " ") - return openai.Embedding.create(input=[text], engine=engine)["data"][0]["embedding"] + return openai.Embedding.create(input=[text], engine=engine, **kwargs)["data"][0]["embedding"] @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) async def aget_embedding( - text: str, engine="text-similarity-davinci-001" + text: str, engine="text-similarity-davinci-001", **kwargs ) -> List[float]: # replace newlines, which can negatively affect performance. text = text.replace("\n", " ") - return (await openai.Embedding.acreate(input=[text], engine=engine))["data"][0][ + return (await openai.Embedding.acreate(input=[text], engine=engine, **kwargs))["data"][0][ "embedding" ] @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) def get_embeddings( - list_of_text: List[str], engine="text-similarity-babbage-001" + list_of_text: List[str], engine="text-similarity-babbage-001", **kwargs ) -> List[List[float]]: assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048." # replace newlines, which can negatively affect performance. list_of_text = [text.replace("\n", " ") for text in list_of_text] - data = openai.Embedding.create(input=list_of_text, engine=engine).data + data = openai.Embedding.create(input=list_of_text, engine=engine, **kwargs).data data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. return [d["embedding"] for d in data] @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) async def aget_embeddings( - list_of_text: List[str], engine="text-similarity-babbage-001" + list_of_text: List[str], engine="text-similarity-babbage-001", **kwargs ) -> List[List[float]]: assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048." # replace newlines, which can negatively affect performance. list_of_text = [text.replace("\n", " ") for text in list_of_text] - data = (await openai.Embedding.acreate(input=list_of_text, engine=engine)).data + data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, **kwargs)).data data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. return [d["embedding"] for d in data]