Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Permit superclass to subclass lazy typing #696

Merged
merged 13 commits into from
Feb 27, 2024
Merged
4 changes: 3 additions & 1 deletion pydra/engine/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,9 @@ def make_klass(spec):
**kwargs,
)
checker_label = f"'{name}' field of {spec.name}"
type_checker = TypeParser[newfield.type](newfield.type, label=checker_label)
type_checker = TypeParser[newfield.type](
newfield.type, label=checker_label, superclass_auto_cast=True
)
if newfield.type in (MultiInputObj, MultiInputFile):
converter = attr.converters.pipe(ensure_list, type_checker)
elif newfield.type in (MultiOutputObj, MultiOutputFile):
Expand Down
2 changes: 1 addition & 1 deletion pydra/engine/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def collect_additional_outputs(self, inputs, output_dir, outputs):
),
):
raise TypeError(
f"Support for {fld.type} type, required for {fld.name} in {self}, "
f"Support for {fld.type} type, required for '{fld.name}' in {self}, "
"has not been implemented in collect_additional_output"
)
# assuming that field should have either default or metadata, but not both
Expand Down
16 changes: 1 addition & 15 deletions pydra/engine/tests/test_node_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,21 +133,7 @@ def test_task_init_3a(


def test_task_init_4():
"""task with interface and inputs. splitter set using split method"""
nn = fun_addtwo(name="NA")
nn.split(splitter="a", a=[3, 5])
assert np.allclose(nn.inputs.a, [3, 5])

assert nn.state.splitter == "NA.a"
assert nn.state.splitter_rpn == ["NA.a"]

nn.state.prepare_states(nn.inputs)
assert nn.state.states_ind == [{"NA.a": 0}, {"NA.a": 1}]
assert nn.state.states_val == [{"NA.a": 3}, {"NA.a": 5}]


def test_task_init_4a():
"""task with a splitter and inputs set in the split method"""
"""task with interface splitter and inputs set in the split method"""
nn = fun_addtwo(name="NA")
nn.split(splitter="a", a=[3, 5])
assert np.allclose(nn.inputs.a, [3, 5])
Expand Down
160 changes: 148 additions & 12 deletions pydra/utils/tests/test_typing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import itertools
import sys
import typing as ty
from pathlib import Path
import tempfile
Expand All @@ -8,13 +9,16 @@
from ...engine.specs import File, LazyOutField
from ..typing import TypeParser
from pydra import Workflow
from fileformats.application import Json
from fileformats.application import Json, Yaml, Xml
from .utils import (
generic_func_task,
GenericShellTask,
specific_func_task,
SpecificShellTask,
other_specific_func_task,
OtherSpecificShellTask,
MyFormatX,
MyOtherFormatX,
MyHeader,
)

Expand Down Expand Up @@ -164,6 +168,18 @@ def test_type_check_nested8():
)(lz(ty.List[float]))


def test_type_check_permit_superclass():
# Typical case as Json is subclass of File
TypeParser(ty.List[File])(lz(ty.List[Json]))
djarecka marked this conversation as resolved.
Show resolved Hide resolved
# Permissive super class, as File is superclass of Json
TypeParser(ty.List[Json], superclass_auto_cast=True)(lz(ty.List[File]))
with pytest.raises(TypeError, match="Cannot coerce"):
TypeParser(ty.List[Json], superclass_auto_cast=False)(lz(ty.List[File]))
# Fails because Yaml is neither sub or super class of Json
with pytest.raises(TypeError, match="Cannot coerce"):
TypeParser(ty.List[Json], superclass_auto_cast=True)(lz(ty.List[Yaml]))


def test_type_check_fail1():
with pytest.raises(TypeError, match="Wrong number of type arguments in tuple"):
TypeParser(ty.Tuple[int, int, int])(lz(ty.Tuple[float, float, float, float]))
Expand Down Expand Up @@ -538,7 +554,17 @@ def specific_task(request):
assert False


