Skip to content

Commit 9f9307b

Browse files
authored
Rework command line specification (#125)
This PR introduces the "{PORT::port_name}" placeholder in the shell task command and thus extends the port concept to all tasks. This allows for arbitrary command lines to be specified without having to make assumptions when parsing it. The port keyword is now mandatory in the yaml config file and inputs are now internally represented by a dictionary mapping port names to the corresponding input data nodes list. Some minor changes are also made to workgraph.py for readability.
1 parent 060e5f9 commit 9f9307b

File tree

18 files changed

+366
-390
lines changed

18 files changed

+366
-390
lines changed

src/sirocco/core/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from ._tasks import IconTask, ShellTask
2-
from .graph_items import Cycle, Data, GraphItem, Task
2+
from .graph_items import AvailableData, Cycle, Data, GeneratedData, GraphItem, Task
33
from .workflow import Workflow
44

5-
__all__ = ["Workflow", "GraphItem", "Data", "Task", "Cycle", "ShellTask", "IconTask"]
5+
__all__ = ["Workflow", "GraphItem", "Data", "AvailableData", "GeneratedData", "Task", "Cycle", "ShellTask", "IconTask"]

src/sirocco/core/_tasks/icon_task.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,9 @@ def update_core_namelists_from_workflow(self):
5959
"experimentStopDate": self.cycle_point.stop_date.isoformat() + "Z",
6060
}
6161
)
62-
self.core_namelists["icon_master.namelist"]["master_nml"]["lrestart"] = any(
63-
# NOTE: in_data[0] contains the actual data node and in_data[1] the port name
64-
in_data[1] == "restart"
65-
for in_data in self.inputs
66-
)
62+
self.core_namelists["icon_master.namelist"]["master_nml"]["lrestart"] = bool(self.inputs["restart"])
6763

68-
def dump_core_namelists(self, folder=None):
64+
def dump_core_namelists(self, folder: str | Path | None = None):
6965
if folder is not None:
7066
folder = Path(folder)
7167
folder.mkdir(parents=True, exist_ok=True)

src/sirocco/core/graph_items.py

+19-14
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from dataclasses import dataclass, field
44
from itertools import chain, product
5-
from typing import TYPE_CHECKING, Any, ClassVar, Self, TypeAlias, TypeVar, cast
5+
from typing import TYPE_CHECKING, Any, ClassVar, Self, TypeVar, cast
66

