| 
15 | 15 | 
 
  | 
16 | 16 | 
 
  | 
17 | 17 | @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))  | 
18 |  | -def get_embedding(text: str, engine="text-similarity-davinci-001") -> List[float]:  | 
 | 18 | +def get_embedding(text: str, engine="text-similarity-davinci-001", **kwargs) -> List[float]:  | 
19 | 19 | 
 
  | 
20 | 20 |     # replace newlines, which can negatively affect performance.  | 
21 | 21 |     text = text.replace("\n", " ")  | 
22 | 22 | 
 
  | 
23 |  | -    return openai.Embedding.create(input=[text], engine=engine)["data"][0]["embedding"]  | 
 | 23 | +    return openai.Embedding.create(input=[text], engine=engine, **kwargs)["data"][0]["embedding"]  | 
24 | 24 | 
 
  | 
25 | 25 | 
 
  | 
26 | 26 | @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))  | 
27 | 27 | async def aget_embedding(  | 
28 |  | -    text: str, engine="text-similarity-davinci-001"  | 
 | 28 | +    text: str, engine="text-similarity-davinci-001", **kwargs  | 
29 | 29 | ) -> List[float]:  | 
30 | 30 | 
 
  | 
31 | 31 |     # replace newlines, which can negatively affect performance.  | 
32 | 32 |     text = text.replace("\n", " ")  | 
33 | 33 | 
 
  | 
34 |  | -    return (await openai.Embedding.acreate(input=[text], engine=engine))["data"][0][  | 
 | 34 | +    return (await openai.Embedding.acreate(input=[text], engine=engine, **kwargs))["data"][0][  | 
35 | 35 |         "embedding"  | 
36 | 36 |     ]  | 
37 | 37 | 
 
  | 
38 | 38 | 
 
  | 
39 | 39 | @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))  | 
40 | 40 | def get_embeddings(  | 
41 |  | -    list_of_text: List[str], engine="text-similarity-babbage-001"  | 
 | 41 | +    list_of_text: List[str], engine="text-similarity-babbage-001", **kwargs  | 
42 | 42 | ) -> List[List[float]]:  | 
43 | 43 |     assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."  | 
44 | 44 | 
 
  | 
45 | 45 |     # replace newlines, which can negatively affect performance.  | 
46 | 46 |     list_of_text = [text.replace("\n", " ") for text in list_of_text]  | 
47 | 47 | 
 
  | 
48 |  | -    data = openai.Embedding.create(input=list_of_text, engine=engine).data  | 
 | 48 | +    data = openai.Embedding.create(input=list_of_text, engine=engine, **kwargs).data  | 
49 | 49 |     data = sorted(data, key=lambda x: x["index"])  # maintain the same order as input.  | 
50 | 50 |     return [d["embedding"] for d in data]  | 
51 | 51 | 
 
  | 
52 | 52 | 
 
  | 
53 | 53 | @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))  | 
54 | 54 | async def aget_embeddings(  | 
55 |  | -    list_of_text: List[str], engine="text-similarity-babbage-001"  | 
 | 55 | +    list_of_text: List[str], engine="text-similarity-babbage-001", **kwargs  | 
56 | 56 | ) -> List[List[float]]:  | 
57 | 57 |     assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."  | 
58 | 58 | 
 
  | 
59 | 59 |     # replace newlines, which can negatively affect performance.  | 
60 | 60 |     list_of_text = [text.replace("\n", " ") for text in list_of_text]  | 
61 | 61 | 
 
  | 
62 |  | -    data = (await openai.Embedding.acreate(input=list_of_text, engine=engine)).data  | 
 | 62 | +    data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, **kwargs)).data  | 
63 | 63 |     data = sorted(data, key=lambda x: x["index"])  # maintain the same order as input.  | 
64 | 64 |     return [d["embedding"] for d in data]  | 
65 | 65 | 
 
  | 
 | 
0 commit comments