Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/trainercontroller_configs/loss.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
controller_metrics:
- name: loss
- name: training_loss
class: Loss
controllers:
- name: loss_controller
triggers:
- on_log
rule: loss < 1.0
rule: training_loss["loss"] < 1.0
operations:
- hfcontrols.should_training_stop
4 changes: 2 additions & 2 deletions tests/data/trainercontroller/loss_custom_operation.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
controller_metrics:
- name: loss
- name: training_loss
class: Loss
operations:
- name: custom_operation
Expand All @@ -8,6 +8,6 @@ controllers:
- name: loss_controller_custom_operation
triggers:
- on_log
rule: loss < 1.0
rule: training_loss['loss'] < 1.0
operations:
- custom_operation.should_perform_action_xyz
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
controller_metrics:
- name: loss
- name: training_loss
class: Loss
operations:
- name: custom_operation
Expand All @@ -8,6 +8,6 @@ controllers:
- name: loss_controller_custom_operation_invalid_action
triggers:
- on_log
rule: loss < 1.0
rule: training_loss["loss"] < 1.0
operations:
- custom_operation.should_
4 changes: 2 additions & 2 deletions tests/data/trainercontroller/loss_invalid_metric.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
controller_metrics:
- name: loss
- name: training_loss
class: MissingMetricClass
controllers:
- name: loss_controller_invalid_metric
triggers:
- on_log
rule: loss < 1.0
rule: training_loss['loss'] < 1.0
operations:
- hfcontrols.should_training_stop
4 changes: 2 additions & 2 deletions tests/data/trainercontroller/loss_invalid_operation.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
controller_metrics:
- name: loss
- name: training_loss
class: Loss
controllers:
- name: loss_controller_invalid_operation
triggers:
- on_log
rule: loss < 1.0
rule: training_loss['loss'] < 1.0
operations:
- missingop.should_training_stop
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
controller_metrics:
- name: loss
- name: training_loss
class: Loss
controllers:
- name: loss_controller_invalid_operation_action
triggers:
- on_log
rule: loss < 1.0
rule: training_loss['loss'] < 1.0
operations:
- hfcontrols.missingaction
4 changes: 2 additions & 2 deletions tests/data/trainercontroller/loss_invalid_trigger.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
controller_metrics:
- name: loss
- name: training_loss
class: Loss
controllers:
- name: loss_controller_invalid_trigger
triggers:
- log_it_all_incorrect_trigger_name
rule: loss < 1.0
rule: training_loss['loss'] < 1.0
operations:
- hfcontrols.should_training_stop
4 changes: 2 additions & 2 deletions tests/data/trainercontroller/loss_on_threshold.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
controller_metrics:
- name: loss
- name: training_loss
class: Loss
controllers:
- name: loss_controller
triggers:
- on_log
rule: loss < 1.0
rule: training_loss['loss'] < 1.0
operations:
- hfcontrols.should_training_stop
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
controller_metrics:
- name: state
class: TrainingState
- name: loss
- name: training_loss
class: Loss
controllers:
- name: loss_controller
triggers:
- on_log
rule: loss < 2 and state["epoch"] >= 0.5
rule: training_loss['loss'] < 2 and state["epoch"] >= 0.5
operations:
- hfcontrols.should_training_stop
4 changes: 2 additions & 2 deletions tests/data/trainercontroller/loss_unavailable_metric.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
controller_metrics:
- name: loss
- name: training_loss
class: Loss
controllers:
- name: loss_controller_unavailable_metric
triggers:
- on_step_end
rule: loss < 1.0
rule: training_loss['loss'] < 1.0
operations:
- hfcontrols.should_training_stop
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
controller_metrics:
- name: loss
- name: training_loss
class: Loss
controllers:
- name: loss_controller_wrong_input_rule
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
controller_metrics:
- name: loss
- name: training_loss
class: Loss
controllers:
- name: loss_controller_wrong_os_rule
Expand Down
7 changes: 0 additions & 7 deletions tests/trainercontroller/custom_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,6 @@
class CustomOperation(Operation):
"""Implements a custom operation for testing"""

def __init__(self, **_):
"""Initializes the custom operation class.
Args:
kwargs: List of arguments (key, value)-pairs
"""
super().__init__()

def should_perform_action_xyz(self, control: TrainerControl, **_):
"""This method performs a set training stop flag action.

Expand Down
7 changes: 0 additions & 7 deletions tests/trainercontroller/custom_operation_invalid_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,6 @@
class CustomOperationInvalidAction(Operation):
"""Implements a custom operation for testing"""

def __init__(self, **_):
"""Initializes the custom operation class.
Args:
kwargs: List of arguments (key, value)-pairs
"""
super().__init__()

def should_(self, control: TrainerControl, **_):
"""This method defines an action within an invalid name.

Expand Down
2 changes: 1 addition & 1 deletion tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def train(
trainer_controller_args.trainer_controller_config_file is not None
):
tc_callback = TrainerControllerCallback(
trainer_controller_args.trainer_controller_config_file
trainer_controller_args.trainer_controller_config_file,
)
trainer_callbacks.append(tc_callback)

Expand Down
1 change: 1 addition & 0 deletions tuning/trainercontroller/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def _take_control_actions(self, event_name: str, **kwargs):
operation_action.instance.act(
action=operation_action.action,
event_name=event_name,
tc_metrics=self.metrics,
**kwargs,
)

Expand Down
2 changes: 1 addition & 1 deletion tuning/trainercontroller/controllermetrics/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@ def compute(self, state: TrainerState = None, **kwargs) -> Any:
log = state.log_history[i]
if "loss" not in log:
continue
return float(log["loss"])
return log
2 changes: 1 addition & 1 deletion tuning/trainercontroller/operations/hfcontrols.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, **kwargs):
for control_field in fields(TrainerControl):
if re.search(r"^should_.+", control_field.name) is not None:
setattr(self, control_field.name, self.control_action)
super().__init__()
super().__init__(**kwargs)

def control_action(self, control: TrainerControl, **kwargs):
"""This method peeks into the stack-frame of the caller to get the action the triggered
Expand Down
4 changes: 3 additions & 1 deletion tuning/trainercontroller/operations/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
class Operation(metaclass=abc.ABCMeta):
"""Base class for operations"""

def __init__(self):
def __init__(self, name: str, **kwargs):
"""Initializes the HuggingFace controls. In this init, we follow the convention that
every action should preceed with prefix `should_`. If so, it is treated as a valid
action.
"""
self.valid_actions = {}
self.name = name
self.kwargs = kwargs
for action_name, action_method in inspect.getmembers(
self, predicate=inspect.ismethod
):
Expand Down