Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add first version of a Comet integration #1

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[settings]
profile=black
known_third_party=wandb
known_third_party=wandb,comet_ml
3 changes: 3 additions & 0 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
normalize_config,
Expand Down Expand Up @@ -421,6 +422,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):

setup_mlflow_env_vars(cfg)

setup_comet_env_vars(cfg)

return cfg


Expand Down
12 changes: 12 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
EvalFirstStepCallback,
GPUStatsCallback,
LossWatchDogCallback,
SaveAxolotlConfigtoCometCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
SaveModelCallback,
Expand Down Expand Up @@ -1104,6 +1105,10 @@ def get_callbacks(self) -> List[TrainerCallback]:
callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_comet:
callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)

return callbacks

Expand Down Expand Up @@ -1172,6 +1177,11 @@ def get_post_trainer_create_callbacks(self, trainer):
trainer, self.tokenizer, "mlflow"
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.use_comet and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "comet_ml"
)
callbacks.append(LogPredictionCallback(self.cfg))

if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
Expand Down Expand Up @@ -1423,6 +1433,8 @@ def build(self, total_num_steps):
report_to.append("mlflow")
if self.cfg.use_tensorboard:
report_to.append("tensorboard")
if self.cfg.use_comet:
report_to.append("comet_ml")

training_arguments_kwargs["report_to"] = report_to
training_arguments_kwargs["run_name"] = (
Expand Down
37 changes: 37 additions & 0 deletions src/axolotl/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Dict, List

import comet_ml
import evaluate
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -747,6 +748,13 @@ def log_table_from_dataloader(name: str, table_dataloader):
artifact_file="PredictionsVsGroundTruth.json",
tracking_uri=tracking_uri,
)
elif logger == "comet_ml":
experiment = comet_ml.get_running_experiment()
if experiment:
experiment.log_table(
f"{name} - Predictions vs Ground Truth.csv",
pd.DataFrame(table_data),
)

