Skip to content

Commit

Permalink
Make FlyteFile compatible with Annotated[..., HashMethod] (#1544)
Browse files Browse the repository at this point in the history
* fix: Make FlyteFile compatible with Annotated[..., HashMethod]

See issue #3424

Signed-off-by: Adrian Rumpold <[email protected]>

* tests: Add test case for FlyteFile with HashMethod annotation

Issue: #3424
Signed-off-by: Adrian Rumpold <[email protected]>

* fix: Use typing_extensions.Annotated for py3.8 compatibility

Issue: #3424
Signed-off-by: Adrian Rumpold <[email protected]>

* fix: Use `get_args` and `get_origin` from typing_extensions for py3.8 compatibility

Issue: #3424
Signed-off-by: Adrian Rumpold <[email protected]>

* fix(tests): Use fixture for local dummy file

Signed-off-by: Adrian Rumpold <[email protected]>

---------

Signed-off-by: Adrian Rumpold <[email protected]>
  • Loading branch information
AdrianoKF authored Mar 21, 2023
1 parent 98e74c2 commit a190431
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
5 changes: 5 additions & 0 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from dataclasses_json import config, dataclass_json
from marshmallow import fields
from typing_extensions import Annotated, get_args, get_origin

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
Expand Down Expand Up @@ -335,6 +336,10 @@ def to_literal(
if python_val is None:
raise TypeTransformerFailedError("None value cannot be converted to a file.")

# Correctly handle `Annotated[FlyteFile, ...]` by extracting the origin type
if get_origin(python_type) is Annotated:
python_type = get_args(python_type)[0]

if not (python_type is os.PathLike or issubclass(python_type, FlyteFile)):
raise ValueError(f"Incorrect type {python_type}, must be either a FlyteFile or os.PathLike")

Expand Down
17 changes: 17 additions & 0 deletions tests/flytekit/unit/core/test_flyte_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from unittest.mock import MagicMock

import pytest
from typing_extensions import Annotated

import flytekit.configuration
from flytekit.configuration import Config, Image, ImageConfig
from flytekit.core.context_manager import ExecutionState, FlyteContextManager
from flytekit.core.data_persistence import FileAccessProvider, flyte_tmp_dir
from flytekit.core.dynamic_workflow_task import dynamic
from flytekit.core.hash import HashMethod
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.task import task
from flytekit.core.type_engine import TypeEngine
Expand Down Expand Up @@ -433,6 +435,21 @@ def wf(path: str) -> os.PathLike:
assert flyte_tmp_dir in wf(path="s3://somewhere").path


def test_flyte_file_annotated_hashmethod(local_dummy_file):
def calc_hash(ff: FlyteFile) -> str:
return str(ff.path)

@task
def t1(path: str) -> Annotated[FlyteFile, HashMethod(calc_hash)]:
return FlyteFile(path)

@workflow
def wf(path: str) -> None:
t1(path=path)

wf(path=local_dummy_file)


@pytest.mark.sandbox_test
def test_file_open_things():
@task
Expand Down

0 comments on commit a190431

Please sign in to comment.