Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add PySparkOvertimeMonitor to avoid exceeding time budget #923

Merged
merged 21 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Project
# Project
levscaut marked this conversation as resolved.
Show resolved Hide resolved
/.vs
.vscode

Expand Down
10 changes: 10 additions & 0 deletions flaml/automl/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,8 @@ def custom_metric(
on disk when deleting automl. By default the checkpoint is preserved.
early_stop: boolean, default=False | Whether to stop early if the
search is considered to converge.
force_cancel: boolean, default=False | Whether to forcely cancel Spark jobs if the
search time exceeded the time budget.
append_log: boolean, default=False | Whetehr to directly append the log
records to the input log file if it exists.
auto_augment: boolean, default=True | Whether to automatically
Expand Down Expand Up @@ -782,6 +784,7 @@ def custom_metric(
settings["keep_search_state"] = settings.get("keep_search_state", False)
settings["preserve_checkpoint"] = settings.get("preserve_checkpoint", True)
settings["early_stop"] = settings.get("early_stop", False)
settings["force_cancel"] = settings.get("force_cancel", False)
settings["append_log"] = settings.get("append_log", False)
settings["min_sample_size"] = settings.get("min_sample_size", MIN_SAMPLE_TRAIN)
settings["use_ray"] = settings.get("use_ray", False)
Expand Down Expand Up @@ -2204,6 +2207,7 @@ def fit(
keep_search_state=None,
preserve_checkpoint=True,
early_stop=None,
force_cancel=None,
levscaut marked this conversation as resolved.
Show resolved Hide resolved
append_log=None,
auto_augment=None,
min_sample_size=None,
Expand Down Expand Up @@ -2393,6 +2397,7 @@ def custom_metric(
on disk when deleting automl. By default the checkpoint is preserved.
early_stop: boolean, default=False | Whether to stop early if the
search is considered to converge.
force_cancel: boolean, default=False | Whether to forcely cancel the PySpark job if overtime.
append_log: boolean, default=False | Whetehr to directly append the log
records to the input log file if it exists.
auto_augment: boolean, default=True | Whether to automatically
Expand Down Expand Up @@ -2595,6 +2600,9 @@ def cv_score_agg_func(val_loss_folds, log_metrics_folds):
early_stop = (
self._settings.get("early_stop") if early_stop is None else early_stop
)
force_cancel = (
self._settings.get("force_cancel") if force_cancel is None else force_cancel
)
# no search budget is provided?
no_budget = time_budget < 0 and max_iter is None and not early_stop
append_log = (
Expand Down Expand Up @@ -2645,6 +2653,7 @@ def cv_score_agg_func(val_loss_folds, log_metrics_folds):
self._n_concurrent_trials = n_concurrent_trials
self._early_stop = early_stop
self._use_spark = use_spark
self._force_cancel = force_cancel
self._use_ray = use_ray
# use the following condition if we have an estimation of average_trial_time and average_trial_overhead
# self._use_ray = use_ray or n_concurrent_trials > ( average_trial_time + average_trial_overhead) / (average_trial_time)
Expand Down Expand Up @@ -3171,6 +3180,7 @@ def _search_parallel(self):
verbose=max(self.verbose - 2, 0),
use_ray=False,
use_spark=True,
force_cancel=self._force_cancel,
# raise_on_failed_trial=False,
# keep_checkpoints_num=1,
# checkpoint_score_attr="min-val_loss",
Expand Down
124 changes: 121 additions & 3 deletions flaml/tune/spark/utils.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
import os
import logging
from functools import partial, lru_cache
import os
import textwrap
import threading
import time
from functools import lru_cache, partial


logger = logging.getLogger(__name__)
logger_formatter = logging.Formatter(
"[%(name)s: %(asctime)s] {%(lineno)d} %(levelname)s - %(message)s", "%m-%d %H:%M:%S"
)

try:
import pyspark
from pyspark.sql import SparkSession
from pyspark.util import VersionUtils
import pyspark
import py4j

_have_spark = True
_spark_major_minor_version = VersionUtils.majorMinorVersion(pyspark.__version__)
except ImportError as e:
logger.debug("Could not import pyspark: %s", e)
_have_spark = False
py4j = None
_spark_major_minor_version = (0, 0)


Expand Down Expand Up @@ -187,3 +192,116 @@ def get_broadcast_data(broadcast_data):
if _have_spark and isinstance(broadcast_data, pyspark.broadcast.Broadcast):
broadcast_data = broadcast_data.value
return broadcast_data


class PySparkOvertimeMonitor:
"""A context manager class to monitor if the PySpark job is overtime.
Example:

```python
with PySparkOvertimeMonitor(time_start, time_budget_s, force_cancel, parallel=parallel):
results = parallel(
delayed(evaluation_function)(trial_to_run.config)
for trial_to_run in trials_to_run
)
```

"""

def __init__(
self,
start_time,
time_budget_s,
force_cancel=False,
cancel_func=None,
parallel=None,
sc=None,
):
"""Constructor.

Specify the time budget and start time of the PySpark job, and specify how to cancel them.

Args:
Args relate to monitoring:
start_time: float | The start time of the PySpark job.
time_budget_s: float | The time budget of the PySpark job in seconds.
force_cancel: boolean, default=False | Whether to forcely cancel the PySpark job if overtime.

Args relate to how to cancel the PySpark job:
(Only one of the following args will work. Priorities from top to bottom)
cancel_func: function | A function to cancel the PySpark job.
parallel: joblib.parallel.Parallel | Specify this if using joblib_spark as a parallel backend. It will call parallel._backend.terminate() to cancel the jobs.
sc: pyspark.SparkContext object | You can pass a specific SparkContext.

If all three args is None, the monitor will call pyspark.SparkContext.getOrCreate().cancelAllJobs() to cancel the jobs.


"""
self._time_budget_s = time_budget_s
self._start_time = start_time
self._force_cancel = force_cancel
# TODO: add support for non-spark scenario
if self._force_cancel and _have_spark:
self._monitor_daemon = None
self._finished_flag = False
self._cancel_flag = False
self.sc = None
if cancel_func:
self.__cancel_func = cancel_func
elif parallel:
self.__cancel_func = parallel._backend.terminate
elif sc:
self.sc = sc
self.__cancel_func = self.sc.cancelAllJobs
else:
self.__cancel_func = pyspark.SparkContext.getOrCreate().cancelAllJobs
# logger.info(self.__cancel_func)

def _monitor_overtime(self):
"""The lifecycle function for monitor thread."""
if self._time_budget_s is None:
self.__cancel_func()
self._cancel_flag = True
return
while time.time() - self._start_time <= self._time_budget_s:
time.sleep(0.01)
if self._finished_flag:
return
self.__cancel_func()
self._cancel_flag = True
return

def _setLogLevel(self, level):
"""Set the log level of the spark context.
Set the level to OFF could block the warning message of Spark."""
levscaut marked this conversation as resolved.
Show resolved Hide resolved
if self.sc:
self.sc.setLogLevel(level)
else:
pyspark.SparkContext.getOrCreate().setLogLevel(level)

def __enter__(self):
"""Enter the context manager.
This will start a monitor thread if spark is available and force_cancel is True."""
if self._force_cancel and _have_spark:
self._monitor_daemon = threading.Thread(target=self._monitor_overtime)
# logger.setLevel("INFO")
logger.info("monitor started")
self._setLogLevel("OFF")
self._monitor_daemon.start()

def __exit__(self, exc_type, exc_value, exc_traceback):
"""Exit the context manager.
This will wait for the monitor thread to nicely exit."""
if self._force_cancel and _have_spark:
self._finished_flag = True
self._monitor_daemon.join()
if self._cancel_flag:
print()
logger.warning("Time exceeded, canceled jobs")
# self._setLogLevel("WARN")
if not exc_type:
return True
elif exc_type == py4j.protocol.Py4JJavaError:
return True
else:
return False
18 changes: 13 additions & 5 deletions flaml/tune/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .trial import Trial
from .result import DEFAULT_METRIC
import logging
from flaml.tune.spark.utils import PySparkOvertimeMonitor

logger = logging.getLogger(__name__)
logger.propagate = False
Expand Down Expand Up @@ -246,6 +247,7 @@ def run(
use_incumbent_result_in_evaluation: Optional[bool] = None,
log_file_name: Optional[str] = None,
lexico_objectives: Optional[dict] = None,
force_cancel: Optional[bool] = False,
**ray_args,
):
"""The trigger for HPO.
Expand Down Expand Up @@ -730,10 +732,14 @@ def easy_objective(config):
logger.debug(
f"Configs of Trials to run: {[trial_to_run.config for trial_to_run in trials_to_run]}"
)
results = parallel(
delayed(evaluation_function)(trial_to_run.config)
for trial_to_run in trials_to_run
)
results = None
thinkall marked this conversation as resolved.
Show resolved Hide resolved
with PySparkOvertimeMonitor(
time_start, time_budget_s, force_cancel, parallel=parallel
):
results = parallel(
delayed(evaluation_function)(trial_to_run.config)
for trial_to_run in trials_to_run
)
# results = [evaluation_function(trial_to_run.config) for trial_to_run in trials_to_run]
while results:
result = results.pop(0)
Expand Down Expand Up @@ -803,7 +809,9 @@ def easy_objective(config):
num_trials += 1
if verbose:
logger.info(f"trial {num_trials} config: {trial_to_run.config}")
result = evaluation_function(trial_to_run.config)
result = None
levscaut marked this conversation as resolved.
Show resolved Hide resolved
with PySparkOvertimeMonitor(time_start, time_budget_s, force_cancel):
result = evaluation_function(trial_to_run.config)
if result is not None:
if isinstance(result, dict):
if result:
Expand Down
32 changes: 32 additions & 0 deletions test/spark/custom_mylearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

custom_code = """
from flaml import tune
import time
from flaml.automl.model import LGBMEstimator, XGBoostSklearnEstimator, SKLearnEstimator
from flaml.automl.data import CLASSIFICATION, get_output_from_log

Expand Down Expand Up @@ -91,6 +92,7 @@ def search_space(cls, **params):
}



def custom_metric(
X_val,
y_val,
Expand Down Expand Up @@ -119,6 +121,36 @@ def custom_metric(
"train_loss": train_loss,
"pred_time": pred_time,
}

def lazy_metric(
X_val,
y_val,
estimator,
labels,
X_train,
y_train,
weight_val=None,
weight_train=None,
config=None,
groups_val=None,
groups_train=None,
):
from sklearn.metrics import log_loss

levscaut marked this conversation as resolved.
Show resolved Hide resolved

time.sleep(2)
start = time.time()
y_pred = estimator.predict_proba(X_val)
pred_time = (time.time() - start) / len(X_val)
val_loss = log_loss(y_val, y_pred, labels=labels, sample_weight=weight_val)
y_pred = estimator.predict_proba(X_train)
train_loss = log_loss(y_train, y_pred, labels=labels, sample_weight=weight_train)
alpha = 0.5
return val_loss * (1 + alpha) - alpha * train_loss, {
"val_loss": val_loss,
"train_loss": train_loss,
"pred_time": pred_time,
}
"""

_ = broadcast_code(custom_code=custom_code)
70 changes: 70 additions & 0 deletions test/spark/test_overtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import os
import time

import numpy as np
import pyspark
import pytest
from sklearn.datasets import load_iris

from flaml import AutoML
from flaml.tune.spark.utils import check_spark

try:
from test.spark.custom_mylearner import *
except ImportError:
from custom_mylearner import *

from flaml.tune.spark.mylearner import lazy_metric

os.environ["FLAML_MAX_CONCURRENT"] = "10"

spark = pyspark.sql.SparkSession.builder.appName("App4OvertimeTest").getOrCreate()
spark_available, _ = check_spark()
skip_spark = not spark_available

pytestmark = pytest.mark.skipif(
skip_spark, reason="Spark is not installed. Skip all spark tests."
)


def test_overtime():
thinkall marked this conversation as resolved.
Show resolved Hide resolved
time_budget = 15
df, y = load_iris(return_X_y=True, as_frame=True)
df["label"] = y
automl_experiment = AutoML()
automl_settings = {
"dataframe": df,
"label": "label",
"time_budget": time_budget,
"eval_method": "cv",
"metric": lazy_metric,
"task": "classification",
"log_file_name": "test/iris_custom.log",
"log_training_metric": True,
"log_type": "all",
"n_jobs": 1,
"model_history": True,
"sample_weight": np.ones(len(y)),
"pred_time_limit": 1e-5,
"estimator_list": ["lgbm"],
"n_concurrent_trials": 2,
thinkall marked this conversation as resolved.
Show resolved Hide resolved
"use_spark": True,
"force_cancel": True,
}
start_time = time.time()
automl_experiment.fit(**automl_settings)
elapsed_time = time.time() - start_time
print(
"time budget: {:.2f}s, actual elapsed time: {:.2f}s".format(
time_budget, elapsed_time
)
)
assert abs(elapsed_time - time_budget) < 2
print(automl_experiment.predict(df))
print(automl_experiment.model)
print(automl_experiment.best_iteration)
print(automl_experiment.best_estimator)


if __name__ == "__main__":
test_overtime()