-
Notifications
You must be signed in to change notification settings - Fork 296
feat: Async open ai llm generate #4879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
f05aa06
705bef1
918a348
ee5bec2
58f520e
7aa2ba6
281024f
bc32682
baed38e
0700b0e
c897fde
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ def llm_generate( | |
model: str = "facebook/opt-125m", | ||
provider: Literal["vllm", "openai"] = "vllm", | ||
concurrency: int = 1, | ||
batch_size: int = 1024, | ||
batch_size: int | None = None, | ||
num_cpus: int | None = None, | ||
num_gpus: int | None = None, | ||
**generation_config: dict[str, Any], | ||
|
@@ -27,8 +27,8 @@ def llm_generate( | |
The LLM provider to use for generation. Supported values: "vllm", "openai" | ||
concurrency: int, default=1 | ||
The number of concurrent instances of the model to run | ||
batch_size: int, default=1024 | ||
The batch size for the UDF | ||
batch_size: int, default=None | ||
The batch size for the UDF. If None, the batch size will be determined by defaults based on the provider. | ||
num_cpus: float, default=None | ||
The number of CPUs to use for the UDF | ||
num_gpus: float, default=None | ||
|
@@ -40,14 +40,35 @@ def llm_generate( | |
Use vLLM provider: | ||
>>> import daft | ||
>>> from daft import col | ||
>>> from daft.functions import llm_generate | ||
>>> df = daft.read_csv("prompts.csv") | ||
>>> df = df.with_column("response", llm_generate(col("prompt"), model="facebook/opt-125m")) | ||
>>> from daft.functions import llm_generate, format | ||
>>> | ||
>>> df = daft.from_pydict({"city": ["Paris", "Tokyo", "New York"]}) | ||
>>> df = df.with_column( | ||
... "description", | ||
... llm_generate( | ||
... format( | ||
... "Describe the main attractions and unique features of this city: {}.", | ||
... col("city"), | ||
... ), | ||
... model="facebook/opt-125m", | ||
... ), | ||
... ) | ||
>>> df.collect() | ||
|
||
Use OpenAI provider: | ||
>>> df = daft.read_csv("prompts.csv") | ||
>>> df = df.with_column("response", llm_generate(col("prompt"), model="gpt-4o", api_key="xxx", provider="openai")) | ||
>>> df = daft.from_pydict({"city": ["Paris", "Tokyo", "New York"]}) | ||
>>> df = df.with_column( | ||
... "description", | ||
... llm_generate( | ||
... format( | ||
... "Describe the main attractions and unique features of this city: {}.", | ||
... col("city"), | ||
... ), | ||
... model="gpt-4o", | ||
... api_key="xxx", | ||
... provider="openai", | ||
... ), | ||
... ) | ||
>>> df.collect() | ||
|
||
Note: | ||
|
@@ -56,8 +77,12 @@ def llm_generate( | |
cls: Any = None | ||
if provider == "vllm": | ||
cls = _vLLMGenerator | ||
if batch_size is None: | ||
batch_size = 1024 | ||
elif provider == "openai": | ||
cls = _OpenAIGenerator | ||
if batch_size is None: | ||
batch_size = 128 | ||
else: | ||
raise ValueError(f"Unsupported provider: {provider}") | ||
|
||
|
@@ -103,8 +128,10 @@ def __init__( | |
model: str = "gpt-4o", | ||
generation_config: dict[str, Any] = {}, | ||
) -> None: | ||
import asyncio | ||
|
||
try: | ||
from openai import OpenAI | ||
from openai import AsyncOpenAI | ||
except ImportError: | ||
raise ImportError("Please install the openai package to use this provider.") | ||
self.model = model | ||
|
@@ -113,17 +140,36 @@ def __init__( | |
|
||
self.generation_config = {k: v for k, v in generation_config.items() if k not in client_params_keys} | ||
|
||
self.llm = OpenAI(**client_params_opts) | ||
self.llm = AsyncOpenAI(**client_params_opts) | ||
try: | ||
self.loop = asyncio.get_running_loop() | ||
except RuntimeError: | ||
self.loop = asyncio.new_event_loop() | ||
|
||
def __call__(self, input_prompt_column: Series) -> list[str]: | ||
prompts = input_prompt_column.to_pylist() | ||
outputs = [] | ||
for prompt in prompts: | ||
import asyncio | ||
|
||
async def get_completion(prompt: str) -> str: | ||
messages = [{"role": "user", "content": prompt}] | ||
completion = self.llm.chat.completions.create( | ||
completion = await self.llm.chat.completions.create( | ||
model=self.model, | ||
messages=messages, | ||
**self.generation_config, | ||
) | ||
outputs.append(completion.choices[0].message.content) | ||
return completion.choices[0].message.content | ||
|
||
prompts = input_prompt_column.to_pylist() | ||
|
||
async def gather_completions() -> list[str]: | ||
tasks = [get_completion(prompt) for prompt in prompts] | ||
return await asyncio.gather(*tasks) | ||
|
||
try: | ||
outputs = self.loop.run_until_complete(gather_completions()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: Using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Watch out for these and do some testing from within Jupyter? I did see some users complaining about needing to do some shenanigans with nested_async to get our stuff working in a notebook. Generally needing to manually handle There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah can cfm that it works in jupyter (native + ray).
|
||
except Exception as e: | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.exception(e) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently exceptions that are unserializable are not handled properly: #4881, so this is a workaround until it is fixed. |
||
raise e | ||
colin-ho marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Event loop management is problematic. Creating a new event loop in
__init__
without setting it as the current loop can cause issues. The loop may not be properly configured for the execution context.