|
1 | 1 | import asyncio
|
| 2 | +import mock |
2 | 3 | import os
|
3 | 4 | import sys
|
4 | 5 | import typing
|
|
9 | 10 | from hypothesis import given
|
10 | 11 |
|
11 | 12 | from flytekit import dynamic, task, workflow
|
| 13 | + |
| 14 | +from flytekit.bin.entrypoint import _get_working_loop, _dispatch_execute |
| 15 | +from flytekit.core import context_manager |
| 16 | +from flytekit.core.promise import VoidPromise |
12 | 17 | from flytekit.exceptions.user import FlyteValidationException
|
13 | 18 | from flytekit.experimental import EagerException, eager
|
| 19 | +from flytekit.models import literals as _literal_models |
14 | 20 | from flytekit.types.directory import FlyteDirectory
|
15 | 21 | from flytekit.types.file import FlyteFile
|
16 | 22 | from flytekit.types.structured import StructuredDataset
|
@@ -275,3 +281,26 @@ async def eager_wf_flyte_directory() -> str:
|
275 | 281 |
|
276 | 282 | result = asyncio.run(eager_wf_flyte_directory())
|
277 | 283 | assert result == "some data"
|
| 284 | + |
| 285 | + |
| 286 | +@mock.patch("flytekit.core.utils.load_proto_from_file") |
| 287 | +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") |
| 288 | +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") |
| 289 | +@mock.patch("flytekit.core.utils.write_proto_to_file") |
| 290 | +def test_eager_workflow_dispatch(mock_write_to_file, mock_put_data, mock_get_data, mock_load_proto, event_loop): |
| 291 | + """Test that event loop is preserved after executing eager workflow via dispatch.""" |
| 292 | + |
| 293 | + @eager |
| 294 | + async def eager_wf(): |
| 295 | + await asyncio.sleep(0.1) |
| 296 | + return |
| 297 | + |
| 298 | + ctx = context_manager.FlyteContext.current_context() |
| 299 | + with context_manager.FlyteContextManager.with_context( |
| 300 | + ctx.with_execution_state( |
| 301 | + ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) |
| 302 | + ) |
| 303 | + ) as ctx: |
| 304 | + _dispatch_execute(ctx, lambda: eager_wf, "inputs path", "outputs prefix") |
| 305 | + loop_after_execute = asyncio.get_event_loop_policy().get_event_loop() |
| 306 | + assert event_loop == loop_after_execute |
0 commit comments