Skip to content

Commit 6430edb

Browse files
cosmicBboyotarabai
authored andcommitted
eager workflow: use event loop instead of asyncio.run (flyteorg#2737)
Signed-off-by: Niels Bantilan <[email protected]>
1 parent 91ba4e1 commit 6430edb

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

flytekit/bin/entrypoint.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import sys
1010
import tempfile
1111
import traceback
12+
import warnings
1213
from sys import exit
1314
from typing import Callable, List, Optional
1415

@@ -70,6 +71,23 @@ def _compute_array_job_index():
7071
return offset
7172

7273

74+
def _get_working_loop():
75+
"""Returns a running event loop."""
76+
try:
77+
return asyncio.get_running_loop()
78+
except RuntimeError:
79+
with warnings.catch_warnings():
80+
warnings.simplefilter("error", DeprecationWarning)
81+
try:
82+
return asyncio.get_event_loop_policy().get_event_loop()
83+
# Since version 3.12, DeprecationWarning is emitted if there is no
84+
# current event loop.
85+
except DeprecationWarning:
86+
loop = asyncio.get_event_loop_policy().new_event_loop()
87+
asyncio.set_event_loop(loop)
88+
return loop
89+
90+
7391
def _dispatch_execute(
7492
ctx: FlyteContext,
7593
load_task: Callable[[], PythonTask],
@@ -109,7 +127,7 @@ def _dispatch_execute(
109127
if inspect.iscoroutine(outputs):
110128
# Handle eager-mode (async) tasks
111129
logger.info("Output is a coroutine")
112-
outputs = asyncio.run(outputs)
130+
outputs = _get_working_loop().run_until_complete(outputs)
113131

114132
# Step3a
115133
if isinstance(outputs, VoidPromise):

tests/flytekit/unit/experimental/test_eager_workflows.py

+29
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import mock
23
import os
34
import sys
45
import typing
@@ -9,8 +10,13 @@
910
from hypothesis import given
1011

1112
from flytekit import dynamic, task, workflow
13+
14+
from flytekit.bin.entrypoint import _get_working_loop, _dispatch_execute
15+
from flytekit.core import context_manager
16+
from flytekit.core.promise import VoidPromise
1217
from flytekit.exceptions.user import FlyteValidationException
1318
from flytekit.experimental import EagerException, eager
19+
from flytekit.models import literals as _literal_models
1420
from flytekit.types.directory import FlyteDirectory
1521
from flytekit.types.file import FlyteFile
1622
from flytekit.types.structured import StructuredDataset
@@ -275,3 +281,26 @@ async def eager_wf_flyte_directory() -> str:
275281

276282
result = asyncio.run(eager_wf_flyte_directory())
277283
assert result == "some data"
284+
285+
286+
@mock.patch("flytekit.core.utils.load_proto_from_file")
287+
@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data")
288+
@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data")
289+
@mock.patch("flytekit.core.utils.write_proto_to_file")
290+
def test_eager_workflow_dispatch(mock_write_to_file, mock_put_data, mock_get_data, mock_load_proto, event_loop):
291+
"""Test that event loop is preserved after executing eager workflow via dispatch."""
292+
293+
@eager
294+
async def eager_wf():
295+
await asyncio.sleep(0.1)
296+
return
297+
298+
ctx = context_manager.FlyteContext.current_context()
299+
with context_manager.FlyteContextManager.with_context(
300+
ctx.with_execution_state(
301+
ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION)
302+
)
303+
) as ctx:
304+
_dispatch_execute(ctx, lambda: eager_wf, "inputs path", "outputs prefix")
305+
loop_after_execute = asyncio.get_event_loop_policy().get_event_loop()
306+
assert event_loop == loop_after_execute

0 commit comments

Comments
 (0)