diff --git a/.isort.cfg b/.isort.cfg index 79067a7c9..e48779732 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,3 +1,3 @@ [settings] profile=black -known_third_party=wandb +known_third_party=wandb,comet_ml diff --git a/README.md b/README.md index c84f1cb8c..f6f4e4e80 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ Features: - Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking - Works with single GPU or multiple GPUs via FSDP or Deepspeed - Easily run with Docker locally or on the cloud -- Log results and optionally checkpoints to wandb or mlflow +- Log results and optionally checkpoints to wandb, mlflow or Comet - And more! @@ -515,6 +515,22 @@ wandb_name: wandb_log_model: ``` +##### Comet Logging + +Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to wandb with `comet login`. + +- wandb options +```yaml +use_comet: +comet_api_key: +comet_workspace: +comet_project_name: +comet_experiment_key: +comet_mode: +comet_online: +comet_experiment_config: +``` + ##### Special Tokens It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this: diff --git a/docs/config.qmd b/docs/config.qmd index e85999978..99a69a097 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -267,6 +267,18 @@ mlflow_tracking_uri: # URI to mlflow mlflow_experiment_name: # Your experiment name hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry +# Comet configuration if you're using it +# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`. +# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start +use_comet: # Enable or disable Comet integration. +comet_api_key: # API key for Comet. Recommended to set via `comet login`. +comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace. +comet_project_name: # Project name in Comet. Defaults to Uncategorized. +comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key. +comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration. +comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True. +comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details. + # Where to save the full-finetuned model to output_dir: ./completed-model diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 13c5b4ab5..32a8b8c41 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -31,6 +31,7 @@ from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.comet_ import setup_comet_env_vars from axolotl.utils.config import ( normalize_cfg_datasets, normalize_config, @@ -421,6 +422,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): setup_mlflow_env_vars(cfg) + setup_comet_env_vars(cfg) + return cfg diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 249398f85..79f453982 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -48,7 +48,7 @@ from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler -from axolotl.utils import is_mlflow_available +from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.callbacks import ( EvalFirstStepCallback, GPUStatsCallback, @@ -1104,6 +1104,12 @@ def get_callbacks(self) -> List[TrainerCallback]: callbacks.append( SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path) ) + if self.cfg.use_comet and is_comet_available(): + from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback + + callbacks.append( + SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path) + ) return callbacks @@ -1172,6 +1178,11 @@ def get_post_trainer_create_callbacks(self, trainer): trainer, self.tokenizer, "mlflow" ) callbacks.append(LogPredictionCallback(self.cfg)) + if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0: + LogPredictionCallback = log_prediction_callback_factory( + trainer, self.tokenizer, "comet_ml" + ) + callbacks.append(LogPredictionCallback(self.cfg)) if self.cfg.do_bench_eval: callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer)) @@ -1423,6 +1434,8 @@ def build(self, total_num_steps): report_to.append("mlflow") if self.cfg.use_tensorboard: report_to.append("tensorboard") + if self.cfg.use_comet: + report_to.append("comet_ml") training_arguments_kwargs["report_to"] = report_to training_arguments_kwargs["run_name"] = ( diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py index 99dec79f1..91545009a 100644 --- a/src/axolotl/utils/__init__.py +++ b/src/axolotl/utils/__init__.py @@ -1,8 +1,12 @@ """ Basic utils for Axolotl """ -import importlib +import importlib.util def is_mlflow_available(): return importlib.util.find_spec("mlflow") is not None + + +def is_comet_available(): + return importlib.util.find_spec("comet_ml") is not None diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 73715b06a..5ec7a5566 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -29,7 +29,7 @@ ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy -from axolotl.utils import is_mlflow_available +from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.callbacks.perplexity import Perplexity from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig @@ -747,6 +747,15 @@ def log_table_from_dataloader(name: str, table_dataloader): artifact_file="PredictionsVsGroundTruth.json", tracking_uri=tracking_uri, ) + elif logger == "comet_ml" and is_comet_available(): + import comet_ml + + experiment = comet_ml.get_running_experiment() + if experiment: + experiment.log_table( + f"{name} - Predictions vs Ground Truth.csv", + pd.DataFrame(table_data), + ) if is_main_process(): log_table_from_dataloader("Eval", eval_dataloader) diff --git a/src/axolotl/utils/callbacks/comet_.py b/src/axolotl/utils/callbacks/comet_.py new file mode 100644 index 000000000..b29f997a8 --- /dev/null +++ b/src/axolotl/utils/callbacks/comet_.py @@ -0,0 +1,43 @@ +"""Comet module for trainer callbacks""" + +import logging +from typing import TYPE_CHECKING + +import comet_ml +from transformers import TrainerCallback, TrainerControl, TrainerState + +from axolotl.utils.distributed import is_main_process + +if TYPE_CHECKING: + from axolotl.core.trainer_builder import AxolotlTrainingArguments + +LOG = logging.getLogger("axolotl.callbacks") + + +class SaveAxolotlConfigtoCometCallback(TrainerCallback): + """Callback to save axolotl config to comet""" + + def __init__(self, axolotl_config_path): + self.axolotl_config_path = axolotl_config_path + + def on_train_begin( + self, + args: "AxolotlTrainingArguments", # pylint: disable=unused-argument + state: TrainerState, # pylint: disable=unused-argument + control: TrainerControl, + **kwargs, # pylint: disable=unused-argument + ): + if is_main_process(): + try: + comet_experiment = comet_ml.start(source="axolotl") + comet_experiment.log_other("Created from", "axolotl") + comet_experiment.log_asset( + self.axolotl_config_path, + file_name="axolotl-config", + ) + LOG.info( + "The Axolotl config has been saved to the Comet Experiment under assets." + ) + except (FileNotFoundError, ConnectionError) as err: + LOG.warning(f"Error while saving Axolotl config to Comet: {err}") + return control diff --git a/src/axolotl/utils/comet_.py b/src/axolotl/utils/comet_.py new file mode 100644 index 000000000..b4ecc80ad --- /dev/null +++ b/src/axolotl/utils/comet_.py @@ -0,0 +1,93 @@ +"""Module for wandb utilities""" + +import logging +import os + +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger("axolotl.utils.comet_") + +COMET_ENV_MAPPING_OVERRIDE = { + "comet_mode": "COMET_START_MODE", + "comet_online": "COMET_START_ONLINE", +} +COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE = { + "auto_histogram_activation_logging": "COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS", + "auto_histogram_epoch_rate": "COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE", + "auto_histogram_gradient_logging": "COMET_AUTO_LOG_HISTOGRAM_GRADIENTS", + "auto_histogram_tensorboard_logging": "COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD", + "auto_histogram_weight_logging": "COMET_AUTO_LOG_HISTOGRAM_WEIGHTS", + "auto_log_co2": "COMET_AUTO_LOG_CO2", + "auto_metric_logging": "COMET_AUTO_LOG_METRICS", + "auto_metric_step_rate": "COMET_AUTO_LOG_METRIC_STEP_RATE", + "auto_output_logging": "COMET_AUTO_LOG_OUTPUT_LOGGER", + "auto_param_logging": "COMET_AUTO_LOG_PARAMETERS", + "comet_disabled": "COMET_AUTO_LOG_DISABLE", + "display_summary_level": "COMET_DISPLAY_SUMMARY_LEVEL", + "distributed_node_identifier": "COMET_DISTRIBUTED_NODE_IDENTIFIER", + "log_code": "COMET_AUTO_LOG_CODE", + "log_env_cpu": "COMET_AUTO_LOG_ENV_CPU", + "log_env_details": "COMET_AUTO_LOG_ENV_DETAILS", + "log_env_disk": "COMET_AUTO_LOG_ENV_DISK", + "log_env_gpu": "COMET_AUTO_LOG_ENV_GPU", + "log_env_host": "COMET_AUTO_LOG_ENV_HOST", + "log_env_network": "COMET_AUTO_LOG_ENV_NETWORK", + "log_git_metadata": "COMET_AUTO_LOG_GIT_METADATA", + "log_git_patch": "COMET_AUTO_LOG_GIT_PATCH", + "log_graph": "COMET_AUTO_LOG_GRAPH", + "name": "COMET_START_EXPERIMENT_NAME", + "offline_directory": "COMET_OFFLINE_DIRECTORY", + "parse_args": "COMET_AUTO_LOG_CLI_ARGUMENTS", + "tags": "COMET_START_EXPERIMENT_TAGS", +} + + +def python_value_to_environ_value(python_value): + if isinstance(python_value, bool): + if python_value is True: + return "true" + + return "false" + + if isinstance(python_value, int): + return str(python_value) + + if isinstance(python_value, list): # Comet only have one list of string parameter + return ",".join(map(str, python_value)) + + return python_value + + +def setup_comet_env_vars(cfg: DictDefault): + # TODO, we need to convert Axolotl configuration to environment variables + # as Transformers integration are call first and would create an + # Experiment first + + for key in cfg.keys(): + if key.startswith("comet_") and key != "comet_experiment_config": + value = cfg.get(key, "") + + if value is not None and value != "": + env_variable_name = COMET_ENV_MAPPING_OVERRIDE.get(key, key.upper()) + final_value = python_value_to_environ_value(value) + os.environ[env_variable_name] = final_value + + if cfg.comet_experiment_config: + for key, value in cfg.comet_experiment_config.items(): + if value is not None and value != "": + config_env_variable_name = ( + COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE.get(key) + ) + + if config_env_variable_name is None: + LOG.warning( + f"Unknown Comet Experiment Config name {key}, ignoring it" + ) + continue + + final_value = python_value_to_environ_value(value) + os.environ[config_env_variable_name] = final_value + + # Enable comet if project name is present + if cfg.comet_project_name and len(cfg.comet_project_name) > 0: + cfg.use_comet = True diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 4e07c9260..e8c6ce762 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -484,6 +484,19 @@ def check_wandb_run(cls, data): return data +class CometConfig(BaseModel): + """Comet configuration subset""" + + use_comet: Optional[bool] = None + comet_api_key: Optional[str] = None + comet_workspace: Optional[str] = None + comet_project_name: Optional[str] = None + comet_experiment_key: Optional[str] = None + comet_mode: Optional[str] = None + comet_online: Optional[bool] = None + comet_experiment_config: Optional[Dict[str, Any]] = None + + class GradioConfig(BaseModel): """Gradio configuration subset""" @@ -504,6 +517,7 @@ class AxolotlInputConfig( HyperparametersConfig, WandbConfig, MLFlowConfig, + CometConfig, LISAConfig, GradioConfig, RemappedParameters, diff --git a/tests/test_validation.py b/tests/test_validation.py index 35d0e265e..6e0d0ad2a 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -9,6 +9,7 @@ import pytest from pydantic import ValidationError +from axolotl.utils import is_comet_available from axolotl.utils.config import validate_config from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities from axolotl.utils.dict import DictDefault @@ -1329,3 +1330,105 @@ def test_wandb_set_disabled(self, minimal_cfg): os.environ.pop("WANDB_PROJECT", None) os.environ.pop("WANDB_DISABLED", None) + + +@pytest.mark.skipif(is_comet_available() is False, reason="comet_ml is not installed") +class TestValidationComet(BaseValidation): + """ + Validation test for comet + """ + + def test_comet_sets_env(self, minimal_cfg): + from axolotl.utils.comet_ import setup_comet_env_vars + + comet_config = { + "comet_api_key": "foo", + "comet_workspace": "some_workspace", + "comet_project_name": "some_project", + "comet_experiment_key": "some_experiment_key", + "comet_mode": "get_or_create", + "comet_online": False, + "comet_experiment_config": { + "auto_histogram_activation_logging": False, + "auto_histogram_epoch_rate": 2, + "auto_histogram_gradient_logging": True, + "auto_histogram_tensorboard_logging": False, + "auto_histogram_weight_logging": True, + "auto_log_co2": False, + "auto_metric_logging": True, + "auto_metric_step_rate": 15, + "auto_output_logging": False, + "auto_param_logging": True, + "comet_disabled": False, + "display_summary_level": 2, + "distributed_node_identifier": "some_distributed_node_identifier", + "log_code": True, + "log_env_cpu": False, + "log_env_details": True, + "log_env_disk": False, + "log_env_gpu": True, + "log_env_host": False, + "log_env_network": True, + "log_git_metadata": False, + "log_git_patch": True, + "log_graph": False, + "name": "some_name", + "offline_directory": "some_offline_directory", + "parse_args": True, + "tags": ["tag1", "tag2"], + }, + } + + cfg = DictDefault(comet_config) | minimal_cfg + + new_cfg = validate_config(cfg) + + setup_comet_env_vars(new_cfg) + + comet_env = { + key: value for key, value in os.environ.items() if key.startswith("COMET_") + } + + assert ( + len(comet_env) + == len(comet_config) + len(comet_config["comet_experiment_config"]) - 1 + ) + + assert comet_env == { + "COMET_API_KEY": "foo", + "COMET_AUTO_LOG_CLI_ARGUMENTS": "true", + "COMET_AUTO_LOG_CO2": "false", + "COMET_AUTO_LOG_CODE": "true", + "COMET_AUTO_LOG_DISABLE": "false", + "COMET_AUTO_LOG_ENV_CPU": "false", + "COMET_AUTO_LOG_ENV_DETAILS": "true", + "COMET_AUTO_LOG_ENV_DISK": "false", + "COMET_AUTO_LOG_ENV_GPU": "true", + "COMET_AUTO_LOG_ENV_HOST": "false", + "COMET_AUTO_LOG_ENV_NETWORK": "true", + "COMET_AUTO_LOG_GIT_METADATA": "false", + "COMET_AUTO_LOG_GIT_PATCH": "true", + "COMET_AUTO_LOG_GRAPH": "false", + "COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS": "false", + "COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE": "2", + "COMET_AUTO_LOG_HISTOGRAM_GRADIENTS": "true", + "COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD": "false", + "COMET_AUTO_LOG_HISTOGRAM_WEIGHTS": "true", + "COMET_AUTO_LOG_METRIC_STEP_RATE": "15", + "COMET_AUTO_LOG_METRICS": "true", + "COMET_AUTO_LOG_OUTPUT_LOGGER": "false", + "COMET_AUTO_LOG_PARAMETERS": "true", + "COMET_DISPLAY_SUMMARY_LEVEL": "2", + "COMET_DISTRIBUTED_NODE_IDENTIFIER": "some_distributed_node_identifier", + "COMET_EXPERIMENT_KEY": "some_experiment_key", + "COMET_OFFLINE_DIRECTORY": "some_offline_directory", + "COMET_PROJECT_NAME": "some_project", + "COMET_START_EXPERIMENT_NAME": "some_name", + "COMET_START_EXPERIMENT_TAGS": "tag1,tag2", + "COMET_START_MODE": "get_or_create", + "COMET_START_ONLINE": "false", + "COMET_WORKSPACE": "some_workspace", + } + + for key in comet_env.keys(): + os.environ.pop(key, None)