Skip to content

Commit

Permalink
working ish
Browse files Browse the repository at this point in the history
  • Loading branch information
ivargr committed Mar 22, 2023
1 parent 70d2277 commit 960d15e
Show file tree
Hide file tree
Showing 5 changed files with 337 additions and 105 deletions.
88 changes: 66 additions & 22 deletions snakehelp/parameters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
from collections import namedtuple
from dataclasses import dataclass, fields
from typing import get_origin, Literal
from types import UnionType
from typing import get_origin, Literal, Union, get_args
from snakehelp.snakehelp import classproperty, string_is_valid_type, type_to_regex


Expand All @@ -15,21 +16,29 @@ def _field_names(cls):
return [field.name for field in fields(cls)]

@classmethod
def get_fields(cls):
def get_fields(cls, minimal=False, minimal_children=False):
"""
Returns a list of tuples (field_name, field_type)
If minimal is True, Literal types with only one possible value are ignored, i.e. only
arguments that are necessary for uniquely representing the object are included.
minimal_children specifies only whether children should be minimal.
"""
field_tuple = namedtuple("Field", ["name", "type"])
out = []
for field in fields(cls):
if minimal and get_origin(field.type) == Literal and len(get_args(field.type)) == 1:
continue

if field.type in (int, str, float):
out.append(field_tuple(field.name, field.type))
elif get_origin(field.type) == Literal:
elif get_origin(field.type) in (Literal, Union, UnionType):
out.append(field_tuple(field.name, field.type))
else:
assert hasattr(field.type, "get_fields"), "Field type %s is not valid. " \
"Must be a base type or a class decorated with @parameters" % field.type
out.extend(field.type.get_fields())
out.extend(field.type.get_fields(minimal=minimal_children, minimal_children=minimal_children))

return out

