diff --git a/tuning/trainercontroller/controllermetrics/__init__.py b/tuning/trainercontroller/controllermetrics/__init__.py index 1f9f76705..6a8165852 100644 --- a/tuning/trainercontroller/controllermetrics/__init__.py +++ b/tuning/trainercontroller/controllermetrics/__init__.py @@ -23,6 +23,7 @@ from .history_based_metrics import HistoryBasedMetric from .loss import Loss from .trainingstate import TrainingState +from tuning.trainercontroller.controllermetrics.per_process_state import PerProcessState # List of metric handlers handlers = [] @@ -39,6 +40,7 @@ def register(cl: Type): # Register the default metric handlers in this package here register(TrainingState) +register(PerProcessState) register(EvalMetrics) register(Loss) register(HistoryBasedMetric) diff --git a/tuning/trainercontroller/controllermetrics/metrics.yaml b/tuning/trainercontroller/controllermetrics/metrics.yaml new file mode 100644 index 000000000..d3a8a32de --- /dev/null +++ b/tuning/trainercontroller/controllermetrics/metrics.yaml @@ -0,0 +1,9 @@ +controller-metrics: + - name: loss + class: Loss + - name: state + class: TrainingState + - name: eval_metrics + class: EvalMetrics + - name: per_process_state + class: PerProcessState diff --git a/tuning/trainercontroller/controllermetrics/per_process_state.py b/tuning/trainercontroller/controllermetrics/per_process_state.py new file mode 100644 index 000000000..58a96de37 --- /dev/null +++ b/tuning/trainercontroller/controllermetrics/per_process_state.py @@ -0,0 +1,77 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Standard +from typing import Any + +# Third Party +from transformers import TrainerState +import torch + +# Local +from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler + + +class PerProcessState(MetricHandler): + """Implements the controller metric which exposes the per process state""" + + def __init__(self, **kwargs): + """Initializes the metric handler, by registering the event \ + list and arguments with base handler. + + Args: + kwargs: List of arguments (key, value)-pairs + """ + super().__init__( + events=[ + "on_init_end", + "on_step_end", + "on_epoch_begin", + "on_epoch_end", + "on_prediction_step", + "on_predict", + "on_log", + "on_train_end", + "on_train_begin", + "on_evaluate", + "on_save", + ], + **kwargs, + ) + + def validate(self) -> bool: + """Validate the training arguments (e.g logging_steps) are \ + compatible with the computation of this metric. + + Returns: + bool + """ + return True + + def compute(self, _: TrainerState = None, **kwargs) -> Any: + """Exposes the trainer state. + + Args: + state: TrainerState object + kwargs: Remaining event arguments + + Returns: + dict. Trainer state as a dictionary + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + return {"rank": torch.distributed.get_rank()} + return {"rank": None}