diff --git a/snakehelp/parameters.py b/snakehelp/parameters.py index fb27013..effcfc1 100644 --- a/snakehelp/parameters.py +++ b/snakehelp/parameters.py @@ -1,7 +1,8 @@ import os from collections import namedtuple from dataclasses import dataclass, fields -from typing import get_origin, Literal +from types import UnionType +from typing import get_origin, Literal, Union, get_args from snakehelp.snakehelp import classproperty, string_is_valid_type, type_to_regex @@ -15,21 +16,29 @@ def _field_names(cls): return [field.name for field in fields(cls)] @classmethod - def get_fields(cls): + def get_fields(cls, minimal=False, minimal_children=False): """ Returns a list of tuples (field_name, field_type) + + If minimal is True, Literal types with only one possible value are ignored, i.e. only + arguments that are necessary for uniquely representing the object are included. + + minimal_children specifies only whether children should be minimal. """ field_tuple = namedtuple("Field", ["name", "type"]) out = [] for field in fields(cls): + if minimal and get_origin(field.type) == Literal and len(get_args(field.type)) == 1: + continue + if field.type in (int, str, float): out.append(field_tuple(field.name, field.type)) - elif get_origin(field.type) == Literal: + elif get_origin(field.type) in (Literal, Union, UnionType): out.append(field_tuple(field.name, field.type)) else: assert hasattr(field.type, "get_fields"), "Field type %s is not valid. " \ "Must be a base type or a class decorated with @parameters" % field.type - out.extend(field.type.get_fields()) + out.extend(field.type.get_fields(minimal=minimal_children, minimal_children=minimal_children)) return out @@ -40,38 +49,73 @@ def parameters(cls): """ return [field.name for field in cls.get_fields()] - @classmethod - def as_input(cls, wildcards): + @classproperty + def minimal_parameters(cls): """ - Tries to return a valid snakemake input-file by using the given wildcards (as many as possible, starting from the first). + Returns a list of the minimum set of parameters needed to uniquely represent + this objeckt, meaning that Literal parameters with only one possible value are ignored. + """ + return [field.name for field in cls.get_fields(minimal=True)] - If fields are Union-types, there may be multible possible paths that can be created. This method checks that there - is no ambiguity, and raises an Exception if there is. + @classmethod + def as_input(cls): """ - assert hasattr(wildcards, - "items"), "As input can only be called with a dictlike object with an items() method" + Returns an input-function that can be used by Snakemake. + """ + + def func(wildcards): + assert hasattr(wildcards, + "items"), "As input can only be called with a dictlike object with an items() method" - fields = cls.get_fields() + fields = cls.get_fields() + # create a path from the wildcards and the parameters + # can maybe be done by just calling output with these wildcards. + return cls.as_output(**{name: t for name, t in wildcards.items() if name in cls.parameters}) - path = [] + path = [] - for i, (name, value) in enumerate(wildcards.items()): - assert name == fields[i].name - assert string_is_valid_type(value, fields[i].type) - path.append(value) + """ + for i, (name, value) in enumerate(wildcards.items()): + if i >= len(fields): + break + assert name == fields[i].name, f"Parsing {cls}. Invalid at {i}, name: {name}, expected {fields[i].name}" + assert string_is_valid_type(value, fields[i].type), f"{value} is not a valid as type {fields[i].type}" + path.append(value) - return os.path.sep.join(path) + return os.path.sep.join(path) + """ + + return func @classmethod - def as_output(cls): + def as_output(cls, **kwargs): """ Returns a valid Snakemake wildcard string with regex so force types + + Keyword arguments can be specified to fix certain variables to values. """ - names_with_regexes = ["{" + field.name + "," + type_to_regex(field.type) + "}" for field in - cls.get_fields()] - return os.path.sep.join(names_with_regexes) + names_with_regexes = [] + for name in kwargs: + assert name in cls.parameters, "Trying to force a field %s. Available fields are %s" % (name, cls.parameters) + for field in cls.get_fields(minimal_children=True): + if field.name in kwargs: + assert string_is_valid_type(kwargs[field.name], field.type), \ + f"Trying to set field {field.name} to value {kwargs[field.name]}, " \ + f"but this is not compatible with the field type {field.type}." + + names_with_regexes.append(str(kwargs[field.name])) + else: + if get_origin(field.type) == Literal and len(get_args(field.type)) == 1: + # literal types enforces a single value, should not be wildcards + names_with_regexes.append(get_args(field.type)[0]) + else: + names_with_regexes.append("{" + field.name + "," + type_to_regex(field.type) + "}") + + return os.path.sep.join(names_with_regexes) + Parameters.__name__ = base_class.__name__ Parameters.__qualname__ = base_class.__qualname__ + return Parameters diff --git a/snakehelp/snakehelp.py b/snakehelp/snakehelp.py index e2a407c..083ba74 100644 --- a/snakehelp/snakehelp.py +++ b/snakehelp/snakehelp.py @@ -1,7 +1,8 @@ import os from collections import namedtuple from dataclasses import dataclass, fields -from typing import Literal, get_origin, get_args +from types import UnionType +from typing import Literal, get_origin, get_args, Union import re @@ -15,16 +16,33 @@ def __get__(self, obj, owner): classproperty = ClassProperty +def is_parameter_type(object): + return hasattr(object, "get_fields") + + +def is_base_type(type): + return type in (str, float, bool, int) + def type_to_regex(type): if type == int: return "\\d+" - elif type == float: + if type == float: return "[+-]?([0-9]*[.])?[0-9]+" - elif type == str: + if type == str: return "\\w+" - elif get_origin(type) == Literal: + if get_origin(type) == Literal: return "|".join([re.escape(arg) for arg in get_args(type)]) + elif get_origin(type) in (Union, UnionType): + if all(is_base_type(t) for t in get_args(type)): + # all types are base type, we can give a regex for each + return "|".join([type_to_regex(t) for t in get_args(type)]) + else: + # There is one or more objects, we can have anything + return ".*" + elif is_parameter_type(type): + # normal parameter-objects are strings in the path + return "\w+" raise Exception("Invalid type %s" % type) @@ -33,7 +51,7 @@ def string_is_valid_type(string, type): if type == str: return True elif type == int: - return string.isdigit() + return isinstance(string, int) or string.isdigit() elif type == float: try: float(string) @@ -42,6 +60,10 @@ def string_is_valid_type(string, type): return False elif get_origin(type) == Literal: return string in get_args(type) + elif get_origin(type) in [Union, UnionType]: + return any((string_is_valid_type(string, t) for t in get_args(type))) + elif is_parameter_type(type): + return isinstance(string, str) else: raise Exception("Type %s not implemented" % type) diff --git a/test_pipeline/Snakefile b/test_pipeline/Snakefile new file mode 100644 index 0000000..38ec7e7 --- /dev/null +++ b/test_pipeline/Snakefile @@ -0,0 +1,96 @@ +configfile: "config/config.yaml" +from snakehelp import parameters +from dataclasses import dataclass +from typing import Literal, Union + + +@parameters +class ReferenceGenome: + genome_build: str + random_seed: int + dataset_size: Literal["small", "medium", "big"] + file: Literal["ref.fa"] + + +@parameters +class ChipSeqReads: + n_peaks: int + binding_strength: float + file: Literal["reads.fq.gz"] + +@parameters +class ReadErrorProfile: + substitution_rate: float + indel_rate: float + + +@parameters +class SingleEndReads: + n_reads: int + read_length: int + error_profile: ReadErrorProfile + ending: Literal["reads.fq.gz"] + + +@parameters +class PairedEndReads: + n_reads: int + fragment_length_mean: int + fragment_length_std: int + error_profile: ReadErrorProfile + + +@parameters +class SimulatedReads: + reference_genome: ReferenceGenome + read_config: Union[ChipSeqReads, PairedEndReads, SingleEndReads] + file: Literal["reads.fq.gz"] + + +@parameters +class MappedReads: + reads: SimulatedReads + method: str + n_threads: int + ending: Literal["mapped.bam"] + + +def test_input_function(wildcards): + print(wildcards, type(wildcards)) + return "test.txt" + + +rule test: + input: test_input_function + output: touch("data/{param1,\d+}/{param2,\w+}/file.txt") + + +rule map: + input: + reads=SimulatedReads.as_input(), + reference=ReferenceGenome.as_input() + output: + reads=touch(MappedReads.as_output(method='bwa')) + + +print(MappedReads.as_output(method='bwa')) + + +rule test_files: + output: + reads=touch("hg38/123/medium/some/read/config/reads.fq.gz"), + reference_genome=touch("hg38/123/medium/ref.fa") + + +rule all: + input: + "hg38/123/medium/some/read/config/bwa/8/mapped.bam" + +""" +rule map: + input: + ref = ReferenceGenome.as_input, + reads = SimulatedReads.as_input + output: + bam=MappedReads.as_output +""" diff --git a/tests/test_parameters.py b/tests/test_parameters.py index 608d84b..1de565c 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -1,12 +1,157 @@ +import os +import pytest from snakehelp import parameters +from snakehelp.snakehelp import type_to_regex +from typing import Literal, Union + + +class WildcardMock: + """Behaves as a Snakemake wildcards object. Initialize with kwargs.""" + def __init__(self, *args, **kwargs): + assert len(args) == 0 + self._data = kwargs + + def __getattr__(self, item): + return self._data[item] + + def items(self): + return self._data.items() + + +@parameters +class MyParams: + seed: int + name: str + ratio: float + + +@parameters +class MyParams2: + param1: MyParams + some_other_param: str + + +@parameters +class MyParams3: + param1: Literal["test", "test2"] + param2: str @parameters -class MyParameters: +class MyParams4: seed: int name: str + file_ending: Literal["file.npz"] + + +def test_init_parameters(): + assert MyParams.parameters == ["seed", "name", "ratio"] + + +def test_init_hierarchical_parameters(): + assert MyParams2.parameters == ["seed", "name", "ratio", "some_other_param"] + + +def test_init_parameters_with_literal(): + assert MyParams3.parameters == ["param1", "param2"] + + +def test_type_to_regex(): + assert type_to_regex(Union[str, int]) == "\\w+|\\d+" + assert type_to_regex(Literal["test1", "test2"]) == "test1|test2" + assert type_to_regex(int) == "\\d+" + + +def test_as_output(): + assert MyParams4.as_output() == r"{seed,\d+}/{name,\w+}/file.npz" + + +def test_as_input(): + wildcards = WildcardMock(seed="1", name="test", file_ending="file.npz") + path = MyParams4.as_input()(wildcards) + assert path == os.path.sep.join(["1", "test", "file.npz"]) + + +def test_as_partial_input(): + # sometimes the input will only match some of the wildcards, but this should work + wildcards = WildcardMock(seed="1", name="test", file_ending="file.npz", b="test", c="test2", d="test3") + path = MyParams4.as_input()(wildcards) + assert path == os.path.sep.join(["1", "test", "file.npz"]) + + +def test_as_partial_input_end(): + # partial inputs where parameters match at the end + # not iplemented + pass + + +def test_as_input_hierarchical(): + wildcards = WildcardMock(seed="1", name="test", ratio="0.3", some_other_param="test2") + path = MyParams2.as_input()(wildcards) + assert path == os.path.sep.join(["1", "test", "0.3", "test2"]) + + +@parameters +class ParamsWithUnion: + param1: Union[int, str] + param2: str + + +def test_union_params(): + assert ParamsWithUnion.parameters == ["param1", "param2"] + assert ParamsWithUnion.as_output() == r"{param1,\d+|\w+}/{param2,\w+}" + + +@parameters +class ParamsA: + a: float + b: int + +@parameters +class ParamsB: + x: int + y: int + z: int + +@parameters +class ParamsWithHierarchcicalUnion: + name: str + config: Union[ParamsA, ParamsB] ending: str -def test_parameters_decorator(): - assert MyParameters.parameters == ["seed", "name", "ending"] +def test_union_and_hierarchical(): + assert ParamsWithHierarchcicalUnion.parameters == ["name", "config", "ending"] + assert ParamsWithHierarchcicalUnion.as_output() == r"{name,\w+}/{config,.*}/{ending,\w+}" + + +def test_as_output_with_arguments(): + assert ParamsB.as_output() == r"{x,\d+}/{y,\d+}/{z,\d+}" + assert ParamsB.as_output(y=10) == r"{x,\d+}/10/{z,\d+}" + + +def test_minimal_parameters(): + assert MyParams4.parameters == ["seed", "name", "file_ending"] + assert MyParams4.minimal_parameters == ["seed", "name"] + + +@parameters +class Child: + type: str + ending: Literal["file.txt"] + + +@parameters +class Parent: + param1: Child + param2: int + ending: Literal["results.txt"] + + +def test_children_with_literal_that_should_be_ignored(): + assert Parent.as_output() == r"{type,\w+}/{param2,\d+}/results.txt" + + +if __name__ == "__main__": + test_type_to_regex() + #test_union_params() diff --git a/tests/test_snakehelp.py b/tests/test_snakehelp.py deleted file mode 100644 index 610f610..0000000 --- a/tests/test_snakehelp.py +++ /dev/null @@ -1,75 +0,0 @@ -import os -import pytest -from snakehelp import parameters -from snakehelp.snakehelp import type_to_regex -from typing import Literal - - -class WildcardMock: - """Behaves as a Snakemake wildcards object. Initialize with kwargs.""" - def __init__(self, *args, **kwargs): - assert len(args) == 0 - self._data = kwargs - - def __getattr__(self, item): - return self._data[item] - - def items(self): - return self._data.items() - - -@parameters -class MyParams: - seed: int - name: str - ratio: float - - -@parameters -class MyParams2: - param1: MyParams - some_other_param: str - - -@parameters -class MyParams3: - param1: Literal["test", "test2"] - param2: str - - -@parameters -class MyParams4: - seed: int - name: str - file_ending: Literal[".npz"] - - -def test_init_parameters(): - assert MyParams.parameters == ["seed", "name", "ratio"] - - -def test_init_hierarchical_parameters(): - assert MyParams2.parameters == ["seed", "name", "ratio", "some_other_param"] - - -def test_init_parameters_with_literal(): - assert MyParams3.parameters == ["param1", "param2"] - - -def test_type_to_regex(): - assert type_to_regex(Literal["test1", "test2"]) == "test1|test2" - assert type_to_regex(int) == "\d+" - - -def test_as_output(): - assert MyParams4.as_output() == "{seed,\d+}/{name,\w+}/{file_ending,\.npz}" - - -def test_as_input(): - wildcards = WildcardMock(seed="1", name="test", file_ending=".npz") - path = MyParams4.as_input(wildcards) - assert path == os.path.sep.join(["1", "test", ".npz"]) - - -def test_as_input_hierarchical(): - pass