diff --git a/examples/llama-3/lora-1b.yml b/examples/llama-3/lora-1b.yml index c31a9f39a4..acc17e21f2 100644 --- a/examples/llama-3/lora-1b.yml +++ b/examples/llama-3/lora-1b.yml @@ -5,7 +5,7 @@ base_model: NousResearch/Llama-3.2-1B datasets: - path: teknium/GPT4-LLM-Cleaned type: alpaca -dataset_prepared_path: last_run_prepared + val_set_size: 0.1 output_dir: ./outputs/lora-out @@ -38,6 +38,7 @@ wandb_log_model: gradient_accumulation_steps: 2 micro_batch_size: 2 num_epochs: 1 + optimizer: adamw_8bit lr_scheduler: cosine learning_rate: 0.0002 diff --git a/src/axolotl/cli/checks.py b/src/axolotl/cli/checks.py index 47348240eb..10086c2a4c 100644 --- a/src/axolotl/cli/checks.py +++ b/src/axolotl/cli/checks.py @@ -1,6 +1,5 @@ """Various checks for Axolotl CLI.""" -import logging import os from pathlib import Path @@ -8,7 +7,9 @@ from huggingface_hub import HfApi from huggingface_hub.utils import LocalTokenNotFoundError -LOG = logging.getLogger(__name__) +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) def check_accelerate_default_config() -> None: diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index 8f1fe7185d..d55448da4d 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -1,7 +1,6 @@ """Configuration loading and processing.""" import json -import logging import os import tempfile from pathlib import Path @@ -22,11 +21,12 @@ validate_config, ) from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from axolotl.utils.mlflow_ import setup_mlflow_env_vars from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env from axolotl.utils.wandb_ import setup_wandb_env_vars -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__, use_environ=True) def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: @@ -119,12 +119,12 @@ def choose_config(path: Path) -> str: ) if len(yaml_files) == 1: - print(f"Using default YAML file '{yaml_files[0]}'") + LOG.info(f"Using default YAML file '{yaml_files[0]}'") return str(yaml_files[0]) - print("Choose a YAML file:") + LOG.info("Choose a YAML file:") for idx, file in enumerate(yaml_files): - print(f"{idx + 1}. {file}") + LOG.info(f"{idx + 1}. {file}") chosen_file = None while chosen_file is None: @@ -133,9 +133,9 @@ def choose_config(path: Path) -> str: if 1 <= choice <= len(yaml_files): chosen_file = str(yaml_files[choice - 1]) else: - print("Invalid choice. Please choose a number from the list.") + LOG.info("Invalid choice. Please choose a number from the list.") except ValueError: - print("Invalid input. Please enter a number.") + LOG.info("Invalid input. Please enter a number.") return chosen_file diff --git a/src/axolotl/cli/evaluate.py b/src/axolotl/cli/evaluate.py index e52da66b72..f131f70830 100644 --- a/src/axolotl/cli/evaluate.py +++ b/src/axolotl/cli/evaluate.py @@ -1,6 +1,5 @@ """CLI to run evaluation on a model.""" -import logging import os from pathlib import Path from typing import Union @@ -17,8 +16,9 @@ from axolotl.evaluate import evaluate from axolotl.utils import patch_optimized_env from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index a4906bbf36..b5bc158fa1 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -1,7 +1,6 @@ """CLI to run inference on a trained model.""" import importlib -import logging import sys from pathlib import Path from threading import Thread @@ -22,8 +21,9 @@ get_chat_template_from_config, ) from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) def get_multi_line_input() -> str: diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index e61dad5d63..3dafa552bd 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -2,7 +2,6 @@ # pylint: disable=redefined-outer-name -import logging import os import subprocess # nosec B404 import tempfile @@ -31,8 +30,11 @@ ) from axolotl.integrations.lm_eval.cli import lm_eval from axolotl.utils import patch_optimized_env +from axolotl.utils.logging import get_logger from axolotl.utils.schemas.config import AxolotlInputConfig +LOG = get_logger(__name__) + @click.group() @click.version_option(version=axolotl.__version__, prog_name="axolotl") @@ -177,7 +179,7 @@ def iter_configs(): do_cli(config=cfg_file, **kwargs) except subprocess.CalledProcessError as exc: - logging.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}") + LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}") if not sweep: raise exc diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 5c8802dd11..2e59d25374 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -1,6 +1,5 @@ """CLI to merge a trained LoRA into a base model.""" -import logging from pathlib import Path from typing import Union @@ -13,8 +12,9 @@ from axolotl.cli.config import load_cfg from axolotl.cli.utils import load_model_and_tokenizer from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) def do_merge_lora(*, cfg: DictDefault) -> None: diff --git a/src/axolotl/cli/merge_sharded_fsdp_weights.py b/src/axolotl/cli/merge_sharded_fsdp_weights.py index d4b36d92c6..297d7946e4 100644 --- a/src/axolotl/cli/merge_sharded_fsdp_weights.py +++ b/src/axolotl/cli/merge_sharded_fsdp_weights.py @@ -1,7 +1,6 @@ """CLI to merge sharded FSDP model checkpoints into a single combined checkpoint.""" import json -import logging import os import shutil from pathlib import Path @@ -27,8 +26,9 @@ from axolotl.cli.args import TrainerCliArgs from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.config import load_cfg +from axolotl.utils.logging import get_logger -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) class BFloat16CastPlanner(_EmptyStateDictLoadPlanner): diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index 2a4dcd2886..9f96f5cc17 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -1,6 +1,5 @@ """CLI to run preprocessing of a dataset.""" -import logging import warnings from pathlib import Path from typing import Union @@ -20,9 +19,10 @@ from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.integrations.base import PluginManager from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from axolotl.utils.trainer import disable_datasets_caching -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None: diff --git a/src/axolotl/cli/quantize.py b/src/axolotl/cli/quantize.py index 2036fddeac..63d51fadf0 100644 --- a/src/axolotl/cli/quantize.py +++ b/src/axolotl/cli/quantize.py @@ -2,7 +2,6 @@ CLI to post-training quantize a model using torchao """ -import logging from pathlib import Path from typing import Union @@ -11,9 +10,10 @@ from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.config import load_cfg from axolotl.loaders import load_tokenizer +from axolotl.utils.logging import get_logger from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) def do_quantize( diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 777d848853..fef80fdbaf 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -1,7 +1,6 @@ """CLI to run training on a model.""" import gc -import logging import os from pathlib import Path from typing import Union @@ -22,8 +21,6 @@ from axolotl.utils.config import normalize_config, resolve_dtype from axolotl.utils.dict import DictDefault -LOG = logging.getLogger(__name__) - def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): """ diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py index e681589f3b..d287953613 100644 --- a/src/axolotl/cli/utils.py +++ b/src/axolotl/cli/utils.py @@ -4,7 +4,6 @@ import dataclasses import hashlib import json -import logging from functools import wraps from pathlib import Path from types import NoneType @@ -23,8 +22,9 @@ from axolotl.loaders import load_processor, load_tokenizer from axolotl.loaders.model import ModelLoader from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) def strip_optional_type(field_type: type | str | None): diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index e3ffb7ae9a..d9c3841124 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -1,6 +1,5 @@ """Dataset loading utilities.""" -import logging import math import random from dataclasses import dataclass @@ -14,10 +13,11 @@ from axolotl.utils.data import prepare_dataset from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import RLType from axolotl.utils.tokenization import check_dataset_labels -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) @dataclass diff --git a/src/axolotl/core/chat/messages.py b/src/axolotl/core/chat/messages.py index 88ff2b7ad0..923b177c1f 100644 --- a/src/axolotl/core/chat/messages.py +++ b/src/axolotl/core/chat/messages.py @@ -156,7 +156,6 @@ def tokenized( len(input_ids) : len(input_ids) + len(pending_input_ids) ] if new_pending_inputs != pending_input_ids: - # logging.warning("tokenization mismatch from concatenation.") pending_input_ids = new_pending_inputs input_ids.extend(pending_input_ids) if pending_weight: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 08759d9f9f..46ec12ccb6 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -19,7 +19,6 @@ import importlib import importlib.util import inspect -import logging import math import os import sys @@ -88,6 +87,7 @@ V2BatchSamplerDataCollatorForSeq2Seq, ) from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator +from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import CustomSupportedOptimizers, RLType try: @@ -95,7 +95,7 @@ except ImportError: pass -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) class TrainerBuilderBase(abc.ABC): diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index d5cfc23df0..25e9f9f0ae 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -4,7 +4,6 @@ from __future__ import annotations -import logging import os from collections import defaultdict from functools import wraps @@ -34,9 +33,10 @@ sanitize_kwargs_for_ds_tagging, sanitize_kwargs_for_tagging, ) +from axolotl.utils.logging import get_logger from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index f4685893b0..196cdb56a5 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -2,7 +2,6 @@ import importlib import inspect -import logging from typing import Any from trl.trainer.grpo_trainer import RewardFunc @@ -13,9 +12,10 @@ AxolotlGRPOTrainer, ) from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from axolotl.utils.schemas.trl import TRLConfig -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) class GRPOStrategy: diff --git a/src/axolotl/core/trainers/mixins/optimizer.py b/src/axolotl/core/trainers/mixins/optimizer.py index bde58aa1d5..abb662706a 100644 --- a/src/axolotl/core/trainers/mixins/optimizer.py +++ b/src/axolotl/core/trainers/mixins/optimizer.py @@ -1,18 +1,17 @@ """Module for Axolotl trainer optimizer mixin""" -import logging - from peft.optimizers import create_loraplus_optimizer from torch import nn from transformers.trainer import Trainer from transformers.utils import is_sagemaker_mp_enabled from axolotl.integrations.base import BaseOptimizerFactory +from axolotl.utils.logging import get_logger if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) class OptimizerMixin(Trainer): diff --git a/src/axolotl/core/trainers/mixins/rng_state_loader.py b/src/axolotl/core/trainers/mixins/rng_state_loader.py index 0e101dabb9..f248394b2e 100644 --- a/src/axolotl/core/trainers/mixins/rng_state_loader.py +++ b/src/axolotl/core/trainers/mixins/rng_state_loader.py @@ -6,7 +6,6 @@ TODO: Remove when upstream added PR to release """ -import logging import os import random @@ -17,7 +16,9 @@ from transformers.trainer_pt_utils import set_rng_state_for_device from transformers.training_args import ParallelMode -LOG = logging.getLogger(__name__) +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) class RngLoaderMixin(Trainer): diff --git a/src/axolotl/core/trainers/mixins/scheduler.py b/src/axolotl/core/trainers/mixins/scheduler.py index 0c36f9f95d..90070ab78a 100644 --- a/src/axolotl/core/trainers/mixins/scheduler.py +++ b/src/axolotl/core/trainers/mixins/scheduler.py @@ -1,12 +1,11 @@ """Module for Axolotl trainer scheduler mixin""" -import logging - import torch from torch.optim.lr_scheduler import LRScheduler, OneCycleLR from transformers.trainer import Trainer from axolotl.integrations.base import PluginManager +from axolotl.utils.logging import get_logger from axolotl.utils.schedulers import ( RexLR, get_cosine_schedule_with_min_lr, @@ -14,7 +13,7 @@ get_cosine_schedule_with_warmup_decay_constant, ) -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) class SchedulerMixin(Trainer): @@ -80,13 +79,15 @@ def create_scheduler( self.lr_scheduler = RexLR( optimizer=optimizer, max_lr=self.args.learning_rate, - min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio), + min_lr=0 if not use_cosine_min_lr else ( + self.args.learning_rate * self.args.cosine_min_lr_ratio), total_steps=num_training_steps, num_warmup_steps=self.args.get_warmup_steps(num_training_steps), ) elif use_cosine_quadratic: if use_cosine_min_lr: - LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") + LOG.warning( + "Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init optimizer, @@ -115,9 +116,11 @@ def create_scheduler( return super().create_scheduler(num_training_steps, optimizer=optimizer) else: if use_cosine_quadratic: - LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") + LOG.warning( + "axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") if use_cosine_min_lr: - LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") + LOG.warning( + "axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") return self.lr_scheduler # type: ignore diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 143928019b..9f1d9500d6 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -1,12 +1,13 @@ """Module containing Dataset functionality""" -import logging import os from typing import List, Optional, Union import torch from datasets import Dataset, IterableDataset +from axolotl.utils.logging import get_logger + from .prompt_tokenizers import PromptTokenizingStrategy # We want this to be a wrapper for an existing dataset that we have loaded @@ -15,7 +16,7 @@ # let's check to ensure we don't truncate an item in the middle, we'll use # the collators later on to pad the datasets -LOG = logging.getLogger("axolotl") +LOG = get_logger(__name__) class TokenizedPromptDataset(Dataset): diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index eb2b29cbe7..11d85f8f81 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -22,7 +22,6 @@ import collections import importlib -import logging from typing import TYPE_CHECKING, Callable, OrderedDict, Union from peft import PeftModel @@ -31,6 +30,9 @@ from transformers import PreTrainedModel, Trainer from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) if TYPE_CHECKING: from axolotl.common.datasets import TrainDatasetMeta @@ -331,12 +333,12 @@ def register(self, plugin_name: str): ImportError: If the plugin module cannot be imported. """ try: - logging.info(f"Attempting to load plugin: {plugin_name}") + LOG.info(f"Attempting to load plugin: {plugin_name}") plugin = load_plugin(plugin_name) self.plugins[plugin_name] = plugin - logging.info(f"Plugin loaded successfully: {plugin_name}") + LOG.info(f"Plugin loaded successfully: {plugin_name}") except ImportError: - logging.error(f"Failed to load plugin: {plugin_name}") + LOG.error(f"Failed to load plugin: {plugin_name}") def get_input_args(self) -> list[str]: """Returns a list of Pydantic classes for all registered plugins' input arguments.' diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index 7420674fa9..a7e94e3637 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -19,17 +19,16 @@ from Apple's ML team. """ import importlib -import logging import torch from axolotl.integrations.base import BasePlugin from axolotl.utils import get_pytorch_version -from axolotl.utils.distributed import is_main_process +from axolotl.utils.logging import get_logger from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401 -LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy") +LOG = get_logger(__name__, use_environ=True) _CCE_INSTALL_MESSAGE = ( "Please install cut_cross_entropy with transformers support using " @@ -76,10 +75,9 @@ def pre_model_load(self, cfg): cce_patch, ) - if is_main_process(use_environ=True): - LOG.info( - f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}" - ) + LOG.info( + f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}" + ) # The patch checks model_type internally cce_patch(cfg.model_config_type) diff --git a/src/axolotl/integrations/cut_cross_entropy/args.py b/src/axolotl/integrations/cut_cross_entropy/args.py index da1db73976..2729ebe2e3 100644 --- a/src/axolotl/integrations/cut_cross_entropy/args.py +++ b/src/axolotl/integrations/cut_cross_entropy/args.py @@ -15,12 +15,13 @@ """ Module for handling Cut Cross Entropy input arguments. """ -import logging from typing import Optional from pydantic import BaseModel, model_validator -LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy.args") +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) class CutCrossEntropyArgs(BaseModel): diff --git a/src/axolotl/integrations/grokfast/__init__.py b/src/axolotl/integrations/grokfast/__init__.py index c8c352bbe0..234d27226a 100644 --- a/src/axolotl/integrations/grokfast/__init__.py +++ b/src/axolotl/integrations/grokfast/__init__.py @@ -2,15 +2,15 @@ Grokfast plugin for Axolotl """ -import logging - from transformers.trainer_callback import TrainerCallback +from axolotl.utils.logging import get_logger + from ..base import BasePlugin from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401 from .optimizer import gradfilter_ema -LOG = logging.getLogger("axolotl.integrations.grokfast") +LOG = get_logger(__name__) class GrokfastCallbackHandler(TrainerCallback): diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index c7ac423729..1c17ab2b59 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -19,16 +19,15 @@ It is designed to be performant, correct, and light-weight. """ import inspect -import logging import sys from axolotl.integrations.base import BasePlugin -from axolotl.utils.distributed import is_main_process +from axolotl.utils.logging import get_logger from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 from .utils import patch_with_compile_disable -LOG = logging.getLogger("axolotl.integrations.liger") +LOG = get_logger(__name__, use_environ=True) class LigerPlugin(BasePlugin): @@ -85,10 +84,7 @@ def pre_model_load(self, cfg): kwargs["geglu"] = cfg.liger_glu_activation elif "swiglu" in liger_fn_sig.parameters: kwargs["swiglu"] = cfg.liger_glu_activation - if is_main_process(use_environ=True): - LOG.info( - f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}" - ) + LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}") apply_liger_fn(**kwargs) elif cfg.model_config_type == "jamba": from transformers.models.jamba import modeling_jamba @@ -124,9 +120,9 @@ def pre_model_load(self, cfg): if cfg.liger_rope: # The DeepseekV2 version of RoPE is different than upstream LLaMA. # See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528 - logging.warning("Fused liger_rope is not supported for DeepseekV2.") + LOG.warning("Fused liger_rope is not supported for DeepseekV2.") if cfg.liger_glu_activation: - logging.warning("liger_glu_activation is not supported for DeepseekV2.") + LOG.warning("liger_glu_activation is not supported for DeepseekV2.") if cfg.liger_rms_norm: modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm if cfg.liger_glu_activation: @@ -186,6 +182,6 @@ def pre_model_load(self, cfg): swiglu=cfg.liger_glu_activation, ) else: - logging.warning( + LOG.warning( f"Unsupported model config type: {cfg.model_config_type}. Liger not applied." ) diff --git a/src/axolotl/integrations/liger/args.py b/src/axolotl/integrations/liger/args.py index 02ece31432..7c9eb23d56 100644 --- a/src/axolotl/integrations/liger/args.py +++ b/src/axolotl/integrations/liger/args.py @@ -15,12 +15,13 @@ """ Module for handling LIGER input arguments. """ -import logging from typing import Optional from pydantic import BaseModel, model_validator -LOG = logging.getLogger("axolotl.integrations.liger.args") +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) class LigerArgs(BaseModel): diff --git a/src/axolotl/integrations/llm_compressor/plugin.py b/src/axolotl/integrations/llm_compressor/plugin.py index d986d51f47..57d506a573 100644 --- a/src/axolotl/integrations/llm_compressor/plugin.py +++ b/src/axolotl/integrations/llm_compressor/plugin.py @@ -3,7 +3,6 @@ by maintaining masks for zero weights during training. """ -import logging from functools import wraps from typing import Any, Callable, Concatenate, ParamSpec, TypeVar @@ -16,11 +15,12 @@ from transformers.training_args import TrainingArguments from axolotl.integrations.base import BasePlugin +from axolotl.utils.logging import get_logger P = ParamSpec("P") # Params for generic function signatures R = TypeVar("R") # Return type for generic function signatures -LOG = logging.getLogger("axolotl.integrations.llm_compressor") +LOG = get_logger(__name__) class LLMCompressorCallbackHandler(TrainerCallback): diff --git a/src/axolotl/integrations/spectrum/__init__.py b/src/axolotl/integrations/spectrum/__init__.py index 6059e7951c..9f66aef97f 100644 --- a/src/axolotl/integrations/spectrum/__init__.py +++ b/src/axolotl/integrations/spectrum/__init__.py @@ -17,14 +17,16 @@ """ import json -import logging import requests from axolotl.integrations.base import BasePlugin +from axolotl.utils.logging import get_logger from .args import SpectrumArgs # pylint: disable=unused-import. # noqa: F401 +LOG = get_logger(__name__) + def _generate_unfrozen_params_yaml(snr_data, top_fraction=0.5): unfrozen_parameters = {} @@ -83,17 +85,17 @@ def pre_model_load(self, cfg): except FileNotFoundError: pass except Exception as exc: # pylint: disable=broad-exception-caught - logging.warning(f"Failed to read SNR data from {snr_path}: {exc}") + LOG.warning(f"Failed to read SNR data from {snr_path}: {exc}") if not snr_data: try: snr_data = requests.get(snr_url, timeout=60).json() except requests.exceptions.RequestException as exc: - logging.warning(f"Failed to fetch SNR data from {snr_url}: {exc}") + LOG.warning(f"Failed to fetch SNR data from {snr_url}: {exc}") return # also catch json parsing errors except json.JSONDecodeError as exc: - logging.warning(f"Failed to parse SNR data from {snr_url}: {exc}") + LOG.warning(f"Failed to parse SNR data from {snr_url}: {exc}") return unfrozen_parameters = _generate_unfrozen_params_yaml( diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index f7a484e9bc..16d8daac8e 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -1,6 +1,5 @@ """Adapter loading functionality, including LoRA / QLoRA and associated utils""" -import logging import os import types from typing import Any @@ -21,8 +20,9 @@ from axolotl.loaders.utils import get_linear_embedding_layers from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) def setup_quantized_meta_for_peft(model: torch.nn.Module): diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 8d8f927a78..681e5d3352 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -3,7 +3,6 @@ """ import gc -import logging import math import os from functools import cached_property @@ -47,10 +46,11 @@ get_device_count, get_device_type, ) +from axolotl.utils.logging import get_logger from axolotl.utils.model_shard_quant import load_sharded_model_quant from axolotl.utils.schemas.enums import RLType -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) PLUGIN_MANAGER = PluginManager.get_instance() diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 36813bafd8..ce1f5cf709 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -4,7 +4,6 @@ """ import importlib.util -import logging from functools import cached_property import addict @@ -17,8 +16,9 @@ patch_for_multipack, ) from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) PLUGIN_MANAGER = PluginManager.get_instance() diff --git a/src/axolotl/loaders/processor.py b/src/axolotl/loaders/processor.py index 57394bc670..2e3ec8d7fe 100644 --- a/src/axolotl/loaders/processor.py +++ b/src/axolotl/loaders/processor.py @@ -1,6 +1,5 @@ """Processor loading functionality for multi-modal models""" -import logging from typing import Any import transformers @@ -10,8 +9,9 @@ ) from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py index ec9d69e8a1..c311d52472 100644 --- a/src/axolotl/loaders/tokenizer.py +++ b/src/axolotl/loaders/tokenizer.py @@ -1,7 +1,6 @@ """Tokenizer loading functionality and associated utils""" import json -import logging import os import transformers @@ -19,8 +18,9 @@ is_local_main_process, is_main_process, ) +from axolotl.utils.logging import get_logger -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) PLUGIN_MANAGER = PluginManager.get_instance() diff --git a/src/axolotl/loaders/utils.py b/src/axolotl/loaders/utils.py index 1aae4834d1..28c9350853 100644 --- a/src/axolotl/loaders/utils.py +++ b/src/axolotl/loaders/utils.py @@ -1,7 +1,6 @@ """Utilities for axolotl.loaders module""" import contextlib -import logging from typing import Type import addict @@ -9,8 +8,9 @@ from transformers import AutoConfig, PretrainedConfig, PreTrainedModel from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) def get_module_class_from_name( diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index ffde17aebd..6a7d48236c 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -2,12 +2,13 @@ monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts """ -import logging import sys import torch -LOG = logging.getLogger(__name__) +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict): diff --git a/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py b/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py index 1275906804..589980c8b9 100644 --- a/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py @@ -3,7 +3,6 @@ """ import importlib -import logging from typing import Optional, Tuple import torch @@ -11,7 +10,9 @@ from flash_attn.flash_attn_interface import flash_attn_func from transformers import AutoConfig, AutoModelForCausalLM -LOG = logging.getLogger("axolotl") +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"): diff --git a/src/axolotl/monkeypatch/gradient_checkpointing/offload_disk.py b/src/axolotl/monkeypatch/gradient_checkpointing/offload_disk.py index 90e70f504a..792d3c6efc 100644 --- a/src/axolotl/monkeypatch/gradient_checkpointing/offload_disk.py +++ b/src/axolotl/monkeypatch/gradient_checkpointing/offload_disk.py @@ -18,7 +18,6 @@ import atexit import concurrent.futures -import logging import os import queue import shutil @@ -32,11 +31,13 @@ import torch +from axolotl.utils.logging import get_logger + torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda") torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda") # Setup logger -logger = logging.getLogger(__name__) +logger = get_logger(__name__) class DiskOffloadManager: diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 998a810279..70e36714c8 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -2,7 +2,6 @@ # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py -import logging import warnings from typing import List, Optional, Tuple, Union @@ -25,6 +24,7 @@ ) from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name +from axolotl.utils.logging import get_logger try: from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports @@ -41,7 +41,7 @@ ) -LOG = logging.getLogger("axolotl") +LOG = get_logger(__name__) def is_xformers_available() -> bool: @@ -612,9 +612,10 @@ def generate_qkv( q, query_padding_mask ) - output_pad_fn = lambda output_unpad: pad_input( # noqa: E731 - output_unpad, indices_q, batch_size, seqlen_q - ) + def output_pad_fn(output_unpad): + return pad_input( # noqa: E731 + output_unpad, indices_q, batch_size, seqlen_q + ) else: q_unpad = rearrange(q, "b s h d -> (b s) h d") @@ -627,9 +628,10 @@ def generate_qkv( ) max_seqlen_q = seqlen_q - output_pad_fn = lambda output_unpad: rearrange( # noqa: E731 - output_unpad, "(b s) h d -> b s h d", b=batch_size - ) + def output_pad_fn(output_unpad): + return rearrange( # noqa: E731 + output_unpad, "(b s) h d -> b s h d", b=batch_size + ) if key_padding_mask is not None: k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index 0c1a4e8224..28223eee36 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -2,7 +2,6 @@ Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments """ -import logging import warnings from typing import Optional, Tuple @@ -11,10 +10,14 @@ import transformers.models.llama.modeling_llama from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + try: import xformers.ops except ImportError: - logging.error("xformers not found! Please install it before trying to use it.") + LOG.error("xformers not found! Please install it before trying to use it.") def hijack_llama_attention(): diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index 6c920dcc86..11e0989cf5 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -7,7 +7,6 @@ from typing import Generator, Tuple, Type import torch -from accelerate.logging import get_logger from peft import PeftModelForCausalLM from torch import nn from transformers import AutoConfig @@ -20,6 +19,7 @@ ) from axolotl.monkeypatch.utils import detab_code from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger LOG = get_logger(__name__) diff --git a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py index ac9815fce2..3fc22917fb 100644 --- a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py @@ -2,7 +2,6 @@ # pylint: disable=duplicate-code -import logging from functools import partial from typing import List, Optional, Tuple, Union @@ -28,8 +27,9 @@ ) from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids +from axolotl.utils.logging import get_logger -LOG = logging.getLogger("axolotl.monkeypatch.mistral") +LOG = get_logger(__name__) def replace_mistral_attn_with_flash_attn( @@ -359,9 +359,10 @@ def generate_qkv( q, query_padding_mask ) - output_pad_fn = lambda output_unpad: pad_input( # noqa: E731 - output_unpad, indices_q, batch_size, seqlen_q - ) + def output_pad_fn(output_unpad): + return pad_input( # noqa: E731 + output_unpad, indices_q, batch_size, seqlen_q + ) else: q_unpad = rearrange(q, "b s h d -> (b s) h d") @@ -374,9 +375,10 @@ def generate_qkv( ) max_seqlen_q = seqlen_q - output_pad_fn = lambda output_unpad: rearrange( # noqa: E731 - output_unpad, "(b s) h d -> b s h d", b=batch_size - ) + def output_pad_fn(output_unpad): + return rearrange( # noqa: E731 + output_unpad, "(b s) h d -> b s h d", b=batch_size + ) if key_padding_mask is not None: k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) diff --git a/src/axolotl/monkeypatch/peft/utils.py b/src/axolotl/monkeypatch/peft/utils.py index fdc49c5f6a..0c571fbd2f 100644 --- a/src/axolotl/monkeypatch/peft/utils.py +++ b/src/axolotl/monkeypatch/peft/utils.py @@ -3,14 +3,14 @@ """ import inspect -import logging import peft import axolotl from axolotl.monkeypatch.utils import detab_code +from axolotl.utils.logging import get_logger -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) ORIGINAL_PREPARE_CODE = """ for param in model.parameters(): diff --git a/src/axolotl/monkeypatch/relora.py b/src/axolotl/monkeypatch/relora.py index 4a27dde81f..5b7418e39d 100644 --- a/src/axolotl/monkeypatch/relora.py +++ b/src/axolotl/monkeypatch/relora.py @@ -2,7 +2,6 @@ import glob import json -import logging import os.path import shutil from functools import partial @@ -27,8 +26,9 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import barrier, is_main_process +from axolotl.utils.logging import get_logger -LOG = logging.getLogger("axolotl.relora") +LOG = get_logger(__name__) @torch.no_grad() diff --git a/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py b/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py index c603021116..85454fe2e3 100644 --- a/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py @@ -32,11 +32,11 @@ from torch import nn from transformers import AutoConfig, AutoModelForCausalLM from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.utils import logging from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids +from axolotl.utils.logging import get_logger -logger = logging.get_logger(__name__) +logger = get_logger(__name__) def replace_stablelm_attn_with_flash_attn(model_name="stabilityai/stablelm-3b-4e1t"): diff --git a/src/axolotl/monkeypatch/trainer/lr.py b/src/axolotl/monkeypatch/trainer/lr.py index 0176093d6a..9afc23c466 100644 --- a/src/axolotl/monkeypatch/trainer/lr.py +++ b/src/axolotl/monkeypatch/trainer/lr.py @@ -2,11 +2,11 @@ monkeypatch for Trainer _get_learning_rate method """ -import logging - import torch -LOG = logging.getLogger(__name__) +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) # TODO remove this patch once https://github.com/huggingface/transformers/pull/37881 is included in a release diff --git a/src/axolotl/monkeypatch/trainer_accelerator_args.py b/src/axolotl/monkeypatch/trainer_accelerator_args.py index d87812c9f8..0a5b27c13e 100644 --- a/src/axolotl/monkeypatch/trainer_accelerator_args.py +++ b/src/axolotl/monkeypatch/trainer_accelerator_args.py @@ -3,13 +3,13 @@ """ import inspect -import logging from transformers import Trainer from axolotl.monkeypatch.utils import detab_code +from axolotl.utils.logging import get_logger -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) ORIGINAL_TRAINER_CODE = """ # create accelerator object diff --git a/src/axolotl/monkeypatch/trainer_eval_guard.py b/src/axolotl/monkeypatch/trainer_eval_guard.py index e929ac766a..8488a16df9 100644 --- a/src/axolotl/monkeypatch/trainer_eval_guard.py +++ b/src/axolotl/monkeypatch/trainer_eval_guard.py @@ -3,13 +3,13 @@ """ import inspect -import logging from transformers import Trainer from axolotl.monkeypatch.utils import detab_code +from axolotl.utils.logging import get_logger -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) ORIGINAL_TRAINER_CODE = """ model.eval() diff --git a/src/axolotl/monkeypatch/trainer_fsdp_optim.py b/src/axolotl/monkeypatch/trainer_fsdp_optim.py index 1cbfefa5b1..4ce5b8ecd3 100644 --- a/src/axolotl/monkeypatch/trainer_fsdp_optim.py +++ b/src/axolotl/monkeypatch/trainer_fsdp_optim.py @@ -3,13 +3,13 @@ """ import inspect -import logging from transformers import Trainer from axolotl.monkeypatch.utils import detab_code +from axolotl.utils.logging import get_logger -LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save") +LOG = get_logger(__name__) ORIGINAL_TRAINER_CODE = """ diff --git a/src/axolotl/monkeypatch/transformers_fa_utils.py b/src/axolotl/monkeypatch/transformers_fa_utils.py index f34ecb8c07..e372dc3f85 100644 --- a/src/axolotl/monkeypatch/transformers_fa_utils.py +++ b/src/axolotl/monkeypatch/transformers_fa_utils.py @@ -2,13 +2,14 @@ see https://github.com/huggingface/transformers/pull/35834 """ -import logging from functools import partial from typing import Optional import torch -logger = logging.getLogger(__name__) +from axolotl.utils.logging import get_logger + +logger = get_logger(__name__) def fixed_fa_peft_integration_check( diff --git a/src/axolotl/monkeypatch/unsloth_.py b/src/axolotl/monkeypatch/unsloth_.py index c81bacbfcb..61f4eeea03 100644 --- a/src/axolotl/monkeypatch/unsloth_.py +++ b/src/axolotl/monkeypatch/unsloth_.py @@ -11,7 +11,7 @@ from axolotl.monkeypatch.utils import detab_code -LOG = get_logger("axolotl.monkeypatch.unsloth") +LOG = get_logger(__name__) ORIGINAL_QKV_CODE = """ query_states = self.q_proj(hidden_states) @@ -133,7 +133,7 @@ def patch_self_attn_lora(): ) exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102 self_attn_lora_patched = True - LOG.info("patching unsloth attn lora", main_process_only=True) + LOG.info("patching unsloth attn lora") LlamaFlashAttention2.forward = ( unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821 ) @@ -153,7 +153,7 @@ def apply_rotary_pos_emb( # pylint: disable=unused-argument ): return fast_rope_embedding(q, k, cos, sin) - LOG.info("patching unsloth RoPE embeddings", main_process_only=True) + LOG.info("patching unsloth RoPE embeddings") transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb @@ -189,7 +189,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM): if is_mlp_lora and mlp_no_bias and mlp_not_dora: layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp) else: - LOG.warning("unable to apply unsloth lora mlp patch to layer %d", idx) + LOG.warning(f"unable to apply unsloth lora mlp patch to layer {idx}") def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg): @@ -215,7 +215,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg): layer.self_attn.apply_qkv = apply_lora_qkv else: layer.self_attn.apply_qkv = original_apply_qkv - LOG.warning("unable to apply unsloth lora qkv patch to layer %d", idx) + LOG.warning(f"unable to apply unsloth lora qkv patch to layer {idx}") if cfg.unsloth_lora_o: layer_modules = [ getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"] @@ -234,9 +234,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg): layer.self_attn.apply_o = apply_lora_o else: layer.self_attn.apply_o = original_apply_o - LOG.warning( - "unable to apply unsloth lora o_proj patch to layer %d", idx - ) + LOG.warning(f"unable to apply unsloth lora o_proj patch to layer {idx}") def patch_unsloth_layernorm(): diff --git a/src/axolotl/processing_strategies.py b/src/axolotl/processing_strategies.py index 1cb6ed0648..ce9b6a838d 100644 --- a/src/axolotl/processing_strategies.py +++ b/src/axolotl/processing_strategies.py @@ -1,6 +1,5 @@ """Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types""" -import logging from copy import deepcopy from typing import Optional @@ -10,7 +9,9 @@ from transformers import ProcessorMixin from transformers.image_utils import load_image -LOG = logging.getLogger(__name__) +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) class ProcessingStrategy: diff --git a/src/axolotl/prompt_strategies/__init__.py b/src/axolotl/prompt_strategies/__init__.py index ba0dad0533..3cdbbb6f33 100644 --- a/src/axolotl/prompt_strategies/__init__.py +++ b/src/axolotl/prompt_strategies/__init__.py @@ -2,11 +2,11 @@ import importlib import inspect -import logging from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig +from axolotl.utils.logging import get_logger -LOG = logging.getLogger("axolotl.prompt_strategies") +LOG = get_logger(__name__) def load(strategy, tokenizer, cfg, ds_cfg, processor=None): diff --git a/src/axolotl/prompt_strategies/base.py b/src/axolotl/prompt_strategies/base.py index c146133fbd..370a51a95a 100644 --- a/src/axolotl/prompt_strategies/base.py +++ b/src/axolotl/prompt_strategies/base.py @@ -3,9 +3,10 @@ """ import importlib -import logging -LOG = logging.getLogger("axolotl") +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) def load(strategy, cfg, module_base=None, **kwargs): diff --git a/src/axolotl/prompt_strategies/bradley_terry/__init__.py b/src/axolotl/prompt_strategies/bradley_terry/__init__.py index 4457c50be5..7530aee192 100644 --- a/src/axolotl/prompt_strategies/bradley_terry/__init__.py +++ b/src/axolotl/prompt_strategies/bradley_terry/__init__.py @@ -2,11 +2,11 @@ import importlib import inspect -import logging from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig +from axolotl.utils.logging import get_logger -LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry") +LOG = get_logger(__name__) def load(strategy, tokenizer, cfg, ds_cfg): diff --git a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py index 67319f5b41..e655f85a1f 100644 --- a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py +++ b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py @@ -2,7 +2,6 @@ Bradley-Terry model with chat template prompt strategy. """ -import logging from typing import Any, Dict, Optional from axolotl.prompt_strategies.chat_template import ( @@ -10,10 +9,11 @@ ChatTemplateStrategy, ) from axolotl.utils.chat_templates import get_chat_template_from_config +from axolotl.utils.logging import get_logger # Configure the logger -LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry.chat_template") -LOG.setLevel(logging.INFO) +LOG = get_logger(__name__) +LOG.setLevel("INFO") class BTChatTemplateStrategy(ChatTemplateStrategy): @@ -44,7 +44,7 @@ def _tokenize_single_prompt(self, prompt): if len(chosen_tokenized["input_ids"]) > max_length: LOG.warning( - f"To-be-trimmed chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}", + f"To-be-trimmed chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}" ) chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length] @@ -62,7 +62,7 @@ def _tokenize_single_prompt(self, prompt): if len(rejected_tokenized["input_ids"]) > max_length: LOG.warning( - f"To-be-trimmed rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}", + f"To-be-trimmed rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}" ) rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][ diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 047a66e947..ebb151876c 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -2,7 +2,6 @@ HF Chat Templates prompt strategy """ -import logging from collections import defaultdict from typing import Any, Dict, List, Set, Union @@ -13,11 +12,12 @@ from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompters import IGNORE_TOKEN_ID, Prompter from axolotl.utils.chat_templates import get_chat_template_from_config +from axolotl.utils.logging import get_logger from axolotl.utils.schemas.datasets import DatasetConfig # Configure the logger -LOG = logging.getLogger("axolotl") -LOG.setLevel(logging.INFO) +LOG = get_logger(__name__) +LOG.setLevel("INFO") class ChatTemplatePrompter(Prompter): @@ -378,7 +378,9 @@ def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]: add_generation_prompt=True, images=images, ) - tokenized_res = self.prompter.build_prompt(turns, images=images) # type: ignore + tokenized_res = self.prompter.build_prompt( + turns, images=images + ) # type: ignore tokenized_prompt = {} if isinstance(tokenized_res, list): input_ids = prompt_ids + tokenized_res[len(prompt_ids) :] @@ -555,8 +557,8 @@ def find_turn(self, turns: list[dict], turn_idx: int): and turns[0].get("role") == "system" and ( "mistral" in self.tokenizer.name_or_path.lower() - # gemma3 uses gemma tokenizer - or "gemma" in self.tokenizer.name_or_path.lower() + or "gemma" + in self.tokenizer.name_or_path.lower() # gemma3 uses gemma tokenizer ) ): return -1, -1 diff --git a/src/axolotl/prompt_strategies/llama2_chat.py b/src/axolotl/prompt_strategies/llama2_chat.py index 29e091bfd0..eef2e1d4d3 100644 --- a/src/axolotl/prompt_strategies/llama2_chat.py +++ b/src/axolotl/prompt_strategies/llama2_chat.py @@ -24,12 +24,14 @@ Important: Don't use "special_tokens:" in your config.yml if you are not sure what you are doing! """ -import logging from dataclasses import dataclass, field from typing import Generator, List, Sequence from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompters import ALTERNATING_ASSERTION_FAILED_ROLE, IGNORE_TOKEN_ID +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) @dataclass @@ -129,7 +131,7 @@ def tokenize_prompt(self, prompt): if cur_len < self.sequence_len: if cur_len != total_len: target[:] = IGNORE_TOKEN_ID - logging.warning( + LOG.warning( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) diff --git a/src/axolotl/prompt_strategies/messages/__init__.py b/src/axolotl/prompt_strategies/messages/__init__.py index d014d93a6b..cc7b84da18 100644 --- a/src/axolotl/prompt_strategies/messages/__init__.py +++ b/src/axolotl/prompt_strategies/messages/__init__.py @@ -2,9 +2,10 @@ import importlib import inspect -import logging -LOG = logging.getLogger("axolotl.prompt_strategies.messages") +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) def load(tokenizer, cfg, ds_cfg, processor=None): diff --git a/src/axolotl/prompt_strategies/metharme.py b/src/axolotl/prompt_strategies/metharme.py index 52d77c00cf..66da723893 100644 --- a/src/axolotl/prompt_strategies/metharme.py +++ b/src/axolotl/prompt_strategies/metharme.py @@ -1,12 +1,12 @@ """Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class""" -import logging from typing import Tuple from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy from axolotl.prompters import AlpacaPrompter +from axolotl.utils.logging import get_logger -LOG = logging.getLogger("axolotl") +LOG = get_logger(__name__) IGNORE_TOKEN_ID = -100 diff --git a/src/axolotl/prompt_strategies/pygmalion.py b/src/axolotl/prompt_strategies/pygmalion.py index 88208f6ec4..51f92f3970 100644 --- a/src/axolotl/prompt_strategies/pygmalion.py +++ b/src/axolotl/prompt_strategies/pygmalion.py @@ -1,7 +1,6 @@ """Module containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class""" import copy -import logging from collections import defaultdict from typing import Generator, List, Tuple @@ -10,8 +9,9 @@ parse_tokenized_to_result, tokenize_prompt_default, ) +from axolotl.utils.logging import get_logger -LOG = logging.getLogger("axolotl") +LOG = get_logger(__name__) IGNORE_TOKEN_ID = -100 diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index c29fd05a4c..cb1a1ba4ed 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -1,14 +1,14 @@ """Module containing PromptTokenizingStrategy and Prompter classes""" import abc -import logging from typing import Callable, Dict, List, Optional, Tuple, Union from transformers import BatchEncoding, PreTrainedTokenizer from axolotl.prompters import Prompter +from axolotl.utils.logging import get_logger -LOG = logging.getLogger("axolotl") +LOG = get_logger(__name__) IGNORE_INDEX = -100 LLAMA_DEFAULT_PAD_TOKEN = "" # nosec diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index ec680702dc..d29da075e0 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -1,12 +1,13 @@ """Module containing prompters""" -import logging from enum import Enum from typing import Generator, Optional, Union from colorama import Fore -LOG = logging.getLogger("axolotl") +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) IGNORE_TOKEN_ID = -100 REPR_TEMPLATE = "\n\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n\n" diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 8a4c0040d9..68ba3a124e 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -2,7 +2,6 @@ import importlib import inspect -import logging import os import signal import sys @@ -37,6 +36,7 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.freeze import freeze_layers_except +from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import RLType from axolotl.utils.trainer import setup_trainer @@ -45,7 +45,7 @@ except ImportError: BetterTransformer = None -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) def setup_model_and_tokenizer( @@ -64,9 +64,7 @@ def setup_model_and_tokenizer( `None`), and processor (if multimodal, else `None`). """ # Load tokenizer - LOG.debug( - f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", - ) + LOG.debug(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") tokenizer = load_tokenizer(cfg) # Load processor for multimodal models if needed diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 0e7b060939..d94f4be74d 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -4,7 +4,6 @@ import gc import json -import logging import os import traceback from shutil import copyfile @@ -43,6 +42,7 @@ is_main_process, zero_first, ) +from axolotl.utils.logging import get_logger from axolotl.utils.schemas.config import AxolotlInputConfig if TYPE_CHECKING: @@ -50,7 +50,7 @@ IGNORE_INDEX = -100 -LOG = logging.getLogger("axolotl.callbacks") +LOG = get_logger(__name__) class EvalFirstStepCallback( @@ -753,7 +753,14 @@ def log_table_from_dataloader(name: str, table_dataloader): ].append(pred_step_text) row_index += 1 if logger == "wandb": - wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined] + # type: ignore[attr-defined] + wandb.run.log( + { + f"{name} - Predictions vs Ground Truth": pd.DataFrame( + table_data + ) + } + ) elif logger == "mlflow" and is_mlflow_available(): import mlflow diff --git a/src/axolotl/utils/callbacks/comet_.py b/src/axolotl/utils/callbacks/comet_.py index b29f997a86..b7e9034b0e 100644 --- a/src/axolotl/utils/callbacks/comet_.py +++ b/src/axolotl/utils/callbacks/comet_.py @@ -1,17 +1,17 @@ """Comet module for trainer callbacks""" -import logging from typing import TYPE_CHECKING import comet_ml from transformers import TrainerCallback, TrainerControl, TrainerState from axolotl.utils.distributed import is_main_process +from axolotl.utils.logging import get_logger if TYPE_CHECKING: from axolotl.core.trainer_builder import AxolotlTrainingArguments -LOG = logging.getLogger("axolotl.callbacks") +LOG = get_logger(__name__) class SaveAxolotlConfigtoCometCallback(TrainerCallback): diff --git a/src/axolotl/utils/callbacks/lisa.py b/src/axolotl/utils/callbacks/lisa.py index e226471b1a..ad7e23144a 100644 --- a/src/axolotl/utils/callbacks/lisa.py +++ b/src/axolotl/utils/callbacks/lisa.py @@ -6,17 +6,18 @@ License: Apache 2.0 """ -import logging from functools import reduce from typing import TYPE_CHECKING import numpy as np from transformers import TrainerCallback +from axolotl.utils.logging import get_logger + if TYPE_CHECKING: from axolotl.core.trainer_builder import AxolotlTrainer -LOG = logging.getLogger("axolotl.callbacks.lisa") +LOG = get_logger(__name__) def lisa_callback_factory(trainer: "AxolotlTrainer"): diff --git a/src/axolotl/utils/callbacks/mlflow_.py b/src/axolotl/utils/callbacks/mlflow_.py index 15ca1ca475..15f8ef0697 100644 --- a/src/axolotl/utils/callbacks/mlflow_.py +++ b/src/axolotl/utils/callbacks/mlflow_.py @@ -1,6 +1,5 @@ """MLFlow module for trainer callbacks""" -import logging import os from shutil import copyfile from tempfile import NamedTemporaryFile @@ -10,11 +9,12 @@ from transformers import TrainerCallback, TrainerControl, TrainerState from axolotl.utils.distributed import is_main_process +from axolotl.utils.logging import get_logger if TYPE_CHECKING: from axolotl.core.trainer_builder import AxolotlTrainingArguments -LOG = logging.getLogger("axolotl.callbacks") +LOG = get_logger(__name__) def should_log_artifacts() -> bool: diff --git a/src/axolotl/utils/callbacks/qat.py b/src/axolotl/utils/callbacks/qat.py index da4f2612be..cf4d9a9373 100644 --- a/src/axolotl/utils/callbacks/qat.py +++ b/src/axolotl/utils/callbacks/qat.py @@ -1,6 +1,5 @@ """QAT Callback for HF Causal Trainer""" -import logging from functools import partial from torch import nn @@ -8,9 +7,10 @@ from torchao.quantization.qat.linear import FakeQuantizedLinear from transformers import TrainerCallback +from axolotl.utils.logging import get_logger from axolotl.utils.schemas.quantization import QATConfig -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) def toggle_fake_quant(mod: nn.Module, enable: bool): diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 72ebffbcdb..bf496d2c51 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -3,13 +3,14 @@ These templates are used for formatting messages in a conversation. """ -import logging from typing import TYPE_CHECKING, Any, Dict, Optional +from axolotl.utils.logging import get_logger + if TYPE_CHECKING: from transformers import PreTrainedTokenizerBase -LOG = logging.getLogger("axolotl.utils.chat_templates") +LOG = get_logger("axolotl.utils.chat_templates") _JINJA_TEMPALTE_CHOICE = "jinja" _DEFAULT_TEMPLATE_CHOICE = "tokenizer_default" @@ -40,9 +41,9 @@ "metharme": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'Enter RP mode. You shall reply to the user while staying in character. Your responses must be detailed, creative, immersive, and drive the scenario forward.' %}{% endif %}{{ '<|system|>' + system_message }}{% for message in loop_messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|user|>' + content.strip() }}{% elif message['role'] == 'assistant' %}{{ '<|model|>' + content.strip() }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|model|>' }}{% else %}{{ eos_token }}{% endif %}", "pixtral": '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message[\'role\'] == \'user\') != (loop.index0 % 2 == 0) %}\n {{- raise_exception(\'After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\') }}\n {%- endif %}\n {%- if message["role"] == "user" %}\n {%- if loop.last and system_message is defined %}\n {{- "[INST]" + system_message + "\n\n" }}\n {%- else %}\n {{- "[INST]" }}\n {%- endif %}\n {%- if message["content"] is not string %}\n {%- for chunk in message["content"] %}\n {%- if chunk["type"] == "text" %}\n {{- chunk["text"] }}\n {%- elif chunk["type"] == "image" %}\n {{- "[IMG]" }}\n {%- else %}\n {{- raise_exception("Unrecognized content type!") }}\n {%- endif %}\n {%- endfor %}\n {%- else %}\n {{- message["content"] }}\n {%- endif %}\n {{- "[/INST]" }}\n {%- elif message["role"] == "assistant" %}\n {%- if message["content"] is not string %}\n {%- for chunk in message["content"] %}\n {%- if chunk["type"] == "text" %}\n {{- chunk["text"] }}\n {%- elif chunk["type"] == "image" %}\n {{- "[IMG]" }}\n {%- else %}\n {{- raise_exception("Unrecognized content type!") }}\n{%- endif %}\n{%- endfor %}\n{{- eos_token }}\n{%- else %}\n{{- message["content"] + eos_token }}\n{%- endif %}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}', "qwen2_vl": "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}", - "command_a": "{{ bos_token }}{% if documents %}\n{% set tools = [] %}\n{%- macro document_turn(documents) -%}\n{# format documents into chat turn #}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[\n {\"tool_call_id\": \"0\", \"tool_name\": \"direct-injected-document\", \"parameters\": {}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n {\n \"tool_call_id\": \"0\",\n \"results\": {\n{% for doc in documents %}\n \"{{ loop.index0 }}\": {{doc|tojson}}{% if not loop.last %},\n {% endif %}\n{% endfor %}\n\n },\n \"is_error\": null\n }\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %}\n{%- macro tool_call_id_to_int(messages, tool_call_id) %}\n{%- set counter = namespace(value=0) %}\n{%- set tool_call_id_seen = namespace(value=false) %}\n{%- for msg in messages %}\n {%- if msg.tool_calls %}\n {%- for tool_call in msg.tool_calls %}\n {%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%}\n {{ counter.value }}\n {%- set tool_call_id_seen.value = true %}\n {%- endif %}\n {%- set counter.value = counter.value + 1 %}\n {%- endfor %}\n {%- endif %}\n{%- endfor %}\n{%- endmacro %}\n{%- macro format_tool_message(messages, tool_msg) -%}\n{# format tool message #}\n {\n \"tool_call_id\": \"{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}\",\n \"results\": {\n \"0\": {{ tool_msg.content|tojson }}\n },\n \"is_error\": null\n }\n{%- endmacro -%}\n{%- if messages and messages[0]['role']|lower == 'system' %}{%- set developer_preamble = messages[0]['content'] %}{% endif %}\n{%- set tool_idx = namespace(value=0) %}\n{%- set tool_ids_seen = namespace(value=[]) %}\n{%- set sent_documents = namespace(value=false) %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble\nYou are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.\n\nYour information cutoff date is June 2024.\n\nYou have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.\n{% if tools or documents %}\n\nYou have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user's requests.\n\n## Tool Use\nThink about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first.\n\n0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>.\n You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed.\n NOTE: You MUST skip this step when you are directly responding to the user's request without using any tools.\n\nThen carry out your plan by repeatedly executing the following steps.\n1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing \"tool_name\" and \"parameters\" fields.\n When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>.\n2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results.\n Every tool call produces a list of results (when a tool call produces no result or a single result, it'll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its \"tool_call_id\".\n3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you've figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>.\n You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded.\n NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user.\n\nYou can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it's time to finally respond to the user.\n\n4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user's last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>.\n{% if enable_citations %}\n\n## Grounding\nImportantly, note that \"Reflection\" and \"Response\" above can be grounded.\nGrounding means you associate pieces of texts (called \"spans\") with those specific tool results that support them (called \"sources\"). And you use a pair of tags \"\" and \"\" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as \"{tool_call_id}:[{list of result indices}]\", before they are joined together by \",\". E.g., \"span\" means that \"span\" is supported by result 1 and 2 from \"tool_call_id=0\" as well as result 0 from \"tool_call_id=1\".\n{% endif %}\n\n## Available Tools\nHere is the list of tools that you have available to you.\nYou can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it.\nEach tool is represented as a JSON object with fields like \"name\", \"description\", \"parameters\" (per JSON Schema), and optionally, \"responses\" (per JSON Schema).\n\n```json\n[\n{% if documents %}\n {\"name\": \"direct-injected-document\", \"description\": \"This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!\", \"parameters\": {\"type\": \"object\", \"properties\": {}, \"required\": []}, \"responses\": {\"200\": {\"description\": \"Successfully returned a list of chunked text snippets from the directly uploaded documents.\", \"content\": {\"application/json\": {\"schema\": {\"type\": \"array\", \"items\": {\"type\": \"object\", \"required\": [\"url\", \"snippet\"], \"properties\": {\"url\": {\"type\": \"string\", \"description\": \"The url of the uploaded document.\"}, \"snippet\": {\"type\": \"string\", \"description\": \"The text snippet for the returned document chunk.\"}}}}}}}}}{%- if tools %},{% endif %}\n\n{% endif %}\n{% for tool in tools %}\n {\"name\": \"{{ tool['function']['name'] }}\", \"description\": \"{{tool['function']['description']}}\", \"parameters\": {{ tool['function']['parameters']|tojson }}, \"responses\": null}{%- if not loop.last %},{% endif %}\n\n{% endfor %}\n]\n```\n\n{% endif %}\n# Default Preamble\nThe following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.\n- Your name is Command.\n- You are a large language model built by Cohere.\n- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.\n- If the input is ambiguous, ask clarifying follow-up questions.\n- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).\n- Use LaTeX to generate mathematical notation for complex equations.\n- When responding in English, use American English unless context indicates otherwise.\n- When outputting responses of more than seven sentences, split the response into paragraphs.\n- Prefer the active voice.\n- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.\n- Use gender-neutral pronouns for unspecified persons.\n- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.\n- Use the third person when asked to write a summary.\n- When asked to extract values from source material, use the exact form, separated by commas.\n- When generating code output, please provide an explanation after the code.\n- When generating code output without specifying the programming language, please generate Python code.\n- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.\n{%- if developer_preamble %}\n\n\n# Developer Preamble\nThe following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.\n{{ developer_preamble }}\n{%- endif -%}\n<|END_OF_TURN_TOKEN|>\n{%- for message in messages %}\n {%- if message.role|lower == 'system' and not (loop.first and developer_preamble)%}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>\n {%- elif message.role|lower == 'user' %}\n<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %}\n {%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[\n {% for tc in message.tool_calls %}\n {\"tool_call_id\": \"{{ tool_idx.value }}\", \"tool_name\": \"{{ tc['function']['name'] }}\", \"parameters\": {{ tc['function']['arguments']|tojson }}}{% if not loop.last %},{% endif %}\n\n {% set tool_idx.value = tool_idx.value + 1 %}\n {% endfor %}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %}\n {% elif message.role|lower == 'tool' and message.tool_call_id not in tool_ids_seen.value %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n{{ format_tool_message(messages, message) }}\n {%- set stopped = namespace(value=false) %}\n {%- for msg in messages[loop.index0 + 1:] %}\n {%- if not stopped.value and msg.role|lower == 'tool' %},\n{{ format_tool_message(messages, msg) }}\n {%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %}\n {%- else %}\n {%- set stopped.value = true %}\n {%- endif %}\n {%- endfor %}\n\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>\n {%- endif %}\n{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>\n{%- else -%}\n{%- if messages and messages[0]['role']|lower == 'system' %}{%- set developer_preamble = messages[0]['content'] %}{% endif %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble\n{% if safety_mode|upper == 'STRICT' -%}\nYou are in strict safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will reject requests to generate content related to violence, hate, misinformation or sex to any amount. You will avoid using profanity. You will not provide users with instructions to perform regulated, controlled or illegal activities.\n{%- else -%}\nYou are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.\n{%- endif %}\n\n\nYour information cutoff date is June 2024.\n\nYou have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.\n\n# Default Preamble\nThe following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.\n- Your name is Command.\n- You are a large language model built by Cohere.\n- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.\n- If the input is ambiguous, ask clarifying follow-up questions.\n- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).\n- Use LaTeX to generate mathematical notation for complex equations.\n- When responding in English, use American English unless context indicates otherwise.\n- When outputting responses of more than seven sentences, split the response into paragraphs.\n- Prefer the active voice.\n- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.\n- Use gender-neutral pronouns for unspecified persons.\n- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.\n- Use the third person when asked to write a summary.\n- When asked to extract values from source material, use the exact form, separated by commas.\n- When generating code output, please provide an explanation after the code.\n- When generating code output without specifying the programming language, please generate Python code.\n- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.\n{%- if developer_preamble %}\n\n\n# Developer Preamble\nThe following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.\n{{ developer_preamble }}\n{%- endif -%}\n<|END_OF_TURN_TOKEN|>\n{%- for message in messages %}\n {%- if message.role|lower == 'system' and not (loop.first and developer_preamble)%}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>\n {%- elif message.role|lower == 'user' %}\n<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>\n {%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>\n {%- endif %}\n{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{%- if add_generation_prompt -%}<|START_RESPONSE|>{%- endif %}\n{% endif %}", - "command_a_tool_use": "{{ bos_token }}{%- macro document_turn(documents) -%}\n{# format documents into chat turn #}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[\n {\"tool_call_id\": \"0\", \"tool_name\": \"direct-injected-document\", \"parameters\": {}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n {\n \"tool_call_id\": \"0\",\n \"results\": {\n{% for doc in documents %}\n \"{{ loop.index0 }}\": {{doc|tojson}}{% if not loop.last %},\n {% endif %}\n{% endfor %}\n\n },\n \"is_error\": null\n }\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %}\n{%- macro tool_call_id_to_int(messages, tool_call_id) %}\n{%- set counter = namespace(value=0) %}\n{%- set tool_call_id_seen = namespace(value=false) %}\n{%- for msg in messages %}\n {%- if msg.tool_calls %}\n {%- for tool_call in msg.tool_calls %}\n {%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%}\n {{ counter.value }}\n {%- set tool_call_id_seen.value = true %}\n {%- endif %}\n {%- set counter.value = counter.value + 1 %}\n {%- endfor %}\n {%- endif %}\n{%- endfor %}\n{%- endmacro %}\n{%- macro format_tool_message(messages, tool_msg) -%}\n{# format tool message #}\n {\n \"tool_call_id\": \"{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}\",\n \"results\": {\n \"0\": {{ tool_msg.content|tojson }}\n },\n \"is_error\": null\n }\n{%- endmacro -%}\n{%- if messages and messages[0]['role']|lower == 'system' %}{%- set developer_preamble = messages[0]['content'] %}{% endif %}\n{%- set tool_idx = namespace(value=0) %}\n{%- set tool_ids_seen = namespace(value=[]) %}\n{%- set sent_documents = namespace(value=false) %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble\nYou are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.\n\nYour information cutoff date is June 2024.\n\nYou have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.\n{% if tools or documents %}\n\nYou have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user's requests.\n\n## Tool Use\nThink about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first.\n\n0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>.\n You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed.\n NOTE: You MUST skip this step when you are directly responding to the user's request without using any tools.\n\nThen carry out your plan by repeatedly executing the following steps.\n1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing \"tool_name\" and \"parameters\" fields.\n When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>.\n2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results.\n Every tool call produces a list of results (when a tool call produces no result or a single result, it'll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its \"tool_call_id\".\n3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you've figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>.\n You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded.\n NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user.\n\nYou can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it's time to finally respond to the user.\n\n4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user's last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>.\n{% if enable_citations %}\n\n## Grounding\nImportantly, note that \"Reflection\" and \"Response\" above can be grounded.\nGrounding means you associate pieces of texts (called \"spans\") with those specific tool results that support them (called \"sources\"). And you use a pair of tags \"\" and \"\" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as \"{tool_call_id}:[{list of result indices}]\", before they are joined together by \",\". E.g., \"span\" means that \"span\" is supported by result 1 and 2 from \"tool_call_id=0\" as well as result 0 from \"tool_call_id=1\".\n{% endif %}\n\n## Available Tools\nHere is the list of tools that you have available to you.\nYou can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it.\nEach tool is represented as a JSON object with fields like \"name\", \"description\", \"parameters\" (per JSON Schema), and optionally, \"responses\" (per JSON Schema).\n\n```json\n[\n{% if documents %}\n {\"name\": \"direct-injected-document\", \"description\": \"This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!\", \"parameters\": {\"type\": \"object\", \"properties\": {}, \"required\": []}, \"responses\": {\"200\": {\"description\": \"Successfully returned a list of chunked text snippets from the directly uploaded documents.\", \"content\": {\"application/json\": {\"schema\": {\"type\": \"array\", \"items\": {\"type\": \"object\", \"required\": [\"url\", \"snippet\"], \"properties\": {\"url\": {\"type\": \"string\", \"description\": \"The url of the uploaded document.\"}, \"snippet\": {\"type\": \"string\", \"description\": \"The text snippet for the returned document chunk.\"}}}}}}}}}{%- if tools %},{% endif %}\n\n{% endif %}\n{% for tool in tools %}\n {\"name\": \"{{ tool['function']['name'] }}\", \"description\": \"{{tool['function']['description']}}\", \"parameters\": {{ tool['function']['parameters']|tojson }}, \"responses\": null}{%- if not loop.last %},{% endif %}\n\n{% endfor %}\n]\n```\n\n{% endif %}\n# Default Preamble\nThe following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.\n- Your name is Command.\n- You are a large language model built by Cohere.\n- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.\n- If the input is ambiguous, ask clarifying follow-up questions.\n- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).\n- Use LaTeX to generate mathematical notation for complex equations.\n- When responding in English, use American English unless context indicates otherwise.\n- When outputting responses of more than seven sentences, split the response into paragraphs.\n- Prefer the active voice.\n- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.\n- Use gender-neutral pronouns for unspecified persons.\n- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.\n- Use the third person when asked to write a summary.\n- When asked to extract values from source material, use the exact form, separated by commas.\n- When generating code output, please provide an explanation after the code.\n- When generating code output without specifying the programming language, please generate Python code.\n- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.\n{%- if developer_preamble %}\n\n\n# Developer Preamble\nThe following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.\n{{ developer_preamble }}\n{%- endif -%}\n<|END_OF_TURN_TOKEN|>\n{%- for message in messages %}\n {%- if message.role|lower == 'system' and not (loop.first and developer_preamble)%}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>\n {%- elif message.role|lower == 'user' %}\n<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %}\n {%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[\n {% for tc in message.tool_calls %}\n {\"tool_call_id\": \"{{ tool_idx.value }}\", \"tool_name\": \"{{ tc['function']['name'] }}\", \"parameters\": {{ tc['function']['arguments']|tojson }}}{% if not loop.last %},{% endif %}\n\n {% set tool_idx.value = tool_idx.value + 1 %}\n {% endfor %}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %}\n {% elif message.role|lower == 'tool' and message.tool_call_id not in tool_ids_seen.value %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n{{ format_tool_message(messages, message) }}\n {%- set stopped = namespace(value=false) %}\n {%- for msg in messages[loop.index0 + 1:] %}\n {%- if not stopped.value and msg.role|lower == 'tool' %},\n{{ format_tool_message(messages, msg) }}\n {%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %}\n {%- else %}\n {%- set stopped.value = true %}\n {%- endif %}\n {%- endfor %}\n\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>\n {%- endif %}\n{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", - "command_a_rag": "{{ bos_token }}{% set tools = [] %}\n{%- macro document_turn(documents) -%}\n{# format documents into chat turn #}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[\n {\"tool_call_id\": \"0\", \"tool_name\": \"direct-injected-document\", \"parameters\": {}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n {\n \"tool_call_id\": \"0\",\n \"results\": {\n{% for doc in documents %}\n \"{{ loop.index0 }}\": {{doc|tojson}}{% if not loop.last %},\n {% endif %}\n{% endfor %}\n\n },\n \"is_error\": null\n }\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %}\n{%- macro tool_call_id_to_int(messages, tool_call_id) %}\n{%- set counter = namespace(value=0) %}\n{%- set tool_call_id_seen = namespace(value=false) %}\n{%- for msg in messages %}\n {%- if msg.tool_calls %}\n {%- for tool_call in msg.tool_calls %}\n {%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%}\n {{ counter.value }}\n {%- set tool_call_id_seen.value = true %}\n {%- endif %}\n {%- set counter.value = counter.value + 1 %}\n {%- endfor %}\n {%- endif %}\n{%- endfor %}\n{%- endmacro %}\n{%- macro format_tool_message(messages, tool_msg) -%}\n{# format tool message #}\n {\n \"tool_call_id\": \"{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}\",\n \"results\": {\n \"0\": {{ tool_msg.content|tojson }}\n },\n \"is_error\": null\n }\n{%- endmacro -%}\n{%- if messages and messages[0]['role']|lower == 'system' %}{%- set developer_preamble = messages[0]['content'] %}{% endif %}\n{%- set tool_idx = namespace(value=0) %}\n{%- set tool_ids_seen = namespace(value=[]) %}\n{%- set sent_documents = namespace(value=false) %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble\nYou are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.\n\nYour information cutoff date is June 2024.\n\nYou have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.\n{% if tools or documents %}\n\nYou have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user's requests.\n\n## Tool Use\nThink about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first.\n\n0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>.\n You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed.\n NOTE: You MUST skip this step when you are directly responding to the user's request without using any tools.\n\nThen carry out your plan by repeatedly executing the following steps.\n1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing \"tool_name\" and \"parameters\" fields.\n When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>.\n2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results.\n Every tool call produces a list of results (when a tool call produces no result or a single result, it'll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its \"tool_call_id\".\n3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you've figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>.\n You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded.\n NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user.\n\nYou can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it's time to finally respond to the user.\n\n4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user's last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>.\n{% if enable_citations %}\n\n## Grounding\nImportantly, note that \"Reflection\" and \"Response\" above can be grounded.\nGrounding means you associate pieces of texts (called \"spans\") with those specific tool results that support them (called \"sources\"). And you use a pair of tags \"\" and \"\" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as \"{tool_call_id}:[{list of result indices}]\", before they are joined together by \",\". E.g., \"span\" means that \"span\" is supported by result 1 and 2 from \"tool_call_id=0\" as well as result 0 from \"tool_call_id=1\".\n{% endif %}\n\n## Available Tools\nHere is the list of tools that you have available to you.\nYou can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it.\nEach tool is represented as a JSON object with fields like \"name\", \"description\", \"parameters\" (per JSON Schema), and optionally, \"responses\" (per JSON Schema).\n\n```json\n[\n{% if documents %}\n {\"name\": \"direct-injected-document\", \"description\": \"This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!\", \"parameters\": {\"type\": \"object\", \"properties\": {}, \"required\": []}, \"responses\": {\"200\": {\"description\": \"Successfully returned a list of chunked text snippets from the directly uploaded documents.\", \"content\": {\"application/json\": {\"schema\": {\"type\": \"array\", \"items\": {\"type\": \"object\", \"required\": [\"url\", \"snippet\"], \"properties\": {\"url\": {\"type\": \"string\", \"description\": \"The url of the uploaded document.\"}, \"snippet\": {\"type\": \"string\", \"description\": \"The text snippet for the returned document chunk.\"}}}}}}}}}{%- if tools %},{% endif %}\n\n{% endif %}\n{% for tool in tools %}\n {\"name\": \"{{ tool['function']['name'] }}\", \"description\": \"{{tool['function']['description']}}\", \"parameters\": {{ tool['function']['parameters']|tojson }}, \"responses\": null}{%- if not loop.last %},{% endif %}\n\n{% endfor %}\n]\n```\n\n{% endif %}\n# Default Preamble\nThe following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.\n- Your name is Command.\n- You are a large language model built by Cohere.\n- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.\n- If the input is ambiguous, ask clarifying follow-up questions.\n- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).\n- Use LaTeX to generate mathematical notation for complex equations.\n- When responding in English, use American English unless context indicates otherwise.\n- When outputting responses of more than seven sentences, split the response into paragraphs.\n- Prefer the active voice.\n- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.\n- Use gender-neutral pronouns for unspecified persons.\n- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.\n- Use the third person when asked to write a summary.\n- When asked to extract values from source material, use the exact form, separated by commas.\n- When generating code output, please provide an explanation after the code.\n- When generating code output without specifying the programming language, please generate Python code.\n- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.\n{%- if developer_preamble %}\n\n\n# Developer Preamble\nThe following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.\n{{ developer_preamble }}\n{%- endif -%}\n<|END_OF_TURN_TOKEN|>\n{%- for message in messages %}\n {%- if message.role|lower == 'system' and not (loop.first and developer_preamble)%}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>\n {%- elif message.role|lower == 'user' %}\n<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %}\n {%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[\n {% for tc in message.tool_calls %}\n {\"tool_call_id\": \"{{ tool_idx.value }}\", \"tool_name\": \"{{ tc['function']['name'] }}\", \"parameters\": {{ tc['function']['arguments']|tojson }}}{% if not loop.last %},{% endif %}\n\n {% set tool_idx.value = tool_idx.value + 1 %}\n {% endfor %}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %}\n {% elif message.role|lower == 'tool' and message.tool_call_id not in tool_ids_seen.value %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n{{ format_tool_message(messages, message) }}\n {%- set stopped = namespace(value=false) %}\n {%- for msg in messages[loop.index0 + 1:] %}\n {%- if not stopped.value and msg.role|lower == 'tool' %},\n{{ format_tool_message(messages, msg) }}\n {%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %}\n {%- else %}\n {%- set stopped.value = true %}\n {%- endif %}\n {%- endfor %}\n\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>\n {%- endif %}\n{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", + "command_a": '{{ bos_token }}{% if documents %}\n{% set tools = [] %}\n{%- macro document_turn(documents) -%}\n{# format documents into chat turn #}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "direct-injected-document", "parameters": {}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n {\n "tool_call_id": "0",\n "results": {\n{% for doc in documents %}\n "{{ loop.index0 }}": {{doc|tojson}}{% if not loop.last %},\n {% endif %}\n{% endfor %}\n\n },\n "is_error": null\n }\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %}\n{%- macro tool_call_id_to_int(messages, tool_call_id) %}\n{%- set counter = namespace(value=0) %}\n{%- set tool_call_id_seen = namespace(value=false) %}\n{%- for msg in messages %}\n {%- if msg.tool_calls %}\n {%- for tool_call in msg.tool_calls %}\n {%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%}\n {{ counter.value }}\n {%- set tool_call_id_seen.value = true %}\n {%- endif %}\n {%- set counter.value = counter.value + 1 %}\n {%- endfor %}\n {%- endif %}\n{%- endfor %}\n{%- endmacro %}\n{%- macro format_tool_message(messages, tool_msg) -%}\n{# format tool message #}\n {\n "tool_call_id": "{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}",\n "results": {\n "0": {{ tool_msg.content|tojson }}\n },\n "is_error": null\n }\n{%- endmacro -%}\n{%- if messages and messages[0][\'role\']|lower == \'system\' %}{%- set developer_preamble = messages[0][\'content\'] %}{% endif %}\n{%- set tool_idx = namespace(value=0) %}\n{%- set tool_ids_seen = namespace(value=[]) %}\n{%- set sent_documents = namespace(value=false) %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble\nYou are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.\n\nYour information cutoff date is June 2024.\n\nYou have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.\n{% if tools or documents %}\n\nYou have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user\'s requests.\n\n## Tool Use\nThink about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first.\n\n0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>.\n You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed.\n NOTE: You MUST skip this step when you are directly responding to the user\'s request without using any tools.\n\nThen carry out your plan by repeatedly executing the following steps.\n1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing "tool_name" and "parameters" fields.\n When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>.\n2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results.\n Every tool call produces a list of results (when a tool call produces no result or a single result, it\'ll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its "tool_call_id".\n3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you\'ve figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>.\n You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded.\n NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user.\n\nYou can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it\'s time to finally respond to the user.\n\n4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user\'s last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>.\n{% if enable_citations %}\n\n## Grounding\nImportantly, note that "Reflection" and "Response" above can be grounded.\nGrounding means you associate pieces of texts (called "spans") with those specific tool results that support them (called "sources"). And you use a pair of tags "" and "" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as "{tool_call_id}:[{list of result indices}]", before they are joined together by ",". E.g., "span" means that "span" is supported by result 1 and 2 from "tool_call_id=0" as well as result 0 from "tool_call_id=1".\n{% endif %}\n\n## Available Tools\nHere is the list of tools that you have available to you.\nYou can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it.\nEach tool is represented as a JSON object with fields like "name", "description", "parameters" (per JSON Schema), and optionally, "responses" (per JSON Schema).\n\n```json\n[\n{% if documents %}\n {"name": "direct-injected-document", "description": "This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!", "parameters": {"type": "object", "properties": {}, "required": []}, "responses": {"200": {"description": "Successfully returned a list of chunked text snippets from the directly uploaded documents.", "content": {"application/json": {"schema": {"type": "array", "items": {"type": "object", "required": ["url", "snippet"], "properties": {"url": {"type": "string", "description": "The url of the uploaded document."}, "snippet": {"type": "string", "description": "The text snippet for the returned document chunk."}}}}}}}}}{%- if tools %},{% endif %}\n\n{% endif %}\n{% for tool in tools %}\n {"name": "{{ tool[\'function\'][\'name\'] }}", "description": "{{tool[\'function\'][\'description\']}}", "parameters": {{ tool[\'function\'][\'parameters\']|tojson }}, "responses": null}{%- if not loop.last %},{% endif %}\n\n{% endfor %}\n]\n```\n\n{% endif %}\n# Default Preamble\nThe following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.\n- Your name is Command.\n- You are a large language model built by Cohere.\n- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.\n- If the input is ambiguous, ask clarifying follow-up questions.\n- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).\n- Use LaTeX to generate mathematical notation for complex equations.\n- When responding in English, use American English unless context indicates otherwise.\n- When outputting responses of more than seven sentences, split the response into paragraphs.\n- Prefer the active voice.\n- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.\n- Use gender-neutral pronouns for unspecified persons.\n- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.\n- Use the third person when asked to write a summary.\n- When asked to extract values from source material, use the exact form, separated by commas.\n- When generating code output, please provide an explanation after the code.\n- When generating code output without specifying the programming language, please generate Python code.\n- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.\n{%- if developer_preamble %}\n\n\n# Developer Preamble\nThe following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.\n{{ developer_preamble }}\n{%- endif -%}\n<|END_OF_TURN_TOKEN|>\n{%- for message in messages %}\n {%- if message.role|lower == \'system\' and not (loop.first and developer_preamble)%}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>\n {%- elif message.role|lower == \'user\' %}\n<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %}\n {%- elif message.role|lower == \'assistant\' or message.role|lower == \'chatbot\' %}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[\n {% for tc in message.tool_calls %}\n {"tool_call_id": "{{ tool_idx.value }}", "tool_name": "{{ tc[\'function\'][\'name\'] }}", "parameters": {{ tc[\'function\'][\'arguments\']|tojson }}}{% if not loop.last %},{% endif %}\n\n {% set tool_idx.value = tool_idx.value + 1 %}\n {% endfor %}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %}\n {% elif message.role|lower == \'tool\' and message.tool_call_id not in tool_ids_seen.value %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n{{ format_tool_message(messages, message) }}\n {%- set stopped = namespace(value=false) %}\n {%- for msg in messages[loop.index0 + 1:] %}\n {%- if not stopped.value and msg.role|lower == \'tool\' %},\n{{ format_tool_message(messages, msg) }}\n {%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %}\n {%- else %}\n {%- set stopped.value = true %}\n {%- endif %}\n {%- endfor %}\n\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>\n {%- endif %}\n{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>\n{%- else -%}\n{%- if messages and messages[0][\'role\']|lower == \'system\' %}{%- set developer_preamble = messages[0][\'content\'] %}{% endif %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble\n{% if safety_mode|upper == \'STRICT\' -%}\nYou are in strict safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will reject requests to generate content related to violence, hate, misinformation or sex to any amount. You will avoid using profanity. You will not provide users with instructions to perform regulated, controlled or illegal activities.\n{%- else -%}\nYou are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.\n{%- endif %}\n\n\nYour information cutoff date is June 2024.\n\nYou have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.\n\n# Default Preamble\nThe following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.\n- Your name is Command.\n- You are a large language model built by Cohere.\n- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.\n- If the input is ambiguous, ask clarifying follow-up questions.\n- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).\n- Use LaTeX to generate mathematical notation for complex equations.\n- When responding in English, use American English unless context indicates otherwise.\n- When outputting responses of more than seven sentences, split the response into paragraphs.\n- Prefer the active voice.\n- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.\n- Use gender-neutral pronouns for unspecified persons.\n- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.\n- Use the third person when asked to write a summary.\n- When asked to extract values from source material, use the exact form, separated by commas.\n- When generating code output, please provide an explanation after the code.\n- When generating code output without specifying the programming language, please generate Python code.\n- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.\n{%- if developer_preamble %}\n\n\n# Developer Preamble\nThe following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.\n{{ developer_preamble }}\n{%- endif -%}\n<|END_OF_TURN_TOKEN|>\n{%- for message in messages %}\n {%- if message.role|lower == \'system\' and not (loop.first and developer_preamble)%}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>\n {%- elif message.role|lower == \'user\' %}\n<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>\n {%- elif message.role|lower == \'assistant\' or message.role|lower == \'chatbot\' %}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>\n {%- endif %}\n{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{%- if add_generation_prompt -%}<|START_RESPONSE|>{%- endif %}\n{% endif %}', + "command_a_tool_use": '{{ bos_token }}{%- macro document_turn(documents) -%}\n{# format documents into chat turn #}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "direct-injected-document", "parameters": {}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n {\n "tool_call_id": "0",\n "results": {\n{% for doc in documents %}\n "{{ loop.index0 }}": {{doc|tojson}}{% if not loop.last %},\n {% endif %}\n{% endfor %}\n\n },\n "is_error": null\n }\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %}\n{%- macro tool_call_id_to_int(messages, tool_call_id) %}\n{%- set counter = namespace(value=0) %}\n{%- set tool_call_id_seen = namespace(value=false) %}\n{%- for msg in messages %}\n {%- if msg.tool_calls %}\n {%- for tool_call in msg.tool_calls %}\n {%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%}\n {{ counter.value }}\n {%- set tool_call_id_seen.value = true %}\n {%- endif %}\n {%- set counter.value = counter.value + 1 %}\n {%- endfor %}\n {%- endif %}\n{%- endfor %}\n{%- endmacro %}\n{%- macro format_tool_message(messages, tool_msg) -%}\n{# format tool message #}\n {\n "tool_call_id": "{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}",\n "results": {\n "0": {{ tool_msg.content|tojson }}\n },\n "is_error": null\n }\n{%- endmacro -%}\n{%- if messages and messages[0][\'role\']|lower == \'system\' %}{%- set developer_preamble = messages[0][\'content\'] %}{% endif %}\n{%- set tool_idx = namespace(value=0) %}\n{%- set tool_ids_seen = namespace(value=[]) %}\n{%- set sent_documents = namespace(value=false) %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble\nYou are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.\n\nYour information cutoff date is June 2024.\n\nYou have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.\n{% if tools or documents %}\n\nYou have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user\'s requests.\n\n## Tool Use\nThink about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first.\n\n0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>.\n You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed.\n NOTE: You MUST skip this step when you are directly responding to the user\'s request without using any tools.\n\nThen carry out your plan by repeatedly executing the following steps.\n1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing "tool_name" and "parameters" fields.\n When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>.\n2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results.\n Every tool call produces a list of results (when a tool call produces no result or a single result, it\'ll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its "tool_call_id".\n3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you\'ve figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>.\n You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded.\n NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user.\n\nYou can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it\'s time to finally respond to the user.\n\n4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user\'s last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>.\n{% if enable_citations %}\n\n## Grounding\nImportantly, note that "Reflection" and "Response" above can be grounded.\nGrounding means you associate pieces of texts (called "spans") with those specific tool results that support them (called "sources"). And you use a pair of tags "" and "" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as "{tool_call_id}:[{list of result indices}]", before they are joined together by ",". E.g., "span" means that "span" is supported by result 1 and 2 from "tool_call_id=0" as well as result 0 from "tool_call_id=1".\n{% endif %}\n\n## Available Tools\nHere is the list of tools that you have available to you.\nYou can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it.\nEach tool is represented as a JSON object with fields like "name", "description", "parameters" (per JSON Schema), and optionally, "responses" (per JSON Schema).\n\n```json\n[\n{% if documents %}\n {"name": "direct-injected-document", "description": "This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!", "parameters": {"type": "object", "properties": {}, "required": []}, "responses": {"200": {"description": "Successfully returned a list of chunked text snippets from the directly uploaded documents.", "content": {"application/json": {"schema": {"type": "array", "items": {"type": "object", "required": ["url", "snippet"], "properties": {"url": {"type": "string", "description": "The url of the uploaded document."}, "snippet": {"type": "string", "description": "The text snippet for the returned document chunk."}}}}}}}}}{%- if tools %},{% endif %}\n\n{% endif %}\n{% for tool in tools %}\n {"name": "{{ tool[\'function\'][\'name\'] }}", "description": "{{tool[\'function\'][\'description\']}}", "parameters": {{ tool[\'function\'][\'parameters\']|tojson }}, "responses": null}{%- if not loop.last %},{% endif %}\n\n{% endfor %}\n]\n```\n\n{% endif %}\n# Default Preamble\nThe following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.\n- Your name is Command.\n- You are a large language model built by Cohere.\n- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.\n- If the input is ambiguous, ask clarifying follow-up questions.\n- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).\n- Use LaTeX to generate mathematical notation for complex equations.\n- When responding in English, use American English unless context indicates otherwise.\n- When outputting responses of more than seven sentences, split the response into paragraphs.\n- Prefer the active voice.\n- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.\n- Use gender-neutral pronouns for unspecified persons.\n- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.\n- Use the third person when asked to write a summary.\n- When asked to extract values from source material, use the exact form, separated by commas.\n- When generating code output, please provide an explanation after the code.\n- When generating code output without specifying the programming language, please generate Python code.\n- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.\n{%- if developer_preamble %}\n\n\n# Developer Preamble\nThe following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.\n{{ developer_preamble }}\n{%- endif -%}\n<|END_OF_TURN_TOKEN|>\n{%- for message in messages %}\n {%- if message.role|lower == \'system\' and not (loop.first and developer_preamble)%}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>\n {%- elif message.role|lower == \'user\' %}\n<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %}\n {%- elif message.role|lower == \'assistant\' or message.role|lower == \'chatbot\' %}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[\n {% for tc in message.tool_calls %}\n {"tool_call_id": "{{ tool_idx.value }}", "tool_name": "{{ tc[\'function\'][\'name\'] }}", "parameters": {{ tc[\'function\'][\'arguments\']|tojson }}}{% if not loop.last %},{% endif %}\n\n {% set tool_idx.value = tool_idx.value + 1 %}\n {% endfor %}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %}\n {% elif message.role|lower == \'tool\' and message.tool_call_id not in tool_ids_seen.value %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n{{ format_tool_message(messages, message) }}\n {%- set stopped = namespace(value=false) %}\n {%- for msg in messages[loop.index0 + 1:] %}\n {%- if not stopped.value and msg.role|lower == \'tool\' %},\n{{ format_tool_message(messages, msg) }}\n {%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %}\n {%- else %}\n {%- set stopped.value = true %}\n {%- endif %}\n {%- endfor %}\n\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>\n {%- endif %}\n{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>', + "command_a_rag": '{{ bos_token }}{% set tools = [] %}\n{%- macro document_turn(documents) -%}\n{# format documents into chat turn #}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|>I will look through the document to address the users needs.<|END_THINKING|><|START_ACTION|>[\n {"tool_call_id": "0", "tool_name": "direct-injected-document", "parameters": {}}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n {\n "tool_call_id": "0",\n "results": {\n{% for doc in documents %}\n "{{ loop.index0 }}": {{doc|tojson}}{% if not loop.last %},\n {% endif %}\n{% endfor %}\n\n },\n "is_error": null\n }\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>{%- endmacro %}\n{%- macro tool_call_id_to_int(messages, tool_call_id) %}\n{%- set counter = namespace(value=0) %}\n{%- set tool_call_id_seen = namespace(value=false) %}\n{%- for msg in messages %}\n {%- if msg.tool_calls %}\n {%- for tool_call in msg.tool_calls %}\n {%- if tool_call.id == tool_call_id and not tool_call_id_seen.value -%}\n {{ counter.value }}\n {%- set tool_call_id_seen.value = true %}\n {%- endif %}\n {%- set counter.value = counter.value + 1 %}\n {%- endfor %}\n {%- endif %}\n{%- endfor %}\n{%- endmacro %}\n{%- macro format_tool_message(messages, tool_msg) -%}\n{# format tool message #}\n {\n "tool_call_id": "{{ tool_call_id_to_int(messages, tool_msg.tool_call_id) }}",\n "results": {\n "0": {{ tool_msg.content|tojson }}\n },\n "is_error": null\n }\n{%- endmacro -%}\n{%- if messages and messages[0][\'role\']|lower == \'system\' %}{%- set developer_preamble = messages[0][\'content\'] %}{% endif %}\n{%- set tool_idx = namespace(value=0) %}\n{%- set tool_ids_seen = namespace(value=[]) %}\n{%- set sent_documents = namespace(value=false) %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble\nYou are in contextual safety mode. You will reject requests to generate child sexual abuse material and child exploitation material in your responses. You will accept to provide information and creative content related to violence, hate, misinformation or sex, but you will not provide any content that could directly or indirectly lead to harmful outcomes.\n\nYour information cutoff date is June 2024.\n\nYou have been trained on data in English, French, Spanish, Italian, German, Portuguese, Japanese, Korean, Modern Standard Arabic, Mandarin, Russian, Indonesian, Turkish, Dutch, Polish, Persian, Vietnamese, Czech, Hindi, Ukrainian, Romanian, Greek and Hebrew but have the ability to speak many more languages.\n{% if tools or documents %}\n\nYou have been trained to have advanced reasoning and tool-use capabilities and you should make best use of these skills to serve user\'s requests.\n\n## Tool Use\nThink about how you can make best use of the provided tools to help with the task and come up with a high level plan that you will execute first.\n\n0. Start by writing <|START_THINKING|> followed by a detailed step by step plan of how you will solve the problem. For each step explain your thinking fully and give details of required tool calls (if needed). Unless specified otherwise, you write your plan in natural language. When you finish, close it out with <|END_THINKING|>.\n You can optionally choose to skip this step when the user request is so straightforward to address that only a trivial plan would be needed.\n NOTE: You MUST skip this step when you are directly responding to the user\'s request without using any tools.\n\nThen carry out your plan by repeatedly executing the following steps.\n1. Action: write <|START_ACTION|> followed by a list of JSON-formatted tool calls, with each one containing "tool_name" and "parameters" fields.\n When there are multiple tool calls which are completely independent of each other (i.e. they can be executed in parallel), you should list them out all together in one step. When you finish, close it out with <|END_ACTION|>.\n2. Observation: you will then receive results of those tool calls in JSON format in the very next turn, wrapped around by <|START_TOOL_RESULT|> and <|END_TOOL_RESULT|>. Carefully observe those results and think about what to do next. Note that these results will be provided to you in a separate turn. NEVER hallucinate results.\n Every tool call produces a list of results (when a tool call produces no result or a single result, it\'ll still get wrapped inside a list). Each result is clearly linked to its originating tool call via its "tool_call_id".\n3. Reflection: start the next turn by writing <|START_THINKING|> followed by what you\'ve figured out so far, any changes you need to make to your plan, and what you will do next. When you finish, close it out with <|END_THINKING|>.\n You can optionally choose to skip this step when everything is going according to plan and no special pieces of information or reasoning chains need to be recorded.\n NOTE: You MUST skip this step when you are done with tool-use actions and are ready to respond to the user.\n\nYou can repeat the above 3 steps multiple times (could be 0 times too if no suitable tool calls are available or needed), until you decide it\'s time to finally respond to the user.\n\n4. Response: then break out of the loop and write <|START_RESPONSE|> followed by a piece of text which serves as a response to the user\'s last request. Use all previous tool calls and results to help you when formulating your response. When you finish, close it out with <|END_RESPONSE|>.\n{% if enable_citations %}\n\n## Grounding\nImportantly, note that "Reflection" and "Response" above can be grounded.\nGrounding means you associate pieces of texts (called "spans") with those specific tool results that support them (called "sources"). And you use a pair of tags "" and "" to indicate when a span can be grounded onto a list of sources, listing them out in the closing tag. Sources from the same tool call are grouped together and listed as "{tool_call_id}:[{list of result indices}]", before they are joined together by ",". E.g., "span" means that "span" is supported by result 1 and 2 from "tool_call_id=0" as well as result 0 from "tool_call_id=1".\n{% endif %}\n\n## Available Tools\nHere is the list of tools that you have available to you.\nYou can ONLY use the tools listed here. When a tool is not listed below, it is NOT available and you should NEVER attempt to use it.\nEach tool is represented as a JSON object with fields like "name", "description", "parameters" (per JSON Schema), and optionally, "responses" (per JSON Schema).\n\n```json\n[\n{% if documents %}\n {"name": "direct-injected-document", "description": "This is a special tool to directly inject user-uploaded documents into the chat as additional context. DO NOT use this tool by yourself!", "parameters": {"type": "object", "properties": {}, "required": []}, "responses": {"200": {"description": "Successfully returned a list of chunked text snippets from the directly uploaded documents.", "content": {"application/json": {"schema": {"type": "array", "items": {"type": "object", "required": ["url", "snippet"], "properties": {"url": {"type": "string", "description": "The url of the uploaded document."}, "snippet": {"type": "string", "description": "The text snippet for the returned document chunk."}}}}}}}}}{%- if tools %},{% endif %}\n\n{% endif %}\n{% for tool in tools %}\n {"name": "{{ tool[\'function\'][\'name\'] }}", "description": "{{tool[\'function\'][\'description\']}}", "parameters": {{ tool[\'function\'][\'parameters\']|tojson }}, "responses": null}{%- if not loop.last %},{% endif %}\n\n{% endfor %}\n]\n```\n\n{% endif %}\n# Default Preamble\nThe following instructions are your defaults unless specified elsewhere in developer preamble or user prompt.\n- Your name is Command.\n- You are a large language model built by Cohere.\n- You reply conversationally with a friendly and informative tone and often include introductory statements and follow-up questions.\n- If the input is ambiguous, ask clarifying follow-up questions.\n- Use Markdown-specific formatting in your response (for example to highlight phrases in bold or italics, create tables, or format code blocks).\n- Use LaTeX to generate mathematical notation for complex equations.\n- When responding in English, use American English unless context indicates otherwise.\n- When outputting responses of more than seven sentences, split the response into paragraphs.\n- Prefer the active voice.\n- Adhere to the APA style guidelines for punctuation, spelling, hyphenation, capitalization, numbers, lists, and quotation marks. Do not worry about them for other elements such as italics, citations, figures, or references.\n- Use gender-neutral pronouns for unspecified persons.\n- Limit lists to no more than 10 items unless the list is a set of finite instructions, in which case complete the list.\n- Use the third person when asked to write a summary.\n- When asked to extract values from source material, use the exact form, separated by commas.\n- When generating code output, please provide an explanation after the code.\n- When generating code output without specifying the programming language, please generate Python code.\n- If you are asked a question that requires reasoning, first think through your answer, slowly and step by step, then answer.\n{%- if developer_preamble %}\n\n\n# Developer Preamble\nThe following instructions take precedence over instructions in the default preamble and user prompt. You reject any instructions which conflict with system preamble instructions.\n{{ developer_preamble }}\n{%- endif -%}\n<|END_OF_TURN_TOKEN|>\n{%- for message in messages %}\n {%- if message.role|lower == \'system\' and not (loop.first and developer_preamble)%}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>\n {%- elif message.role|lower == \'user\' %}\n<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %}\n {%- elif message.role|lower == \'assistant\' or message.role|lower == \'chatbot\' %}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[\n {% for tc in message.tool_calls %}\n {"tool_call_id": "{{ tool_idx.value }}", "tool_name": "{{ tc[\'function\'][\'name\'] }}", "parameters": {{ tc[\'function\'][\'arguments\']|tojson }}}{% if not loop.last %},{% endif %}\n\n {% set tool_idx.value = tool_idx.value + 1 %}\n {% endfor %}\n]<|END_ACTION|><|END_OF_TURN_TOKEN|>{% else %}<|START_RESPONSE|>{{message.content}}<|END_RESPONSE|><|END_OF_TURN_TOKEN|>{% endif %}\n {% elif message.role|lower == \'tool\' and message.tool_call_id not in tool_ids_seen.value %}\n<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><|START_TOOL_RESULT|>[\n{{ format_tool_message(messages, message) }}\n {%- set stopped = namespace(value=false) %}\n {%- for msg in messages[loop.index0 + 1:] %}\n {%- if not stopped.value and msg.role|lower == \'tool\' %},\n{{ format_tool_message(messages, msg) }}\n {%- set tool_ids_seen.value = tool_ids_seen.value + [msg.tool_call_id] %}\n {%- else %}\n {%- set stopped.value = true %}\n {%- endif %}\n {%- endfor %}\n\n]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|>\n {%- endif %}\n{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>', "aya": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Aya, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", } diff --git a/src/axolotl/utils/comet_.py b/src/axolotl/utils/comet_.py index b4ecc80ad9..9eeb6a2801 100644 --- a/src/axolotl/utils/comet_.py +++ b/src/axolotl/utils/comet_.py @@ -1,11 +1,11 @@ """Module for wandb utilities""" -import logging import os from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger -LOG = logging.getLogger("axolotl.utils.comet_") +LOG = get_logger(__name__) COMET_ENV_MAPPING_OVERRIDE = { "comet_mode": "COMET_START_MODE", diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 49e4cfc6fb..e0eaf9ac92 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -1,7 +1,6 @@ """Module for working with config dicts""" import json -import logging import os from typing import Optional @@ -15,13 +14,14 @@ from axolotl.loaders.utils import load_model_config from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from axolotl.utils.schemas.config import ( AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, ) from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset -LOG = logging.getLogger("axolotl") +LOG = get_logger(__name__, use_environ=True) def choose_device(cfg): diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py index f20ced221a..44d8d6fed0 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/pretraining.py @@ -1,7 +1,6 @@ """data handling specific to pretraining""" import functools -import logging from collections import defaultdict from typing import Callable, Dict, List, Optional @@ -11,10 +10,11 @@ from transformers import PreTrainedTokenizerBase from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq +from axolotl.utils.logging import get_logger from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.trainer import process_pretraining_datasets_for_packing -LOG = logging.getLogger("axolotl") +LOG = get_logger(__name__) def encode_pretraining( diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index 15744d4c60..eeea6f2072 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -1,7 +1,6 @@ """data handling specific to DPO""" import inspect -import logging from functools import partial from pathlib import Path from typing import Any, List, Union @@ -18,9 +17,10 @@ from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5 from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first +from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import RLType -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) def _get_path(ds_hash, cfg): @@ -217,7 +217,7 @@ def load_split(dataset_cfgs, _cfg): + "|" + "train" + "|" - + str(seed) + + str(cfg.seed or 42) ) to_hash_test = ( train_dataset._fingerprint # pylint: disable=protected-access @@ -226,7 +226,7 @@ def load_split(dataset_cfgs, _cfg): + "|" + "test" + "|" - + str(seed) + + str(cfg.seed or 42) ) train_fingerprint = md5(to_hash_train) test_fingerprint = md5(to_hash_test) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 6de2d2cf74..88c78174bc 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -1,7 +1,6 @@ """data handling specific to SFT""" import functools -import logging import os import tempfile from pathlib import Path @@ -54,12 +53,13 @@ ) from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_local_main_process, zero_first +from axolotl.utils.logging import get_logger from axolotl.utils.trainer import ( calculate_total_num_steps, process_datasets_for_packing, ) -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) @retry_on_request_exceptions(max_retries=3, delay=5) @@ -182,10 +182,9 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None): total_num_steps = min( calculate_total_num_steps(cfg, train_dataset), cfg.max_steps ) - LOG.info(f"Maximum number of steps set at {total_num_steps}") else: total_num_steps = calculate_total_num_steps(cfg, train_dataset) - + LOG.info(f"Maximum number of steps set at {total_num_steps}") return train_dataset, eval_dataset, total_num_steps, prompters @@ -331,12 +330,12 @@ def load_tokenized_prepared_datasets( if len(datasets) == 1: dataset = datasets[0] else: - LOG.info("merging datasets") + LOG.info("Merging datasets...") dataset = concatenate_datasets(datasets) if len(datasets) > 1: if cfg.shuffle_merged_datasets: - LOG.debug("shuffle merged datasets") + LOG.debug("Shuffling merged datasets...") dataset = dataset.shuffle(seed=seed) else: LOG.debug("NOT shuffling merged datasets") @@ -426,7 +425,7 @@ def load_prepare_datasets( + "|" + "train" + "|" - + str(seed) + + str(cfg.seed or 42) ) to_hash_test = ( dataset._fingerprint # pylint: disable=protected-access @@ -435,7 +434,7 @@ def load_prepare_datasets( + "|" + "test" + "|" - + str(seed) + + str(cfg.seed or 42) ) train_fingerprint = md5(to_hash_train) test_fingerprint = md5(to_hash_test) diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index a8e19582e7..5f3b8d3cc6 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -2,7 +2,6 @@ import functools import hashlib -import logging import time from enum import Enum @@ -12,10 +11,11 @@ from datasets import Dataset, IterableDataset from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from axolotl.utils.samplers.utils import get_dataset_lengths from axolotl.utils.trainer import drop_long_seq -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) class RetryStrategy(Enum): diff --git a/src/axolotl/utils/logging.py b/src/axolotl/utils/logging.py new file mode 100644 index 0000000000..80daab4eaa --- /dev/null +++ b/src/axolotl/utils/logging.py @@ -0,0 +1,62 @@ +""" +logging helpers to only log on main process +""" + +import functools +import logging +import os + +from axolotl.utils.distributed import is_main_process + +# Adapted from Accelerate +# https://github.com/huggingface/accelerate/blob/main/src/accelerate/logging.py + + +class MultiProcessAdapter(logging.LoggerAdapter): + """ + logger adapter for distributed logging, specifically to only log on main process + """ + + def __init__(self, logger, use_environ=False, extra=None): + super().__init__(logger, extra) + self.use_environ = use_environ + + @staticmethod + def _should_log(main_process_only, use_environ=False): + return not main_process_only or ( + main_process_only and is_main_process(use_environ=use_environ) + ) + + def log(self, level, msg, *args, **kwargs): + use_environ = kwargs.pop("use_environ", self.use_environ) + main_process_only = kwargs.pop("main_process_only", True) + kwargs.setdefault("stacklevel", 2) + + if self.isEnabledFor(level) and self._should_log( + main_process_only, use_environ=use_environ + ): + msg, kwargs = self.process(msg, kwargs) + self.logger.log(level, msg, *args, **kwargs) + + @functools.lru_cache(maxsize=10) + def warning_once(self, *args, **kwargs): + """ + This method is identical to `logger.warning()`, but will emit the warning with the same message only once + + Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the + cache. The assumption here is that all warning messages are unique across the code. If they aren't then need to + switch to another type of cache that includes the caller frame information in the hashing function. + """ + self.warning(*args, **kwargs) + + +def get_logger( + name: str, log_level: str | None = None, use_environ: bool = False +) -> MultiProcessAdapter: + if log_level is None: + log_level = os.environ.get("AXOLOTL_LOG_LEVEL", None) + logger = logging.getLogger(name) + if log_level is not None: + logger.setLevel(log_level.upper()) + logger.root.setLevel(log_level.upper()) + return MultiProcessAdapter(logger, use_environ=use_environ, extra={}) diff --git a/src/axolotl/utils/quantization.py b/src/axolotl/utils/quantization.py index 612b1d44e6..f9a30b660e 100644 --- a/src/axolotl/utils/quantization.py +++ b/src/axolotl/utils/quantization.py @@ -2,8 +2,6 @@ Utilities for quantization including QAT and PTQ using torchao. """ -import logging - import torch from torch import nn from torchao.core.config import AOBaseConfig @@ -25,8 +23,6 @@ from axolotl.utils.schemas.enums import TorchIntDType -LOG = logging.getLogger(__name__) - def get_ptq_config( weight_dtype: TorchIntDType, diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 1bfa2ec6e0..e488ed7d5c 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -3,7 +3,6 @@ into fixed-capacity batches to optimize memory usage and training throughput. """ -import logging import math from concurrent.futures import ProcessPoolExecutor from multiprocessing import cpu_count, get_context @@ -14,9 +13,9 @@ from torch.utils.data import BatchSampler, Sampler, SequentialSampler from axolotl.utils.distributed import reduce_and_broadcast +from axolotl.utils.logging import get_logger -LOG = logging.getLogger(__name__) -LOG.setLevel(logging.INFO) +LOG = get_logger(__name__) @numba.njit diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 8a4d6d63fa..75551085b6 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -2,7 +2,6 @@ # pylint: disable=too-many-lines -import logging import os from typing import Annotated, Any, Literal @@ -18,6 +17,7 @@ ) from transformers.utils.import_utils import is_torch_npu_available +from axolotl.utils.logging import get_logger from axolotl.utils.schemas.datasets import ( DatasetConfig, DPODataset, @@ -49,7 +49,7 @@ from axolotl.utils.schemas.trl import TRLConfig from axolotl.utils.schemas.vllm import VllmConfig -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__, use_environ=True) SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"} diff --git a/src/axolotl/utils/schemas/deprecated.py b/src/axolotl/utils/schemas/deprecated.py index d42d6ff9eb..b8904136e4 100644 --- a/src/axolotl/utils/schemas/deprecated.py +++ b/src/axolotl/utils/schemas/deprecated.py @@ -1,11 +1,12 @@ """Pydantic models for deprecated and remapped configuration parameters""" -import logging from typing import Any from pydantic import BaseModel, Field, field_validator -LOG = logging.getLogger(__name__) +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) class DeprecatedParameters(BaseModel): diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index 91fdce1614..d09ab63877 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -64,6 +64,7 @@ class ChatTemplate(str, Enum): command_a_rag = "command_a_rag" # pylint: disable=invalid-name aya = "aya" # pylint: disable=invalid-name + class CustomSupportedOptimizers(str, Enum): """Custom supported optimizers""" diff --git a/src/axolotl/utils/schemas/integrations.py b/src/axolotl/utils/schemas/integrations.py index 9d8f9c190c..4843e3592d 100644 --- a/src/axolotl/utils/schemas/integrations.py +++ b/src/axolotl/utils/schemas/integrations.py @@ -1,11 +1,12 @@ """Pydantic models for Axolotl integrations""" -import logging from typing import Any from pydantic import BaseModel, Field, model_validator -LOG = logging.getLogger(__name__) +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) class MLFlowConfig(BaseModel): diff --git a/src/axolotl/utils/schemas/model.py b/src/axolotl/utils/schemas/model.py index 5f1d26e84c..57f5ae309c 100644 --- a/src/axolotl/utils/schemas/model.py +++ b/src/axolotl/utils/schemas/model.py @@ -1,10 +1,10 @@ """Pydantic models for model input / output, etc. configuration""" -import logging - from pydantic import BaseModel, Field, field_validator -LOG = logging.getLogger(__name__) +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__, use_environ=True) class ModelInputConfig(BaseModel): diff --git a/src/axolotl/utils/schemas/training.py b/src/axolotl/utils/schemas/training.py index 69547c17f3..ad7f899aca 100644 --- a/src/axolotl/utils/schemas/training.py +++ b/src/axolotl/utils/schemas/training.py @@ -1,15 +1,15 @@ """Pydantic models for training hyperparameters""" -import logging from typing import Any, Literal from pydantic import BaseModel, Field, field_validator from transformers import SchedulerType from transformers.training_args import OptimizerNames +from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import CustomSupportedOptimizers -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) class LrGroup(BaseModel): diff --git a/src/axolotl/utils/schemas/utils.py b/src/axolotl/utils/schemas/utils.py index bf74390f67..b46c8f8475 100644 --- a/src/axolotl/utils/schemas/utils.py +++ b/src/axolotl/utils/schemas/utils.py @@ -1,8 +1,8 @@ """Utilities for Axolotl Pydantic models""" -import logging +from axolotl.utils.logging import get_logger -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) def handle_legacy_message_fields_logic(data: dict) -> dict: diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index e0b21a9f02..3526bd5b58 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -1,10 +1,10 @@ """Module for tokenization utilities""" -import logging - from termcolor import colored -LOG = logging.getLogger("axolotl") +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) def check_dataset_labels( diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 96f54b39d1..c08504d73c 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -22,7 +22,7 @@ from axolotl.utils.environment import check_cuda_p2p_ib_support from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths -LOG = get_logger("axolotl") +LOG = get_logger(__name__) @torch.jit.script @@ -402,7 +402,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): .apply(len) .values ) - LOG.debug(f"total_num_tokens: {total_num_tokens:_}", main_process_only=True) + LOG.debug(f"total_num_tokens: {total_num_tokens:_}") if update: cfg.total_num_tokens = total_num_tokens @@ -420,10 +420,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): .apply(lambda x: np.sum(np.array(x) != -100)) .sum() ) - LOG.debug( - f"`total_supervised_tokens: {total_supervised_tokens:_}`", - main_process_only=True, - ) + LOG.debug(f"`total_supervised_tokens: {total_supervised_tokens:_}`") if update: cfg.total_supervised_tokens = total_supervised_tokens @@ -448,8 +445,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): * cfg.sequence_parallel_degree ) LOG.debug( - f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}", - main_process_only=True, + f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}" ) else: if cfg.flash_attention and not cfg.multipack_real_batches: @@ -478,7 +474,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): batch_sampler=sampler, ) data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size - LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True) + LOG.debug(f"data_loader_len: {data_loader_len}") # FIXME: is there a bug here somewhere? the total num steps depends # on the agreed on value for sample_packing_eff_est total_num_steps = int( @@ -500,10 +496,7 @@ def calc_sample_packing_eff_est(estimates: List[float]): ) if update: cfg.sample_packing_eff_est = sample_packing_eff_est - LOG.debug( - f"sample_packing_eff_est: {cfg.sample_packing_eff_est}", - main_process_only=True, - ) + LOG.debug(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}") else: total_num_steps = int( math.ceil( @@ -513,7 +506,7 @@ def calc_sample_packing_eff_est(estimates: List[float]): / cfg.batch_size ) ) - LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True) + LOG.debug(f"total_num_steps: {total_num_steps}") return total_num_steps diff --git a/tests/e2e/multigpu/solo/test_flex.py b/tests/e2e/multigpu/solo/test_flex.py index 471b112c10..080ea4c97a 100644 --- a/tests/e2e/multigpu/solo/test_flex.py +++ b/tests/e2e/multigpu/solo/test_flex.py @@ -2,7 +2,6 @@ E2E tests for multigpu lora tinyllama """ -import logging import os from pathlib import Path @@ -14,10 +13,11 @@ from transformers.utils import is_torch_bf16_gpu_available from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from tests.e2e.utils import check_tensorboard, require_torch_2_6_0 -LOG = logging.getLogger("axolotl.tests.e2e.multigpu") +LOG = get_logger("axolotl.tests.e2e.multigpu") os.environ["WANDB_DISABLED"] = "true" AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent diff --git a/tests/e2e/multigpu/test_eval.py b/tests/e2e/multigpu/test_eval.py index 4989b81df7..45a961b7a4 100644 --- a/tests/e2e/multigpu/test_eval.py +++ b/tests/e2e/multigpu/test_eval.py @@ -2,7 +2,6 @@ E2E tests for multigpu eval """ -import logging import os from pathlib import Path @@ -11,10 +10,11 @@ from transformers.testing_utils import get_torch_dist_unique_port from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_tensorboard -LOG = logging.getLogger("axolotl.tests.e2e.multigpu") +LOG = get_logger("axolotl.tests.e2e.multigpu") os.environ["WANDB_DISABLED"] = "true" AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent diff --git a/tests/e2e/multigpu/test_gemma3.py b/tests/e2e/multigpu/test_gemma3.py index 9de3ed82f8..8540ec91fb 100644 --- a/tests/e2e/multigpu/test_gemma3.py +++ b/tests/e2e/multigpu/test_gemma3.py @@ -2,7 +2,6 @@ E2E tests for multigpu lora tinyllama """ -import logging import os from pathlib import Path @@ -13,10 +12,11 @@ from transformers.testing_utils import get_torch_dist_unique_port from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from tests.e2e.utils import check_tensorboard -LOG = logging.getLogger("axolotl.tests.e2e.multigpu") +LOG = get_logger("axolotl.tests.e2e.multigpu") os.environ["WANDB_DISABLED"] = "true" AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index 38e6e741a1..e383c54413 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -2,7 +2,6 @@ E2E tests for multigpu lora tinyllama """ -import logging import os from pathlib import Path @@ -15,10 +14,11 @@ from transformers.testing_utils import get_torch_dist_unique_port from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from tests.e2e.utils import check_tensorboard, require_torch_2_6_0 -LOG = logging.getLogger("axolotl.tests.e2e.multigpu") +LOG = get_logger("axolotl.tests.e2e.multigpu") os.environ["WANDB_DISABLED"] = "true" AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent diff --git a/tests/e2e/multigpu/test_qwen2.py b/tests/e2e/multigpu/test_qwen2.py index 9599c3abf8..23650b10dc 100644 --- a/tests/e2e/multigpu/test_qwen2.py +++ b/tests/e2e/multigpu/test_qwen2.py @@ -2,7 +2,6 @@ E2E tests for multigpu qwen2 """ -import logging import os from pathlib import Path @@ -12,8 +11,9 @@ from transformers.testing_utils import get_torch_dist_unique_port from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger -LOG = logging.getLogger("axolotl.tests.e2e.multigpu") +LOG = get_logger("axolotl.tests.e2e.multigpu") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/multigpu/test_ray.py b/tests/e2e/multigpu/test_ray.py index 843adac912..64c2d501ff 100644 --- a/tests/e2e/multigpu/test_ray.py +++ b/tests/e2e/multigpu/test_ray.py @@ -2,7 +2,6 @@ E2E tests for multigpu post-training use Ray Train """ -import logging import os from pathlib import Path @@ -11,10 +10,11 @@ from accelerate.test_utils import execute_subprocess_async from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0 -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) os.environ["WANDB_DISABLED"] = "true" AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent diff --git a/tests/e2e/patched/test_4d_multipack_llama.py b/tests/e2e/patched/test_4d_multipack_llama.py index 12dd51c134..27b2b2ca04 100644 --- a/tests/e2e/patched/test_4d_multipack_llama.py +++ b/tests/e2e/patched/test_4d_multipack_llama.py @@ -2,7 +2,6 @@ E2E tests for multipack fft llama using 4d attention masks """ -import logging import os import unittest @@ -11,10 +10,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/patched/test_fa_xentropy.py b/tests/e2e/patched/test_fa_xentropy.py index f71e4fb4af..2581d39a6e 100644 --- a/tests/e2e/patched/test_fa_xentropy.py +++ b/tests/e2e/patched/test_fa_xentropy.py @@ -2,7 +2,6 @@ E2E tests for lora llama """ -import logging import os import pytest @@ -13,10 +12,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_model_output_exists, check_tensorboard -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/patched/test_falcon_samplepack.py b/tests/e2e/patched/test_falcon_samplepack.py index 667b62ffba..61689ca1fc 100644 --- a/tests/e2e/patched/test_falcon_samplepack.py +++ b/tests/e2e/patched/test_falcon_samplepack.py @@ -2,7 +2,6 @@ E2E tests for falcon """ -import logging import os import unittest @@ -13,10 +12,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/patched/test_fused_llama.py b/tests/e2e/patched/test_fused_llama.py index 7725e095d1..20fd2acb53 100644 --- a/tests/e2e/patched/test_fused_llama.py +++ b/tests/e2e/patched/test_fused_llama.py @@ -2,7 +2,6 @@ E2E tests for lora llama """ -import logging import os import unittest @@ -14,10 +13,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py index 3cf43ba9d3..3c81a274a7 100644 --- a/tests/e2e/patched/test_llama_s2_attention.py +++ b/tests/e2e/patched/test_llama_s2_attention.py @@ -2,7 +2,6 @@ E2E tests for llama w/ S2 attn """ -import logging import os import unittest @@ -13,10 +12,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/patched/test_lora_llama_multipack.py b/tests/e2e/patched/test_lora_llama_multipack.py index ca989f241e..894742a7e8 100644 --- a/tests/e2e/patched/test_lora_llama_multipack.py +++ b/tests/e2e/patched/test_lora_llama_multipack.py @@ -2,7 +2,6 @@ E2E tests for lora llama """ -import logging import os import unittest @@ -14,10 +13,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/patched/test_mistral_samplepack.py b/tests/e2e/patched/test_mistral_samplepack.py index fe8fafb19d..5ae5a6dc5a 100644 --- a/tests/e2e/patched/test_mistral_samplepack.py +++ b/tests/e2e/patched/test_mistral_samplepack.py @@ -2,7 +2,6 @@ E2E tests for lora llama """ -import logging import os import unittest @@ -11,10 +10,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py index ebc2ba0927..38a5d6b658 100644 --- a/tests/e2e/patched/test_mixtral_samplepack.py +++ b/tests/e2e/patched/test_mixtral_samplepack.py @@ -2,7 +2,6 @@ E2E tests for mixtral """ -import logging import os import unittest @@ -11,10 +10,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/patched/test_phi_multipack.py b/tests/e2e/patched/test_phi_multipack.py index d8130d1190..54cac15dcd 100644 --- a/tests/e2e/patched/test_phi_multipack.py +++ b/tests/e2e/patched/test_phi_multipack.py @@ -2,7 +2,6 @@ E2E tests for lora llama """ -import logging import os import unittest @@ -11,10 +10,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/patched/test_resume.py b/tests/e2e/patched/test_resume.py index 61e4a0e03d..8ba6b7c540 100644 --- a/tests/e2e/patched/test_resume.py +++ b/tests/e2e/patched/test_resume.py @@ -2,7 +2,6 @@ E2E tests for resuming training """ -import logging import os import re import subprocess @@ -14,10 +13,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_model_output_exists, most_recent_subdir, require_torch_2_6_0 -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/patched/test_unsloth_qlora.py b/tests/e2e/patched/test_unsloth_qlora.py index 5f8fde6b4d..3b429279f5 100644 --- a/tests/e2e/patched/test_unsloth_qlora.py +++ b/tests/e2e/patched/test_unsloth_qlora.py @@ -2,7 +2,6 @@ e2e tests for unsloth qlora """ -import logging import os import pytest @@ -12,10 +11,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_model_output_exists, check_tensorboard -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/solo/test_flex.py b/tests/e2e/solo/test_flex.py index 71da795f89..431afd55ba 100644 --- a/tests/e2e/solo/test_flex.py +++ b/tests/e2e/solo/test_flex.py @@ -2,7 +2,6 @@ E2E tests for packed training w/ flex attention """ -import logging import os import unittest @@ -13,10 +12,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_tensorboard, require_torch_2_6_0, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/solo/test_relora_llama.py b/tests/e2e/solo/test_relora_llama.py index 504466b90c..6e9f403d0e 100644 --- a/tests/e2e/solo/test_relora_llama.py +++ b/tests/e2e/solo/test_relora_llama.py @@ -2,7 +2,6 @@ E2E tests for relora llama """ -import logging import os import unittest from pathlib import Path @@ -12,10 +11,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_deepseekv3.py b/tests/e2e/test_deepseekv3.py index 2afda640f1..0a228aa052 100644 --- a/tests/e2e/test_deepseekv3.py +++ b/tests/e2e/test_deepseekv3.py @@ -2,7 +2,6 @@ E2E tests for deepseekv3 """ -import logging import os from pathlib import Path @@ -13,10 +12,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from tests.hf_offline_utils import enable_hf_offline -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index 84d723ec08..b039893849 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -2,7 +2,6 @@ E2E tests for lora llama """ -import logging import os import unittest from pathlib import Path @@ -14,10 +13,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_embeddings_lr.py b/tests/e2e/test_embeddings_lr.py index 82b822ad60..fe6a507449 100644 --- a/tests/e2e/test_embeddings_lr.py +++ b/tests/e2e/test_embeddings_lr.py @@ -2,7 +2,6 @@ E2E tests for llama pretrain """ -import logging import os import unittest @@ -11,10 +10,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, check_tensorboard, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_falcon.py b/tests/e2e/test_falcon.py index 24afab0b3f..4f15867caf 100644 --- a/tests/e2e/test_falcon.py +++ b/tests/e2e/test_falcon.py @@ -2,7 +2,6 @@ E2E tests for falcon """ -import logging import os import unittest @@ -13,10 +12,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_gemma2.py b/tests/e2e/test_gemma2.py index 68dc4855d9..8b9b0d11d4 100644 --- a/tests/e2e/test_gemma2.py +++ b/tests/e2e/test_gemma2.py @@ -2,7 +2,6 @@ E2E tests for gemma2 """ -import logging import os from pathlib import Path @@ -13,8 +12,9 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_gemma3_text.py b/tests/e2e/test_gemma3_text.py index 5cbde04d10..9873de6279 100644 --- a/tests/e2e/test_gemma3_text.py +++ b/tests/e2e/test_gemma3_text.py @@ -2,7 +2,6 @@ E2E tests for gemma3_text """ -import logging import os from pathlib import Path @@ -13,8 +12,9 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py index d3e37fb3fc..352372e1ec 100644 --- a/tests/e2e/test_llama.py +++ b/tests/e2e/test_llama.py @@ -2,7 +2,6 @@ E2E tests for llama """ -import logging import os from axolotl.cli.args import TrainerCliArgs @@ -10,10 +9,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from tests.e2e.utils import check_model_output_exists -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py index 647285e464..9d0e4d7a6f 100644 --- a/tests/e2e/test_llama_pretrain.py +++ b/tests/e2e/test_llama_pretrain.py @@ -2,7 +2,6 @@ E2E tests for llama pretrain """ -import logging import os import pytest @@ -12,10 +11,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, check_tensorboard -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_llama_vision.py b/tests/e2e/test_llama_vision.py index e1e496ccf8..890f275698 100644 --- a/tests/e2e/test_llama_vision.py +++ b/tests/e2e/test_llama_vision.py @@ -2,7 +2,6 @@ E2E tests for lora llama """ -import logging import os import unittest @@ -11,10 +10,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index b02fe3d447..02d2868dac 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -2,7 +2,6 @@ E2E tests for lora llama """ -import logging import os import unittest @@ -11,10 +10,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_mamba.py b/tests/e2e/test_mamba.py index f49b53987d..92397ab88f 100644 --- a/tests/e2e/test_mamba.py +++ b/tests/e2e/test_mamba.py @@ -2,7 +2,6 @@ E2E tests for lora llama """ -import logging import os import unittest @@ -13,10 +12,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_mistral.py b/tests/e2e/test_mistral.py index ba8cf28962..ac57848435 100644 --- a/tests/e2e/test_mistral.py +++ b/tests/e2e/test_mistral.py @@ -2,7 +2,6 @@ E2E tests for lora llama """ -import logging import os import unittest @@ -13,10 +12,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_mixtral.py b/tests/e2e/test_mixtral.py index 4e0693b949..329428473f 100644 --- a/tests/e2e/test_mixtral.py +++ b/tests/e2e/test_mixtral.py @@ -2,7 +2,6 @@ E2E tests for mixtral """ -import logging import os import unittest @@ -14,10 +13,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index 91f45b762f..291ed3d6a1 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -2,7 +2,6 @@ E2E tests for custom optimizers using Llama """ -import logging import os import unittest @@ -11,10 +10,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, require_torch_2_5_1, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_packing_loss.py b/tests/e2e/test_packing_loss.py index 73716f44bb..52e27a2c17 100644 --- a/tests/e2e/test_packing_loss.py +++ b/tests/e2e/test_packing_loss.py @@ -2,7 +2,6 @@ E2E tests for packed training """ -import logging import os import unittest @@ -13,10 +12,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_tensorboard, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_phi.py b/tests/e2e/test_phi.py index f531a17c50..349ae9efba 100644 --- a/tests/e2e/test_phi.py +++ b/tests/e2e/test_phi.py @@ -2,7 +2,6 @@ E2E tests for lora llama """ -import logging import os import unittest @@ -11,10 +10,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_process_reward_model_smollm2.py b/tests/e2e/test_process_reward_model_smollm2.py index 446facdb0d..0673409ab2 100644 --- a/tests/e2e/test_process_reward_model_smollm2.py +++ b/tests/e2e/test_process_reward_model_smollm2.py @@ -2,7 +2,6 @@ E2E tests for process reward model w/ lora llama """ -import logging import os import unittest @@ -11,10 +10,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, check_tensorboard, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_qwen.py b/tests/e2e/test_qwen.py index 39d55603f5..1f57c6ae18 100644 --- a/tests/e2e/test_qwen.py +++ b/tests/e2e/test_qwen.py @@ -2,7 +2,6 @@ E2E tests for qwen """ -import logging import os from pathlib import Path @@ -12,8 +11,9 @@ from transformers.testing_utils import get_torch_dist_unique_port from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger -LOG = logging.getLogger("axolotl.tests.qwen") +LOG = get_logger("axolotl.tests.qwen") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_reward_model_smollm2.py b/tests/e2e/test_reward_model_smollm2.py index 240c4b3924..31938ea589 100644 --- a/tests/e2e/test_reward_model_smollm2.py +++ b/tests/e2e/test_reward_model_smollm2.py @@ -2,7 +2,6 @@ E2E tests for reward model lora llama """ -import logging import os import unittest @@ -11,10 +10,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, check_tensorboard, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_schedulers.py b/tests/e2e/test_schedulers.py index 694bb21e81..12783cfb7d 100644 --- a/tests/e2e/test_schedulers.py +++ b/tests/e2e/test_schedulers.py @@ -2,7 +2,6 @@ E2E tests for custom schedulers using Llama """ -import logging import os import unittest @@ -11,10 +10,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/integrations/test_liger.py b/tests/integrations/test_liger.py index cbe1408b81..2d6abe311b 100644 --- a/tests/integrations/test_liger.py +++ b/tests/integrations/test_liger.py @@ -2,8 +2,6 @@ config validation tests for swiglu args """ -# pylint: disable=duplicate-code -import logging from typing import Optional import pytest @@ -11,6 +9,11 @@ from axolotl.utils.config import prepare_plugins, validate_config from axolotl.utils.dict import DictDefault +# pylint: disable=duplicate-code +from axolotl.utils.logging import get_logger + +LOG = get_logger("axolotl.integrations.test_liger") + @pytest.fixture(name="minimal_liger_cfg") def fixture_cfg(): @@ -41,7 +44,7 @@ class TestValidation: @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): - caplog.set_level(logging.WARNING) + caplog.set_level("WARNING") self._caplog = caplog def test_deprecated_swiglu(self, minimal_liger_cfg): @@ -52,9 +55,7 @@ def test_deprecated_swiglu(self, minimal_liger_cfg): | minimal_liger_cfg ) - with self._caplog.at_level( - logging.WARNING, logger="axolotl.integrations.liger.args" - ): + with self._caplog.at_level("WARNING", logger="axolotl.integrations.liger.args"): prepare_plugins(test_cfg) updated_cfg = validate_config(test_cfg) # TODO this test is brittle in CI diff --git a/tests/patched/test_validation.py b/tests/patched/test_validation.py index 1c7325dffb..93347e2a4f 100644 --- a/tests/patched/test_validation.py +++ b/tests/patched/test_validation.py @@ -1,7 +1,6 @@ # pylint: disable=too-many-lines """Module for testing the validation module""" -import logging import os import warnings from typing import Optional @@ -13,12 +12,15 @@ from axolotl.utils import is_comet_available from axolotl.utils.config import validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from axolotl.utils.mlflow_ import setup_mlflow_env_vars from axolotl.utils.schemas.config import AxolotlConfigWCapabilities from axolotl.utils.wandb_ import setup_wandb_env_vars warnings.filterwarnings("error") +LOG = get_logger(__name__) + @pytest.fixture(name="minimal_cfg") def fixture_cfg(): @@ -80,7 +82,7 @@ def test_zero3_qlora_use_reentrant_false(self, minimal_cfg): | minimal_cfg ) - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(test_cfg) assert ( "qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values" @@ -218,7 +220,7 @@ def test_batch_size_unused_warning(self): } ) - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(cfg) assert "batch_size is not recommended" in self._caplog.records[0].message @@ -513,7 +515,7 @@ def test_flash_optimum(self, minimal_cfg): | minimal_cfg ) - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(cfg) assert any( "BetterTransformers probably doesn't work with PEFT adapters" @@ -531,7 +533,7 @@ def test_flash_optimum(self, minimal_cfg): | minimal_cfg ) - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(cfg) assert any( "probably set bfloat16 or float16" in record.message @@ -577,7 +579,7 @@ def test_adamw_hyperparams(self, minimal_cfg): | minimal_cfg ) - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(cfg) assert any( "adamw hyperparameters found, but no adamw optimizer set" @@ -595,7 +597,7 @@ def test_adamw_hyperparams(self, minimal_cfg): | minimal_cfg ) - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(cfg) assert any( "adamw hyperparameters found, but no adamw optimizer set" @@ -654,7 +656,7 @@ def test_packing(self, minimal_cfg): ) | minimal_cfg ) - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(cfg) assert any( "`pad_to_sequence_len: true` is recommended when using sample_packing" @@ -673,7 +675,7 @@ def test_packing_autoset(self, minimal_cfg): ) | minimal_cfg ) - with self._caplog.at_level(logging.INFO): + with self._caplog.at_level("INFO"): cfg = validate_config(cfg) assert any( "Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing" @@ -1109,7 +1111,7 @@ def test_unfrozen_parameters_w_peft_layers_to_transform(self, minimal_cfg): def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_cfg): cfg = DictDefault({"hub_model_id": "test", "save_strategy": "no"}) | minimal_cfg - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(cfg) assert len(self._caplog.records) == 1 @@ -1118,7 +1120,7 @@ def test_hub_model_id_save_value_warns_random_value(self, minimal_cfg): DictDefault({"hub_model_id": "test", "save_strategy": "test"}) | minimal_cfg ) - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(cfg) assert len(self._caplog.records) == 1 @@ -1128,7 +1130,7 @@ def test_hub_model_id_save_value_steps(self, minimal_cfg): | minimal_cfg ) - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(cfg) assert len(self._caplog.records) == 0 @@ -1138,28 +1140,28 @@ def test_hub_model_id_save_value_epochs(self, minimal_cfg): | minimal_cfg ) - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(cfg) assert len(self._caplog.records) == 0 def test_hub_model_id_save_value_none(self, minimal_cfg): cfg = DictDefault({"hub_model_id": "test", "save_strategy": None}) | minimal_cfg - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(cfg) assert len(self._caplog.records) == 0 def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg): cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(cfg) assert len(self._caplog.records) == 0 def test_dpo_beta_deprecation(self, minimal_cfg): cfg = DictDefault({"dpo_beta": 0.2}) | minimal_cfg - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): new_cfg = validate_config(cfg) assert new_cfg["rl_beta"] == 0.2 assert new_cfg["dpo_beta"] is None @@ -1175,7 +1177,7 @@ def test_eval_strategy_remap(self, minimal_cfg): | minimal_cfg ) - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): new_cfg = validate_config(cfg) assert new_cfg.eval_strategy == "steps" assert ( @@ -1455,7 +1457,7 @@ def test_wandb_set_run_id_to_name(self, minimal_cfg): | minimal_cfg ) - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): new_cfg = validate_config(cfg) assert any( "wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead." diff --git a/tests/prompt_strategies/messages/test_chat.py b/tests/prompt_strategies/messages/test_chat.py index 2681bb7431..a4c2ae67fd 100644 --- a/tests/prompt_strategies/messages/test_chat.py +++ b/tests/prompt_strategies/messages/test_chat.py @@ -3,14 +3,13 @@ """ # pylint: disable=duplicate-code -import logging import unittest from axolotl.prompt_strategies.messages.chat import load from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger -logging.basicConfig(level=logging.DEBUG) -LOG = logging.getLogger("axolotl") +LOG = get_logger(__name__, log_level="DEBUG") class TestMessagesChatLlama3: diff --git a/tests/prompt_strategies/test_chat_templates.py b/tests/prompt_strategies/test_chat_templates.py index 68772b56b3..371ccf6161 100644 --- a/tests/prompt_strategies/test_chat_templates.py +++ b/tests/prompt_strategies/test_chat_templates.py @@ -2,7 +2,6 @@ tests for chat_template prompt strategy """ -import logging import unittest from axolotl.prompt_strategies.chat_template import ( @@ -13,9 +12,9 @@ from axolotl.prompters import IGNORE_TOKEN_ID from axolotl.utils.chat_templates import get_chat_template from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger -logging.basicConfig(level=logging.DEBUG) -LOG = logging.getLogger("axolotl") +LOG = get_logger(__name__) class TestAssistantChatTemplateLlama3: diff --git a/tests/prompt_strategies/test_chat_templates_advanced.py b/tests/prompt_strategies/test_chat_templates_advanced.py index 38a5b6c432..7f011f9543 100644 --- a/tests/prompt_strategies/test_chat_templates_advanced.py +++ b/tests/prompt_strategies/test_chat_templates_advanced.py @@ -4,7 +4,6 @@ # pylint: disable=too-many-lines -import logging from copy import deepcopy import pytest @@ -18,11 +17,11 @@ ) from axolotl.prompters import IGNORE_TOKEN_ID from axolotl.utils.chat_templates import get_chat_template +from axolotl.utils.logging import get_logger from tests.hf_offline_utils import enable_hf_offline -logging.basicConfig(level=logging.DEBUG) -LOG = logging.getLogger("axolotl") +LOG = get_logger(__name__) PARAMETRIZE_KEYS = "tokenizer, chat_template, chat_template_jinja, eos_token" PARAMETRIZE_PARAMS = [ diff --git a/tests/prompt_strategies/test_chat_templates_thinking.py b/tests/prompt_strategies/test_chat_templates_thinking.py index 9fe292317d..21d8c4d5ea 100644 --- a/tests/prompt_strategies/test_chat_templates_thinking.py +++ b/tests/prompt_strategies/test_chat_templates_thinking.py @@ -2,8 +2,6 @@ Tests for splitting reasoning/thinking from content into separate field """ -import logging - import pytest from datasets import Dataset from transformers import AutoTokenizer @@ -12,11 +10,11 @@ load, ) from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from tests.hf_offline_utils import enable_hf_offline -logging.basicConfig(level=logging.DEBUG) -LOG = logging.getLogger("axolotl") +LOG = get_logger(__name__) @pytest.fixture(name="messages_w_reasoning") diff --git a/tests/prompt_strategies/test_jinja_template_analyzer.py b/tests/prompt_strategies/test_jinja_template_analyzer.py index f666c738c9..41b9a0203a 100644 --- a/tests/prompt_strategies/test_jinja_template_analyzer.py +++ b/tests/prompt_strategies/test_jinja_template_analyzer.py @@ -2,14 +2,12 @@ tests for jinja_template_analyzer """ -import logging - import pytest from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer +from axolotl.utils.logging import get_logger -logging.basicConfig(level=logging.DEBUG) -LOG = logging.getLogger("axolotl") +LOG = get_logger(__name__, log_level="DEBUG") class TestJinjaTemplateAnalyzer: diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index 3f16bc9177..d34b774b36 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -1,7 +1,6 @@ """Module for testing prompt tokenizers.""" import json -import logging from pathlib import Path from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter @@ -17,10 +16,11 @@ from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy from axolotl.prompters import AlpacaPrompter, PromptStyle from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from tests.hf_offline_utils import enable_hf_offline -LOG = logging.getLogger("axolotl") +LOG = get_logger(__name__) test_data = { "multi_turn_sys": {