Skip to content

Commit

Permalink
Merge pull request #687 from tclose/typing-bugfixes
Browse files Browse the repository at this point in the history
Typing bugfixes
  • Loading branch information
effigies authored Sep 7, 2023
2 parents 428cf04 + 103cefc commit 31aea01
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 68 deletions.
9 changes: 7 additions & 2 deletions pydra/engine/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ def make_klass(spec):
type=tp,
**kwargs,
)
type_checker = TypeParser[newfield.type](newfield.type)
checker_label = f"'{name}' field of {spec.name}"
type_checker = TypeParser[newfield.type](newfield.type, label=checker_label)
if newfield.type in (MultiInputObj, MultiInputFile):
converter = attr.converters.pipe(ensure_list, type_checker)
elif newfield.type in (MultiOutputObj, MultiOutputFile):
Expand Down Expand Up @@ -652,7 +653,11 @@ def argstr_formatting(argstr, inputs, value_updates=None):
for fld in inp_fields:
fld_name = fld[1:-1] # extracting the name form {field_name}
fld_value = inputs_dict[fld_name]
if fld_value is attr.NOTHING:
fld_attr = getattr(attrs.fields(type(inputs)), fld_name)
if fld_value is attr.NOTHING or (
fld_value is False
and TypeParser.matches_type(fld_attr.type, ty.Union[Path, bool])
):
# if value is NOTHING, nothing should be added to the command
val_dict[fld_name] = ""
else:
Expand Down
61 changes: 34 additions & 27 deletions pydra/engine/helpers_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def template_update(inputs, output_dir, state_ind=None, map_copyfiles=None):
field
for field in attr_fields(inputs)
if field.metadata.get("output_file_template")
and getattr(inputs, field.name) is not False
and all(
getattr(inputs, required_field) is not attr.NOTHING
for required_field in field.metadata.get("requires", ())
Expand Down Expand Up @@ -150,25 +151,19 @@ def template_update_single(
# if input_dict_st with state specific value is not available,
# the dictionary will be created from inputs object
from ..utils.typing import TypeParser # noqa
from pydra.engine.specs import LazyField

VALID_TYPES = (str, ty.Union[str, bool], Path, ty.Union[Path, bool], LazyField)
from pydra.engine.specs import LazyField, OUTPUT_TEMPLATE_TYPES

if inputs_dict_st is None:
inputs_dict_st = attr.asdict(inputs, recurse=False)

if spec_type == "input":
inp_val_set = inputs_dict_st[field.name]
if inp_val_set is not attr.NOTHING and not TypeParser.is_instance(
inp_val_set, VALID_TYPES
):
raise TypeError(
f"'{field.name}' field has to be a Path instance or a bool, but {inp_val_set} set"
)
if isinstance(inp_val_set, bool) and field.type in (Path, str):
raise TypeError(
f"type of '{field.name}' is Path, consider using Union[Path, bool]"
)
if inp_val_set is not attr.NOTHING and not isinstance(inp_val_set, LazyField):
inp_val_set = TypeParser(ty.Union[OUTPUT_TEMPLATE_TYPES])(inp_val_set)
elif spec_type == "output":
if not TypeParser.contains_type(FileSet, field.type):
raise TypeError(
Expand All @@ -178,22 +173,23 @@ def template_update_single(
else:
raise TypeError(f"spec_type can be input or output, but {spec_type} provided")
# for inputs that the value is set (so the template is ignored)
if spec_type == "input" and isinstance(inputs_dict_st[field.name], (str, Path)):
return inputs_dict_st[field.name]
elif spec_type == "input" and inputs_dict_st[field.name] is False:
# if input fld is set to False, the fld shouldn't be used (setting NOTHING)
return attr.NOTHING
else: # inputs_dict[field.name] is True or spec_type is output
value = _template_formatting(field, inputs, inputs_dict_st)
# changing path so it is in the output_dir
if output_dir and value is not attr.NOTHING:
# should be converted to str, it is also used for input fields that should be str
if type(value) is list:
return [str(output_dir / Path(val).name) for val in value]
else:
return str(output_dir / Path(value).name)
else:
if spec_type == "input":
if isinstance(inp_val_set, (Path, list)):
return inp_val_set
if inp_val_set is False:
# if input fld is set to False, the fld shouldn't be used (setting NOTHING)
return attr.NOTHING
# inputs_dict[field.name] is True or spec_type is output
value = _template_formatting(field, inputs, inputs_dict_st)
# changing path so it is in the output_dir
if output_dir and value is not attr.NOTHING:
# should be converted to str, it is also used for input fields that should be str
if type(value) is list:
return [str(output_dir / Path(val).name) for val in value]
else:
return str(output_dir / Path(value).name)
else:
return attr.NOTHING


def _template_formatting(field, inputs, inputs_dict_st):
Expand All @@ -204,16 +200,27 @@ def _template_formatting(field, inputs, inputs_dict_st):
Allowing for multiple input values used in the template as longs as
there is no more than one file (i.e. File, PathLike or string with extensions)
"""
from .specs import MultiInputObj, MultiOutputFile

# if a template is a function it has to be run first with the inputs as the only arg
template = field.metadata["output_file_template"]
if callable(template):
template = template(inputs)

# as default, we assume that keep_extension is True
keep_extension = field.metadata.get("keep_extension", True)
if isinstance(template, (tuple, list)):
formatted = [
_string_template_formatting(field, t, inputs, inputs_dict_st)
for t in template
]
else:
assert isinstance(template, str)
formatted = _string_template_formatting(field, template, inputs, inputs_dict_st)
return formatted


def _string_template_formatting(field, template, inputs, inputs_dict_st):
from .specs import MultiInputObj, MultiOutputFile

keep_extension = field.metadata.get("keep_extension", True)
inp_fields = re.findall(r"{\w+}", template)
inp_fields_fl = re.findall(r"{\w+:[0-9.]+f}", template)
inp_fields += [re.sub(":[0-9.]+f", "", el) for el in inp_fields_fl]
Expand Down
27 changes: 18 additions & 9 deletions pydra/engine/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ class MultiOutputType:
MultiOutputObj = ty.Union[list, object, MultiOutputType]
MultiOutputFile = ty.Union[File, ty.List[File], MultiOutputType]

OUTPUT_TEMPLATE_TYPES = (
Path,
ty.List[Path],
ty.Union[Path, bool],
ty.Union[ty.List[Path], bool],
ty.List[ty.List[Path]],
)


@attr.s(auto_attribs=True, kw_only=True)
class SpecInfo:
Expand Down Expand Up @@ -343,6 +351,8 @@ def check_metadata(self):
Also sets the default values when available and needed.
"""
from ..utils.typing import TypeParser

supported_keys = {
"allowed_values",
"argstr",
Expand All @@ -361,6 +371,7 @@ def check_metadata(self):
"formatter",
"_output_type",
}

for fld in attr_fields(self, exclude_names=("_func", "_graph_checksums")):
mdata = fld.metadata
# checking keys from metadata
Expand All @@ -377,16 +388,13 @@ def check_metadata(self):
)
# assuming that fields with output_file_template shouldn't have default
if mdata.get("output_file_template"):
if fld.type not in (
Path,
ty.Union[Path, bool],
str,
ty.Union[str, bool],
if not any(
TypeParser.matches_type(fld.type, t) for t in OUTPUT_TEMPLATE_TYPES
):
raise TypeError(
f"Type of '{fld.name}' should be either pathlib.Path or "
f"typing.Union[pathlib.Path, bool] (not {fld.type}) because "
f"it has a value for output_file_template ({mdata['output_file_template']!r})"
f"Type of '{fld.name}' should be one of {OUTPUT_TEMPLATE_TYPES} "
f"(not {fld.type}) because it has a value for output_file_template "
f"({mdata['output_file_template']!r})"
)
if fld.default not in [attr.NOTHING, True, False]:
raise AttributeError(
Expand Down Expand Up @@ -443,7 +451,8 @@ def collect_additional_outputs(self, inputs, output_dir, outputs):
input_value = getattr(inputs, fld.name, attr.NOTHING)
if input_value is not attr.NOTHING:
if TypeParser.contains_type(FileSet, fld.type):
input_value = TypeParser(fld.type).coerce(input_value)
label = f"output field '{fld.name}' of {self}"
input_value = TypeParser(fld.type, label=label).coerce(input_value)
additional_out[fld.name] = input_value
elif (
fld.default is None or fld.default == attr.NOTHING
Expand Down
74 changes: 74 additions & 0 deletions pydra/engine/tests/test_helpers_file.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import typing as ty
import sys
from pathlib import Path
import attr
from unittest.mock import Mock
import pytest
from fileformats.generic import File
from ..specs import SpecInfo, ShellSpec
from ..task import ShellCommandTask
from ..helpers_file import (
ensure_list,
MountIndentifier,
copy_nested_files,
template_update_single,
)


Expand Down Expand Up @@ -343,3 +348,72 @@ def test_cifs_check():
with MountIndentifier.patch_table(fake_table):
for target, expected in cifs_targets:
assert MountIndentifier.on_cifs(target) is expected


def test_output_template(tmp_path):
filename = str(tmp_path / "file.txt")
with open(filename, "w") as f:
f.write("hello from pydra")
in_file = File(filename)

my_input_spec = SpecInfo(
name="Input",
fields=[
(
"in_file",
attr.ib(
type=File,
metadata={
"mandatory": True,
"position": 1,
"argstr": "",
"help_string": "input file",
},
),
),
(
"optional",
attr.ib(
type=ty.Union[Path, bool],
default=False,
metadata={
"position": 2,
"argstr": "--opt",
"output_file_template": "{in_file}.out",
"help_string": "optional file output",
},
),
),
],
bases=(ShellSpec,),
)

class MyCommand(ShellCommandTask):
executable = "my"
input_spec = my_input_spec

task = MyCommand(in_file=filename)
assert task.cmdline == f"my {filename}"
task.inputs.optional = True
assert task.cmdline == f"my {filename} --opt {task.output_dir / 'file.out'}"
task.inputs.optional = False
assert task.cmdline == f"my {filename}"
task.inputs.optional = "custom-file-out.txt"
assert task.cmdline == f"my {filename} --opt custom-file-out.txt"


def test_template_formatting(tmp_path):
field = Mock()
field.name = "grad"
field.argstr = "--grad"
field.metadata = {"output_file_template": ("{in_file}.bvec", "{in_file}.bval")}
inputs = Mock()
inputs_dict = {"in_file": "/a/b/c/file.txt", "grad": True}

assert template_update_single(
field,
inputs,
inputs_dict_st=inputs_dict,
output_dir=tmp_path,
spec_type="input",
) == [str(tmp_path / "file.bvec"), str(tmp_path / "file.bval")]
28 changes: 28 additions & 0 deletions pydra/utils/tests/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,3 +597,31 @@ def test_typing_cast(tmp_path, generic_task, specific_task):
assert out_file.parent != in_file.parent
assert type(out_file.header) is MyHeader
assert out_file.header.parent != in_file.header.parent


def test_type_is_subclass1():
assert TypeParser.is_subclass(ty.Type[File], type)


def test_type_is_subclass2():
assert not TypeParser.is_subclass(ty.Type[File], ty.Type[Json])


def test_type_is_subclass3():
assert TypeParser.is_subclass(ty.Type[Json], ty.Type[File])


def test_type_is_instance1():
assert TypeParser.is_instance(File, ty.Type[File])


def test_type_is_instance2():
assert not TypeParser.is_instance(File, ty.Type[Json])


def test_type_is_instance3():
assert TypeParser.is_instance(Json, ty.Type[File])


def test_type_is_instance4():
assert TypeParser.is_instance(Json, type)
Loading

0 comments on commit 31aea01

Please sign in to comment.