Skip to content

Commit

Permalink
refac
Browse files Browse the repository at this point in the history
  • Loading branch information
kramstrom committed Oct 8, 2024
1 parent 5b9a3a5 commit ae1d8f7
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 14 deletions.
13 changes: 1 addition & 12 deletions modal/_utils/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@
AsyncIterator,
Awaitable,
Callable,
Iterable,
Iterator,
List,
Optional,
Set,
TypeVar,
Union,
cast,
)

Expand Down Expand Up @@ -482,17 +480,8 @@ def run_generator_sync(

@asynccontextmanager
async def aclosing(
gen: Union[AsyncGenerator[T, None], Iterable[T]],
agen: AsyncGenerator[T, None],
) -> AsyncGenerator[AsyncGenerator[T, None], None]:
def iter_to_agen(gen: Iterable[T]) -> AsyncGenerator[T, None]:
async def agen():
for item in gen:
yield item

return agen()

agen = iter_to_agen(gen) if not isinstance(gen, AsyncGenerator) else gen

try:
yield agen
finally:
Expand Down
8 changes: 6 additions & 2 deletions modal/parallel_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,12 @@ async def _starmap_async(
async def feed_queue():
# This runs in a main thread event loop, so it doesn't block the synchronizer loop

async with aclosing(input_iterator) as stream:
async for args in stream:
if isinstance(input_iterator, typing.AsyncIterable):
async with aclosing(input_iterator) as stream:
async for args in stream:
await raw_input_queue.put.aio((args, kwargs))
else:
for args in input_iterator:
await raw_input_queue.put.aio((args, kwargs))

await raw_input_queue.put.aio(None) # end-of-input sentinel
Expand Down
194 changes: 194 additions & 0 deletions self_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import asyncio
import random
from contextlib import asynccontextmanager


async def async_generator(count, multiplier=1):
for i in range(count):
# sleep = random.random()
# print(f"async_generator: {i=} {sleep=}\n")
await asyncio.sleep(0.1)

res = (i + 1) * multiplier
yield res


async def async_generator_error(count, multiplier=1):
for i in range(count):
# sleep = random.random()
# print(f"async_generator: {i=} {sleep=}\n")
await asyncio.sleep(0.1)

res = (i + 1) * multiplier
if res == 30:
print("raising exception")
raise Exception("res == 30")
yield res


def func(x):
return f"func({x})"


async def async_func(x):
await asyncio.sleep(random.random())
return f"async_func({x})"


async def async_func_error(x):
if x == 5:
raise Exception("x == 5")
await asyncio.sleep(random.random())
return f"async_func({x})"


@asynccontextmanager
async def aclosing(agen):
try:
yield agen
finally:
await agen.aclose()


async def async_map(input, callable, concurrency=1):
input_queue = asyncio.Queue(maxsize=concurrency)
results_queue = asyncio.Queue()

new_result_event = asyncio.Event()

async def producer():
async for item in input:
await input_queue.put(item)

# as long as there are inputs
#

async def worker():
while True:
try:
item = await input_queue.get()

# check if callable is async
if asyncio.iscoroutinefunction(callable):
result = await callable(item)
else:
result = callable(item)

# result = await callable(item)
await results_queue.put(result)
new_result_event.set()
# input_queue.task_done()
except Exception as e:
await results_queue.put(e)
new_result_event.set()
finally:
input_queue.task_done()

producer_task = asyncio.create_task(producer())
worker_tasks = [asyncio.create_task(worker()) for _ in range(concurrency)]

wait_for_results_task = asyncio.create_task(new_result_event.wait())

async def complete_map():
await producer_task
await input_queue.join()

complete_map_task = asyncio.create_task(complete_map())

try:
while True:
await asyncio.wait(
[complete_map_task, producer_task, *worker_tasks, wait_for_results_task],
return_when=asyncio.FIRST_COMPLETED,
)

if complete_map_task.done():
while not results_queue.empty():
yield await results_queue.get()
break

if new_result_event.is_set():
while not results_queue.empty():
result = await results_queue.get()
if isinstance(result, Exception):
raise result
yield result
new_result_event.clear()

finally:
for task in [producer_task, complete_map_task, *worker_tasks]:
task.cancel()
await asyncio.gather(producer_task, complete_map_task, *worker_tasks, return_exceptions=True)


async def async_merge(input, *more_inputs):
queue = asyncio.Queue()
inputs = [input] + list(more_inputs)

async def producer(iterator):
async for item in iterator:
await queue.put(item)

tasks = [asyncio.create_task(producer(it)) for it in inputs]

async def complete_merge():
for task in tasks:
await task
await queue.join()

complete_merge_task = asyncio.create_task(complete_merge())

try:
while True:
await asyncio.wait([complete_merge_task, *tasks], return_when=asyncio.FIRST_COMPLETED)
if complete_merge_task.done():
break

while not queue.empty():
item = await queue.get()
# if isinstance(item, Exception):
# print("raising exception when getting item")
# raise item
yield item
queue.task_done()

finally:
for task in [complete_merge_task, *tasks]:
task.cancel()
await asyncio.gather(complete_merge_task, *tasks, return_exceptions=False)


async def main():
# start = time.time()
# res = [i async for i in async_map(async_generator(10), async_func, concurrency=1)]
# end = time.time()
# print(f"time: {end - start}\n")

async with aclosing(
async_merge(async_generator(5, 1), async_generator_error(5, 10), async_generator(5, 100))
) as stream:
async for i in stream:
print(i)

# import time
# start = time.time()
# res = []
# async with aclosing(async_map(async_generator(10), async_func_error, concurrency=10)) as stream:
# async for i in stream:
# res.append(i)

# # import aiostream
# # async with aiostream.stream.map(async_generator(10), async_func_error, task_limit=10).stream() as stream:
# # async for i in stream:
# # res.append(i)
# end = time.time()
# print(f"time: {end - start}\n")

# for r in res:
# print(r)


if __name__ == "__main__":
# print("Hello World")

asyncio.run(main())

0 comments on commit ae1d8f7

Please sign in to comment.