if is_main_process():
log_table_from_dataloader("Eval", eval_dataloader)
Expand Down Expand Up @@ -790,6 +798,35 @@ def on_train_begin(
return control


class SaveAxolotlConfigtoCometCallback(TrainerCallback):
"""Callback to save axolotl config to comet"""

def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path

def on_train_begin(
self,
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
try:
comet_experiment = comet_ml.start(source="axolotl")
comet_experiment._log_asset(
self.axolotl_config_path,
file_name="axolotl-config",
framework="axolotl",
)
LOG.info(
"The Axolotl config has been saved to the Comet Experiment under assets."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to Comet: {err}")
return control


class SaveModelCallback(TrainerCallback):
"""Callback to save model on train end"""

Expand Down
93 changes: 93 additions & 0 deletions src/axolotl/utils/comet_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Module for wandb utilities"""

import logging
import os

from axolotl.utils.dict import DictDefault

LOG = logging.getLogger("axolotl.utils.comet_")

COMET_ENV_MAPPING_OVERRIDE = {
"comet_mode": "COMET_START_MODE",
"comet_online": "COMET_START_ONLINE",
}
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE = {
"auto_histogram_activation_logging": "COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS",
"auto_histogram_epoch_rate": "COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE",
"auto_histogram_gradient_logging": "COMET_AUTO_LOG_HISTOGRAM_GRADIENTS",
"auto_histogram_tensorboard_logging": "COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD",
"auto_histogram_weight_logging": "COMET_AUTO_LOG_HISTOGRAM_WEIGHTS",
"auto_log_co2": "COMET_AUTO_LOG_CO2",
"auto_metric_logging": "COMET_AUTO_LOG_METRICS",
"auto_metric_step_rate": "COMET_AUTO_LOG_METRIC_STEP_RATE",
"auto_output_logging": "COMET_AUTO_LOG_OUTPUT_LOGGER",
"auto_param_logging": "COMET_AUTO_LOG_PARAMETERS",
"comet_disabled": "COMET_AUTO_LOG_DISABLE",
"display_summary_level": "COMET_DISPLAY_SUMMARY_LEVEL",
"distributed_node_identifier": "COMET_DISTRIBUTED_NODE_IDENTIFIER",
"log_code": "COMET_AUTO_LOG_CODE",
"log_env_cpu": "COMET_AUTO_LOG_ENV_CPU",
"log_env_details": "COMET_AUTO_LOG_ENV_DETAILS",
"log_env_disk": "COMET_AUTO_LOG_ENV_DISK",
"log_env_gpu": "COMET_AUTO_LOG_ENV_GPU",
"log_env_host": "COMET_AUTO_LOG_ENV_HOST",
"log_env_network": "COMET_AUTO_LOG_ENV_NETWORK",
"log_git_metadata": "COMET_AUTO_LOG_GIT_METADATA",
"log_git_patch": "COMET_AUTO_LOG_GIT_PATCH",
"log_graph": "COMET_AUTO_LOG_GRAPH",
"name": "COMET_START_EXPERIMENT_NAME",
"offline_directory": "COMET_OFFLINE_DIRECTORY",
"parse_args": "COMET_AUTO_LOG_CLI_ARGUMENTS",
"tags": "COMET_START_EXPERIMENT_TAGS",
}


def python_value_to_environ_value(python_value):
if isinstance(python_value, bool):
if python_value is True:
return "true"

return "false"

if isinstance(python_value, int):
return str(python_value)

if isinstance(python_value, list): # Comet only have one list of string parameter
return ",".join(map(str, python_value))

return python_value


def setup_comet_env_vars(cfg: DictDefault):
# TODO, we need to convert Axolotl configuration to environment variables
# as Transformers integration are call first and would create an
# Experiment first

for key in cfg.keys():
if key.startswith("comet_") and key != "comet_experiment_config":
value = cfg.get(key, "")

if value is not None and value != "":
env_variable_name = COMET_ENV_MAPPING_OVERRIDE.get(key, key.upper())
final_value = python_value_to_environ_value(value)
os.environ[env_variable_name] = final_value

if cfg.comet_experiment_config:
for key, value in cfg.comet_experiment_config.items():
if value is not None and value != "":
config_env_variable_name = (
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE.get(key)
)

if config_env_variable_name is None:
LOG.warning(
f"Unknown Comet Experiment Config name {key}, ignoring it"
)
continue

final_value = python_value_to_environ_value(value)
os.environ[config_env_variable_name] = final_value

# Enable comet if project name is present
if cfg.comet_project_name and len(cfg.comet_project_name) > 0:
cfg.use_comet = True
14 changes: 14 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,19 @@ def check_wandb_run(cls, data):
return data


class CometConfig(BaseModel):
"""Comet configuration subset"""

use_comet: Optional[bool] = None
comet_api_key: Optional[str] = None
comet_workspace: Optional[str] = None
comet_project_name: Optional[str] = None
comet_experiment_key: Optional[str] = None
comet_mode: Optional[str] = None
comet_online: Optional[bool] = None
comet_experiment_config: Optional[Dict[str, Any]] = None


class GradioConfig(BaseModel):
"""Gradio configuration subset"""

Expand All @@ -504,6 +517,7 @@ class AxolotlInputConfig(
HyperparametersConfig,
WandbConfig,
MLFlowConfig,
CometConfig,
LISAConfig,
GradioConfig,
RemappedParameters,
Expand Down
100 changes: 100 additions & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest
from pydantic import ValidationError

from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import validate_config
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
from axolotl.utils.dict import DictDefault
Expand Down Expand Up @@ -1329,3 +1330,102 @@ def test_wandb_set_disabled(self, minimal_cfg):

os.environ.pop("WANDB_PROJECT", None)
os.environ.pop("WANDB_DISABLED", None)


class TestValidationComet(BaseValidation):
"""
Validation test for comet
"""

def test_comet_sets_env(self, minimal_cfg):
comet_config = {
"comet_api_key": "foo",
"comet_workspace": "some_workspace",
"comet_project_name": "some_project",
"comet_experiment_key": "some_experiment_key",
"comet_mode": "get_or_create",
"comet_online": False,
"comet_experiment_config": {
"auto_histogram_activation_logging": False,
"auto_histogram_epoch_rate": 2,
"auto_histogram_gradient_logging": True,
"auto_histogram_tensorboard_logging": False,
"auto_histogram_weight_logging": True,
"auto_log_co2": False,
"auto_metric_logging": True,
"auto_metric_step_rate": 15,
"auto_output_logging": False,
"auto_param_logging": True,
"comet_disabled": False,
"display_summary_level": 2,
"distributed_node_identifier": "some_distributed_node_identifier",
"log_code": True,
"log_env_cpu": False,
"log_env_details": True,
"log_env_disk": False,
"log_env_gpu": True,
"log_env_host": False,
"log_env_network": True,
"log_git_metadata": False,
"log_git_patch": True,
"log_graph": False,
"name": "some_name",
"offline_directory": "some_offline_directory",
"parse_args": True,
"tags": ["tag1", "tag2"],
},
}

cfg = DictDefault(comet_config) | minimal_cfg

new_cfg = validate_config(cfg)

setup_comet_env_vars(new_cfg)

comet_env = {
key: value for key, value in os.environ.items() if key.startswith("COMET_")
}

assert (
len(comet_env)
== len(comet_config) + len(comet_config["comet_experiment_config"]) - 1
)

assert comet_env == {
"COMET_API_KEY": "foo",
"COMET_AUTO_LOG_CLI_ARGUMENTS": "true",
"COMET_AUTO_LOG_CO2": "false",
"COMET_AUTO_LOG_CODE": "true",
"COMET_AUTO_LOG_DISABLE": "false",
"COMET_AUTO_LOG_ENV_CPU": "false",
"COMET_AUTO_LOG_ENV_DETAILS": "true",
"COMET_AUTO_LOG_ENV_DISK": "false",
"COMET_AUTO_LOG_ENV_GPU": "true",
"COMET_AUTO_LOG_ENV_HOST": "false",
"COMET_AUTO_LOG_ENV_NETWORK": "true",
"COMET_AUTO_LOG_GIT_METADATA": "false",
"COMET_AUTO_LOG_GIT_PATCH": "true",
"COMET_AUTO_LOG_GRAPH": "false",
"COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS": "false",
"COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE": "2",
"COMET_AUTO_LOG_HISTOGRAM_GRADIENTS": "true",
"COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD": "false",
"COMET_AUTO_LOG_HISTOGRAM_WEIGHTS": "true",
"COMET_AUTO_LOG_METRIC_STEP_RATE": "15",
"COMET_AUTO_LOG_METRICS": "true",
"COMET_AUTO_LOG_OUTPUT_LOGGER": "false",
"COMET_AUTO_LOG_PARAMETERS": "true",
"COMET_DISPLAY_SUMMARY_LEVEL": "2",
"COMET_DISTRIBUTED_NODE_IDENTIFIER": "some_distributed_node_identifier",
"COMET_EXPERIMENT_KEY": "some_experiment_key",
"COMET_OFFLINE_DIRECTORY": "some_offline_directory",
"COMET_PROJECT_NAME": "some_project",
"COMET_START_EXPERIMENT_NAME": "some_name",
"COMET_START_EXPERIMENT_TAGS": "tag1,tag2",
"COMET_START_MODE": "get_or_create",
"COMET_START_ONLINE": "false",
"COMET_WORKSPACE": "some_workspace",
}

for key in comet_env.keys():
os.environ.pop(key, None)