Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions dspy/utils/asyncify.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import TYPE_CHECKING, Any, Awaitable, Callable
from typing import TYPE_CHECKING, Awaitable, Callable, ParamSpec, TypeVar, Union, overload

import asyncer
from anyio import CapacityLimiter

if TYPE_CHECKING:
from dspy.primitives.module import Module

P = ParamSpec("P")
T = TypeVar("T")

_limiter = None


Expand All @@ -27,22 +30,30 @@ def get_limiter():
return _limiter


def asyncify(program: "Module") -> Callable[[Any, Any], Awaitable[Any]]:
@overload
def asyncify(program: Callable[P, T]) -> Callable[P, Awaitable[T]]: ...


@overload
def asyncify(program: "Module") -> Callable[..., Awaitable[T]]: ...
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot how can we know T in this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! T was unbound in the Module overload. Since Module doesn't have a generic return type, we can't statically determine T. Fixed to use Any instead: Callable[..., Awaitable[Any]].

For typed callables, the signature is preserved: Callable[P, T]Callable[P, Awaitable[T]].
For Modules, we use Any since the return type varies by implementation.

Changes in commit be2bd89.



def asyncify(program: Union[Callable[P, T], "Module"]) -> Callable[P, Awaitable[T]] | Callable[..., Awaitable[T]]:
"""
Wraps a DSPy program so that it can be called asynchronously. This is useful for running a
Wraps a DSPy program or callable so that it can be called asynchronously. This is useful for running a
program in parallel with another task (e.g., another DSPy program).

This implementation propagates the current thread's configuration context to the worker thread.

Args:
program: The DSPy program to be wrapped for asynchronous execution.
program: The DSPy program or callable to be wrapped for asynchronous execution.

Returns:
An async function: An async function that, when awaited, runs the program in a worker thread.
The current thread's configuration context is inherited for each call.
"""

async def async_program(*args, **kwargs) -> Any:
async def async_program(*args: P.args, **kwargs: P.kwargs) -> T:
# Capture the current overrides at call-time.
from dspy.dsp.utils.settings import thread_local_overrides

Expand All @@ -62,4 +73,4 @@ def wrapped_program(*a, **kw):
call_async = asyncer.asyncify(wrapped_program, abandon_on_cancel=True, limiter=get_limiter())
return await call_async(*args, **kwargs)

return async_program
return async_program # type: ignore[return-value]
21 changes: 21 additions & 0 deletions tests/utils/test_asyncify.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,24 @@ async def verify_asyncify(capacity: int, number_of_tasks: int, wait: float = 0.5
await verify_asyncify(4, 10)
await verify_asyncify(8, 15)
await verify_asyncify(8, 30)


@pytest.mark.anyio
async def test_asyncify_with_dspy_module():
"""Test that asyncify works with DSPy modules and can be type-checked."""

class SimpleModule(dspy.Module):
def forward(self, x: int) -> int:
return x * 2

module = SimpleModule()
async_module = dspy.asyncify(module)

# Test with positional argument
result = await async_module(5)
assert result == 10, "Asyncified module should return correct result"

# Test with keyword argument
result = await async_module(x=7)
assert result == 14, "Asyncified module should work with keyword arguments"