|
15 | 15 |
|
16 | 16 |
|
17 | 17 | @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) |
18 | | -def get_embedding(text: str, engine="text-similarity-davinci-001") -> List[float]: |
| 18 | +def get_embedding(text: str, engine="text-similarity-davinci-001", **kwargs) -> List[float]: |
19 | 19 |
|
20 | 20 | # replace newlines, which can negatively affect performance. |
21 | 21 | text = text.replace("\n", " ") |
22 | 22 |
|
23 | | - return openai.Embedding.create(input=[text], engine=engine)["data"][0]["embedding"] |
| 23 | + return openai.Embedding.create(input=[text], engine=engine, **kwargs)["data"][0]["embedding"] |
24 | 24 |
|
25 | 25 |
|
26 | 26 | @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) |
27 | 27 | async def aget_embedding( |
28 | | - text: str, engine="text-similarity-davinci-001" |
| 28 | + text: str, engine="text-similarity-davinci-001", **kwargs |
29 | 29 | ) -> List[float]: |
30 | 30 |
|
31 | 31 | # replace newlines, which can negatively affect performance. |
32 | 32 | text = text.replace("\n", " ") |
33 | 33 |
|
34 | | - return (await openai.Embedding.acreate(input=[text], engine=engine))["data"][0][ |
| 34 | + return (await openai.Embedding.acreate(input=[text], engine=engine, **kwargs))["data"][0][ |
35 | 35 | "embedding" |
36 | 36 | ] |
37 | 37 |
|
38 | 38 |
|
39 | 39 | @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) |
40 | 40 | def get_embeddings( |
41 | | - list_of_text: List[str], engine="text-similarity-babbage-001" |
| 41 | + list_of_text: List[str], engine="text-similarity-babbage-001", **kwargs |
42 | 42 | ) -> List[List[float]]: |
43 | 43 | assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048." |
44 | 44 |
|
45 | 45 | # replace newlines, which can negatively affect performance. |
46 | 46 | list_of_text = [text.replace("\n", " ") for text in list_of_text] |
47 | 47 |
|
48 | | - data = openai.Embedding.create(input=list_of_text, engine=engine).data |
| 48 | + data = openai.Embedding.create(input=list_of_text, engine=engine, **kwargs).data |
49 | 49 | data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. |
50 | 50 | return [d["embedding"] for d in data] |
51 | 51 |
|
52 | 52 |
|
53 | 53 | @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) |
54 | 54 | async def aget_embeddings( |
55 | | - list_of_text: List[str], engine="text-similarity-babbage-001" |
| 55 | + list_of_text: List[str], engine="text-similarity-babbage-001", **kwargs |
56 | 56 | ) -> List[List[float]]: |
57 | 57 | assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048." |
58 | 58 |
|
59 | 59 | # replace newlines, which can negatively affect performance. |
60 | 60 | list_of_text = [text.replace("\n", " ") for text in list_of_text] |
61 | 61 |
|
62 | | - data = (await openai.Embedding.acreate(input=list_of_text, engine=engine)).data |
| 62 | + data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, **kwargs)).data |
63 | 63 | data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. |
64 | 64 | return [d["embedding"] for d in data] |
65 | 65 |
|
|
0 commit comments