77
from sirocco.parsing.target_cycle import DateList, LagList, NoTargetCycle
88
from sirocco.parsing.yaml_data_models import (
@@ -46,21 +46,23 @@ class Data(ConfigBaseDataSpecs, GraphItem):
4646

4747
color: ClassVar[Color] = field(default="light_blue", repr=False)
4848

49-
available: bool
50-
5149
@classmethod
52-
def from_config(cls, config: ConfigBaseData, coordinates: dict) -> Self:
53-
return cls(
50+
def from_config(cls, config: ConfigBaseData, coordinates: dict) -> AvailableData | GeneratedData:
51+
data_class = AvailableData if isinstance(config, ConfigAvailableData) else GeneratedData
52+
return data_class(
5453
name=config.name,
5554
type=config.type,
5655
src=config.src,
57-
available=isinstance(config, ConfigAvailableData),
5856
coordinates=coordinates,
5957
)
6058

6159

62-
# contains the input data and its potential associated port
63-
BoundData: TypeAlias = tuple[Data, str | None]
60+
class AvailableData(Data):
61+
pass
62+
63+
64+
class GeneratedData(Data):
65+
pass
6466

6567

6668
@dataclass(kw_only=True)
@@ -70,7 +72,7 @@ class Task(ConfigBaseTaskSpecs, GraphItem):
7072
plugin_classes: ClassVar[dict[str, type[Self]]] = field(default={}, repr=False)
7173
color: ClassVar[Color] = field(default="light_red", repr=False)
7274

73-
inputs: list[BoundData] = field(default_factory=list)
75+
inputs: dict[str, list[Data]] = field(default_factory=dict)
7476
outputs: list[Data] = field(default_factory=list)
7577
wait_on: list[Task] = field(default_factory=list)
7678
config_rootdir: Path
@@ -85,6 +87,9 @@ def __init_subclass__(cls, **kwargs):
8587
raise ValueError(msg)
8688
Task.plugin_classes[cls.plugin] = cls
8789

90+
def input_data_nodes(self) -> Iterator[Data]:
91+
yield from chain(*self.inputs.values())
92+
8893
@classmethod
8994
def from_config(
9095
cls: type[Self],
@@ -95,11 +100,11 @@ def from_config(
95100
datastore: Store,
96101
graph_spec: ConfigCycleTask,
97102
) -> Task:
98-
inputs = [
99-
(data_node, input_spec.port)
100-
for input_spec in graph_spec.inputs
101-
for data_node in datastore.iter_from_cycle_spec(input_spec, coordinates)
102-
]
103+
inputs: dict[str, list[Data]] = {}
104+
for input_spec in graph_spec.inputs:
105+
if input_spec.port not in inputs:
106+
inputs[input_spec.port] = []
107+
inputs[input_spec.port].extend(datastore.iter_from_cycle_spec(input_spec, coordinates))
103108
outputs = [datastore[output_spec.name, coordinates] for output_spec in graph_spec.outputs]
104109
if (plugin_cls := Task.plugin_classes.get(type(config).plugin, None)) is None:
105110
msg = f"Plugin {type(config).plugin!r} is not supported."

src/sirocco/core/workflow.py

+15-14
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ def __init__(
2929
self,
3030
name: str,
3131
config_rootdir: Path,
32-
cycles: list[ConfigCycle],
33-
tasks: list[ConfigTask],
34-
data: ConfigData,
32+
config_cycles: list[ConfigCycle],
33+
config_tasks: list[ConfigTask],
34+
config_data: ConfigData,
3535
parameters: dict[str, list],
3636
) -> None:
3737
self.name: str = name
@@ -41,8 +41,10 @@ def __init__(
4141
self.data: Store[Data] = Store()
4242
self.cycles: Store[Cycle] = Store()
4343

44-
data_dict: dict[str, ConfigBaseData] = {data.name: data for data in chain(data.available, data.generated)}
45-
task_dict: dict[str, ConfigTask] = {task.name: task for task in tasks}
44+
config_data_dict: dict[str, ConfigBaseData] = {
45+
data.name: data for data in chain(config_data.available, config_data.generated)
46+
}
47+
config_task_dict: dict[str, ConfigTask] = {task.name: task for task in config_tasks}
4648

4749
# Function to iterate over date and parameter combinations
4850
def iter_coordinates(cycle_point: CyclePoint, param_refs: list[str]) -> Iterator[dict]:
@@ -52,28 +54,27 @@ def iter_coordinates(cycle_point: CyclePoint, param_refs: list[str]) -> Iterator
5254
yield from (dict(zip(axes.keys(), x, strict=False)) for x in product(*axes.values()))
5355

5456
# 1 - create availalbe data nodes
55-
for available_data_config in data.available:
57+
for available_data_config in config_data.available:
5658
for coordinates in iter_coordinates(OneOffPoint(), available_data_config.parameters):
5759
self.data.add(Data.from_config(config=available_data_config, coordinates=coordinates))
5860

5961
# 2 - create output data nodes
60-
for cycle_config in cycles:
62+
for cycle_config in config_cycles:
6163
for cycle_point in cycle_config.cycling.iter_cycle_points():
6264
for task_ref in cycle_config.tasks:
6365
for data_ref in task_ref.outputs:
64-
data_name = data_ref.name
65-
data_config = data_dict[data_name]
66+
data_config = config_data_dict[data_ref.name]
6667
for coordinates in iter_coordinates(cycle_point, data_config.parameters):
6768
self.data.add(Data.from_config(config=data_config, coordinates=coordinates))
6869

6970
# 3 - create cycles and tasks
70-
for cycle_config in cycles:
71+
for cycle_config in config_cycles:
7172
cycle_name = cycle_config.name
7273
for cycle_point in cycle_config.cycling.iter_cycle_points():
7374
cycle_tasks = []
7475
for task_graph_spec in cycle_config.tasks:
7576
task_name = task_graph_spec.name
76-
task_config = task_dict[task_name]
77+
task_config = config_task_dict[task_name]
7778
for coordinates in iter_coordinates(cycle_point, task_config.parameters):
7879
task = Task.from_config(
7980
config=task_config,
@@ -113,8 +114,8 @@ def from_config_workflow(cls: type[Self], config_workflow: ConfigWorkflow) -> Se
113114
return cls(
114115
name=config_workflow.name,
115116
config_rootdir=config_workflow.rootdir,
116-
cycles=config_workflow.cycles,
117-
tasks=config_workflow.tasks,
118-
data=config_workflow.data,
117+
config_cycles=config_workflow.cycles,
118+
config_tasks=config_workflow.tasks,
119+
config_data=config_workflow.data,
119120
parameters=config_workflow.parameters,
120121
)

src/sirocco/parsing/yaml_data_models.py

+62-99
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import enum
44
import itertools
5+
import re
56
import time
67
import typing
78
from dataclasses import dataclass, field
@@ -160,7 +161,7 @@ class TargetNodesBaseModel(_NamedBaseModel):
160161

161162

162163
class ConfigCycleTaskInput(TargetNodesBaseModel):
163-
port: str | None = None
164+
port: str
164165

165166

166167
class ConfigCycleTaskWaitOn(TargetNodesBaseModel):
@@ -273,58 +274,66 @@ class ConfigRootTask(ConfigBaseTask):
273274
plugin: ClassVar[Literal["_root"]] = "_root"
274275

275276

276-
# By using a frozen class we only need to validate on initialization
277-
@dataclass(frozen=True)
278-
class ShellCliArgument:
279-
"""A holder for a CLI argument to simplify access.
280-
281-
Stores CLI arguments of the form "file", "--init", "{file}" or "{--init file}". These examples translate into
282-
ShellCliArguments ShellCliArgument(name="file", references_data_item=False, cli_option_of_data_item=None),
283-
ShellCliArgument(name="--init", references_data_item=False, cli_option_of_data_item=None),
284-
ShellCliArgument(name="file", references_data_item=True, cli_option_of_data_item=None),
285-
ShellCliArgument(name="file", references_data_item=True, cli_option_of_data_item="--init")
277+
@dataclass(kw_only=True)
278+
class ConfigShellTaskSpecs:
279+
plugin: ClassVar[Literal["shell"]] = "shell"
280+
port_pattern: ClassVar[re.Pattern] = field(default=re.compile(r"{PORT(\[sep=.+\])?::(.+?)}"), repr=False)
281+
sep_pattern: ClassVar[re.Pattern] = field(default=re.compile(r"\[sep=(.+)\]"), repr=False)
282+
src: str | None = None
283+
command: str
284+
env_source_files: list[str] = field(default_factory=list)
286285

287-
Attributes:
288-
name: Name of the argument. For the examples it is "file", "--init", "file" and "file"
289-
references_data_item: Specifies if the argument references a data item signified by enclosing it by curly
290-
brackets.
291-
cli_option_of_data_item: The CLI option associated to the data item.
292-
"""
286+
def resolve_ports(self, input_labels: dict[str, list[str]]) -> str:
287+
"""Replace port placeholders in command string with provided input labels.
293288
294-
name: str
295-
references_data_item: bool
296-
cli_option_of_data_item: str | None = None
289+
Returns a string corresponding to self.command with "{PORT::port_name}"
290+
placeholders replaced by the content provided in the input_labels dict.
291+
When multiple input nodes are linked to a single port (e.g. with
292+
parameterized data or if the `when` keyword specifies a list of lags or
293+
dates), the provided input labels are inserted with a separator
294+
defaulting to a " ". Specifying an alternative separator, e.g. a comma,
295+
is done via "{PORT[sep=,]::port_name}"
297296
298-
def __post_init__(self):
299-
if self.cli_option_of_data_item is not None and not self.references_data_item:
300-
msg = "data_item_option cannot be not None if cli_option_of_data_item is False"
301-
raise ValueError(msg)
297+
Examples:
302298
303-
@classmethod
304-
def from_cli_argument(cls, arg: str) -> ShellCliArgument:
305-
len_arg_with_option = 2
306-
len_arg_no_option = 1
307-
references_data_item = arg.startswith("{") and arg.endswith("}")
308-
# remove curly brackets "{--init file}" -> "--init file"
309-
arg_unwrapped = arg[1:-1] if arg.startswith("{") and arg.endswith("}") else arg
310-
311-
# "--init file" -> ["--init", "file"]
312-
input_arg = arg_unwrapped.split()
313-
if len(input_arg) != len_arg_with_option and len(input_arg) != len_arg_no_option:
314-
msg = f"Expected argument of format {{data}} or {{option data}} but found {arg}"
315-
raise ValueError(msg)
316-
name = input_arg[0] if len(input_arg) == len_arg_no_option else input_arg[1]
317-
cli_option_of_data_item = input_arg[0] if len(input_arg) == len_arg_with_option else None
318-
return cls(name, references_data_item, cli_option_of_data_item)
299+
>>> task_specs = ConfigShellTaskSpecs(
300+
... command="./my_script {PORT::positionals} -l -c --verbose 2 --arg {PORT::my_arg}"
301+
... )
302+
>>> task_specs.resolve_ports(
303+
... {"positionals": ["input_1", "input_2"], "my_arg": ["input_3"]}
304+
... )
305+
'./my_script input_1 input_2 -l -c --verbose 2 --arg input_3'
319306
307+
>>> task_specs = ConfigShellTaskSpecs(
308+
... command="./my_script {PORT::positionals} --multi_arg {PORT[sep=,]::multi_arg}"
309+
... )
310+
>>> task_specs.resolve_ports(
311+
... {"positionals": ["input_1", "input_2"], "multi_arg": ["input_3", "input_4"]}
312+
... )
313+
'./my_script input_1 input_2 --multi_arg input_3,input_4'
320314
321-
@dataclass(kw_only=True)
322-
class ConfigShellTaskSpecs:
323-
plugin: ClassVar[Literal["shell"]] = "shell"
324-
command: str = ""
325-
cli_arguments: list[ShellCliArgument] = field(default_factory=list)
326-
env_source_files: list[str] = field(default_factory=list)
327-
src: str | None = None
315+
>>> task_specs = ConfigShellTaskSpecs(
316+
... command="./my_script --input {PORT[sep= --input ]::repeat_input}"
317+
... )
318+
>>> task_specs.resolve_ports({"repeat_input": ["input_1", "input_2", "input_3"]})
319+
'./my_script --input input_1 --input input_2 --input input_3'
320+
"""
321+
cmd = self.command
322+
for port_match in self.port_pattern.finditer(cmd):
323+
if (port_name := port_match.group(2)) is None:
324+
msg = f"Wrong port specification: {port_match.group(0)}"
325+
raise ValueError(msg)
326+
if (sep := port_match.group(1)) is None:
327+
arg_sep = " "
328+
else:
329+
if (sep_match := self.sep_pattern.match(sep)) is None:
330+
msg = "Wrong separator specification: sep"
331+
raise ValueError(msg)
332+
if (arg_sep := sep_match.group(1)) is None:
333+
msg = "Wrong separator specification: sep"
334+
raise ValueError(msg)
335+
cmd = cmd.replace(port_match.group(0), arg_sep.join(input_labels[port_name]))
336+
return cmd
328337

329338

330339
class ConfigShellTask(ConfigBaseTask, ConfigShellTaskSpecs):
@@ -340,75 +349,26 @@ class ConfigShellTask(ConfigBaseTask, ConfigShellTaskSpecs):
340349
... '''
341350
... my_task:
342351
... plugin: shell
343-
... command: my_script.sh
344-
... src: post_run_scripts
345-
... cli_arguments: "-n 1024 {current_sim_output}"
352+
... command: "my_script.sh -n 1024 {PORT::current_sim_output}"
353+
... src: post_run_scripts/my_script.sh
346354
... env_source_files: "env.sh"
347355
... walltime: 00:01:00
348356
... '''
349357
... ),
350358
... )
351-
>>> my_task.cli_arguments[0]
352-
ShellCliArgument(name='-n', references_data_item=False, cli_option_of_data_item=None)
353-
>>> my_task.cli_arguments[1]
354-
ShellCliArgument(name='1024', references_data_item=False, cli_option_of_data_item=None)
355-
>>> my_task.cli_arguments[2]
356-
ShellCliArgument(name='current_sim_output', references_data_item=True, cli_option_of_data_item=None)
357359
>>> my_task.env_source_files
358360
['env.sh']
359361
>>> my_task.walltime.tm_min
360362
1
361363
"""
362364

363-
command: str = ""
364-
cli_arguments: list[ShellCliArgument] = Field(default_factory=list)
365365
env_source_files: list[str] = Field(default_factory=list)
366366

367-
@field_validator("cli_arguments", mode="before")
368-
@classmethod
369-
def validate_cli_arguments(cls, value: str) -> list[ShellCliArgument]:
370-
return cls.parse_cli_arguments(value)
371-
372367
@field_validator("env_source_files", mode="before")
373368
@classmethod
374369
def validate_env_source_files(cls, value: str | list[str]) -> list[str]:
375370
return [value] if isinstance(value, str) else value
376371

377-
@staticmethod
378-
def split_cli_arguments(cli_arguments: str) -> list[str]:
379-
"""Splits the CLI arguments into a list of separate entities.
380-
381-
Splits the CLI arguments by whitespaces except if the whitespace is contained within curly brackets. For example
382-
the string
383-
"-D --CMAKE_CXX_COMPILER=${CXX_COMPILER} {--init file}"
384-
will be splitted into the list
385-
["-D", "--CMAKE_CXX_COMPILER=${CXX_COMPILER}", "{--init file}"]
386-
"""
387-
388-
nb_open_curly_brackets = 0
389-
last_split_idx = 0
390-
splits = []
391-
for i, char in enumerate(cli_arguments):
392-
if char == " " and not nb_open_curly_brackets:
393-
# we ommit the space in the splitting therefore we only store up to i but move the last_split_idx to i+1
394-
splits.append(cli_arguments[last_split_idx:i])
395-
last_split_idx = i + 1
396-
elif char == "{":
397-
nb_open_curly_brackets += 1
398-
elif char == "}":
399-
if nb_open_curly_brackets == 0:
400-
msg = f"Invalid input for cli_arguments. Found a closing curly bracket before an opening in {cli_arguments!r}"
401-
raise ValueError(msg)
402-
nb_open_curly_brackets -= 1
403-
404-
if last_split_idx != len(cli_arguments):
405-
splits.append(cli_arguments[last_split_idx : len(cli_arguments)])
406-
return splits
407-
408-
@staticmethod
409-
def parse_cli_arguments(cli_arguments: str) -> list[ShellCliArgument]:
410-
return [ShellCliArgument.from_cli_argument(arg) for arg in ConfigShellTask.split_cli_arguments(cli_arguments)]
411-
412372

413373
@dataclass(kw_only=True)
414374
class NamelistSpec:
@@ -662,6 +622,7 @@ class ConfigWorkflow(BaseModel):
662622
... tasks:
663623
... - task_a:
664624
... plugin: shell
625+
... command: "some_command"
665626
... data:
666627
... available:
667628
... - foo:
@@ -681,7 +642,9 @@ class ConfigWorkflow(BaseModel):
681642
... name="minimal",
682643
... rootdir=Path("/location/of/config/file"),
683644
... cycles=[ConfigCycle(minimal_cycle={"tasks": [ConfigCycleTask(task_a={})]})],
684-
... tasks=[ConfigShellTask(task_a={"plugin": "shell"})],
645+
... tasks=[
646+
... ConfigShellTask(task_a={"plugin": "shell", "command": "some_command"})
647+
... ],
685648
... data=ConfigData(
686649
... available=[
687650
... ConfigAvailableData(name="foo", type=DataType.FILE, src="foo.txt")

0 commit comments

Comments
 (0)