Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
controller_metrics:
- name: training_loss_window
class: HistoryBasedMetric
arguments:
window_size: 1
controllers:
- name: epoch_level_stop_on_training_loss_below_threshold
triggers:
- on_step_end
rule: len(training_loss_window["training_loss"]["loss"]) == training_loss_window["window_size"] and training_loss_window["training_loss"]["loss"][0] < 2.2 and training_loss_window["training_loss"]["epoch"][0] > 2
config:
trigger_log_level: warning
operations:
- hfcontrols.should_training_stop
31 changes: 25 additions & 6 deletions tuning/trainercontroller/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,14 @@
CONTROLLER_CONFIG_KEY = "config"
CONTROLLER_PATIENCE_CONFIG_KEY = "patience"
CONTROLLER_TRIGGERS_KEY = "triggers"
CONTROLLER_CONFIG_TRIGGER_LOG_LEVEL = "trigger_log_level"
CONTROLLER_OPERATIONS_KEY = OPERATIONS_KEY

# Default operations / metrics to register
# Default values
DEFAULT_OPERATIONS = {"operations": [{"name": "hfcontrols", "class": "HFControls"}]}
DEFAULT_METRICS = {}
DEFAULT_CONFIG = {}
DEFAULT_TRIGGER_LOG_LEVEL = "debug"

# pylint: disable=too-many-instance-attributes
class TrainerControllerCallback(TrainerCallback):
Expand Down Expand Up @@ -250,15 +253,14 @@ def _take_control_actions(self, event_name: str, **kwargs):
continue
if rule_succeeded:
for operation_action in control_action.operation_actions:
logger.info(
"Taking [%s] action in controller [%s]",
operation_action.action,
control_action.name,
)
operation_action.instance.act(
action=operation_action.action,
event_name=event_name,
tc_metrics=self.metrics,
control_name=control_action.name,
log_level=control_action.config[
CONTROLLER_CONFIG_TRIGGER_LOG_LEVEL
],
**kwargs,
)

Expand Down Expand Up @@ -303,6 +305,7 @@ def on_init_end(
kwargs["state"] = state
kwargs["control"] = control

log_levels = logging.get_log_levels_dict()
# Check if there any metrics listed in the configuration
if (
CONTROLLER_METRICS_KEY not in self.trainer_controller_config
Expand Down Expand Up @@ -399,8 +402,24 @@ def on_init_end(
),
operation_actions=[],
)
config_log_level_str = DEFAULT_TRIGGER_LOG_LEVEL
if CONTROLLER_CONFIG_KEY in controller:
control.config = controller[CONTROLLER_CONFIG_KEY]
config_log_level_str = control.config.get(
CONTROLLER_CONFIG_TRIGGER_LOG_LEVEL, config_log_level_str
)
if config_log_level_str not in log_levels:
logger.warning(
"Incorrect trigger log-level [%s] specified in the config."
" Defaulting to 'debug' level",
config_log_level_str,
)
config_log_level_str = DEFAULT_TRIGGER_LOG_LEVEL
else:
control.config = DEFAULT_CONFIG
control.config[CONTROLLER_CONFIG_TRIGGER_LOG_LEVEL] = log_levels[
config_log_level_str
]
if CONTROLLER_PATIENCE_CONFIG_KEY in controller:
control.patience = PatienceControl(
**controller[CONTROLLER_PATIENCE_CONFIG_KEY]
Expand Down
19 changes: 18 additions & 1 deletion tuning/trainercontroller/operations/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
import inspect
import re

# Third Party
from transformers.utils import logging

logger = logging.get_logger(__name__)


class Operation(metaclass=abc.ABCMeta):
"""Base class for operations"""
Expand Down Expand Up @@ -32,15 +37,27 @@ def validate(self, action: str) -> bool:
"""
return action in self.valid_actions

def act(self, action: str, **kwargs):
def act(
self, action: str, event_name: str, control_name: str, log_level: int, **kwargs
):
"""Validates the action and invokes it.

Args:
action: str. String depicting the action.
event_name: str. Event name triggering the act.
control_name: str. Name of the controller defining the act.
log_level: int. Log level for triggering the log.
kwargs: List of arguments (key, value)-pairs.
"""
if not self.validate(action):
raise ValueError(f"Invalid operation {action}")
logger.log(
log_level,
"Taking [%s] action in controller [%s] triggered at event [%s]",
action,
control_name,
event_name,
)
self.valid_actions[action](**kwargs)

def get_actions(self) -> list[str]:
Expand Down