Expand All @@ -40,38 +49,73 @@ def parameters(cls):
"""
return [field.name for field in cls.get_fields()]

@classmethod
def as_input(cls, wildcards):
@classproperty
def minimal_parameters(cls):
"""
Tries to return a valid snakemake input-file by using the given wildcards (as many as possible, starting from the first).
Returns a list of the minimum set of parameters needed to uniquely represent
this objeckt, meaning that Literal parameters with only one possible value are ignored.
"""
return [field.name for field in cls.get_fields(minimal=True)]

If fields are Union-types, there may be multible possible paths that can be created. This method checks that there
is no ambiguity, and raises an Exception if there is.
@classmethod
def as_input(cls):
"""
assert hasattr(wildcards,
"items"), "As input can only be called with a dictlike object with an items() method"
Returns an input-function that can be used by Snakemake.
"""

def func(wildcards):
assert hasattr(wildcards,
"items"), "As input can only be called with a dictlike object with an items() method"

fields = cls.get_fields()
fields = cls.get_fields()
# create a path from the wildcards and the parameters
# can maybe be done by just calling output with these wildcards.
return cls.as_output(**{name: t for name, t in wildcards.items() if name in cls.parameters})

path = []
path = []

for i, (name, value) in enumerate(wildcards.items()):
assert name == fields[i].name
assert string_is_valid_type(value, fields[i].type)
path.append(value)
"""
for i, (name, value) in enumerate(wildcards.items()):
if i >= len(fields):
break
assert name == fields[i].name, f"Parsing {cls}. Invalid at {i}, name: {name}, expected {fields[i].name}"
assert string_is_valid_type(value, fields[i].type), f"{value} is not a valid as type {fields[i].type}"
path.append(value)
return os.path.sep.join(path)
return os.path.sep.join(path)
"""

return func

@classmethod
def as_output(cls):
def as_output(cls, **kwargs):
"""
Returns a valid Snakemake wildcard string with regex so force types
Keyword arguments can be specified to fix certain variables to values.
"""
names_with_regexes = ["{" + field.name + "," + type_to_regex(field.type) + "}" for field in
cls.get_fields()]
return os.path.sep.join(names_with_regexes)
names_with_regexes = []
for name in kwargs:
assert name in cls.parameters, "Trying to force a field %s. Available fields are %s" % (name, cls.parameters)


for field in cls.get_fields(minimal_children=True):
if field.name in kwargs:
assert string_is_valid_type(kwargs[field.name], field.type), \
f"Trying to set field {field.name} to value {kwargs[field.name]}, " \
f"but this is not compatible with the field type {field.type}."

names_with_regexes.append(str(kwargs[field.name]))
else:
if get_origin(field.type) == Literal and len(get_args(field.type)) == 1:
# literal types enforces a single value, should not be wildcards
names_with_regexes.append(get_args(field.type)[0])
else:
names_with_regexes.append("{" + field.name + "," + type_to_regex(field.type) + "}")

return os.path.sep.join(names_with_regexes)

Parameters.__name__ = base_class.__name__
Parameters.__qualname__ = base_class.__qualname__

return Parameters
32 changes: 27 additions & 5 deletions snakehelp/snakehelp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
from collections import namedtuple
from dataclasses import dataclass, fields
from typing import Literal, get_origin, get_args
from types import UnionType
from typing import Literal, get_origin, get_args, Union
import re


Expand All @@ -15,16 +16,33 @@ def __get__(self, obj, owner):

classproperty = ClassProperty

def is_parameter_type(object):
return hasattr(object, "get_fields")


def is_base_type(type):
return type in (str, float, bool, int)


def type_to_regex(type):
if type == int:
return "\\d+"
elif type == float:
if type == float:
return "[+-]?([0-9]*[.])?[0-9]+"
elif type == str:
if type == str:
return "\\w+"
elif get_origin(type) == Literal:
if get_origin(type) == Literal:
return "|".join([re.escape(arg) for arg in get_args(type)])
elif get_origin(type) in (Union, UnionType):
if all(is_base_type(t) for t in get_args(type)):
# all types are base type, we can give a regex for each
return "|".join([type_to_regex(t) for t in get_args(type)])
else:
# There is one or more objects, we can have anything
return ".*"
elif is_parameter_type(type):
# normal parameter-objects are strings in the path
return "\w+"

raise Exception("Invalid type %s" % type)

Expand All @@ -33,7 +51,7 @@ def string_is_valid_type(string, type):
if type == str:
return True
elif type == int:
return string.isdigit()
return isinstance(string, int) or string.isdigit()
elif type == float:
try:
float(string)
Expand All @@ -42,6 +60,10 @@ def string_is_valid_type(string, type):
return False
elif get_origin(type) == Literal:
return string in get_args(type)
elif get_origin(type) in [Union, UnionType]:
return any((string_is_valid_type(string, t) for t in get_args(type)))
elif is_parameter_type(type):
return isinstance(string, str)
else:
raise Exception("Type %s not implemented" % type)

Expand Down
96 changes: 96 additions & 0 deletions test_pipeline/Snakefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
configfile: "config/config.yaml"
from snakehelp import parameters
from dataclasses import dataclass
from typing import Literal, Union


@parameters
class ReferenceGenome:
genome_build: str
random_seed: int
dataset_size: Literal["small", "medium", "big"]
file: Literal["ref.fa"]


@parameters
class ChipSeqReads:
n_peaks: int
binding_strength: float
file: Literal["reads.fq.gz"]

@parameters
class ReadErrorProfile:
substitution_rate: float
indel_rate: float


@parameters
class SingleEndReads:
n_reads: int
read_length: int
error_profile: ReadErrorProfile
ending: Literal["reads.fq.gz"]


@parameters
class PairedEndReads:
n_reads: int
fragment_length_mean: int
fragment_length_std: int
error_profile: ReadErrorProfile


@parameters
class SimulatedReads:
reference_genome: ReferenceGenome
read_config: Union[ChipSeqReads, PairedEndReads, SingleEndReads]
file: Literal["reads.fq.gz"]


@parameters
class MappedReads:
reads: SimulatedReads
method: str
n_threads: int
ending: Literal["mapped.bam"]


def test_input_function(wildcards):
print(wildcards, type(wildcards))
return "test.txt"


rule test:
input: test_input_function
output: touch("data/{param1,\d+}/{param2,\w+}/file.txt")


rule map:
input:
reads=SimulatedReads.as_input(),
reference=ReferenceGenome.as_input()
output:
reads=touch(MappedReads.as_output(method='bwa'))


print(MappedReads.as_output(method='bwa'))


rule test_files:
output:
reads=touch("hg38/123/medium/some/read/config/reads.fq.gz"),
reference_genome=touch("hg38/123/medium/ref.fa")


rule all:
input:
"hg38/123/medium/some/read/config/bwa/8/mapped.bam"

"""
rule map:
input:
ref = ReferenceGenome.as_input,
reads = SimulatedReads.as_input
output:
bam=MappedReads.as_output
"""
Loading

0 comments on commit 960d15e

Please sign in to comment.