def test_typing_cast(tmp_path, generic_task, specific_task):
@pytest.fixture(params=["func", "shell"])
def other_specific_task(request):
if request.param == "func":
return other_specific_func_task
elif request.param == "shell":
return OtherSpecificShellTask
else:
assert False


def test_typing_implicit_cast_from_super(tmp_path, generic_task, specific_task):
"""Check the casting of lazy fields and whether specific file-sets can be recovered
from generic `File` classes"""

Expand All @@ -562,33 +588,86 @@ def test_typing_cast(tmp_path, generic_task, specific_task):
)
)

wf.add(
specific_task(
in_file=wf.generic.lzout.out,
name="specific2",
)
)

wf.set_output(
[
("out_file", wf.specific2.lzout.out),
]
)

in_file = MyFormatX.sample()

result = wf(in_file=in_file, plugin="serial")

out_file: MyFormatX = result.output.out_file
assert type(out_file) is MyFormatX
assert out_file.parent != in_file.parent
assert type(out_file.header) is MyHeader
assert out_file.header.parent != in_file.header.parent


def test_typing_cast(tmp_path, specific_task, other_specific_task):
"""Check the casting of lazy fields and whether specific file-sets can be recovered
from generic `File` classes"""

wf = Workflow(
name="test",
input_spec={"in_file": MyFormatX},
output_spec={"out_file": MyFormatX},
)

wf.add(
specific_task(
in_file=wf.lzin.in_file,
name="entry",
)
)

with pytest.raises(TypeError, match="Cannot coerce"):
# No cast of generic task output to MyFormatX
wf.add( # Generic task
other_specific_task(
in_file=wf.entry.lzout.out,
name="inner",
)
)

wf.add( # Generic task
other_specific_task(
in_file=wf.entry.lzout.out.cast(MyOtherFormatX),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this cast a method on fileformats? Four attribute accesses feels like a lot to follow. I would find it more straightforward to have a function:

Suggested change
in_file=wf.entry.lzout.out.cast(MyOtherFormatX),
in_file=pydra.cast(wf.entry.lzout.out, MyOtherFormatX),

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't fileformats specific, it can be any class that is able to be "cast", i.e. newclass(oldinstance) is valid code such as Path("mystring"). I'm not averse to your suggested syntax, but since it would only be applicable to lazy fields I think I would find having it as a method more convenient

name="inner",
)
)

with pytest.raises(TypeError, match="Cannot coerce"):
# No cast of generic task output to MyFormatX
wf.add(
specific_task(
in_file=wf.generic.lzout.out,
name="specific2",
in_file=wf.inner.lzout.out,
name="exit",
)
)

wf.add(
specific_task(
in_file=wf.generic.lzout.out.cast(MyFormatX),
name="specific2",
in_file=wf.inner.lzout.out.cast(MyFormatX),
name="exit",
)
)

wf.set_output(
[
("out_file", wf.specific2.lzout.out),
("out_file", wf.exit.lzout.out),
]
)

my_fspath = tmp_path / "in_file.my"
hdr_fspath = tmp_path / "in_file.hdr"
my_fspath.write_text("my-format")
hdr_fspath.write_text("my-header")
in_file = MyFormatX([my_fspath, hdr_fspath])
in_file = MyFormatX.sample()

result = wf(in_file=in_file, plugin="serial")

Expand All @@ -611,6 +690,63 @@ def test_type_is_subclass3():
assert TypeParser.is_subclass(ty.Type[Json], ty.Type[File])


def test_union_is_subclass1():
assert TypeParser.is_subclass(ty.Union[Json, Yaml], ty.Union[Json, Yaml, Xml])


def test_union_is_subclass2():
assert not TypeParser.is_subclass(ty.Union[Json, Yaml, Xml], ty.Union[Json, Yaml])


def test_union_is_subclass3():
assert TypeParser.is_subclass(Json, ty.Union[Json, Yaml])


def test_union_is_subclass4():
assert not TypeParser.is_subclass(ty.Union[Json, Yaml], Json)


def test_generic_is_subclass1():
assert TypeParser.is_subclass(ty.List[int], list)


