Skip to content

Commit

Permalink
working prototype of experiment sequence
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhylkaaa committed Nov 10, 2022
1 parent c35c334 commit 421293e
Show file tree
Hide file tree
Showing 33 changed files with 802 additions and 161 deletions.
8 changes: 6 additions & 2 deletions hydra/_internal/core_plugins/basic_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Sequence
from typing import List, Optional, Sequence, Union

from omegaconf import DictConfig, open_dict

Expand All @@ -14,6 +14,7 @@
run_job,
setup_globals,
)
from hydra.plugins.experiment_sequence import ExperimentSequence
from hydra.plugins.launcher import Launcher
from hydra.types import HydraContext, TaskFunction

Expand Down Expand Up @@ -49,7 +50,7 @@ def setup(
self.task_function = task_function

def launch(
self, job_overrides: Sequence[Sequence[str]], initial_job_idx: int
self, job_overrides: Union[Sequence[Sequence[str]], ExperimentSequence], initial_job_idx: int
) -> Sequence[JobReturn]:
setup_globals()
assert self.hydra_context is not None
Expand All @@ -65,6 +66,7 @@ def launch(
idx = initial_job_idx + idx
lst = " ".join(filter_overrides(overrides))
log.info(f"\t#{idx} : {lst}")
print(overrides)
sweep_config = self.hydra_context.config_loader.load_sweep_config(
self.config, list(overrides)
)
Expand All @@ -79,5 +81,7 @@ def launch(
job_subdir_key="hydra.sweep.subdir",
)
runs.append(ret)
if isinstance(job_overrides, ExperimentSequence):
job_overrides.update_sequence((overrides, ret))
configure_log(self.config.hydra.hydra_logging, self.config.hydra.verbose)
return runs
38 changes: 38 additions & 0 deletions hydra/plugins/experiment_sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractmethod
import typing

from collections.abc import Iterator
from typing import Any, Sequence, Tuple


class ExperimentSequence(Iterator):
@abstractmethod
def __next__(self):
"""Return tuple of experiment id, optional trial object and experiment overrides."""
raise NotImplementedError()

def __iter__(self) -> typing.Iterator[Sequence[str]]:
return self

@abstractmethod
def update_sequence(self, experiment_result: Tuple[Sequence[str], Any]):
"""Update experiment generator(study) with experiment results"""
raise NotImplementedError()

def __len__(self):
"""Return maximum number of experiments sequence can produce"""
raise NotImplementedError()
20 changes: 17 additions & 3 deletions hydra/plugins/launcher.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Launcher plugin interface
"""
from abc import abstractmethod
from typing import Sequence
from typing import Sequence, Union

from omegaconf import DictConfig

from hydra.core.utils import JobReturn

from hydra.plugins.experiment_sequence import ExperimentSequence
from hydra.types import TaskFunction, HydraContext

from .plugin import Plugin
Expand All @@ -30,7 +44,7 @@ def setup(

@abstractmethod
def launch(
self, job_overrides: Sequence[Sequence[str]], initial_job_idx: int
self, job_overrides: Union[Sequence[Sequence[str]], ExperimentSequence], initial_job_idx: int
) -> Sequence[JobReturn]:
"""
:param job_overrides: a batch of job arguments
Expand Down
16 changes: 15 additions & 1 deletion hydra/plugins/sweeper.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Sweeper plugin interface
"""
from abc import abstractmethod
from typing import Any, List, Sequence, Optional
from typing import Any, List, Sequence, Optional, Dict, Tuple

from hydra.types import TaskFunction
from omegaconf import DictConfig
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
from pathlib import Path
from typing import Any, Dict, List, Sequence
from typing import Any, Dict, Union, List, Sequence

from hydra.core.hydra_config import HydraConfig
from hydra.core.singleton import Singleton
Expand All @@ -12,9 +12,11 @@
run_job,
setup_globals,
)
from hydra.plugins.experiment_sequence import ExperimentSequence
from hydra.types import HydraContext, TaskFunction
from joblib import Parallel, delayed # type: ignore
from omegaconf import DictConfig, open_dict
import multiprocessing as mp

from .joblib_launcher import JoblibLauncher

Expand Down Expand Up @@ -63,13 +65,22 @@ def process_joblib_cfg(joblib_cfg: Dict[str, Any]) -> None:
pass


def _batch_sequence(sequence, batch_size=1):
while True:
overrides = [experiment_config for _, experiment_config in zip(range(batch_size), sequence)]
if overrides:
yield overrides
if len(overrides) != batch_size:
raise StopIteration


def launch(
launcher: JoblibLauncher,
job_overrides: Sequence[Sequence[str]],
job_overrides: Union[Sequence[Sequence[str]], ExperimentSequence],
initial_job_idx: int,
) -> Sequence[JobReturn]:
"""
:param job_overrides: a List of List<String>, where each inner list is the arguments for one job run.
:param job_overrides: an Iterable of List<String>, where each inner list is the arguments for one job run.
:param initial_job_idx: Initial job idx in batch.
:return: an array of return values from run_job with indexes corresponding to the input list indexes.
"""
Expand All @@ -87,30 +98,54 @@ def launch(
joblib_cfg = launcher.joblib
joblib_cfg["backend"] = "loky"
process_joblib_cfg(joblib_cfg)

log.info(
"Joblib.Parallel({}) is launching {} jobs".format(
",".join([f"{k}={v}" for k, v in joblib_cfg.items()]),
len(job_overrides),
)
)
log.info("Launching jobs, sweep output dir : {}".format(sweep_dir))
for idx, overrides in enumerate(job_overrides):
log.info("\t#{} : {}".format(idx, " ".join(filter_overrides(overrides))))

singleton_state = Singleton.get_state()

runs = Parallel(**joblib_cfg)(
delayed(execute_job)(
initial_job_idx + idx,
overrides,
launcher.hydra_context,
launcher.config,
launcher.task_function,
singleton_state,
if isinstance(job_overrides, ExperimentSequence):
log.info(
"Joblib.Parallel({}) is launching {} jobs".format(
",".join([f"{k}={v}" for k, v in joblib_cfg.items()]),
'generator of',
)
)
batch_size = v if (v := joblib_cfg['n_jobs']) != -1 else mp.cpu_count()
runs = []
overrides = []
for idx, overrides in enumerate(_batch_sequence(job_overrides, batch_size)):
results = Parallel(**joblib_cfg)(
delayed(execute_job)(
initial_job_idx + idx,
override,
launcher.hydra_context,
launcher.config,
launcher.task_function,
singleton_state,
)
for override in overrides
)
for experiment_result in zip(overrides, results):
job_overrides.update_sequence(experiment_result)
else:
log.info(
"Joblib.Parallel({}) is launching {} jobs".format(
",".join([f"{k}={v}" for k, v in joblib_cfg.items()]),
len(job_overrides),
)
)
log.info("Launching jobs, sweep output dir : {}".format(sweep_dir))
for idx, overrides in enumerate(job_overrides):
log.info("\t#{} : {}".format(idx, " ".join(filter_overrides(overrides))))

runs = Parallel(**joblib_cfg)(
delayed(execute_job)(
initial_job_idx + idx,
overrides,
launcher.hydra_context,
launcher.config,
launcher.task_function,
singleton_state,
)
for idx, overrides in enumerate(job_overrides)
)
for idx, overrides in enumerate(job_overrides)
)

assert isinstance(runs, List)
for run in runs:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
from typing import Any, Optional, Sequence
from typing import Any, Optional, Sequence, Union

from hydra.core.utils import JobReturn
from hydra.plugins.launcher import Launcher
from hydra.plugins.experiment_sequence import ExperimentSequence
from hydra.types import HydraContext, TaskFunction
from omegaconf import DictConfig

Expand Down Expand Up @@ -38,7 +39,7 @@ def setup(
self.hydra_context = hydra_context

def launch(
self, job_overrides: Sequence[Sequence[str]], initial_job_idx: int
self, job_overrides: Union[Sequence[Sequence[str]], ExperimentSequence], initial_job_idx: int
) -> Sequence[JobReturn]:
from . import _core

Expand Down
3 changes: 3 additions & 0 deletions plugins/hydra_loky_launcher/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
global-exclude *.pyc
global-exclude __pycache__
recursive-include hydra_plugins/* *.yaml py.typed
1 change: 1 addition & 0 deletions plugins/hydra_loky_launcher/NEWS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

3 changes: 3 additions & 0 deletions plugins/hydra_loky_launcher/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Hydra loky Launcher
Provides a [loky](link) based Hydra Launcher supporting parallel worker pool execution.

9 changes: 9 additions & 0 deletions plugins/hydra_loky_launcher/example/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
defaults:
- override hydra/launcher: loky

task: 1

hydra:
launcher:
# override the number of jobs for loky
max_workers: 10
20 changes: 20 additions & 0 deletions plugins/hydra_loky_launcher/example/my_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import os
import time

import hydra
from omegaconf import DictConfig

log = logging.getLogger(__name__)


@hydra.main(config_name="config")
def my_app(cfg: DictConfig) -> None:
log.info(f"Process ID {os.getpid()} executing task {cfg.task} ...")

time.sleep(1)


if __name__ == "__main__":
my_app()
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

__version__ = "1.2.0"
Loading

0 comments on commit 421293e

Please sign in to comment.