Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 61 additions & 15 deletions daft/functions/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def llm_generate(
model: str = "facebook/opt-125m",
provider: Literal["vllm", "openai"] = "vllm",
concurrency: int = 1,
batch_size: int = 1024,
batch_size: int | None = None,
num_cpus: int | None = None,
num_gpus: int | None = None,
**generation_config: dict[str, Any],
Expand All @@ -27,8 +27,8 @@ def llm_generate(
The LLM provider to use for generation. Supported values: "vllm", "openai"
concurrency: int, default=1
The number of concurrent instances of the model to run
batch_size: int, default=1024
The batch size for the UDF
batch_size: int, default=None
The batch size for the UDF. If None, the batch size will be determined by defaults based on the provider.
num_cpus: float, default=None
The number of CPUs to use for the UDF
num_gpus: float, default=None
Expand All @@ -40,14 +40,35 @@ def llm_generate(
Use vLLM provider:
>>> import daft
>>> from daft import col
>>> from daft.functions import llm_generate
>>> df = daft.read_csv("prompts.csv")
>>> df = df.with_column("response", llm_generate(col("prompt"), model="facebook/opt-125m"))
>>> from daft.functions import llm_generate, format
>>>
>>> df = daft.from_pydict({"city": ["Paris", "Tokyo", "New York"]})
>>> df = df.with_column(
... "description",
... llm_generate(
... format(
... "Describe the main attractions and unique features of this city: {}.",
... col("city"),
... ),
... model="facebook/opt-125m",
... ),
... )
>>> df.collect()

Use OpenAI provider:
>>> df = daft.read_csv("prompts.csv")
>>> df = df.with_column("response", llm_generate(col("prompt"), model="gpt-4o", api_key="xxx", provider="openai"))
>>> df = daft.from_pydict({"city": ["Paris", "Tokyo", "New York"]})
>>> df = df.with_column(
... "description",
... llm_generate(
... format(
... "Describe the main attractions and unique features of this city: {}.",
... col("city"),
... ),
... model="gpt-4o",
... api_key="xxx",
... provider="openai",
... ),
... )
>>> df.collect()

Note:
Expand All @@ -56,8 +77,12 @@ def llm_generate(
cls: Any = None
if provider == "vllm":
cls = _vLLMGenerator
if batch_size is None:
batch_size = 1024
elif provider == "openai":
cls = _OpenAIGenerator
if batch_size is None:
batch_size = 128
else:
raise ValueError(f"Unsupported provider: {provider}")

Expand Down Expand Up @@ -103,8 +128,10 @@ def __init__(
model: str = "gpt-4o",
generation_config: dict[str, Any] = {},
) -> None:
import asyncio

try:
from openai import OpenAI
from openai import AsyncOpenAI
except ImportError:
raise ImportError("Please install the openai package to use this provider.")
self.model = model
Expand All @@ -113,17 +140,36 @@ def __init__(

self.generation_config = {k: v for k, v in generation_config.items() if k not in client_params_keys}

self.llm = OpenAI(**client_params_opts)
self.llm = AsyncOpenAI(**client_params_opts)
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
self.loop = asyncio.new_event_loop()
Comment on lines +144 to +147
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Event loop management is problematic. Creating a new event loop in __init__ without setting it as the current loop can cause issues. The loop may not be properly configured for the execution context.


def __call__(self, input_prompt_column: Series) -> list[str]:
prompts = input_prompt_column.to_pylist()
outputs = []
for prompt in prompts:
import asyncio

async def get_completion(prompt: str) -> str:
messages = [{"role": "user", "content": prompt}]
completion = self.llm.chat.completions.create(
completion = await self.llm.chat.completions.create(
model=self.model,
messages=messages,
**self.generation_config,
)
outputs.append(completion.choices[0].message.content)
return completion.choices[0].message.content

prompts = input_prompt_column.to_pylist()

async def gather_completions() -> list[str]:
tasks = [get_completion(prompt) for prompt in prompts]
return await asyncio.gather(*tasks)

try:
outputs = self.loop.run_until_complete(gather_completions())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Using run_until_complete() on a potentially running loop will raise RuntimeError if called from within an async context. This will break in environments like Jupyter notebooks or async UDF execution.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Watch out for these and do some testing from within Jupyter? I did see some users complaining about needing to do some shenanigans with nested_async to get our stuff working in a notebook.

Generally needing to manually handle self.loop feels kinda icky to me. Is it not possible to just grab the currently available loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah can cfm that it works in jupyter (native + ray).

self.loop is to store either the current available loop, or create a new one if it does not exist

except Exception as e:
import logging

logger = logging.getLogger(__name__)
logger.exception(e)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently exceptions that are unserializable are not handled properly: #4881, so this is a workaround until it is fixed.

raise e
return outputs
Loading