def test_generic_is_subclass2():
assert not TypeParser.is_subclass(list, ty.List[int])


def test_generic_is_subclass3():
assert not TypeParser.is_subclass(ty.List[float], ty.List[int])


@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Cannot subscript tuple in < Py3.9"
)
def test_generic_is_subclass4():
class MyTuple(tuple):
pass

class A:
pass

class B(A):
pass

assert TypeParser.is_subclass(MyTuple[A], ty.Tuple[A])
assert TypeParser.is_subclass(ty.Tuple[B], ty.Tuple[A])
assert TypeParser.is_subclass(MyTuple[B], ty.Tuple[A])
assert not TypeParser.is_subclass(ty.Tuple[A], ty.Tuple[B])
assert not TypeParser.is_subclass(ty.Tuple[A], MyTuple[A])
assert not TypeParser.is_subclass(MyTuple[A], ty.Tuple[B])
assert TypeParser.is_subclass(MyTuple[A, int], ty.Tuple[A, int])
assert TypeParser.is_subclass(ty.Tuple[B, int], ty.Tuple[A, int])
assert TypeParser.is_subclass(MyTuple[B, int], ty.Tuple[A, int])
assert TypeParser.is_subclass(MyTuple[int, B], ty.Tuple[int, A])
assert not TypeParser.is_subclass(MyTuple[B, int], ty.Tuple[int, A])
assert not TypeParser.is_subclass(MyTuple[int, B], ty.Tuple[A, int])
assert not TypeParser.is_subclass(MyTuple[B, int], ty.Tuple[A])
assert not TypeParser.is_subclass(MyTuple[B], ty.Tuple[A, int])


def test_type_is_instance1():
assert TypeParser.is_instance(File, ty.Type[File])

Expand Down
65 changes: 63 additions & 2 deletions pydra/utils/tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from fileformats.generic import File
from fileformats.core.mixin import WithSeparateHeader
from fileformats.core.mixin import WithSeparateHeader, WithMagicNumber
from pydra import mark
from pydra.engine.task import ShellCommandTask
from pydra.engine import specs


class MyFormat(File):
class MyFormat(WithMagicNumber, File):
ext = ".my"
magic_number = b"MYFORMAT"


class MyHeader(File):
Expand All @@ -17,6 +18,12 @@ class MyFormatX(WithSeparateHeader, MyFormat):
header_type = MyHeader


class MyOtherFormatX(WithMagicNumber, WithSeparateHeader, File):
djarecka marked this conversation as resolved.
Show resolved Hide resolved
magic_number = b"MYFORMAT"
ext = ".my"
header_type = MyHeader


@mark.task
def generic_func_task(in_file: File) -> File:
return in_file
Expand Down Expand Up @@ -118,3 +125,57 @@ class SpecificShellTask(ShellCommandTask):
input_spec = specific_shell_input_spec
output_spec = specific_shelloutput_spec
executable = "echo"


@mark.task
def other_specific_func_task(in_file: MyOtherFormatX) -> MyOtherFormatX:
return in_file


other_specific_shell_input_fields = [
(
"in_file",
MyOtherFormatX,
{
"help_string": "the input file",
"argstr": "",
"copyfile": "copy",
"sep": " ",
},
),
(
"out",
str,
{
"help_string": "output file name",
"argstr": "",
"position": -1,
"output_file_template": "{in_file}", # Pass through un-altered
},
),
]

other_specific_shell_input_spec = specs.SpecInfo(
name="Input", fields=other_specific_shell_input_fields, bases=(specs.ShellSpec,)
)

other_specific_shell_output_fields = [
(
"out",
MyOtherFormatX,
{
"help_string": "output file",
},
),
]
other_specific_shelloutput_spec = specs.SpecInfo(
name="Output",
fields=other_specific_shell_output_fields,
bases=(specs.ShellOutSpec,),
)


class OtherSpecificShellTask(ShellCommandTask):
input_spec = other_specific_shell_input_spec
output_spec = other_specific_shelloutput_spec
executable = "echo"
Loading
Loading