From 2d68a1f58e29a35bcd955fd53c8b74578812579c Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 1 May 2025 14:16:05 +0100 Subject: [PATCH 01/29] quick start --- src/axolotl/utils/data/sft.py | 83 +++++++++++++++++++++++++-------- src/axolotl/utils/data/utils.py | 21 +++++++-- src/axolotl/utils/logging.py | 18 +++++++ 3 files changed, 97 insertions(+), 25 deletions(-) create mode 100644 src/axolotl/utils/logging.py diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 12f0701f06..82ab9dace2 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -54,6 +54,11 @@ ) from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_local_main_process, zero_first +from axolotl.utils.logging import ( + log_debug_rank_zero, + log_info_rank_zero, + log_warning_rank_zero, +) from axolotl.utils.trainer import ( calculate_total_num_steps, process_datasets_for_packing, @@ -167,7 +172,10 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None): ) if cfg.dataset_exact_deduplication: - LOG.info("Deduplication not available for pretrained datasets") + log_info_rank_zero( + LOG, + "Deduplication not available for pretrained datasets", + ) return train_dataset, eval_dataset, cfg.max_steps, prompters @@ -182,10 +190,12 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None): total_num_steps = min( calculate_total_num_steps(cfg, train_dataset), cfg.max_steps ) - LOG.info(f"Maximum number of steps set at {total_num_steps}") else: total_num_steps = calculate_total_num_steps(cfg, train_dataset) - + log_info_rank_zero( + LOG, + f"Maximum number of steps set at {total_num_steps}", + ) return train_dataset, eval_dataset, total_num_steps, prompters @@ -235,8 +245,9 @@ def load_tokenized_prepared_datasets( use_auth_token = cfg.hf_use_auth_token try: if cfg.push_dataset_to_hub: - LOG.info( - f"Attempting to load prepared dataset from Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..." + log_info_rank_zero( + LOG, + f"Attempting to load prepared dataset from Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})...", ) dataset = load_dataset( cfg.push_dataset_to_hub, @@ -257,28 +268,48 @@ def load_tokenized_prepared_datasets( and not cfg.is_preprocess and not cfg.skip_prepare_dataset ): - LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") + log_info_rank_zero( + LOG, + f"Loading prepared dataset from disk at {prepared_ds_path}...", + ) dataset = load_from_disk(str(prepared_ds_path)) - LOG.info("Prepared dataset loaded from disk...") + log_info_rank_zero( + LOG, + "Prepared dataset loaded from disk...", + ) else: if cfg.push_dataset_to_hub: - LOG.info("Unable to find prepared dataset in Huggingface hub") + log_info_rank_zero( + LOG, + "Unable to find prepared dataset in Huggingface hub", + ) if cfg.is_preprocess: - LOG.info( - f"Skipping prepared dataset in {prepared_ds_path} for pre-processing..." + log_info_rank_zero( + LOG, + f"Skipping prepared dataset in {prepared_ds_path} for pre-processing...", ) else: - LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}") - LOG.info("Loading raw datasets...") + log_info_rank_zero( + LOG, + f"Unable to find prepared dataset in {prepared_ds_path}", + ) + log_info_rank_zero( + LOG, + "Loading raw datasets...", + ) if not cfg.is_preprocess: - LOG.warning( - "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset." + log_warning_rank_zero( + LOG, + "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset.", ) if cfg.seed: seed = cfg.seed else: - LOG.info("No seed provided, using default seed of 42") + log_info_rank_zero( + LOG, + "No seed provided, using default seed of 42", + ) seed = 42 datasets = [] @@ -331,15 +362,24 @@ def load_tokenized_prepared_datasets( if len(datasets) == 1: dataset = datasets[0] else: - LOG.info("merging datasets") + log_info_rank_zero( + LOG, + "Merging datasets...", + ) dataset = concatenate_datasets(datasets) if len(datasets) > 1: if cfg.shuffle_merged_datasets: - LOG.debug("shuffle merged datasets") + log_debug_rank_zero( + LOG, + "Shuffling merged datasets...", + ) dataset = dataset.shuffle(seed=seed) else: - LOG.debug("NOT shuffling merged datasets") + log_debug_rank_zero( + LOG, + "NOT shuffling merged datasets", + ) if not cfg.skip_prepare_dataset: dataset = drop_long_seq_in_dataset(dataset, cfg) @@ -348,7 +388,10 @@ def load_tokenized_prepared_datasets( dataset, _ = process_datasets_for_packing(cfg, dataset, None) if cfg.local_rank == 0 and not cfg.skip_prepare_dataset: - LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") + log_info_rank_zero( + LOG, + f"Saving merged prepared dataset to disk... {prepared_ds_path}", + ) if isinstance(dataset, IterableDataset): num_workers = cfg.dataset_processes @@ -482,7 +525,7 @@ def get_dataset_wrapper( } LOG.info( - f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}" + f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}", ) if ( diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index a8e19582e7..1e71853fbd 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -12,6 +12,7 @@ from datasets import Dataset, IterableDataset from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import log_warning_rank_zero from axolotl.utils.samplers.utils import get_dataset_lengths from axolotl.utils.trainer import drop_long_seq @@ -160,8 +161,9 @@ def deduplicate_and_log_datasets( def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): if "input_ids" not in dataset.column_names: - LOG.warning( - "Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling." + log_warning_rank_zero( + LOG, + "Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling.", ) return dataset @@ -174,9 +176,15 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): try: ds_lengths = get_dataset_lengths(dataset, from_arrow=True) min_input_len = np.min(ds_lengths) - LOG.info(f"min_input_len: {min_input_len}") + LOG.info( + LOG, + f"min_input_len: {min_input_len}", + ) max_input_len = np.max(ds_lengths) - LOG.info(f"max_input_len: {max_input_len}") + LOG.info( + LOG, + f"max_input_len: {max_input_len}", + ) except AttributeError: pass @@ -204,6 +212,9 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): if prior_len: dropped = prior_len - len(dataset) if dropped: - LOG.warning(f"Dropped {dropped} long samples from dataset") + LOG.warning( + LOG, + f"Dropped {dropped} long samples from dataset", + ) return dataset diff --git a/src/axolotl/utils/logging.py b/src/axolotl/utils/logging.py new file mode 100644 index 0000000000..e72439c74c --- /dev/null +++ b/src/axolotl/utils/logging.py @@ -0,0 +1,18 @@ +import logging + +from axolotl.utils.distributed import is_main_process + + +def log_info_rank_zero(log: logging.Logger, message: str): + if is_main_process(): + log.info(message) + + +def log_debug_rank_zero(log: logging.Logger, message: str): + if is_main_process(): + log.debug(message) + + +def log_warning_rank_zero(log: logging.Logger, message: str): + if is_main_process(): + log.warning(message) From 730fe0d3e63ba1982c7a77003391d0a4fe6b3925 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 1 May 2025 14:16:05 +0100 Subject: [PATCH 02/29] quick start --- src/axolotl/utils/data/sft.py | 83 +++++++++++++++++++++++++-------- src/axolotl/utils/data/utils.py | 21 +++++++-- src/axolotl/utils/logging.py | 18 +++++++ 3 files changed, 97 insertions(+), 25 deletions(-) create mode 100644 src/axolotl/utils/logging.py diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 12f0701f06..82ab9dace2 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -54,6 +54,11 @@ ) from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_local_main_process, zero_first +from axolotl.utils.logging import ( + log_debug_rank_zero, + log_info_rank_zero, + log_warning_rank_zero, +) from axolotl.utils.trainer import ( calculate_total_num_steps, process_datasets_for_packing, @@ -167,7 +172,10 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None): ) if cfg.dataset_exact_deduplication: - LOG.info("Deduplication not available for pretrained datasets") + log_info_rank_zero( + LOG, + "Deduplication not available for pretrained datasets", + ) return train_dataset, eval_dataset, cfg.max_steps, prompters @@ -182,10 +190,12 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None): total_num_steps = min( calculate_total_num_steps(cfg, train_dataset), cfg.max_steps ) - LOG.info(f"Maximum number of steps set at {total_num_steps}") else: total_num_steps = calculate_total_num_steps(cfg, train_dataset) - + log_info_rank_zero( + LOG, + f"Maximum number of steps set at {total_num_steps}", + ) return train_dataset, eval_dataset, total_num_steps, prompters @@ -235,8 +245,9 @@ def load_tokenized_prepared_datasets( use_auth_token = cfg.hf_use_auth_token try: if cfg.push_dataset_to_hub: - LOG.info( - f"Attempting to load prepared dataset from Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..." + log_info_rank_zero( + LOG, + f"Attempting to load prepared dataset from Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})...", ) dataset = load_dataset( cfg.push_dataset_to_hub, @@ -257,28 +268,48 @@ def load_tokenized_prepared_datasets( and not cfg.is_preprocess and not cfg.skip_prepare_dataset ): - LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") + log_info_rank_zero( + LOG, + f"Loading prepared dataset from disk at {prepared_ds_path}...", + ) dataset = load_from_disk(str(prepared_ds_path)) - LOG.info("Prepared dataset loaded from disk...") + log_info_rank_zero( + LOG, + "Prepared dataset loaded from disk...", + ) else: if cfg.push_dataset_to_hub: - LOG.info("Unable to find prepared dataset in Huggingface hub") + log_info_rank_zero( + LOG, + "Unable to find prepared dataset in Huggingface hub", + ) if cfg.is_preprocess: - LOG.info( - f"Skipping prepared dataset in {prepared_ds_path} for pre-processing..." + log_info_rank_zero( + LOG, + f"Skipping prepared dataset in {prepared_ds_path} for pre-processing...", ) else: - LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}") - LOG.info("Loading raw datasets...") + log_info_rank_zero( + LOG, + f"Unable to find prepared dataset in {prepared_ds_path}", + ) + log_info_rank_zero( + LOG, + "Loading raw datasets...", + ) if not cfg.is_preprocess: - LOG.warning( - "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset." + log_warning_rank_zero( + LOG, + "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset.", ) if cfg.seed: seed = cfg.seed else: - LOG.info("No seed provided, using default seed of 42") + log_info_rank_zero( + LOG, + "No seed provided, using default seed of 42", + ) seed = 42 datasets = [] @@ -331,15 +362,24 @@ def load_tokenized_prepared_datasets( if len(datasets) == 1: dataset = datasets[0] else: - LOG.info("merging datasets") + log_info_rank_zero( + LOG, + "Merging datasets...", + ) dataset = concatenate_datasets(datasets) if len(datasets) > 1: if cfg.shuffle_merged_datasets: - LOG.debug("shuffle merged datasets") + log_debug_rank_zero( + LOG, + "Shuffling merged datasets...", + ) dataset = dataset.shuffle(seed=seed) else: - LOG.debug("NOT shuffling merged datasets") + log_debug_rank_zero( + LOG, + "NOT shuffling merged datasets", + ) if not cfg.skip_prepare_dataset: dataset = drop_long_seq_in_dataset(dataset, cfg) @@ -348,7 +388,10 @@ def load_tokenized_prepared_datasets( dataset, _ = process_datasets_for_packing(cfg, dataset, None) if cfg.local_rank == 0 and not cfg.skip_prepare_dataset: - LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") + log_info_rank_zero( + LOG, + f"Saving merged prepared dataset to disk... {prepared_ds_path}", + ) if isinstance(dataset, IterableDataset): num_workers = cfg.dataset_processes @@ -482,7 +525,7 @@ def get_dataset_wrapper( } LOG.info( - f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}" + f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}", ) if ( diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index a8e19582e7..1e71853fbd 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -12,6 +12,7 @@ from datasets import Dataset, IterableDataset from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import log_warning_rank_zero from axolotl.utils.samplers.utils import get_dataset_lengths from axolotl.utils.trainer import drop_long_seq @@ -160,8 +161,9 @@ def deduplicate_and_log_datasets( def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): if "input_ids" not in dataset.column_names: - LOG.warning( - "Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling." + log_warning_rank_zero( + LOG, + "Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling.", ) return dataset @@ -174,9 +176,15 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): try: ds_lengths = get_dataset_lengths(dataset, from_arrow=True) min_input_len = np.min(ds_lengths) - LOG.info(f"min_input_len: {min_input_len}") + LOG.info( + LOG, + f"min_input_len: {min_input_len}", + ) max_input_len = np.max(ds_lengths) - LOG.info(f"max_input_len: {max_input_len}") + LOG.info( + LOG, + f"max_input_len: {max_input_len}", + ) except AttributeError: pass @@ -204,6 +212,9 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): if prior_len: dropped = prior_len - len(dataset) if dropped: - LOG.warning(f"Dropped {dropped} long samples from dataset") + LOG.warning( + LOG, + f"Dropped {dropped} long samples from dataset", + ) return dataset diff --git a/src/axolotl/utils/logging.py b/src/axolotl/utils/logging.py new file mode 100644 index 0000000000..e72439c74c --- /dev/null +++ b/src/axolotl/utils/logging.py @@ -0,0 +1,18 @@ +import logging + +from axolotl.utils.distributed import is_main_process + + +def log_info_rank_zero(log: logging.Logger, message: str): + if is_main_process(): + log.info(message) + + +def log_debug_rank_zero(log: logging.Logger, message: str): + if is_main_process(): + log.debug(message) + + +def log_warning_rank_zero(log: logging.Logger, message: str): + if is_main_process(): + log.warning(message) From 66162cbe9b4f3a432e86878218898d44f0f0d4d8 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 2 May 2025 09:36:19 -0400 Subject: [PATCH 03/29] refactor log rank zero funcs --- src/axolotl/utils/logging.py | 23 ++++++++++++----------- src/axolotl/utils/models.py | 11 +++++------ 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/axolotl/utils/logging.py b/src/axolotl/utils/logging.py index e72439c74c..d220efaf02 100644 --- a/src/axolotl/utils/logging.py +++ b/src/axolotl/utils/logging.py @@ -1,18 +1,19 @@ +""" +logging helpers to only log on main process +""" + import logging +from functools import partial from axolotl.utils.distributed import is_main_process -def log_info_rank_zero(log: logging.Logger, message: str): - if is_main_process(): - log.info(message) - - -def log_debug_rank_zero(log: logging.Logger, message: str): - if is_main_process(): - log.debug(message) +def log_rank_zero(log: logging.Logger, message: str, level: str = "info"): + if is_main_process(use_environ=True): + getattr(log, level.lower())(message) -def log_warning_rank_zero(log: logging.Logger, message: str): - if is_main_process(): - log.warning(message) +log_info_rank_zero = partial(log_rank_zero, level="info") +log_debug_rank_zero = partial(log_rank_zero, level="debug") +log_warning_rank_zero = partial(log_rank_zero, level="warning") +log_error_rank_zero = partial(log_rank_zero, level="error") diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index ba71ea4598..3c70b25b6a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -68,9 +68,9 @@ get_device_count, get_device_type, is_local_main_process, - is_main_process, ) from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper +from axolotl.utils.logging import log_debug_rank_zero from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant @@ -453,11 +453,10 @@ def load_tokenizer(cfg): {"additional_special_tokens": additional_special_tokens} ) - if is_main_process(use_environ=True): - LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") - LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") - LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") - LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") + log_debug_rank_zero(LOG, f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") + log_debug_rank_zero(LOG, f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") + log_debug_rank_zero(LOG, f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") + log_debug_rank_zero(LOG, f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") if cfg.chat_template: chat_template_string = get_chat_template_from_config( From 89d44dd36f20f07cf89968b720fb04baffa823a0 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 2 May 2025 10:13:23 -0400 Subject: [PATCH 04/29] use multi process logging adapter similar to accelerate --- src/axolotl/cli/train.py | 3 -- src/axolotl/train.py | 2 +- src/axolotl/utils/data/utils.py | 15 +++--- src/axolotl/utils/logging.py | 34 ++++++++++++ src/axolotl/utils/models.py | 92 +++++++++++++++++++++++---------- 5 files changed, 107 insertions(+), 39 deletions(-) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 4f258313d2..517f6b0661 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -1,7 +1,6 @@ """CLI to run training on a model.""" import gc -import logging import os from pathlib import Path from typing import Union @@ -22,8 +21,6 @@ from axolotl.utils.config import normalize_config, resolve_dtype from axolotl.utils.dict import DictDefault -LOG = logging.getLogger(__name__) - def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): """ diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 30d26b7063..ebbbb45040 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -12,7 +12,6 @@ import torch import transformers.modelcard -from accelerate.logging import get_logger from accelerate.utils import save_fsdp_model from datasets import Dataset from huggingface_hub.errors import OfflineModeIsEnabled @@ -33,6 +32,7 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.freeze import freeze_layers_except +from axolotl.utils.logging import get_logger from axolotl.utils.models import load_model, load_processor, load_tokenizer from axolotl.utils.trainer import setup_trainer diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index 1e71853fbd..7bc375e5a2 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -2,7 +2,6 @@ import functools import hashlib -import logging import time from enum import Enum @@ -12,11 +11,11 @@ from datasets import Dataset, IterableDataset from axolotl.utils.dict import DictDefault -from axolotl.utils.logging import log_warning_rank_zero +from axolotl.utils.logging import get_logger from axolotl.utils.samplers.utils import get_dataset_lengths from axolotl.utils.trainer import drop_long_seq -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) class RetryStrategy(Enum): @@ -161,9 +160,9 @@ def deduplicate_and_log_datasets( def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): if "input_ids" not in dataset.column_names: - log_warning_rank_zero( - LOG, + LOG.warning( "Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling.", + main_process_only=True, ) return dataset @@ -177,13 +176,13 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): ds_lengths = get_dataset_lengths(dataset, from_arrow=True) min_input_len = np.min(ds_lengths) LOG.info( - LOG, f"min_input_len: {min_input_len}", + main_process_only=True, ) max_input_len = np.max(ds_lengths) LOG.info( - LOG, f"max_input_len: {max_input_len}", + main_process_only=True, ) except AttributeError: pass @@ -213,8 +212,8 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): dropped = prior_len - len(dataset) if dropped: LOG.warning( - LOG, f"Dropped {dropped} long samples from dataset", + main_process_only=True, ) return dataset diff --git a/src/axolotl/utils/logging.py b/src/axolotl/utils/logging.py index d220efaf02..4a65af60b3 100644 --- a/src/axolotl/utils/logging.py +++ b/src/axolotl/utils/logging.py @@ -3,6 +3,7 @@ """ import logging +import os from functools import partial from axolotl.utils.distributed import is_main_process @@ -17,3 +18,36 @@ def log_rank_zero(log: logging.Logger, message: str, level: str = "info"): log_debug_rank_zero = partial(log_rank_zero, level="debug") log_warning_rank_zero = partial(log_rank_zero, level="warning") log_error_rank_zero = partial(log_rank_zero, level="error") + + +# Adapted from Accelerate +# https://github.com/huggingface/accelerate/blob/main/src/accelerate/logging.py +class MultiProcessAdapter(logging.LoggerAdapter): + """ + logger adapter for distributed logging, specifically to only log on main process + """ + + @staticmethod + def _should_log(main_process_only): + return not main_process_only or ( + main_process_only and is_main_process(use_environ=True) + ) + + def log(self, level, msg, *args, **kwargs): + main_process_only = kwargs.pop("main_process_only", True) + kwargs.setdefault("stacklevel", 2) + + if self.isEnabledFor(level): + if self._should_log(main_process_only): + msg, kwargs = self.process(msg, kwargs) + self.logger.log(level, msg, *args, **kwargs) + + +def get_logger(name: str, log_level: str | None = None): + if log_level is None: + log_level = os.environ.get("AXOLOTL_LOG_LEVEL", None) + logger = logging.getLogger(name) + if log_level is not None: + logger.setLevel(log_level.upper()) + logger.root.setLevel(log_level.upper()) + return MultiProcessAdapter(logger, {}) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 3c70b25b6a..9640c79796 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -3,7 +3,6 @@ # pylint: disable=too-many-lines import gc import importlib -import logging import math import os import types @@ -70,11 +69,11 @@ is_local_main_process, ) from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper -from axolotl.utils.logging import log_debug_rank_zero +from axolotl.utils.logging import get_logger from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) PLUGIN_MANAGER = PluginManager.get_instance() MULTIMODAL_AUTO_MODEL_MAPPING = { @@ -135,7 +134,10 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig): and hasattr(model_config.vision_config, "image_size") ): cfg.image_size = model_config.vision_config.image_size - LOG.debug(f"Loaded image size: {cfg.image_size} from model config") + LOG.debug( + f"Loaded image size: {cfg.image_size} from model config", + main_process_only=True, + ) quant_config_exists = ( hasattr(model_config, "quantization_config") @@ -152,7 +154,8 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig): if model_config.quantization_config.get("config_groups"): LOG.warning( "Found `config_groups` in a compressed-tensors config. " - "QAT integration with llmcompressor is not tested." + "QAT integration with llmcompressor is not tested.", + main_process_only=True, ) # Skip further quant checks for compressed-tensors return @@ -453,10 +456,15 @@ def load_tokenizer(cfg): {"additional_special_tokens": additional_special_tokens} ) - log_debug_rank_zero(LOG, f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") - log_debug_rank_zero(LOG, f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") - log_debug_rank_zero(LOG, f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") - log_debug_rank_zero(LOG, f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") + LOG.debug( + f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}", main_process_only=True + ) + LOG.debug( + f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}", main_process_only=True + ) + LOG.debug( + f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}", main_process_only=True + ) if cfg.chat_template: chat_template_string = get_chat_template_from_config( @@ -513,7 +521,10 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): elif im_height is not None: cfg.image_size = im_height - LOG.debug(f"Loaded image size: {cfg.image_size} from processor") + LOG.debug( + f"Loaded image size: {cfg.image_size} from processor", + main_process_only=True, + ) return processor @@ -740,14 +751,20 @@ def patch_llama_derived_model(self): if self.cfg.sample_packing: if self.cfg.device not in ["mps", "cpu"] and not self.inference: - LOG.info("patching with flash attention for sample packing") + LOG.info( + "patching with flash attention for sample packing", + main_process_only=True, + ) replace_llama_attn_with_flash_attn( packed=True, cross_entropy=self.cfg.flash_attn_cross_entropy, rms_norm=self.cfg.flash_attn_rms_norm, ) elif self.cfg.s2_attention: - LOG.info("patching w/ flash-enabled, shifted-sparse attention") + LOG.info( + "patching w/ flash-enabled, shifted-sparse attention", + main_process_only=True, + ) replace_llama_attn_with_flash_attn( packed=False, cross_entropy=self.cfg.flash_attn_cross_entropy, @@ -765,14 +782,17 @@ def patch_llama_derived_model(self): hijack_llama_attention, ) - LOG.info("patching with xformers attention") + LOG.info("patching with xformers attention", main_process_only=True) hijack_llama_attention() elif self.cfg.sample_packing: from axolotl.monkeypatch.llama_patch_multipack import ( hijack_llama_prepare_4d_mask, ) - LOG.info("patching llama _prepare_4d_causal_attention_mask*") + LOG.info( + "patching llama _prepare_4d_causal_attention_mask*", + main_process_only=True, + ) hijack_llama_prepare_4d_mask() elif self.cfg.s2_attention: raise NotImplementedError( @@ -854,7 +874,8 @@ def set_quantization_config(self) -> None: if self.cfg.gptq: if not hasattr(self.model_config, "quantization_config"): LOG.warning( - "model config does not contain quantization_config information" + "model config does not contain quantization_config information", + main_process_only=True, ) else: if self.cfg.gptq_disable_exllama is not None: @@ -1066,11 +1087,11 @@ def _configure_zero3_memory_efficient_loading(): ) if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): - LOG.info("patching with SwiGLU") + LOG.info("patching with SwiGLU", main_process_only=True) replace_llama_mlp_with_swiglu(self.model) if self.cfg.flash_attn_fuse_qkv: - LOG.info("patching with fused QKV") + LOG.info("patching with fused QKV", main_process_only=True) replace_llama_qkv_with_fused(self.model) elif self.model_type == "MambaLMHeadModel": # FIXME this is janky at best and hacked together to make it work @@ -1143,7 +1164,8 @@ def adjust_model_config(self) -> None: and self.cfg.sequence_len > self.model.config.max_position_embeddings ): LOG.warning( - f"increasing model.config.max_position_embeddings from {self.model.config.max_position_embeddings} to {self.cfg.sequence_len}" + f"increasing model.config.max_position_embeddings from {self.model.config.max_position_embeddings} to {self.cfg.sequence_len}", + main_process_only=True, ) self.model.config.max_position_embeddings = self.cfg.sequence_len @@ -1207,7 +1229,10 @@ def prepare_model(self, qlora_fsdp) -> None: and self.cfg.adapter in ["lora", "qlora"] and (self.cfg.load_in_8bit or self.cfg.load_in_4bit) ): - LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") + LOG.info( + "converting PEFT model w/ prepare_model_for_kbit_training", + main_process_only=True, + ) self.model = prepare_model_for_kbit_training( self.model, use_gradient_checkpointing=self.cfg.gradient_checkpointing ) @@ -1269,7 +1294,7 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: skip_move_to_device = self.build_model(qlora_fsdp) PLUGIN_MANAGER.post_model_build(self.cfg, self.model) except Exception as err: # pylint: disable=broad-exception-caught - LOG.exception(err) + LOG.exception(err, main_process_only=True) raise err if isinstance(self.model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp: @@ -1340,7 +1365,9 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: ) if should_convert: - LOG.info("Converting modules to %s", self.cfg.torch_dtype) + LOG.info( + "Converting modules to %s", self.cfg.torch_dtype, main_process_only=True + ) self.convert_embedding_modules_dtype( embedding_modules=embedding_modules, dist_dtype=self.cfg.torch_dtype, @@ -1393,7 +1420,10 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: if param.requires_grad: requires_grad.append(f"{name}: {param.requires_grad}") if len(requires_grad) == 0: - LOG.warning("there are no parameters that require gradient updates") + LOG.warning( + "there are no parameters that require gradient updates", + main_process_only=True, + ) if self.cfg.flash_optimum: from optimum.bettertransformer import BetterTransformer @@ -1467,7 +1497,7 @@ def load_llama_adapter(model, cfg): ) if cfg.lora_model_dir: - LOG.debug("Loading pretrained PEFT - llama_adapter") + LOG.debug("Loading pretrained PEFT - llama_adapter", main_process_only=True) model = PeftModel.from_pretrained( model, cfg.lora_model_dir, @@ -1534,7 +1564,10 @@ def load_lora(model, cfg, inference=False, config_only=False): if cfg.lora_target_linear: linear_names = find_all_linear_names(model) - LOG.info(f"found linear modules: {repr(sorted(linear_names))}") + LOG.info( + f"found linear modules: {repr(sorted(linear_names))}", + main_process_only=True, + ) lora_target_modules_as_list = ( lora_target_modules if isinstance(lora_target_modules, list) @@ -1551,7 +1584,10 @@ def load_lora(model, cfg, inference=False, config_only=False): lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights if cfg.peft_use_dora: lora_config_kwargs["use_dora"] = cfg.peft_use_dora - LOG.info("Initializing LoRA weights using dora. This might take longer.") + LOG.info( + "Initializing LoRA weights using dora. This might take longer.", + main_process_only=True, + ) if cfg.peft_use_rslora: lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora if cfg.peft_layer_replication: @@ -1585,7 +1621,7 @@ def load_lora(model, cfg, inference=False, config_only=False): setup_quantized_meta_for_peft(model) if cfg.lora_model_dir: - LOG.debug("Loading pretrained PEFT - LoRA") + LOG.debug("Loading pretrained PEFT - LoRA", main_process_only=True) model_kwargs: Any = {} if cfg.lora_on_cpu: model_kwargs["max_memory"] = {"cpu": "256GiB"} @@ -1604,7 +1640,9 @@ def load_lora(model, cfg, inference=False, config_only=False): model.print_trainable_parameters() except AttributeError as exc: LOG.warning( - "Exception caught during model.print_trainable_parameters(): %s", exc + "Exception caught during model.print_trainable_parameters(): %s", + exc, + main_process_only=True, ) elif ( cfg.fsdp From 39c26902ca3e0ade92a658d50b754ee8c36fb367 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 1 May 2025 14:16:05 +0100 Subject: [PATCH 05/29] quick start --- src/axolotl/utils/data/sft.py | 83 +++++++++++++++++++++++++-------- src/axolotl/utils/data/utils.py | 21 +++++++-- src/axolotl/utils/logging.py | 18 +++++++ 3 files changed, 97 insertions(+), 25 deletions(-) create mode 100644 src/axolotl/utils/logging.py diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 5fa0cb60d6..7d7f114dae 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -54,6 +54,11 @@ ) from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_local_main_process, zero_first +from axolotl.utils.logging import ( + log_debug_rank_zero, + log_info_rank_zero, + log_warning_rank_zero, +) from axolotl.utils.trainer import ( calculate_total_num_steps, process_datasets_for_packing, @@ -167,7 +172,10 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None): ) if cfg.dataset_exact_deduplication: - LOG.info("Deduplication not available for pretrained datasets") + log_info_rank_zero( + LOG, + "Deduplication not available for pretrained datasets", + ) return train_dataset, eval_dataset, cfg.max_steps, prompters @@ -182,10 +190,12 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None): total_num_steps = min( calculate_total_num_steps(cfg, train_dataset), cfg.max_steps ) - LOG.info(f"Maximum number of steps set at {total_num_steps}") else: total_num_steps = calculate_total_num_steps(cfg, train_dataset) - + log_info_rank_zero( + LOG, + f"Maximum number of steps set at {total_num_steps}", + ) return train_dataset, eval_dataset, total_num_steps, prompters @@ -235,8 +245,9 @@ def load_tokenized_prepared_datasets( use_auth_token = cfg.hf_use_auth_token try: if cfg.push_dataset_to_hub: - LOG.info( - f"Attempting to load prepared dataset from Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..." + log_info_rank_zero( + LOG, + f"Attempting to load prepared dataset from Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})...", ) dataset = load_dataset( cfg.push_dataset_to_hub, @@ -257,28 +268,48 @@ def load_tokenized_prepared_datasets( and not cfg.is_preprocess and not cfg.skip_prepare_dataset ): - LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") + log_info_rank_zero( + LOG, + f"Loading prepared dataset from disk at {prepared_ds_path}...", + ) dataset = load_from_disk(str(prepared_ds_path)) - LOG.info("Prepared dataset loaded from disk...") + log_info_rank_zero( + LOG, + "Prepared dataset loaded from disk...", + ) else: if cfg.push_dataset_to_hub: - LOG.info("Unable to find prepared dataset in Huggingface hub") + log_info_rank_zero( + LOG, + "Unable to find prepared dataset in Huggingface hub", + ) if cfg.is_preprocess: - LOG.info( - f"Skipping prepared dataset in {prepared_ds_path} for pre-processing..." + log_info_rank_zero( + LOG, + f"Skipping prepared dataset in {prepared_ds_path} for pre-processing...", ) else: - LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}") - LOG.info("Loading raw datasets...") + log_info_rank_zero( + LOG, + f"Unable to find prepared dataset in {prepared_ds_path}", + ) + log_info_rank_zero( + LOG, + "Loading raw datasets...", + ) if not cfg.is_preprocess: - LOG.warning( - "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset." + log_warning_rank_zero( + LOG, + "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset.", ) if cfg.seed: seed = cfg.seed else: - LOG.info("No seed provided, using default seed of 42") + log_info_rank_zero( + LOG, + "No seed provided, using default seed of 42", + ) seed = 42 datasets = [] @@ -331,15 +362,24 @@ def load_tokenized_prepared_datasets( if len(datasets) == 1: dataset = datasets[0] else: - LOG.info("merging datasets") + log_info_rank_zero( + LOG, + "Merging datasets...", + ) dataset = concatenate_datasets(datasets) if len(datasets) > 1: if cfg.shuffle_merged_datasets: - LOG.debug("shuffle merged datasets") + log_debug_rank_zero( + LOG, + "Shuffling merged datasets...", + ) dataset = dataset.shuffle(seed=seed) else: - LOG.debug("NOT shuffling merged datasets") + log_debug_rank_zero( + LOG, + "NOT shuffling merged datasets", + ) if not cfg.skip_prepare_dataset: dataset = drop_long_seq_in_dataset(dataset, cfg) @@ -348,7 +388,10 @@ def load_tokenized_prepared_datasets( dataset, _ = process_datasets_for_packing(cfg, dataset, None) if cfg.local_rank == 0 and not cfg.skip_prepare_dataset: - LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") + log_info_rank_zero( + LOG, + f"Saving merged prepared dataset to disk... {prepared_ds_path}", + ) if isinstance(dataset, IterableDataset): num_workers = cfg.dataset_processes @@ -484,7 +527,7 @@ def get_dataset_wrapper( } LOG.info( - f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}" + f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}", ) if ( diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index a8e19582e7..1e71853fbd 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -12,6 +12,7 @@ from datasets import Dataset, IterableDataset from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import log_warning_rank_zero from axolotl.utils.samplers.utils import get_dataset_lengths from axolotl.utils.trainer import drop_long_seq @@ -160,8 +161,9 @@ def deduplicate_and_log_datasets( def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): if "input_ids" not in dataset.column_names: - LOG.warning( - "Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling." + log_warning_rank_zero( + LOG, + "Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling.", ) return dataset @@ -174,9 +176,15 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): try: ds_lengths = get_dataset_lengths(dataset, from_arrow=True) min_input_len = np.min(ds_lengths) - LOG.info(f"min_input_len: {min_input_len}") + LOG.info( + LOG, + f"min_input_len: {min_input_len}", + ) max_input_len = np.max(ds_lengths) - LOG.info(f"max_input_len: {max_input_len}") + LOG.info( + LOG, + f"max_input_len: {max_input_len}", + ) except AttributeError: pass @@ -204,6 +212,9 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): if prior_len: dropped = prior_len - len(dataset) if dropped: - LOG.warning(f"Dropped {dropped} long samples from dataset") + LOG.warning( + LOG, + f"Dropped {dropped} long samples from dataset", + ) return dataset diff --git a/src/axolotl/utils/logging.py b/src/axolotl/utils/logging.py new file mode 100644 index 0000000000..e72439c74c --- /dev/null +++ b/src/axolotl/utils/logging.py @@ -0,0 +1,18 @@ +import logging + +from axolotl.utils.distributed import is_main_process + + +def log_info_rank_zero(log: logging.Logger, message: str): + if is_main_process(): + log.info(message) + + +def log_debug_rank_zero(log: logging.Logger, message: str): + if is_main_process(): + log.debug(message) + + +def log_warning_rank_zero(log: logging.Logger, message: str): + if is_main_process(): + log.warning(message) From adb78c710b881490cec29cb2bb3da0875341f32a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 2 May 2025 09:36:19 -0400 Subject: [PATCH 06/29] refactor log rank zero funcs --- src/axolotl/utils/logging.py | 23 ++++++++++++----------- src/axolotl/utils/models.py | 11 +++++------ 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/axolotl/utils/logging.py b/src/axolotl/utils/logging.py index e72439c74c..d220efaf02 100644 --- a/src/axolotl/utils/logging.py +++ b/src/axolotl/utils/logging.py @@ -1,18 +1,19 @@ +""" +logging helpers to only log on main process +""" + import logging +from functools import partial from axolotl.utils.distributed import is_main_process -def log_info_rank_zero(log: logging.Logger, message: str): - if is_main_process(): - log.info(message) - - -def log_debug_rank_zero(log: logging.Logger, message: str): - if is_main_process(): - log.debug(message) +def log_rank_zero(log: logging.Logger, message: str, level: str = "info"): + if is_main_process(use_environ=True): + getattr(log, level.lower())(message) -def log_warning_rank_zero(log: logging.Logger, message: str): - if is_main_process(): - log.warning(message) +log_info_rank_zero = partial(log_rank_zero, level="info") +log_debug_rank_zero = partial(log_rank_zero, level="debug") +log_warning_rank_zero = partial(log_rank_zero, level="warning") +log_error_rank_zero = partial(log_rank_zero, level="error") diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 316fbec8ce..776f25b351 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -68,12 +68,12 @@ get_device_count, get_device_type, is_local_main_process, - is_main_process, ) from axolotl.utils.gradient_checkpointing import ( hf_grad_checkpoint_disk_offload_wrapper, hf_grad_checkpoint_offload_wrapper, ) +from axolotl.utils.logging import log_debug_rank_zero from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant from axolotl.utils.schemas.enums import RLType @@ -457,11 +457,10 @@ def load_tokenizer(cfg): {"additional_special_tokens": additional_special_tokens} ) - if is_main_process(use_environ=True): - LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") - LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") - LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") - LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") + log_debug_rank_zero(LOG, f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") + log_debug_rank_zero(LOG, f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") + log_debug_rank_zero(LOG, f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") + log_debug_rank_zero(LOG, f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") if cfg.chat_template: chat_template_string = get_chat_template_from_config( From aa97c92a5875e59304a11567c5dea41f48464dfb Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 2 May 2025 10:13:23 -0400 Subject: [PATCH 07/29] use multi process logging adapter similar to accelerate --- src/axolotl/cli/train.py | 3 -- src/axolotl/train.py | 1 + src/axolotl/utils/data/utils.py | 15 +++--- src/axolotl/utils/logging.py | 34 ++++++++++++ src/axolotl/utils/models.py | 92 +++++++++++++++++++++++---------- 5 files changed, 107 insertions(+), 38 deletions(-) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 777d848853..fef80fdbaf 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -1,7 +1,6 @@ """CLI to run training on a model.""" import gc -import logging import os from pathlib import Path from typing import Union @@ -22,8 +21,6 @@ from axolotl.utils.config import normalize_config, resolve_dtype from axolotl.utils.dict import DictDefault -LOG = logging.getLogger(__name__) - def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): """ diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 90ab10e9f9..9d80fde1b5 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -32,6 +32,7 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.freeze import freeze_layers_except +from axolotl.utils.logging import get_logger from axolotl.utils.models import load_model, load_processor, load_tokenizer from axolotl.utils.schemas.enums import RLType from axolotl.utils.trainer import setup_trainer diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index 1e71853fbd..7bc375e5a2 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -2,7 +2,6 @@ import functools import hashlib -import logging import time from enum import Enum @@ -12,11 +11,11 @@ from datasets import Dataset, IterableDataset from axolotl.utils.dict import DictDefault -from axolotl.utils.logging import log_warning_rank_zero +from axolotl.utils.logging import get_logger from axolotl.utils.samplers.utils import get_dataset_lengths from axolotl.utils.trainer import drop_long_seq -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) class RetryStrategy(Enum): @@ -161,9 +160,9 @@ def deduplicate_and_log_datasets( def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): if "input_ids" not in dataset.column_names: - log_warning_rank_zero( - LOG, + LOG.warning( "Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling.", + main_process_only=True, ) return dataset @@ -177,13 +176,13 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): ds_lengths = get_dataset_lengths(dataset, from_arrow=True) min_input_len = np.min(ds_lengths) LOG.info( - LOG, f"min_input_len: {min_input_len}", + main_process_only=True, ) max_input_len = np.max(ds_lengths) LOG.info( - LOG, f"max_input_len: {max_input_len}", + main_process_only=True, ) except AttributeError: pass @@ -213,8 +212,8 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): dropped = prior_len - len(dataset) if dropped: LOG.warning( - LOG, f"Dropped {dropped} long samples from dataset", + main_process_only=True, ) return dataset diff --git a/src/axolotl/utils/logging.py b/src/axolotl/utils/logging.py index d220efaf02..4a65af60b3 100644 --- a/src/axolotl/utils/logging.py +++ b/src/axolotl/utils/logging.py @@ -3,6 +3,7 @@ """ import logging +import os from functools import partial from axolotl.utils.distributed import is_main_process @@ -17,3 +18,36 @@ def log_rank_zero(log: logging.Logger, message: str, level: str = "info"): log_debug_rank_zero = partial(log_rank_zero, level="debug") log_warning_rank_zero = partial(log_rank_zero, level="warning") log_error_rank_zero = partial(log_rank_zero, level="error") + + +# Adapted from Accelerate +# https://github.com/huggingface/accelerate/blob/main/src/accelerate/logging.py +class MultiProcessAdapter(logging.LoggerAdapter): + """ + logger adapter for distributed logging, specifically to only log on main process + """ + + @staticmethod + def _should_log(main_process_only): + return not main_process_only or ( + main_process_only and is_main_process(use_environ=True) + ) + + def log(self, level, msg, *args, **kwargs): + main_process_only = kwargs.pop("main_process_only", True) + kwargs.setdefault("stacklevel", 2) + + if self.isEnabledFor(level): + if self._should_log(main_process_only): + msg, kwargs = self.process(msg, kwargs) + self.logger.log(level, msg, *args, **kwargs) + + +def get_logger(name: str, log_level: str | None = None): + if log_level is None: + log_level = os.environ.get("AXOLOTL_LOG_LEVEL", None) + logger = logging.getLogger(name) + if log_level is not None: + logger.setLevel(log_level.upper()) + logger.root.setLevel(log_level.upper()) + return MultiProcessAdapter(logger, {}) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 776f25b351..3890e3f834 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -3,7 +3,6 @@ # pylint: disable=too-many-lines import gc import importlib -import logging import math import os import types @@ -73,12 +72,12 @@ hf_grad_checkpoint_disk_offload_wrapper, hf_grad_checkpoint_offload_wrapper, ) -from axolotl.utils.logging import log_debug_rank_zero +from axolotl.utils.logging import get_logger from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant from axolotl.utils.schemas.enums import RLType -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) PLUGIN_MANAGER = PluginManager.get_instance() MULTIMODAL_AUTO_MODEL_MAPPING = { @@ -139,7 +138,10 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig): and hasattr(model_config.vision_config, "image_size") ): cfg.image_size = model_config.vision_config.image_size - LOG.debug(f"Loaded image size: {cfg.image_size} from model config") + LOG.debug( + f"Loaded image size: {cfg.image_size} from model config", + main_process_only=True, + ) quant_config_exists = ( hasattr(model_config, "quantization_config") @@ -156,7 +158,8 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig): if model_config.quantization_config.get("config_groups"): LOG.warning( "Found `config_groups` in a compressed-tensors config. " - "QAT integration with llmcompressor is not tested." + "QAT integration with llmcompressor is not tested.", + main_process_only=True, ) # Skip further quant checks for compressed-tensors return @@ -457,10 +460,15 @@ def load_tokenizer(cfg): {"additional_special_tokens": additional_special_tokens} ) - log_debug_rank_zero(LOG, f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") - log_debug_rank_zero(LOG, f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") - log_debug_rank_zero(LOG, f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") - log_debug_rank_zero(LOG, f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") + LOG.debug( + f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}", main_process_only=True + ) + LOG.debug( + f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}", main_process_only=True + ) + LOG.debug( + f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}", main_process_only=True + ) if cfg.chat_template: chat_template_string = get_chat_template_from_config( @@ -517,7 +525,10 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): elif im_height is not None: cfg.image_size = im_height - LOG.debug(f"Loaded image size: {cfg.image_size} from processor") + LOG.debug( + f"Loaded image size: {cfg.image_size} from processor", + main_process_only=True, + ) return processor @@ -758,14 +769,20 @@ def patch_llama_derived_model(self): if self.cfg.sample_packing: if self.cfg.device not in ["mps", "cpu"] and not self.inference: - LOG.info("patching with flash attention for sample packing") + LOG.info( + "patching with flash attention for sample packing", + main_process_only=True, + ) replace_llama_attn_with_flash_attn( packed=True, cross_entropy=self.cfg.flash_attn_cross_entropy, rms_norm=self.cfg.flash_attn_rms_norm, ) elif self.cfg.s2_attention: - LOG.info("patching w/ flash-enabled, shifted-sparse attention") + LOG.info( + "patching w/ flash-enabled, shifted-sparse attention", + main_process_only=True, + ) replace_llama_attn_with_flash_attn( packed=False, cross_entropy=self.cfg.flash_attn_cross_entropy, @@ -783,14 +800,17 @@ def patch_llama_derived_model(self): hijack_llama_attention, ) - LOG.info("patching with xformers attention") + LOG.info("patching with xformers attention", main_process_only=True) hijack_llama_attention() elif self.cfg.sample_packing: from axolotl.monkeypatch.llama_patch_multipack import ( hijack_llama_prepare_4d_mask, ) - LOG.info("patching llama _prepare_4d_causal_attention_mask*") + LOG.info( + "patching llama _prepare_4d_causal_attention_mask*", + main_process_only=True, + ) hijack_llama_prepare_4d_mask() elif self.cfg.s2_attention: raise NotImplementedError( @@ -872,7 +892,8 @@ def set_quantization_config(self) -> None: if self.cfg.gptq: if not hasattr(self.model_config, "quantization_config"): LOG.warning( - "model config does not contain quantization_config information" + "model config does not contain quantization_config information", + main_process_only=True, ) else: if self.cfg.gptq_disable_exllama is not None: @@ -1084,11 +1105,11 @@ def _configure_zero3_memory_efficient_loading(): ) if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): - LOG.info("patching with SwiGLU") + LOG.info("patching with SwiGLU", main_process_only=True) replace_llama_mlp_with_swiglu(self.model) if self.cfg.flash_attn_fuse_qkv: - LOG.info("patching with fused QKV") + LOG.info("patching with fused QKV", main_process_only=True) replace_llama_qkv_with_fused(self.model) elif self.model_type == "MambaLMHeadModel": # FIXME this is janky at best and hacked together to make it work @@ -1161,7 +1182,8 @@ def adjust_model_config(self) -> None: and self.cfg.sequence_len > self.model.config.max_position_embeddings ): LOG.warning( - f"increasing model.config.max_position_embeddings from {self.model.config.max_position_embeddings} to {self.cfg.sequence_len}" + f"increasing model.config.max_position_embeddings from {self.model.config.max_position_embeddings} to {self.cfg.sequence_len}", + main_process_only=True, ) self.model.config.max_position_embeddings = self.cfg.sequence_len @@ -1225,7 +1247,10 @@ def prepare_model(self, qlora_fsdp: bool) -> None: and self.cfg.adapter in ["lora", "qlora"] and (self.cfg.load_in_8bit or self.cfg.load_in_4bit) ): - LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") + LOG.info( + "converting PEFT model w/ prepare_model_for_kbit_training", + main_process_only=True, + ) self.model = prepare_model_for_kbit_training( self.model, use_gradient_checkpointing=self.cfg.gradient_checkpointing ) @@ -1287,7 +1312,7 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: skip_move_to_device = self.build_model(qlora_fsdp) PLUGIN_MANAGER.post_model_build(self.cfg, self.model) except Exception as err: # pylint: disable=broad-exception-caught - LOG.exception(err) + LOG.exception(err, main_process_only=True) raise err if isinstance(self.model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp: @@ -1361,7 +1386,9 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: ) if should_convert: - LOG.info("Converting modules to %s", self.cfg.torch_dtype) + LOG.info( + "Converting modules to %s", self.cfg.torch_dtype, main_process_only=True + ) self.convert_embedding_modules_dtype( embedding_modules=embedding_modules, dist_dtype=self.cfg.torch_dtype, @@ -1414,7 +1441,10 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: if param.requires_grad: requires_grad.append(f"{name}: {param.requires_grad}") if len(requires_grad) == 0: - LOG.warning("there are no parameters that require gradient updates") + LOG.warning( + "there are no parameters that require gradient updates", + main_process_only=True, + ) if self.cfg.flash_optimum: from optimum.bettertransformer import BetterTransformer @@ -1488,7 +1518,7 @@ def load_llama_adapter(model, cfg): ) if cfg.lora_model_dir: - LOG.debug("Loading pretrained PEFT - llama_adapter") + LOG.debug("Loading pretrained PEFT - llama_adapter", main_process_only=True) model = PeftModel.from_pretrained( model, cfg.lora_model_dir, @@ -1555,7 +1585,10 @@ def load_lora(model, cfg, inference=False, config_only=False): if cfg.lora_target_linear: linear_names = find_all_linear_names(model) - LOG.info(f"found linear modules: {repr(sorted(linear_names))}") + LOG.info( + f"found linear modules: {repr(sorted(linear_names))}", + main_process_only=True, + ) lora_target_modules_as_list = ( lora_target_modules if isinstance(lora_target_modules, list) @@ -1572,7 +1605,10 @@ def load_lora(model, cfg, inference=False, config_only=False): lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights if cfg.peft_use_dora: lora_config_kwargs["use_dora"] = cfg.peft_use_dora - LOG.info("Initializing LoRA weights using dora. This might take longer.") + LOG.info( + "Initializing LoRA weights using dora. This might take longer.", + main_process_only=True, + ) if cfg.peft_use_rslora: lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora if cfg.peft_layer_replication: @@ -1606,7 +1642,7 @@ def load_lora(model, cfg, inference=False, config_only=False): setup_quantized_meta_for_peft(model) if cfg.lora_model_dir: - LOG.debug("Loading pretrained PEFT - LoRA") + LOG.debug("Loading pretrained PEFT - LoRA", main_process_only=True) model_kwargs: Any = {} if cfg.lora_on_cpu: model_kwargs["max_memory"] = {"cpu": "256GiB"} @@ -1625,7 +1661,9 @@ def load_lora(model, cfg, inference=False, config_only=False): model.print_trainable_parameters() except AttributeError as exc: LOG.warning( - "Exception caught during model.print_trainable_parameters(): %s", exc + "Exception caught during model.print_trainable_parameters(): %s", + exc, + main_process_only=True, ) elif ( cfg.fsdp From 014f4992df6a32e80b5e2ee79afc24d7e6fa1f50 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Fri, 16 May 2025 17:20:42 +0100 Subject: [PATCH 08/29] replacing more calls - testing on dist setup --- src/axolotl/common/datasets.py | 12 ------- src/axolotl/core/trainers/grpo/__init__.py | 14 +++----- .../core/trainers/mixins/sequence_parallel.py | 7 +--- src/axolotl/utils/data/rl.py | 23 ------------- src/axolotl/utils/data/sft.py | 19 ----------- src/axolotl/utils/logging.py | 18 ++-------- src/axolotl/utils/samplers/multipack.py | 34 ------------------- src/axolotl/utils/schemas/config.py | 15 -------- .../prompt_strategies/test_chat_templates.py | 6 ++-- .../test_chat_templates_advanced.py | 5 ++- .../test_chat_templates_thinking.py | 6 ++-- 11 files changed, 14 insertions(+), 145 deletions(-) diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index fe0c37eb1a..c6933aaf18 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -80,24 +80,12 @@ def load_datasets( preprocess_iterable=preprocess_iterable, ) -<<<<<<< Updated upstream - if ( # pylint: disable=too-many-boolean-expressions - cli_args - and ( - cli_args.debug - or cfg.debug - or cli_args.debug_text_only - or int(cli_args.debug_num_examples) > 0 - ) - ) or debug: -======= if cli_args and ( cli_args.debug or cfg.debug or cli_args.debug_text_only or int(cli_args.debug_num_examples) > 0 ): ->>>>>>> Stashed changes LOG.info("check_dataset_labels...") num_examples = cli_args.debug_num_examples if cli_args else 1 diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 73c96543ae..d5c39f8f42 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -2,12 +2,8 @@ import importlib import inspect -<<<<<<< Updated upstream -import logging from typing import Any -======= from axolotl.utils.logging import get_logger ->>>>>>> Stashed changes from trl.trainer.grpo_trainer import RewardFunc @@ -19,11 +15,7 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.schemas.trl import TRLConfig -<<<<<<< Updated upstream -LOG = logging.getLogger(__name__) -======= LOG = get_logger(__name__) ->>>>>>> Stashed changes class GRPOStrategy: @@ -52,8 +44,10 @@ def set_training_args_kwargs(cls, cfg: DictDefault) -> dict[str, Any]: if trl.use_vllm: grpo_args_kwargs["use_vllm"] = trl.use_vllm - grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host # type: ignore[attr-defined] - grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port # type: ignore[attr-defined] + # type: ignore[attr-defined] + grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host + # type: ignore[attr-defined] + grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port if trl.vllm_server_timeout: grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout if trl.vllm_guided_decoding_regex: diff --git a/src/axolotl/core/trainers/mixins/sequence_parallel.py b/src/axolotl/core/trainers/mixins/sequence_parallel.py index 0f2a8859c4..0e63f7bfc2 100644 --- a/src/axolotl/core/trainers/mixins/sequence_parallel.py +++ b/src/axolotl/core/trainers/mixins/sequence_parallel.py @@ -1,22 +1,18 @@ """Module for Axolotl trainer sequence parallelism mixin""" -<<<<<<< Updated upstream -======= import functools from axolotl.utils.logging import get_logger import torch ->>>>>>> Stashed changes import torch.distributed as dist from datasets import Dataset from torch.utils.data import DistributedSampler, Sampler +from axolotl.utils.schemas.enums import RingAttnFunc from axolotl.monkeypatch.attention.ring_attn import ( get_ring_attn_group, ) -<<<<<<< Updated upstream -======= LOG = get_logger(__name__) @@ -79,7 +75,6 @@ def apply_sequence_parallelism( return batch ->>>>>>> Stashed changes class SequenceParallelMixin: """ diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index cff6437089..96d973c154 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -20,11 +20,7 @@ from axolotl.utils.models import load_tokenizer from axolotl.utils.schemas.enums import RLType -<<<<<<< Updated upstream -LOG = logging.getLogger(__name__) -======= LOG = get_logger(__name__) ->>>>>>> Stashed changes def _get_path(ds_hash, cfg): @@ -214,24 +210,6 @@ def load_split(dataset_cfgs, _cfg): # ensure we end up with the same fingerprint by doing rank0 first and being able to cache to_hash_train = ( -<<<<<<< Updated upstream - train_dataset._fingerprint # pylint: disable=protected-access - + "|" - + str(cfg.val_set_size) - + "|" - + "train" - + "|" - + str(seed) - ) - to_hash_test = ( - train_dataset._fingerprint # pylint: disable=protected-access - + "|" - + str(cfg.val_set_size) - + "|" - + "test" - + "|" - + str(seed) -======= train_dataset._fingerprint + # pylint: disable=protected-access "|" + str(cfg.val_set_size) + @@ -248,7 +226,6 @@ def load_split(dataset_cfgs, _cfg): "test" + "|" + str(cfg.seed or 42) ->>>>>>> Stashed changes ) train_fingerprint = md5(to_hash_train) test_fingerprint = md5(to_hash_test) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index f20a4f1f1d..b835b1f1c4 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -463,24 +463,6 @@ def load_prepare_datasets( # ensure we end up with the same fingerprint by doing rank0 first and being able to cache to_hash_train = ( -<<<<<<< Updated upstream - dataset._fingerprint # pylint: disable=protected-access - + "|" - + str(val_set_size) - + "|" - + "train" - + "|" - + str(seed) - ) - to_hash_test = ( - dataset._fingerprint # pylint: disable=protected-access - + "|" - + str(val_set_size) - + "|" - + "test" - + "|" - + str(seed) -======= dataset._fingerprint + # pylint: disable=protected-access "|" + str(val_set_size) + @@ -497,7 +479,6 @@ def load_prepare_datasets( "test" + "|" + str(cfg.seed or 42) ->>>>>>> Stashed changes ) train_fingerprint = md5(to_hash_train) test_fingerprint = md5(to_hash_test) diff --git a/src/axolotl/utils/logging.py b/src/axolotl/utils/logging.py index 58234547d5..55d616712e 100644 --- a/src/axolotl/utils/logging.py +++ b/src/axolotl/utils/logging.py @@ -2,23 +2,9 @@ logging helpers to only log on main process """ -from axolotl.utils.logging import get_logger import os -from functools import partial - from axolotl.utils.distributed import is_main_process - - -def log_rank_zero(log: logging.Logger, message: str, level: str = "info"): - if is_main_process(use_environ=True): - getattr(log, level.lower())(message) - - -log_info_rank_zero = partial(log_rank_zero, level="info") -log_debug_rank_zero = partial(log_rank_zero, level="debug") -log_warning_rank_zero = partial(log_rank_zero, level="warning") -log_error_rank_zero = partial(log_rank_zero, level="error") - +import logging # Adapted from Accelerate # https://github.com/huggingface/accelerate/blob/main/src/accelerate/logging.py @@ -46,7 +32,7 @@ def log(self, level, msg, *args, **kwargs): def get_logger(name: str, log_level: str | None = None): if log_level is None: log_level = os.environ.get("AXOLOTL_LOG_LEVEL", None) - logger = get_logger(name) + logger = logging.getLogger(name) if log_level is not None: logger.setLevel(log_level.upper()) logger.root.setLevel(log_level.upper()) diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 833940ad51..88b1ad9fa5 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -2,12 +2,8 @@ Multipack Batch Sampler - An efficient batch sampler for packing variable-length sequences into fixed-capacity batches to optimize memory usage and training throughput. """ -<<<<<<< Updated upstream - import logging -======= from axolotl.utils.logging import get_logger ->>>>>>> Stashed changes import math from concurrent.futures import ProcessPoolExecutor from multiprocessing import cpu_count, get_context @@ -19,14 +15,8 @@ from axolotl.utils.distributed import reduce_and_broadcast -<<<<<<< Updated upstream -LOG = logging.getLogger(__name__) -======= LOG = get_logger(__name__) ->>>>>>> Stashed changes -LOG.setLevel(logging.INFO) - @numba.njit def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int): @@ -164,23 +154,6 @@ def pack_parallel( max_bins = len(group_lengths) # Allow as many bins as items in the group tasks.append((group_lengths, i, bin_capacity, max_bins, bin_size, safe_mode)) -<<<<<<< Updated upstream - # Process groups in parallel - all_bins = [] - - mp_ctx = None - if mp_start_method: - try: - mp_ctx = get_context(mp_start_method) - except ValueError: - LOG.warning( - f"Failed to get multiprocessing context '{mp_start_method}'. " - f"Falling back to default. Available: {get_context().get_all_start_methods()}" - ) - mp_ctx = ( - None # Fallback to default context if specified one is not available - ) -======= while right - left > 1: mid = (left + right) // 2 if ffd_check(lengths[start_index: start_index + mid], c, n): @@ -195,7 +168,6 @@ def pack_parallel( assert len(batch) <= n if len(batch) < n: break ->>>>>>> Stashed changes if num_processes == 1: LOG.debug("Using single process for pack_parallel, running sequentially.") @@ -343,18 +315,12 @@ def generate_batches(self, set_stats=False): if self._batches is not None: return self._batches -<<<<<<< Updated upstream - # Get indices from the sampler - indices = [ # pylint: disable=unnecessary-comprehension - idx for idx in self.sampler -======= batches = [ [ [indices[b_idx] for b_idx in batch] for batch in batches[i: i + self.batch_size] ] for i in range(0, len(batches), self.batch_size) ->>>>>>> Stashed changes ] # Get lengths of the selected sequences diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 683a900d9e..b1002dd6dd 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -459,18 +459,10 @@ def check_chat_template_config(cls, data): @classmethod def check_sample_packing_wo_flash(cls, data): if ( -<<<<<<< Updated upstream - data.get("sample_packing") - and not data.get("flash_attention") - and not data.get("sdp_attention") - and not data.get("flex_attention") - and not data.get("xformers_attention") -======= data.get("sample_packing") and not data.get("flash_attention") and not data.get("sdp_attention") and not data.get("flex_attention") ->>>>>>> Stashed changes ): LOG.warning( "sample_packing without flash, sdp, xformers or flex attention does not handle cross sample decontamination." @@ -1173,17 +1165,10 @@ def check_kto_config(cls, data): @classmethod def check_grpo_liger_sequence_parallel(cls, data): if ( -<<<<<<< Updated upstream - data.get("rl") == "grpo" - and data.get("trl", {}) - and data.get("trl").get("use_liger_loss") - and data.get("sequence_parallel_degree", 1) > 1 -======= data.get("rl") == "grpo" and data.get("trl", {}) and data.get("trl").get("use_liger_loss") and data.get("adapter") ->>>>>>> Stashed changes ): raise ValueError("GRPO + SP + Liger not currently supported") return data diff --git a/tests/prompt_strategies/test_chat_templates.py b/tests/prompt_strategies/test_chat_templates.py index 68772b56b3..667320b5fa 100644 --- a/tests/prompt_strategies/test_chat_templates.py +++ b/tests/prompt_strategies/test_chat_templates.py @@ -2,7 +2,6 @@ tests for chat_template prompt strategy """ -import logging import unittest from axolotl.prompt_strategies.chat_template import ( @@ -14,8 +13,9 @@ from axolotl.utils.chat_templates import get_chat_template from axolotl.utils.dict import DictDefault -logging.basicConfig(level=logging.DEBUG) -LOG = logging.getLogger("axolotl") +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) class TestAssistantChatTemplateLlama3: diff --git a/tests/prompt_strategies/test_chat_templates_advanced.py b/tests/prompt_strategies/test_chat_templates_advanced.py index 38a5b6c432..7f011f9543 100644 --- a/tests/prompt_strategies/test_chat_templates_advanced.py +++ b/tests/prompt_strategies/test_chat_templates_advanced.py @@ -4,7 +4,6 @@ # pylint: disable=too-many-lines -import logging from copy import deepcopy import pytest @@ -18,11 +17,11 @@ ) from axolotl.prompters import IGNORE_TOKEN_ID from axolotl.utils.chat_templates import get_chat_template +from axolotl.utils.logging import get_logger from tests.hf_offline_utils import enable_hf_offline -logging.basicConfig(level=logging.DEBUG) -LOG = logging.getLogger("axolotl") +LOG = get_logger(__name__) PARAMETRIZE_KEYS = "tokenizer, chat_template, chat_template_jinja, eos_token" PARAMETRIZE_PARAMS = [ diff --git a/tests/prompt_strategies/test_chat_templates_thinking.py b/tests/prompt_strategies/test_chat_templates_thinking.py index 9fe292317d..51495bdb1f 100644 --- a/tests/prompt_strategies/test_chat_templates_thinking.py +++ b/tests/prompt_strategies/test_chat_templates_thinking.py @@ -2,8 +2,6 @@ Tests for splitting reasoning/thinking from content into separate field """ -import logging - import pytest from datasets import Dataset from transformers import AutoTokenizer @@ -14,9 +12,9 @@ from axolotl.utils.dict import DictDefault from tests.hf_offline_utils import enable_hf_offline +from axolotl.utils.logging import get_logger -logging.basicConfig(level=logging.DEBUG) -LOG = logging.getLogger("axolotl") +LOG = get_logger(__name__) @pytest.fixture(name="messages_w_reasoning") From d0a30f1cb2e955abd87759c7a215213606e68507 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Mon, 19 May 2025 16:43:20 +0000 Subject: [PATCH 09/29] updating --- examples/llama-3/qlora-1b-kto.yaml | 3 +- src/axolotl/cli/main.py | 6 +- src/axolotl/core/chat/messages.py | 6 +- src/axolotl/integrations/base.py | 11 +- src/axolotl/integrations/liger/__init__.py | 8 +- src/axolotl/integrations/spectrum/__init__.py | 10 +- .../monkeypatch/llama_attn_hijack_xformers.py | 7 +- src/axolotl/monkeypatch/lora_kernels.py | 23 ++-- src/axolotl/monkeypatch/peft/utils.py | 4 +- .../monkeypatch/stablelm_attn_hijack_flash.py | 8 +- src/axolotl/monkeypatch/unsloth_.py | 14 +- .../bradley_terry/chat_template.py | 4 +- .../prompt_strategies/chat_template.py | 50 +++---- src/axolotl/prompt_strategies/llama2_chat.py | 8 +- src/axolotl/train.py | 3 +- src/axolotl/utils/data/sft.py | 128 ++++++++---------- src/axolotl/utils/data/utils.py | 4 - src/axolotl/utils/distributed.py | 5 +- .../gradient_checkpointing/offload_disk.py | 5 +- src/axolotl/utils/models.py | 17 +-- src/axolotl/utils/samplers/multipack.py | 57 +++----- src/axolotl/utils/trainer.py | 63 +++++---- tests/e2e/multigpu/solo/test_flex.py | 4 +- tests/e2e/multigpu/test_eval.py | 4 +- tests/e2e/multigpu/test_gemma3.py | 4 +- tests/e2e/multigpu/test_llama.py | 4 +- tests/e2e/multigpu/test_qwen2.py | 4 +- tests/e2e/multigpu/test_ray.py | 4 +- tests/e2e/patched/test_4d_multipack_llama.py | 4 +- tests/e2e/patched/test_fa_xentropy.py | 4 +- tests/e2e/patched/test_falcon_samplepack.py | 4 +- tests/e2e/patched/test_fused_llama.py | 4 +- tests/e2e/patched/test_llama_s2_attention.py | 4 +- .../e2e/patched/test_lora_llama_multipack.py | 4 +- tests/e2e/patched/test_mistral_samplepack.py | 4 +- tests/e2e/patched/test_mixtral_samplepack.py | 4 +- tests/e2e/patched/test_phi_multipack.py | 4 +- tests/e2e/patched/test_resume.py | 4 +- tests/e2e/patched/test_unsloth_qlora.py | 4 +- tests/e2e/solo/test_flex.py | 4 +- tests/e2e/solo/test_relora_llama.py | 4 +- tests/e2e/test_deepseekv3.py | 4 +- tests/e2e/test_dpo.py | 4 +- tests/e2e/test_embeddings_lr.py | 4 +- tests/e2e/test_falcon.py | 4 +- tests/e2e/test_gemma2.py | 4 +- tests/e2e/test_gemma3_text.py | 4 +- tests/e2e/test_llama.py | 4 +- tests/e2e/test_llama_pretrain.py | 4 +- tests/e2e/test_llama_vision.py | 4 +- tests/e2e/test_lora_llama.py | 4 +- tests/e2e/test_mamba.py | 4 +- tests/e2e/test_mistral.py | 4 +- tests/e2e/test_mixtral.py | 4 +- tests/e2e/test_optimizers.py | 4 +- tests/e2e/test_packing_loss.py | 4 +- tests/e2e/test_phi.py | 4 +- .../e2e/test_process_reward_model_smollm2.py | 4 +- tests/e2e/test_qwen.py | 4 +- tests/e2e/test_reward_model_smollm2.py | 4 +- tests/e2e/test_schedulers.py | 4 +- tests/integrations/test_liger.py | 13 +- tests/patched/test_validation.py | 38 +++--- tests/prompt_strategies/messages/test_chat.py | 21 ++- .../test_jinja_template_analyzer.py | 23 ++-- tests/test_prompt_tokenizers.py | 43 +++--- update_logging.py | 106 +++++++++++++++ 67 files changed, 462 insertions(+), 382 deletions(-) create mode 100644 update_logging.py diff --git a/examples/llama-3/qlora-1b-kto.yaml b/examples/llama-3/qlora-1b-kto.yaml index 89a51ea68f..aa52a62ef2 100644 --- a/examples/llama-3/qlora-1b-kto.yaml +++ b/examples/llama-3/qlora-1b-kto.yaml @@ -40,7 +40,8 @@ wandb_log_model: gradient_accumulation_steps: 1 micro_batch_size: 2 -num_epochs: 1 +# num_epochs: 1 +max_steps: 2 optimizer: adamw_8bit lr_scheduler: cosine learning_rate: 0.0002 diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index cafb2bee97..f0690ffbea 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -2,7 +2,6 @@ # pylint: disable=redefined-outer-name -from axolotl.utils.logging import get_logger import os import subprocess # nosec B404 import tempfile @@ -30,8 +29,11 @@ ) from axolotl.integrations.lm_eval.cli import lm_eval from axolotl.utils import patch_optimized_env +from axolotl.utils.logging import get_logger from axolotl.utils.schemas.config import AxolotlInputConfig +LOG = get_logger(__name__) + @click.group() @click.version_option(version=axolotl.__version__, prog_name="axolotl") @@ -176,7 +178,7 @@ def iter_configs(): do_cli(config=cfg_file, **kwargs) except subprocess.CalledProcessError as exc: - logging.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}") + LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}") if not sweep: raise exc diff --git a/src/axolotl/core/chat/messages.py b/src/axolotl/core/chat/messages.py index 88ff2b7ad0..655e4ce93c 100644 --- a/src/axolotl/core/chat/messages.py +++ b/src/axolotl/core/chat/messages.py @@ -9,6 +9,10 @@ from pydantic import BaseModel from transformers import PreTrainedTokenizer +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + class MessageRoles(str, Enum): """ @@ -156,7 +160,7 @@ def tokenized( len(input_ids) : len(input_ids) + len(pending_input_ids) ] if new_pending_inputs != pending_input_ids: - # logging.warning("tokenization mismatch from concatenation.") + # LOG.warning("tokenization mismatch from concatenation.") pending_input_ids = new_pending_inputs input_ids.extend(pending_input_ids) if pending_weight: diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index fa0ad99f14..1427244226 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -20,13 +20,16 @@ """ import collections import importlib -from axolotl.utils.logging import get_logger from typing import OrderedDict import torch from torch.optim.lr_scheduler import LRScheduler +from transformers.trainer_utils import SchedulerType from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) class BasePlugin: @@ -345,12 +348,12 @@ def register(self, plugin_name: str): ImportError: If the plugin module cannot be imported. """ try: - logging.info(f"Attempting to load plugin: {plugin_name}") + LOG.info(f"Attempting to load plugin: {plugin_name}") plugin = load_plugin(plugin_name) self.plugins[plugin_name] = plugin - logging.info(f"Plugin loaded successfully: {plugin_name}") + LOG.info(f"Plugin loaded successfully: {plugin_name}") except ImportError: - logging.error(f"Failed to load plugin: {plugin_name}") + LOG.error(f"Failed to load plugin: {plugin_name}") def get_input_args(self): """ diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index 4e0addcbe0..a5c4588693 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -19,11 +19,11 @@ It is designed to be performant, correct, and light-weight. """ import inspect -from axolotl.utils.logging import get_logger import sys from axolotl.integrations.base import BasePlugin from axolotl.utils.distributed import is_main_process +from axolotl.utils.logging import get_logger from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 from .utils import patch_with_compile_disable @@ -124,9 +124,9 @@ def pre_model_load(self, cfg): if cfg.liger_rope: # The DeepseekV2 version of RoPE is different than upstream LLaMA. # See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528 - logging.warning("Fused liger_rope is not supported for DeepseekV2.") + LOG.warning("Fused liger_rope is not supported for DeepseekV2.") if cfg.liger_glu_activation: - logging.warning("liger_glu_activation is not supported for DeepseekV2.") + LOG.warning("liger_glu_activation is not supported for DeepseekV2.") if cfg.liger_rms_norm: modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm if cfg.liger_glu_activation: @@ -176,6 +176,6 @@ def pre_model_load(self, cfg): layer_norm=cfg.liger_layer_norm, ) else: - logging.warning( + LOG.warning( f"Unsupported model config type: {cfg.model_config_type}. Liger not applied." ) diff --git a/src/axolotl/integrations/spectrum/__init__.py b/src/axolotl/integrations/spectrum/__init__.py index 91c943ec78..9f66aef97f 100644 --- a/src/axolotl/integrations/spectrum/__init__.py +++ b/src/axolotl/integrations/spectrum/__init__.py @@ -17,14 +17,16 @@ """ import json -from axolotl.utils.logging import get_logger import requests from axolotl.integrations.base import BasePlugin +from axolotl.utils.logging import get_logger from .args import SpectrumArgs # pylint: disable=unused-import. # noqa: F401 +LOG = get_logger(__name__) + def _generate_unfrozen_params_yaml(snr_data, top_fraction=0.5): unfrozen_parameters = {} @@ -83,17 +85,17 @@ def pre_model_load(self, cfg): except FileNotFoundError: pass except Exception as exc: # pylint: disable=broad-exception-caught - logging.warning(f"Failed to read SNR data from {snr_path}: {exc}") + LOG.warning(f"Failed to read SNR data from {snr_path}: {exc}") if not snr_data: try: snr_data = requests.get(snr_url, timeout=60).json() except requests.exceptions.RequestException as exc: - logging.warning(f"Failed to fetch SNR data from {snr_url}: {exc}") + LOG.warning(f"Failed to fetch SNR data from {snr_url}: {exc}") return # also catch json parsing errors except json.JSONDecodeError as exc: - logging.warning(f"Failed to parse SNR data from {snr_url}: {exc}") + LOG.warning(f"Failed to parse SNR data from {snr_url}: {exc}") return unfrozen_parameters = _generate_unfrozen_params_yaml( diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index 5782814529..28223eee36 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -2,7 +2,6 @@ Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments """ -from axolotl.utils.logging import get_logger import warnings from typing import Optional, Tuple @@ -11,10 +10,14 @@ import transformers.models.llama.modeling_llama from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + try: import xformers.ops except ImportError: - logging.error("xformers not found! Please install it before trying to use it.") + LOG.error("xformers not found! Please install it before trying to use it.") def hijack_llama_attention(): diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index be127a6556..34797a759f 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -2,12 +2,10 @@ import importlib import inspect -from axolotl.utils.logging import get_logger import types from typing import Generator, Tuple, Type import torch -from accelerate.logging import get_logger from peft import PeftModelForCausalLM from torch import nn from transformers import AutoConfig @@ -20,6 +18,7 @@ ) from axolotl.monkeypatch.utils import detab_code from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger LOG = get_logger(__name__) @@ -318,7 +317,7 @@ def apply_lora_kernel_patches( # This needs to be reset after patching original_level = LOG.getEffectiveLevel() - LOG.setLevel(logging.INFO) + LOG.setLevel("INFO") # Choose activation based on model type activation = None @@ -366,9 +365,9 @@ def apply_lora_kernel_patches( for linear_proj in ["q_proj", "k_proj", "v_proj"] ] can_patch_qkv = all( - hasattr(module, "lora_A") and - getattr(module, "base_layer", module).bias is None and - len(getattr(module, "lora_magnitude_vector", []) or []) == 0 + hasattr(module, "lora_A") + and getattr(module, "base_layer", module).bias is None + and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 for module in layer_modules ) @@ -385,9 +384,9 @@ def apply_lora_kernel_patches( getattr(self_attn, linear_proj) for linear_proj in ["o_proj"] ] can_patch_o = all( - hasattr(module, "lora_A") and - getattr(module, "base_layer", module).bias is None and - len(getattr(module, "lora_magnitude_vector", []) or []) == 0 + hasattr(module, "lora_A") + and getattr(module, "base_layer", module).bias is None + and len(getattr(module, "lora_magnitude_vector", []) or []) == 0 for module in layer_modules ) @@ -401,9 +400,9 @@ def apply_lora_kernel_patches( if cfg.lora_mlp_kernel: # MLP patching can_patch_mlp = all( - hasattr(proj, "lora_A") and - getattr(proj, "base_layer", proj).bias is None and - len(getattr(proj, "lora_magnitude_vector", []) or []) == 0 + hasattr(proj, "lora_A") + and getattr(proj, "base_layer", proj).bias is None + and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0 for proj in (gate_proj, up_proj, down_proj) ) diff --git a/src/axolotl/monkeypatch/peft/utils.py b/src/axolotl/monkeypatch/peft/utils.py index b3703d398f..5ea5a5e2c6 100644 --- a/src/axolotl/monkeypatch/peft/utils.py +++ b/src/axolotl/monkeypatch/peft/utils.py @@ -3,14 +3,14 @@ """ import inspect -import logging import peft import axolotl from axolotl.monkeypatch.utils import detab_code +from axolotl.utils.logging import get_logger -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) ORIGINAL_PREPARE_CODE = """ for param in model.parameters(): diff --git a/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py b/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py index 136c519915..85454fe2e3 100644 --- a/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/stablelm_attn_hijack_flash.py @@ -32,11 +32,11 @@ from torch import nn from transformers import AutoConfig, AutoModelForCausalLM from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.utils from axolotl.utils.logging import get_logger from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids +from axolotl.utils.logging import get_logger -logger = logging.get_logger(__name__) +logger = get_logger(__name__) def replace_stablelm_attn_with_flash_attn(model_name="stabilityai/stablelm-3b-4e1t"): @@ -121,9 +121,9 @@ def flashattn_attn( ).transpose(1, 2) query_rot = query_states[..., : self.rotary_ndims] - query_pass = query_states[..., self.rotary_ndims:] + query_pass = query_states[..., self.rotary_ndims :] key_rot = key_states[..., : self.rotary_ndims] - key_pass = key_states[..., self.rotary_ndims:] + key_pass = key_states[..., self.rotary_ndims :] kv_seq_len = key_states.shape[-2] if past_key_value is not None: diff --git a/src/axolotl/monkeypatch/unsloth_.py b/src/axolotl/monkeypatch/unsloth_.py index f7d6abf6d4..3566be7253 100644 --- a/src/axolotl/monkeypatch/unsloth_.py +++ b/src/axolotl/monkeypatch/unsloth_.py @@ -126,14 +126,16 @@ def patch_self_attn_lora(): items_to_import.append(item) exec( # pylint: disable=exec-used # nosec B102 - "from transformers.models.llama.modeling_llama import (" + - ", ".join(x for x in items_to_import) + - ")", + "from transformers.models.llama.modeling_llama import (" + + ", ".join(x for x in items_to_import) + + ")", globals(), ) exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102 self_attn_lora_patched = True - LOG.info("patching unsloth attn lora", main_process_only=True) + LOG.info( + "patching unsloth attn lora", + ) LlamaFlashAttention2.forward = ( unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821 ) @@ -153,7 +155,9 @@ def apply_rotary_pos_emb( # pylint: disable=unused-argument ): return fast_rope_embedding(q, k, cos, sin) - LOG.info("patching unsloth RoPE embeddings", main_process_only=True) + LOG.info( + "patching unsloth RoPE embeddings", + ) transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb diff --git a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py index 503fc4bb5b..1f693f1b4c 100644 --- a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py +++ b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py @@ -2,7 +2,6 @@ Bradley-Terry model with chat template prompt strategy. """ -from axolotl.utils.logging import get_logger from typing import Any, Dict, Optional from axolotl.prompt_strategies.chat_template import ( @@ -10,10 +9,11 @@ ChatTemplateStrategy, ) from axolotl.utils.chat_templates import get_chat_template_from_config +from axolotl.utils.logging import get_logger # Configure the logger LOG = get_logger(__name__) -LOG.setLevel(logging.INFO) +LOG.setLevel("INFO") class BTChatTemplateStrategy(ChatTemplateStrategy): diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 3d7a75fa01..59f1c759b3 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -2,7 +2,6 @@ HF Chat Templates prompt strategy """ -from axolotl.utils.logging import get_logger from collections import defaultdict from typing import Any, Dict, List, Set, Union @@ -13,11 +12,12 @@ from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompters import IGNORE_TOKEN_ID, Prompter from axolotl.utils.chat_templates import get_chat_template_from_config +from axolotl.utils.logging import get_logger from axolotl.utils.schemas.datasets import DatasetConfig # Configure the logger LOG = get_logger(__name__) -LOG.setLevel(logging.INFO) +LOG.setLevel("INFO") class ChatTemplatePrompter(Prompter): @@ -179,8 +179,8 @@ def adjust_train_details( -1, ) if ( - end_token < len(token_offsets) - 1 and - token_offsets[end_token + 1][0] < end_offset + end_token < len(token_offsets) - 1 + and token_offsets[end_token + 1][0] < end_offset ): end_token += 1 @@ -277,8 +277,8 @@ def _validate_eot_and_eos_tokens(self): # Check if the eos_token is in the chat_template or as a variable `eos_token` # Note: we check for `eos_token` in the string, but it could possibly not be a variable if ( - self.tokenizer.eos_token not in self.prompter.chat_template and - "eos_token" not in self.prompter.chat_template + self.tokenizer.eos_token not in self.prompter.chat_template + and "eos_token" not in self.prompter.chat_template ): LOG.warning( f"EOS token '{self.tokenizer.eos_token}' not found in chat_template. Please check if your template/EOS token is correct." @@ -314,8 +314,8 @@ def _validate_eot_and_eos_tokens(self): # If eos_token is in eot_tokens and conflict between train_on_eos and train_on_eot, raise an error if ( - self.tokenizer.eos_token in self.eot_tokens and - self.train_on_eos != self.train_on_eot + self.tokenizer.eos_token in self.eot_tokens + and self.train_on_eos != self.train_on_eot ): raise ValueError( "Conflict between train_on_eos and train_on_eot. eos_token is in eot_tokens and train_on_eos != train_on_eot" @@ -365,11 +365,11 @@ def tokenize_prompt(self, prompt: dict[str, Any]): def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]: # Old simple legacy behavior that works reliably. if ( - not self.roles_to_train and - not self.train_on_eos and - not self.train_on_eot and - not self.prompter.message_field_training and # type: ignore - not self.prompter.message_field_training_detail # type: ignore + not self.roles_to_train + and not self.train_on_eos + and not self.train_on_eot + and not self.prompter.message_field_training # type: ignore + and not self.prompter.message_field_training_detail # type: ignore ): turns = self.get_conversation_thread(prompt) images = self.get_images(prompt) @@ -379,10 +379,11 @@ def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]: images=images, ) tokenized_res = self.prompter.build_prompt( - turns, images=images) # type: ignore + turns, images=images + ) # type: ignore tokenized_prompt = {} if isinstance(tokenized_res, list): - input_ids = prompt_ids + tokenized_res[len(prompt_ids):] + input_ids = prompt_ids + tokenized_res[len(prompt_ids) :] tokenized_prompt["input_ids"] = input_ids tokenized_prompt["attention_mask"] = [1] * len(input_ids) else: @@ -538,10 +539,11 @@ def find_turn(self, turns: list[dict], turn_idx: int): # mistral/gemma3 does not output message if it contains only system message if ( - turn_idx == 0 and - turns[0].get("role") == "system" and - ( - "mistral" in self.tokenizer.name_or_path.lower() or + turn_idx == 0 + and turns[0].get("role") == "system" + and ( + "mistral" in self.tokenizer.name_or_path.lower() + or # gemma3 uses gemma tokenizer "gemma" in self.tokenizer.name_or_path.lower() ) @@ -618,8 +620,8 @@ def get_conversation_thread(self, prompt): prompt[self.prompter.field_messages][0] ) if ( - possible_sys_turn["role"] != "system" and - self.prompter.field_system in prompt + possible_sys_turn["role"] != "system" + and self.prompter.field_system in prompt ): turn = {"role": "system", "content": prompt[self.prompter.field_system]} turns.append(turn) @@ -677,12 +679,12 @@ def transform_message(self, message): t_end_idx = content.find(tpair[1]) # get the thinking content - thinking_content = content[t_start_idx + len(tpair[0]): t_end_idx] + thinking_content = content[t_start_idx + len(tpair[0]) : t_end_idx] transformed_message["reasoning_content"] = thinking_content.strip() # take remainder of the content # strip whitespace from beginning of the remainder (thinking tokens) - remainder = content[t_end_idx + len(tpair[1]):].lstrip() + remainder = content[t_end_idx + len(tpair[1]) :].lstrip() # check if the content pair is in the remainder cpair_found = False @@ -694,7 +696,7 @@ def transform_message(self, message): # get the content content content_content = remainder[ - c_start_idx + len(cpair[0]): c_end_idx + c_start_idx + len(cpair[0]) : c_end_idx ] transformed_message["content"] = content_content.strip() cpair_found = True diff --git a/src/axolotl/prompt_strategies/llama2_chat.py b/src/axolotl/prompt_strategies/llama2_chat.py index d7b80d483d..eef2e1d4d3 100644 --- a/src/axolotl/prompt_strategies/llama2_chat.py +++ b/src/axolotl/prompt_strategies/llama2_chat.py @@ -24,12 +24,14 @@ Important: Don't use "special_tokens:" in your config.yml if you are not sure what you are doing! """ -from axolotl.utils.logging import get_logger from dataclasses import dataclass, field from typing import Generator, List, Sequence from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompters import ALTERNATING_ASSERTION_FAILED_ROLE, IGNORE_TOKEN_ID +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) @dataclass @@ -121,7 +123,7 @@ def tokenize_prompt(self, prompt): instruction_len = len(self.tokenizer(parts[0]).input_ids) - 1 # Ignore the user instructions - target[cur_len - 1: cur_len + instruction_len] = IGNORE_TOKEN_ID + target[cur_len - 1 : cur_len + instruction_len] = IGNORE_TOKEN_ID cur_len += turn_len + 2 # due to length of role token target[cur_len:] = IGNORE_TOKEN_ID @@ -129,7 +131,7 @@ def tokenize_prompt(self, prompt): if cur_len < self.sequence_len: if cur_len != total_len: target[:] = IGNORE_TOKEN_ID - logging.warning( + LOG.warning( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 9d80fde1b5..81d7596c07 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -2,7 +2,6 @@ import importlib import inspect -import logging import os import signal import sys @@ -42,7 +41,7 @@ except ImportError: BetterTransformer = None -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) def setup_model_and_tokenizer( diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index b835b1f1c4..8f490d4a73 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -1,7 +1,6 @@ """data handling specific to SFT""" import functools -from axolotl.utils.logging import get_logger import os import tempfile from pathlib import Path @@ -54,11 +53,7 @@ ) from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_local_main_process, zero_first -from axolotl.utils.logging import ( - log_debug_rank_zero, - log_info_rank_zero, - log_warning_rank_zero, -) +from axolotl.utils.logging import get_logger from axolotl.utils.trainer import ( calculate_total_num_steps, process_datasets_for_packing, @@ -126,9 +121,9 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None): # when letting accelerator dispatch batches from the main process, we don't need to load the dataset from # other ranks, we just need to present a fake dataset if ( - cfg.accelerator_config and - cfg.accelerator_config.dispatch_batches and - not is_local_main_process() + cfg.accelerator_config + and cfg.accelerator_config.dispatch_batches + and not is_local_main_process() ): with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f: f.write("text\n") @@ -172,8 +167,7 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None): ) if cfg.dataset_exact_deduplication: - log_info_rank_zero( - LOG, + LOG.info( "Deduplication not available for pretrained datasets", ) @@ -192,8 +186,7 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None): ) else: total_num_steps = calculate_total_num_steps(cfg, train_dataset) - log_info_rank_zero( - LOG, + LOG.info( f"Maximum number of steps set at {total_num_steps}", ) return train_dataset, eval_dataset, total_num_steps, prompters @@ -213,25 +206,25 @@ def load_tokenized_prepared_datasets( ds_hash = str( md5( ( - str(cfg.sequence_len) + - "@" + - str(cfg.sample_packing) + - "@" + - str(cfg.eval_sample_packing) + - "@" + - str(cfg.group_by_length) + - "@" + - str(cfg.kd_temperature or 1.0) + - "|".join( + str(cfg.sequence_len) + + "@" + + str(cfg.sample_packing) + + "@" + + str(cfg.eval_sample_packing) + + "@" + + str(cfg.group_by_length) + + "@" + + str(cfg.kd_temperature or 1.0) + + "|".join( sorted( [ f"{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}" for d in cfg_datasets ] ) - ) + - "|" + - tokenizer_name + ) + + "|" + + tokenizer_name ) ) ) @@ -245,8 +238,7 @@ def load_tokenized_prepared_datasets( use_auth_token = cfg.hf_use_auth_token try: if cfg.push_dataset_to_hub: - log_info_rank_zero( - LOG, + LOG.info( f"Attempting to load prepared dataset from Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})...", ) dataset = load_dataset( @@ -263,51 +255,43 @@ def load_tokenized_prepared_datasets( # This is for the case where we already loaded a pretokenized dataset from the hub ... elif ( - cfg.dataset_prepared_path and - any(prepared_ds_path.glob("*")) and - not cfg.is_preprocess and - not cfg.skip_prepare_dataset + cfg.dataset_prepared_path + and any(prepared_ds_path.glob("*")) + and not cfg.is_preprocess + and not cfg.skip_prepare_dataset ): - log_info_rank_zero( - LOG, + LOG.info( f"Loading prepared dataset from disk at {prepared_ds_path}...", ) dataset = load_from_disk(str(prepared_ds_path)) - log_info_rank_zero( - LOG, + LOG.info( "Prepared dataset loaded from disk...", ) else: if cfg.push_dataset_to_hub: - log_info_rank_zero( - LOG, + LOG.info( "Unable to find prepared dataset in Huggingface hub", ) if cfg.is_preprocess: - log_info_rank_zero( - LOG, + LOG.info( f"Skipping prepared dataset in {prepared_ds_path} for pre-processing...", ) else: - log_info_rank_zero( - LOG, + LOG.info( f"Unable to find prepared dataset in {prepared_ds_path}", ) - log_info_rank_zero( - LOG, + LOG.info( "Loading raw datasets...", ) if not cfg.is_preprocess: - log_warning_rank_zero( - LOG, + LOG.warning( "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset.", ) if cfg.seed: seed = cfg.seed else: - log_info_rank_zero( - LOG, + LOG.info( "No seed provided, using default seed of 42", ) seed = 42 @@ -362,22 +346,19 @@ def load_tokenized_prepared_datasets( if len(datasets) == 1: dataset = datasets[0] else: - log_info_rank_zero( - LOG, + LOG.info( "Merging datasets...", ) dataset = concatenate_datasets(datasets) if len(datasets) > 1: if cfg.shuffle_merged_datasets: - log_debug_rank_zero( - LOG, + LOG.debug( "Shuffling merged datasets...", ) dataset = dataset.shuffle(seed=seed) else: - log_debug_rank_zero( - LOG, + LOG.debug( "NOT shuffling merged datasets", ) @@ -388,8 +369,7 @@ def load_tokenized_prepared_datasets( dataset, _ = process_datasets_for_packing(cfg, dataset, None) if cfg.local_rank == 0 and not cfg.skip_prepare_dataset: - log_info_rank_zero( - LOG, + LOG.info( f"Saving merged prepared dataset to disk... {prepared_ds_path}", ) if isinstance(dataset, IterableDataset): @@ -463,22 +443,22 @@ def load_prepare_datasets( # ensure we end up with the same fingerprint by doing rank0 first and being able to cache to_hash_train = ( - dataset._fingerprint + # pylint: disable=protected-access - "|" + - str(val_set_size) + - "|" + - "train" + - "|" + - str(cfg.seed or 42) + dataset._fingerprint # pylint: disable=protected-access + + "|" + + str(val_set_size) + + "|" + + "train" + + "|" + + str(cfg.seed or 42) ) to_hash_test = ( - dataset._fingerprint + # pylint: disable=protected-access - "|" + - str(val_set_size) + - "|" + - "test" + - "|" + - str(cfg.seed or 42) + dataset._fingerprint # pylint: disable=protected-access + + "|" + + str(val_set_size) + + "|" + + "test" + + "|" + + str(cfg.seed or 42) ) train_fingerprint = md5(to_hash_train) test_fingerprint = md5(to_hash_test) @@ -531,10 +511,10 @@ def get_dataset_wrapper( ) if ( - isinstance(dataset, Dataset) and - "input_ids" in dataset.features and - "attention_mask" in dataset.features and - "labels" in dataset.features + isinstance(dataset, Dataset) + and "input_ids" in dataset.features + and "attention_mask" in dataset.features + and "labels" in dataset.features ): # dataset is already tokenized, just drop it straight in dataset_prompter = UnsupportedPrompter() diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index 7bc375e5a2..6202d336d6 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -162,7 +162,6 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): if "input_ids" not in dataset.column_names: LOG.warning( "Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling.", - main_process_only=True, ) return dataset @@ -177,12 +176,10 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): min_input_len = np.min(ds_lengths) LOG.info( f"min_input_len: {min_input_len}", - main_process_only=True, ) max_input_len = np.max(ds_lengths) LOG.info( f"max_input_len: {max_input_len}", - main_process_only=True, ) except AttributeError: pass @@ -213,7 +210,6 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): if dropped: LOG.warning( f"Dropped {dropped} long samples from dataset", - main_process_only=True, ) return dataset diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 8c52102c89..adf6fa33e1 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -80,10 +80,11 @@ def is_main_process(use_environ=False): Returns: - bool: `True` if the current process is the main process, `False` otherwise. """ - if use_environ: - return os.environ.get("LOCAL_RANK", "0") == "0" if not is_distributed(): return True + if use_environ: + return os.environ.get("LOCAL_RANK", "0") == "0" + return dist.get_rank() == 0 diff --git a/src/axolotl/utils/gradient_checkpointing/offload_disk.py b/src/axolotl/utils/gradient_checkpointing/offload_disk.py index 90e70f504a..792d3c6efc 100644 --- a/src/axolotl/utils/gradient_checkpointing/offload_disk.py +++ b/src/axolotl/utils/gradient_checkpointing/offload_disk.py @@ -18,7 +18,6 @@ import atexit import concurrent.futures -import logging import os import queue import shutil @@ -32,11 +31,13 @@ import torch +from axolotl.utils.logging import get_logger + torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda") torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda") # Setup logger -logger = logging.getLogger(__name__) +logger = get_logger(__name__) class DiskOffloadManager: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 3890e3f834..2035afc64b 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -140,7 +140,6 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig): cfg.image_size = model_config.vision_config.image_size LOG.debug( f"Loaded image size: {cfg.image_size} from model config", - main_process_only=True, ) quant_config_exists = ( @@ -159,7 +158,6 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig): LOG.warning( "Found `config_groups` in a compressed-tensors config. " "QAT integration with llmcompressor is not tested.", - main_process_only=True, ) # Skip further quant checks for compressed-tensors return @@ -527,7 +525,6 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): LOG.debug( f"Loaded image size: {cfg.image_size} from processor", - main_process_only=True, ) return processor @@ -771,7 +768,6 @@ def patch_llama_derived_model(self): if self.cfg.device not in ["mps", "cpu"] and not self.inference: LOG.info( "patching with flash attention for sample packing", - main_process_only=True, ) replace_llama_attn_with_flash_attn( packed=True, @@ -781,7 +777,6 @@ def patch_llama_derived_model(self): elif self.cfg.s2_attention: LOG.info( "patching w/ flash-enabled, shifted-sparse attention", - main_process_only=True, ) replace_llama_attn_with_flash_attn( packed=False, @@ -809,7 +804,6 @@ def patch_llama_derived_model(self): LOG.info( "patching llama _prepare_4d_causal_attention_mask*", - main_process_only=True, ) hijack_llama_prepare_4d_mask() elif self.cfg.s2_attention: @@ -893,7 +887,6 @@ def set_quantization_config(self) -> None: if not hasattr(self.model_config, "quantization_config"): LOG.warning( "model config does not contain quantization_config information", - main_process_only=True, ) else: if self.cfg.gptq_disable_exllama is not None: @@ -1183,7 +1176,6 @@ def adjust_model_config(self) -> None: ): LOG.warning( f"increasing model.config.max_position_embeddings from {self.model.config.max_position_embeddings} to {self.cfg.sequence_len}", - main_process_only=True, ) self.model.config.max_position_embeddings = self.cfg.sequence_len @@ -1249,7 +1241,6 @@ def prepare_model(self, qlora_fsdp: bool) -> None: ): LOG.info( "converting PEFT model w/ prepare_model_for_kbit_training", - main_process_only=True, ) self.model = prepare_model_for_kbit_training( self.model, use_gradient_checkpointing=self.cfg.gradient_checkpointing @@ -1382,7 +1373,9 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: (needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) and not qlora_fsdp ) - or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass + or + # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass + self.cfg.cut_cross_entropy ) if should_convert: @@ -1443,7 +1436,6 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: if len(requires_grad) == 0: LOG.warning( "there are no parameters that require gradient updates", - main_process_only=True, ) if self.cfg.flash_optimum: @@ -1587,7 +1579,6 @@ def load_lora(model, cfg, inference=False, config_only=False): linear_names = find_all_linear_names(model) LOG.info( f"found linear modules: {repr(sorted(linear_names))}", - main_process_only=True, ) lora_target_modules_as_list = ( lora_target_modules @@ -1607,7 +1598,6 @@ def load_lora(model, cfg, inference=False, config_only=False): lora_config_kwargs["use_dora"] = cfg.peft_use_dora LOG.info( "Initializing LoRA weights using dora. This might take longer.", - main_process_only=True, ) if cfg.peft_use_rslora: lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora @@ -1663,7 +1653,6 @@ def load_lora(model, cfg, inference=False, config_only=False): LOG.warning( "Exception caught during model.print_trainable_parameters(): %s", exc, - main_process_only=True, ) elif ( cfg.fsdp diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 88b1ad9fa5..feccb19806 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -2,8 +2,8 @@ Multipack Batch Sampler - An efficient batch sampler for packing variable-length sequences into fixed-capacity batches to optimize memory usage and training throughput. """ + import logging -from axolotl.utils.logging import get_logger import math from concurrent.futures import ProcessPoolExecutor from multiprocessing import cpu_count, get_context @@ -14,8 +14,10 @@ from torch.utils.data import BatchSampler, Sampler, SequentialSampler from axolotl.utils.distributed import reduce_and_broadcast +from axolotl.utils.logging import get_logger LOG = get_logger(__name__) +LOG.setLevel(logging.INFO) @numba.njit @@ -77,11 +79,15 @@ def pack_group( Returns: List of bins, where each bin contains indices of sequences assigned to it """ + # Get sorting indices and sort lengths in descending order + indices = np.argsort(sequence_lengths)[::-1] + sorted_lengths = sequence_lengths[indices] + bins_remaining_space: list = [] # Tracks remaining capacity in each bin bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin - for seq_id, size in enumerate(sequence_lengths): - global_idx = seq_id + group_offset + for seq_id, size in enumerate(sorted_lengths): + global_idx = indices[seq_id] + group_offset # Try to place sequence in existing bins add_new_bin = True @@ -125,7 +131,6 @@ def pack_parallel( bin_size: int, num_processes: int | None = None, safe_mode: bool = True, - mp_start_method: str | None = "spawn", ): """ Pack sequences into bins using parallel processing @@ -137,9 +142,7 @@ def pack_parallel( bin_size: Maximum number of bins to use num_processes: Number of parallel processes to use safe_mode: If True, use a more conservative packing approach - mp_start_method: Multiprocessing start method ('fork', 'spawn', 'forkserver'). - 'spawn' is often safer with Numba/PyTorch. - Set to None to use system default. + Returns: List of bins, where each bin contains indices of sequences assigned to it """ @@ -154,34 +157,11 @@ def pack_parallel( max_bins = len(group_lengths) # Allow as many bins as items in the group tasks.append((group_lengths, i, bin_capacity, max_bins, bin_size, safe_mode)) - while right - left > 1: - mid = (left + right) // 2 - if ffd_check(lengths[start_index: start_index + mid], c, n): - left = mid - else: - right = mid - - # use length l - batch = ffd_with_result( - lengths[start_index: start_index + left], c, start_index - ) - assert len(batch) <= n - if len(batch) < n: - break - - if num_processes == 1: - LOG.debug("Using single process for pack_parallel, running sequentially.") - for task_args in tasks: - group_bins = _process_group(task_args) + # Process groups in parallel + all_bins = [] + with ProcessPoolExecutor(max_workers=num_processes) as executor: + for group_bins in executor.map(_process_group, tasks): all_bins.extend(group_bins) - else: - # Use ProcessPoolExecutor only if num_processes > 1 - # Pass mp_context if available - with ProcessPoolExecutor( - max_workers=num_processes, mp_context=mp_ctx - ) as executor: - for group_bins in executor.map(_process_group, tasks): - all_bins.extend(group_bins) return all_bins @@ -315,12 +295,9 @@ def generate_batches(self, set_stats=False): if self._batches is not None: return self._batches - batches = [ - [ - [indices[b_idx] for b_idx in batch] - for batch in batches[i: i + self.batch_size] - ] - for i in range(0, len(batches), self.batch_size) + # Get indices from the sampler + indices = [ # pylint: disable=unnecessary-comprehension + idx for idx in self.sampler ] # Get lengths of the selected sequences diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 3d9fc90e14..6cfefe97f0 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -392,9 +392,9 @@ def process_pretraining_datasets_for_packing( def calculate_total_num_steps(cfg, train_dataset, update=True): if ( - not cfg.total_num_tokens and - not cfg.skip_prepare_dataset and - not cfg.reward_model + not cfg.total_num_tokens + and not cfg.skip_prepare_dataset + and not cfg.reward_model ): total_num_tokens = np.sum( train_dataset.select_columns("input_ids") @@ -402,17 +402,19 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): .apply(len) .values ) - LOG.debug(f"total_num_tokens: {total_num_tokens:_}", main_process_only=True) + LOG.debug( + f"total_num_tokens: {total_num_tokens:_}", + ) if update: cfg.total_num_tokens = total_num_tokens skip_estimates = cfg.model_config_type == "mamba" if ( - not skip_estimates and - not cfg.total_supervised_tokens and - not cfg.skip_prepare_dataset and - not cfg.reward_model + not skip_estimates + and not cfg.total_supervised_tokens + and not cfg.skip_prepare_dataset + and not cfg.reward_model ): total_supervised_tokens = ( train_dataset.data.column("labels") @@ -422,7 +424,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): ) LOG.debug( f"`total_supervised_tokens: {total_supervised_tokens:_}`", - main_process_only=True, ) if update: cfg.total_supervised_tokens = total_supervised_tokens @@ -436,20 +437,19 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): # match count to len est in dataloader int( math.floor( - 0.99 * - cfg.total_num_tokens / - cfg.sample_packing_eff_est / - cfg.sequence_len // - cfg.batch_size - ) - - 1 - ) * - cfg.num_epochs * - cfg.sequence_parallel_degree + 0.99 + * cfg.total_num_tokens + / cfg.sample_packing_eff_est + / cfg.sequence_len + // cfg.batch_size + ) + - 1 + ) + * cfg.num_epochs + * cfg.sequence_parallel_degree ) LOG.debug( f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}", - main_process_only=True, ) else: if cfg.flash_attention and not cfg.multipack_real_batches: @@ -478,7 +478,9 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): batch_sampler=sampler, ) data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size - LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True) + LOG.debug( + f"data_loader_len: {data_loader_len}", + ) # FIXME: is there a bug here somewhere? the total num steps depends # on the agreed on value for sample_packing_eff_est total_num_steps = int( @@ -502,18 +504,19 @@ def calc_sample_packing_eff_est(estimates: List[float]): cfg.sample_packing_eff_est = sample_packing_eff_est LOG.debug( f"sample_packing_eff_est: {cfg.sample_packing_eff_est}", - main_process_only=True, ) else: total_num_steps = int( math.ceil( - len(train_dataset) * - cfg.num_epochs * - cfg.sequence_parallel_degree / - cfg.batch_size + len(train_dataset) + * cfg.num_epochs + * cfg.sequence_parallel_degree + / cfg.batch_size ) ) - LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True) + LOG.debug( + f"total_num_steps: {total_num_steps}", + ) return total_num_steps @@ -637,9 +640,9 @@ def setup_trainer( on the provided parameters. """ if ( - cfg.torch_compile and - cfg.fsdp_config and - str(cfg.fsdp_config.fsdp_version) == "2" + cfg.torch_compile + and cfg.fsdp_config + and str(cfg.fsdp_config.fsdp_version) == "2" ): patch_evaluation_loop_for_fsdp2() if cfg.rl: diff --git a/tests/e2e/multigpu/solo/test_flex.py b/tests/e2e/multigpu/solo/test_flex.py index 471b112c10..080ea4c97a 100644 --- a/tests/e2e/multigpu/solo/test_flex.py +++ b/tests/e2e/multigpu/solo/test_flex.py @@ -2,7 +2,6 @@ E2E tests for multigpu lora tinyllama """ -import logging import os from pathlib import Path @@ -14,10 +13,11 @@ from transformers.utils import is_torch_bf16_gpu_available from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from tests.e2e.utils import check_tensorboard, require_torch_2_6_0 -LOG = logging.getLogger("axolotl.tests.e2e.multigpu") +LOG = get_logger("axolotl.tests.e2e.multigpu") os.environ["WANDB_DISABLED"] = "true" AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent diff --git a/tests/e2e/multigpu/test_eval.py b/tests/e2e/multigpu/test_eval.py index 4989b81df7..45a961b7a4 100644 --- a/tests/e2e/multigpu/test_eval.py +++ b/tests/e2e/multigpu/test_eval.py @@ -2,7 +2,6 @@ E2E tests for multigpu eval """ -import logging import os from pathlib import Path @@ -11,10 +10,11 @@ from transformers.testing_utils import get_torch_dist_unique_port from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_tensorboard -LOG = logging.getLogger("axolotl.tests.e2e.multigpu") +LOG = get_logger("axolotl.tests.e2e.multigpu") os.environ["WANDB_DISABLED"] = "true" AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent diff --git a/tests/e2e/multigpu/test_gemma3.py b/tests/e2e/multigpu/test_gemma3.py index 9de3ed82f8..8540ec91fb 100644 --- a/tests/e2e/multigpu/test_gemma3.py +++ b/tests/e2e/multigpu/test_gemma3.py @@ -2,7 +2,6 @@ E2E tests for multigpu lora tinyllama """ -import logging import os from pathlib import Path @@ -13,10 +12,11 @@ from transformers.testing_utils import get_torch_dist_unique_port from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from tests.e2e.utils import check_tensorboard -LOG = logging.getLogger("axolotl.tests.e2e.multigpu") +LOG = get_logger("axolotl.tests.e2e.multigpu") os.environ["WANDB_DISABLED"] = "true" AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index 38e6e741a1..e383c54413 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -2,7 +2,6 @@ E2E tests for multigpu lora tinyllama """ -import logging import os from pathlib import Path @@ -15,10 +14,11 @@ from transformers.testing_utils import get_torch_dist_unique_port from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from tests.e2e.utils import check_tensorboard, require_torch_2_6_0 -LOG = logging.getLogger("axolotl.tests.e2e.multigpu") +LOG = get_logger("axolotl.tests.e2e.multigpu") os.environ["WANDB_DISABLED"] = "true" AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent diff --git a/tests/e2e/multigpu/test_qwen2.py b/tests/e2e/multigpu/test_qwen2.py index 9599c3abf8..23650b10dc 100644 --- a/tests/e2e/multigpu/test_qwen2.py +++ b/tests/e2e/multigpu/test_qwen2.py @@ -2,7 +2,6 @@ E2E tests for multigpu qwen2 """ -import logging import os from pathlib import Path @@ -12,8 +11,9 @@ from transformers.testing_utils import get_torch_dist_unique_port from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger -LOG = logging.getLogger("axolotl.tests.e2e.multigpu") +LOG = get_logger("axolotl.tests.e2e.multigpu") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/multigpu/test_ray.py b/tests/e2e/multigpu/test_ray.py index 843adac912..64c2d501ff 100644 --- a/tests/e2e/multigpu/test_ray.py +++ b/tests/e2e/multigpu/test_ray.py @@ -2,7 +2,6 @@ E2E tests for multigpu post-training use Ray Train """ -import logging import os from pathlib import Path @@ -11,10 +10,11 @@ from accelerate.test_utils import execute_subprocess_async from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0 -LOG = logging.getLogger(__name__) +LOG = get_logger(__name__) os.environ["WANDB_DISABLED"] = "true" AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent diff --git a/tests/e2e/patched/test_4d_multipack_llama.py b/tests/e2e/patched/test_4d_multipack_llama.py index 12dd51c134..27b2b2ca04 100644 --- a/tests/e2e/patched/test_4d_multipack_llama.py +++ b/tests/e2e/patched/test_4d_multipack_llama.py @@ -2,7 +2,6 @@ E2E tests for multipack fft llama using 4d attention masks """ -import logging import os import unittest @@ -11,10 +10,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/patched/test_fa_xentropy.py b/tests/e2e/patched/test_fa_xentropy.py index f71e4fb4af..2581d39a6e 100644 --- a/tests/e2e/patched/test_fa_xentropy.py +++ b/tests/e2e/patched/test_fa_xentropy.py @@ -2,7 +2,6 @@ E2E tests for lora llama """ -import logging import os import pytest @@ -13,10 +12,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_model_output_exists, check_tensorboard -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/patched/test_falcon_samplepack.py b/tests/e2e/patched/test_falcon_samplepack.py index 667b62ffba..61689ca1fc 100644 --- a/tests/e2e/patched/test_falcon_samplepack.py +++ b/tests/e2e/patched/test_falcon_samplepack.py @@ -2,7 +2,6 @@ E2E tests for falcon """ -import logging import os import unittest @@ -13,10 +12,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/patched/test_fused_llama.py b/tests/e2e/patched/test_fused_llama.py index 7725e095d1..20fd2acb53 100644 --- a/tests/e2e/patched/test_fused_llama.py +++ b/tests/e2e/patched/test_fused_llama.py @@ -2,7 +2,6 @@ E2E tests for lora llama """ -import logging import os import unittest @@ -14,10 +13,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py index 3cf43ba9d3..3c81a274a7 100644 --- a/tests/e2e/patched/test_llama_s2_attention.py +++ b/tests/e2e/patched/test_llama_s2_attention.py @@ -2,7 +2,6 @@ E2E tests for llama w/ S2 attn """ -import logging import os import unittest @@ -13,10 +12,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/patched/test_lora_llama_multipack.py b/tests/e2e/patched/test_lora_llama_multipack.py index ca989f241e..894742a7e8 100644 --- a/tests/e2e/patched/test_lora_llama_multipack.py +++ b/tests/e2e/patched/test_lora_llama_multipack.py @@ -2,7 +2,6 @@ E2E tests for lora llama """ -import logging import os import unittest @@ -14,10 +13,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/patched/test_mistral_samplepack.py b/tests/e2e/patched/test_mistral_samplepack.py index fe8fafb19d..5ae5a6dc5a 100644 --- a/tests/e2e/patched/test_mistral_samplepack.py +++ b/tests/e2e/patched/test_mistral_samplepack.py @@ -2,7 +2,6 @@ E2E tests for lora llama """ -import logging import os import unittest @@ -11,10 +10,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py index ebc2ba0927..38a5d6b658 100644 --- a/tests/e2e/patched/test_mixtral_samplepack.py +++ b/tests/e2e/patched/test_mixtral_samplepack.py @@ -2,7 +2,6 @@ E2E tests for mixtral """ -import logging import os import unittest @@ -11,10 +10,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/patched/test_phi_multipack.py b/tests/e2e/patched/test_phi_multipack.py index d8130d1190..54cac15dcd 100644 --- a/tests/e2e/patched/test_phi_multipack.py +++ b/tests/e2e/patched/test_phi_multipack.py @@ -2,7 +2,6 @@ E2E tests for lora llama """ -import logging import os import unittest @@ -11,10 +10,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/patched/test_resume.py b/tests/e2e/patched/test_resume.py index 61e4a0e03d..8ba6b7c540 100644 --- a/tests/e2e/patched/test_resume.py +++ b/tests/e2e/patched/test_resume.py @@ -2,7 +2,6 @@ E2E tests for resuming training """ -import logging import os import re import subprocess @@ -14,10 +13,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_model_output_exists, most_recent_subdir, require_torch_2_6_0 -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/patched/test_unsloth_qlora.py b/tests/e2e/patched/test_unsloth_qlora.py index 5f8fde6b4d..3b429279f5 100644 --- a/tests/e2e/patched/test_unsloth_qlora.py +++ b/tests/e2e/patched/test_unsloth_qlora.py @@ -2,7 +2,6 @@ e2e tests for unsloth qlora """ -import logging import os import pytest @@ -12,10 +11,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_model_output_exists, check_tensorboard -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/solo/test_flex.py b/tests/e2e/solo/test_flex.py index 71da795f89..431afd55ba 100644 --- a/tests/e2e/solo/test_flex.py +++ b/tests/e2e/solo/test_flex.py @@ -2,7 +2,6 @@ E2E tests for packed training w/ flex attention """ -import logging import os import unittest @@ -13,10 +12,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_tensorboard, require_torch_2_6_0, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/solo/test_relora_llama.py b/tests/e2e/solo/test_relora_llama.py index 504466b90c..6e9f403d0e 100644 --- a/tests/e2e/solo/test_relora_llama.py +++ b/tests/e2e/solo/test_relora_llama.py @@ -2,7 +2,6 @@ E2E tests for relora llama """ -import logging import os import unittest from pathlib import Path @@ -12,10 +11,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_deepseekv3.py b/tests/e2e/test_deepseekv3.py index 2afda640f1..0a228aa052 100644 --- a/tests/e2e/test_deepseekv3.py +++ b/tests/e2e/test_deepseekv3.py @@ -2,7 +2,6 @@ E2E tests for deepseekv3 """ -import logging import os from pathlib import Path @@ -13,10 +12,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from tests.hf_offline_utils import enable_hf_offline -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index 84d723ec08..b039893849 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -2,7 +2,6 @@ E2E tests for lora llama """ -import logging import os import unittest from pathlib import Path @@ -14,10 +13,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_embeddings_lr.py b/tests/e2e/test_embeddings_lr.py index 82b822ad60..fe6a507449 100644 --- a/tests/e2e/test_embeddings_lr.py +++ b/tests/e2e/test_embeddings_lr.py @@ -2,7 +2,6 @@ E2E tests for llama pretrain """ -import logging import os import unittest @@ -11,10 +10,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, check_tensorboard, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_falcon.py b/tests/e2e/test_falcon.py index 24afab0b3f..4f15867caf 100644 --- a/tests/e2e/test_falcon.py +++ b/tests/e2e/test_falcon.py @@ -2,7 +2,6 @@ E2E tests for falcon """ -import logging import os import unittest @@ -13,10 +12,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_gemma2.py b/tests/e2e/test_gemma2.py index 68dc4855d9..8b9b0d11d4 100644 --- a/tests/e2e/test_gemma2.py +++ b/tests/e2e/test_gemma2.py @@ -2,7 +2,6 @@ E2E tests for gemma2 """ -import logging import os from pathlib import Path @@ -13,8 +12,9 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_gemma3_text.py b/tests/e2e/test_gemma3_text.py index 5cbde04d10..9873de6279 100644 --- a/tests/e2e/test_gemma3_text.py +++ b/tests/e2e/test_gemma3_text.py @@ -2,7 +2,6 @@ E2E tests for gemma3_text """ -import logging import os from pathlib import Path @@ -13,8 +12,9 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py index d3e37fb3fc..352372e1ec 100644 --- a/tests/e2e/test_llama.py +++ b/tests/e2e/test_llama.py @@ -2,7 +2,6 @@ E2E tests for llama """ -import logging import os from axolotl.cli.args import TrainerCliArgs @@ -10,10 +9,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from tests.e2e.utils import check_model_output_exists -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py index 647285e464..9d0e4d7a6f 100644 --- a/tests/e2e/test_llama_pretrain.py +++ b/tests/e2e/test_llama_pretrain.py @@ -2,7 +2,6 @@ E2E tests for llama pretrain """ -import logging import os import pytest @@ -12,10 +11,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, check_tensorboard -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_llama_vision.py b/tests/e2e/test_llama_vision.py index e1e496ccf8..890f275698 100644 --- a/tests/e2e/test_llama_vision.py +++ b/tests/e2e/test_llama_vision.py @@ -2,7 +2,6 @@ E2E tests for lora llama """ -import logging import os import unittest @@ -11,10 +10,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index b02fe3d447..02d2868dac 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -2,7 +2,6 @@ E2E tests for lora llama """ -import logging import os import unittest @@ -11,10 +10,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_mamba.py b/tests/e2e/test_mamba.py index f49b53987d..92397ab88f 100644 --- a/tests/e2e/test_mamba.py +++ b/tests/e2e/test_mamba.py @@ -2,7 +2,6 @@ E2E tests for lora llama """ -import logging import os import unittest @@ -13,10 +12,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_mistral.py b/tests/e2e/test_mistral.py index ba8cf28962..ac57848435 100644 --- a/tests/e2e/test_mistral.py +++ b/tests/e2e/test_mistral.py @@ -2,7 +2,6 @@ E2E tests for lora llama """ -import logging import os import unittest @@ -13,10 +12,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_mixtral.py b/tests/e2e/test_mixtral.py index 4e0693b949..329428473f 100644 --- a/tests/e2e/test_mixtral.py +++ b/tests/e2e/test_mixtral.py @@ -2,7 +2,6 @@ E2E tests for mixtral """ -import logging import os import unittest @@ -14,10 +13,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index 91f45b762f..291ed3d6a1 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -2,7 +2,6 @@ E2E tests for custom optimizers using Llama """ -import logging import os import unittest @@ -11,10 +10,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, require_torch_2_5_1, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_packing_loss.py b/tests/e2e/test_packing_loss.py index 73716f44bb..52e27a2c17 100644 --- a/tests/e2e/test_packing_loss.py +++ b/tests/e2e/test_packing_loss.py @@ -2,7 +2,6 @@ E2E tests for packed training """ -import logging import os import unittest @@ -13,10 +12,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_tensorboard, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_phi.py b/tests/e2e/test_phi.py index f531a17c50..349ae9efba 100644 --- a/tests/e2e/test_phi.py +++ b/tests/e2e/test_phi.py @@ -2,7 +2,6 @@ E2E tests for lora llama """ -import logging import os import unittest @@ -11,10 +10,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_process_reward_model_smollm2.py b/tests/e2e/test_process_reward_model_smollm2.py index 446facdb0d..0673409ab2 100644 --- a/tests/e2e/test_process_reward_model_smollm2.py +++ b/tests/e2e/test_process_reward_model_smollm2.py @@ -2,7 +2,6 @@ E2E tests for process reward model w/ lora llama """ -import logging import os import unittest @@ -11,10 +10,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, check_tensorboard, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_qwen.py b/tests/e2e/test_qwen.py index 39d55603f5..1f57c6ae18 100644 --- a/tests/e2e/test_qwen.py +++ b/tests/e2e/test_qwen.py @@ -2,7 +2,6 @@ E2E tests for qwen """ -import logging import os from pathlib import Path @@ -12,8 +11,9 @@ from transformers.testing_utils import get_torch_dist_unique_port from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger -LOG = logging.getLogger("axolotl.tests.qwen") +LOG = get_logger("axolotl.tests.qwen") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_reward_model_smollm2.py b/tests/e2e/test_reward_model_smollm2.py index 240c4b3924..31938ea589 100644 --- a/tests/e2e/test_reward_model_smollm2.py +++ b/tests/e2e/test_reward_model_smollm2.py @@ -2,7 +2,6 @@ E2E tests for reward model lora llama """ -import logging import os import unittest @@ -11,10 +10,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, check_tensorboard, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/e2e/test_schedulers.py b/tests/e2e/test_schedulers.py index 694bb21e81..12783cfb7d 100644 --- a/tests/e2e/test_schedulers.py +++ b/tests/e2e/test_schedulers.py @@ -2,7 +2,6 @@ E2E tests for custom schedulers using Llama """ -import logging import os import unittest @@ -11,10 +10,11 @@ from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from .utils import check_model_output_exists, with_temp_dir -LOG = logging.getLogger("axolotl.tests.e2e") +LOG = get_logger("axolotl.tests.e2e") os.environ["WANDB_DISABLED"] = "true" diff --git a/tests/integrations/test_liger.py b/tests/integrations/test_liger.py index cbe1408b81..2d6abe311b 100644 --- a/tests/integrations/test_liger.py +++ b/tests/integrations/test_liger.py @@ -2,8 +2,6 @@ config validation tests for swiglu args """ -# pylint: disable=duplicate-code -import logging from typing import Optional import pytest @@ -11,6 +9,11 @@ from axolotl.utils.config import prepare_plugins, validate_config from axolotl.utils.dict import DictDefault +# pylint: disable=duplicate-code +from axolotl.utils.logging import get_logger + +LOG = get_logger("axolotl.integrations.test_liger") + @pytest.fixture(name="minimal_liger_cfg") def fixture_cfg(): @@ -41,7 +44,7 @@ class TestValidation: @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): - caplog.set_level(logging.WARNING) + caplog.set_level("WARNING") self._caplog = caplog def test_deprecated_swiglu(self, minimal_liger_cfg): @@ -52,9 +55,7 @@ def test_deprecated_swiglu(self, minimal_liger_cfg): | minimal_liger_cfg ) - with self._caplog.at_level( - logging.WARNING, logger="axolotl.integrations.liger.args" - ): + with self._caplog.at_level("WARNING", logger="axolotl.integrations.liger.args"): prepare_plugins(test_cfg) updated_cfg = validate_config(test_cfg) # TODO this test is brittle in CI diff --git a/tests/patched/test_validation.py b/tests/patched/test_validation.py index 683db61b28..ec75027940 100644 --- a/tests/patched/test_validation.py +++ b/tests/patched/test_validation.py @@ -1,7 +1,6 @@ # pylint: disable=too-many-lines """Module for testing the validation module""" -import logging import os import warnings from typing import Optional @@ -12,6 +11,7 @@ from axolotl.utils import is_comet_available from axolotl.utils.config import validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from axolotl.utils.mlflow_ import setup_mlflow_env_vars from axolotl.utils.models import check_model_config from axolotl.utils.schemas.config import AxolotlConfigWCapabilities @@ -19,6 +19,8 @@ warnings.filterwarnings("error") +LOG = get_logger(__name__) + @pytest.fixture(name="minimal_cfg") def fixture_cfg(): @@ -80,7 +82,7 @@ def test_zero3_qlora_use_reentrant_false(self, minimal_cfg): | minimal_cfg ) - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(test_cfg) assert ( "qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values" @@ -218,7 +220,7 @@ def test_batch_size_unused_warning(self): } ) - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(cfg) assert "batch_size is not recommended" in self._caplog.records[0].message @@ -513,7 +515,7 @@ def test_flash_optimum(self, minimal_cfg): | minimal_cfg ) - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(cfg) assert any( "BetterTransformers probably doesn't work with PEFT adapters" @@ -531,7 +533,7 @@ def test_flash_optimum(self, minimal_cfg): | minimal_cfg ) - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(cfg) assert any( "probably set bfloat16 or float16" in record.message @@ -577,7 +579,7 @@ def test_adamw_hyperparams(self, minimal_cfg): | minimal_cfg ) - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(cfg) assert any( "adamw hyperparameters found, but no adamw optimizer set" @@ -595,7 +597,7 @@ def test_adamw_hyperparams(self, minimal_cfg): | minimal_cfg ) - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(cfg) assert any( "adamw hyperparameters found, but no adamw optimizer set" @@ -654,7 +656,7 @@ def test_packing(self, minimal_cfg): ) | minimal_cfg ) - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(cfg) assert any( "`pad_to_sequence_len: true` is recommended when using sample_packing" @@ -673,7 +675,7 @@ def test_packing_autoset(self, minimal_cfg): ) | minimal_cfg ) - with self._caplog.at_level(logging.INFO): + with self._caplog.at_level("INFO"): cfg = validate_config(cfg) assert any( "Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing" @@ -1109,7 +1111,7 @@ def test_unfrozen_parameters_w_peft_layers_to_transform(self, minimal_cfg): def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_cfg): cfg = DictDefault({"hub_model_id": "test", "save_strategy": "no"}) | minimal_cfg - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(cfg) assert len(self._caplog.records) == 1 @@ -1118,7 +1120,7 @@ def test_hub_model_id_save_value_warns_random_value(self, minimal_cfg): DictDefault({"hub_model_id": "test", "save_strategy": "test"}) | minimal_cfg ) - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(cfg) assert len(self._caplog.records) == 1 @@ -1128,7 +1130,7 @@ def test_hub_model_id_save_value_steps(self, minimal_cfg): | minimal_cfg ) - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(cfg) assert len(self._caplog.records) == 0 @@ -1138,28 +1140,28 @@ def test_hub_model_id_save_value_epochs(self, minimal_cfg): | minimal_cfg ) - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(cfg) assert len(self._caplog.records) == 0 def test_hub_model_id_save_value_none(self, minimal_cfg): cfg = DictDefault({"hub_model_id": "test", "save_strategy": None}) | minimal_cfg - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(cfg) assert len(self._caplog.records) == 0 def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg): cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): validate_config(cfg) assert len(self._caplog.records) == 0 def test_dpo_beta_deprecation(self, minimal_cfg): cfg = DictDefault({"dpo_beta": 0.2}) | minimal_cfg - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): new_cfg = validate_config(cfg) assert new_cfg["rl_beta"] == 0.2 assert new_cfg["dpo_beta"] is None @@ -1175,7 +1177,7 @@ def test_eval_strategy_remap(self, minimal_cfg): | minimal_cfg ) - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): new_cfg = validate_config(cfg) assert new_cfg.eval_strategy == "steps" assert ( @@ -1441,7 +1443,7 @@ def test_wandb_set_run_id_to_name(self, minimal_cfg): | minimal_cfg ) - with self._caplog.at_level(logging.WARNING): + with self._caplog.at_level("WARNING"): new_cfg = validate_config(cfg) assert any( "wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead." diff --git a/tests/prompt_strategies/messages/test_chat.py b/tests/prompt_strategies/messages/test_chat.py index 2681bb7431..5d7a4f18b6 100644 --- a/tests/prompt_strategies/messages/test_chat.py +++ b/tests/prompt_strategies/messages/test_chat.py @@ -1,16 +1,23 @@ -""" -tests for chat_template prompt strategy -""" +"""Module for testing chat message internals.""" -# pylint: disable=duplicate-code -import logging +import os import unittest +from transformers import AutoTokenizer + +from axolotl.core.chat.messages import ( + ChatFormattedChats, + Chats, + MessageContents, + MessageContentTypes, + MessageRoles, + Messages, +) from axolotl.prompt_strategies.messages.chat import load from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger -logging.basicConfig(level=logging.DEBUG) -LOG = logging.getLogger("axolotl") +LOG = get_logger("axolotl") class TestMessagesChatLlama3: diff --git a/tests/prompt_strategies/test_jinja_template_analyzer.py b/tests/prompt_strategies/test_jinja_template_analyzer.py index f666c738c9..497e1f390a 100644 --- a/tests/prompt_strategies/test_jinja_template_analyzer.py +++ b/tests/prompt_strategies/test_jinja_template_analyzer.py @@ -1,15 +1,16 @@ -""" -tests for jinja_template_analyzer -""" +"""Module for testing jinja template analyzer.""" -import logging +import os import pytest -from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer +from axolotl.prompt_strategies.jinja_template_analyzer import ( + PromptComponentStatus, + PromptTemplateAnalyzer, +) +from axolotl.utils.logging import get_logger -logging.basicConfig(level=logging.DEBUG) -LOG = logging.getLogger("axolotl") +LOG = get_logger("axolotl") class TestJinjaTemplateAnalyzer: @@ -80,7 +81,7 @@ def test_nested_property_access(self): LOG.info("Testing nested property access") template = """{{ user.profile.name }}{{ user.settings['preference'] }}""" - analyzer = JinjaTemplateAnalyzer(template) + analyzer = PromptTemplateAnalyzer(template) variables = analyzer.get_template_variables() assert "user" in variables @@ -99,7 +100,7 @@ def test_loop_variable_handling(self): {% endfor %} {% endfor %} """ - analyzer = JinjaTemplateAnalyzer(template) + analyzer = PromptTemplateAnalyzer(template) analysis = analyzer.analyze_template() assert analysis["items"]["is_iterated"] @@ -115,7 +116,7 @@ def test_conditional_variable_usage(self): {{ debug_info }} {% endif %} """ - analyzer = JinjaTemplateAnalyzer(template) + analyzer = PromptTemplateAnalyzer(template) analysis = analyzer.analyze_template() assert analysis["user"]["is_conditional"] @@ -132,7 +133,7 @@ def test_complex_expressions(self): {{ messages | length > 0 and messages[0].content }} {{ data['key'].nested['value'] }} """ - analyzer = JinjaTemplateAnalyzer(template) + analyzer = PromptTemplateAnalyzer(template) variables = analyzer.get_template_variables() assert "user" in variables diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index 3f16bc9177..be2d734d09 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -1,26 +1,21 @@ -"""Module for testing prompt tokenizers.""" +"""Testing for prompt_tokenizers.py""" -import json -import logging -from pathlib import Path +import unittest -from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter -from axolotl.prompt_strategies.alpaca_w_system import ( - InstructionWSystemPromptTokenizingStrategy, - SystemDataPrompter, -) -from axolotl.prompt_strategies.llama2_chat import ( - Llama2ChatPrompter, - LLama2ChatTokenizingStrategy, -) -from axolotl.prompt_strategies.orpo.chat_template import load -from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy -from axolotl.prompters import AlpacaPrompter, PromptStyle -from axolotl.utils.dict import DictDefault +import pytest +from transformers import AutoTokenizer -from tests.hf_offline_utils import enable_hf_offline +from axolotl.prompt_strategies.alpaca import AlpacaPrompter +from axolotl.prompt_tokenizers import ( + AlpacaPromptTokenizingStrategy, + InstructionPromptTokenizingStrategy, + PromptTokenizingStrategy, + ShareGPTPromptTokenizingStrategy, +) +from axolotl.prompters import AlpacaInstructionPrompter, PromptStyle, ShareGPTPrompter +from axolotl.utils.logging import get_logger -LOG = logging.getLogger("axolotl") +LOG = get_logger("axolotl") test_data = { "multi_turn_sys": { @@ -61,7 +56,7 @@ class TestPromptTokenizationStrategies: Test class for prompt tokenization strategies. """ - @enable_hf_offline + @pytest.mark.enable_hf_offline def test_no_sys_prompt(self, tokenizer_huggyllama_w_special_tokens): """ tests the interface between the user and assistant parts @@ -83,7 +78,7 @@ def test_no_sys_prompt(self, tokenizer_huggyllama_w_special_tokens): assert example["labels"][world_idx] == 3186 assert example["labels"][world_idx - 1] == -100 - @enable_hf_offline + @pytest.mark.enable_hf_offline def test_alpaca(self, tokenizer_huggyllama_w_special_tokens): """ tests the interface between the user and assistant parts @@ -108,7 +103,7 @@ class TestInstructionWSystemPromptTokenizingStrategy: Test class for prompt tokenization strategies with sys prompt from the dataset """ - @enable_hf_offline + @pytest.mark.enable_hf_offline def test_system_alpaca(self, tokenizer_huggyllama_w_special_tokens): prompter = SystemDataPrompter(PromptStyle.CHAT.value) strat = InstructionWSystemPromptTokenizingStrategy( @@ -139,7 +134,7 @@ class Llama2ChatTokenizationTest: Test class for prompt tokenization strategies with sys prompt from the dataset """ - @enable_hf_offline + @pytest.mark.enable_hf_offline def test_llama2_chat_integration(self, tokenizer_llama2_7b): with open( Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8" @@ -213,7 +208,7 @@ def compare_with_transformers_integration(self, tokenizer_llama2_7b): class OrpoTokenizationTest: """test case for the ORPO tokenization""" - @enable_hf_offline + @pytest.mark.enable_hf_offline def test_orpo_integration( self, tokenizer_mistral_7b_instruct_chatml, diff --git a/update_logging.py b/update_logging.py new file mode 100644 index 0000000000..461ab9a849 --- /dev/null +++ b/update_logging.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +""" +Script to update all test files to use the standardized logging approach. +""" + +import os +import re +import sys +from pathlib import Path + + +def update_file(file_path, dry_run=False): + """Update a file to use the standardized logging approach.""" + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + + # Keep track of changes + changes_made = False + + # Replace the import if it's a standalone import + import_pattern = r"import\s+logging\s*\n" + if re.search(import_pattern, content): + new_content = re.sub( + import_pattern, "from axolotl.utils.logging import get_logger\n", content + ) + changes_made = new_content != content + content = new_content + + # Replace the logger initialization + logger_pattern = r'LOG\s*=\s*logging\.getLogger\([\'"]([^\'"]+)[\'"]\)' + if re.search(logger_pattern, content): + new_content = re.sub(logger_pattern, r'LOG = get_logger("\1")', content) + changes_made = changes_made or (new_content != content) + content = new_content + + # Remove logging.basicConfig if present + basicconfig_pattern = r"logging\.basicConfig\([^\)]+\)\s*\n" + if re.search(basicconfig_pattern, content): + new_content = re.sub(basicconfig_pattern, "", content) + changes_made = changes_made or (new_content != content) + content = new_content + + if changes_made and not dry_run: + with open(file_path, "w", encoding="utf-8") as f: + f.write(content) + + return changes_made + + +def find_and_update_files(base_dir, dry_run=False): + """Find and update all test files that use logging.""" + updated_files = [] + skipped_files = [] + + for root, _, files in os.walk(base_dir): + for file in files: + if file.endswith(".py"): + file_path = os.path.join(root, file) + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + + if ( + "import logging" in content + or "logging.getLogger" in content + or "logging.basicConfig" in content + ): + if "from axolotl.utils.logging import get_logger" in content: + if "import logging" in content: + # Both imports present, probably needs manual inspection + skipped_files.append(file_path) + else: + # Already using the standardized logger + pass + else: + if update_file(file_path, dry_run): + updated_files.append(file_path) + + return updated_files, skipped_files + + +if __name__ == "__main__": + dry_run = "--dry-run" in sys.argv + if dry_run: + sys.argv.remove("--dry-run") + + if len(sys.argv) > 1: + base_dir = sys.argv[1] + else: + base_dir = "tests" + + updated_files, skipped_files = find_and_update_files(base_dir, dry_run) + + if dry_run: + print(f"DRY RUN: Would update {len(updated_files)} files:") + else: + print(f"Updated {len(updated_files)} files:") + + for file in updated_files: + rel_path = os.path.relpath(file, os.getcwd()) + print(f" - {rel_path}") + + if skipped_files: + print(f"\nSkipped {len(skipped_files)} files (need manual inspection):") + for file in skipped_files: + rel_path = os.path.relpath(file, os.getcwd()) + print(f" - {rel_path}") From 07ae9958629bd41b6426cafd7831386fa3365fde Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Mon, 19 May 2025 17:09:41 +0000 Subject: [PATCH 10/29] seems to be working --- examples/llama-3/lora-1b.yml | 3 +- .../core/trainers/mixins/sequence_parallel.py | 67 ------------------- src/axolotl/utils/logging.py | 6 +- src/axolotl/utils/models.py | 33 ++++++--- 4 files changed, 29 insertions(+), 80 deletions(-) diff --git a/examples/llama-3/lora-1b.yml b/examples/llama-3/lora-1b.yml index c31a9f39a4..acc17e21f2 100644 --- a/examples/llama-3/lora-1b.yml +++ b/examples/llama-3/lora-1b.yml @@ -5,7 +5,7 @@ base_model: NousResearch/Llama-3.2-1B datasets: - path: teknium/GPT4-LLM-Cleaned type: alpaca -dataset_prepared_path: last_run_prepared + val_set_size: 0.1 output_dir: ./outputs/lora-out @@ -38,6 +38,7 @@ wandb_log_model: gradient_accumulation_steps: 2 micro_batch_size: 2 num_epochs: 1 + optimizer: adamw_8bit lr_scheduler: cosine learning_rate: 0.0002 diff --git a/src/axolotl/core/trainers/mixins/sequence_parallel.py b/src/axolotl/core/trainers/mixins/sequence_parallel.py index 0e63f7bfc2..0f30458cdc 100644 --- a/src/axolotl/core/trainers/mixins/sequence_parallel.py +++ b/src/axolotl/core/trainers/mixins/sequence_parallel.py @@ -1,80 +1,13 @@ """Module for Axolotl trainer sequence parallelism mixin""" -import functools -from axolotl.utils.logging import get_logger - -import torch import torch.distributed as dist from datasets import Dataset from torch.utils.data import DistributedSampler, Sampler -from axolotl.utils.schemas.enums import RingAttnFunc from axolotl.monkeypatch.attention.ring_attn import ( get_ring_attn_group, ) -LOG = get_logger(__name__) - - -def apply_sequence_parallelism( - batch: dict[str, torch.Tensor], - local_rank: int, - local_world_size: int, - ring_attn_func: RingAttnFunc, -) -> dict[str, torch.Tensor]: - """ - Apply sequence parallelism slicing to a batch. - - Args: - batch: Batch dictionary (e.g., input_ids, attention_mask, etc.) - local_rank: Local rank in the sequence parallel group - local_world_size: World size of the sequence parallel group - ring_attn_func: The ring attention function to use - - Returns: - Sliced batch dictionary. - """ - # Update ring attention params if needed - if batch.get("position_ids") is not None: - update_ring_attn_params(position_ids=batch["position_ids"]) - - # Slice batch for sequence parallel processing - total_seq_len = batch["input_ids"].size(1) - for key in batch: - if ( - key in batch and - isinstance(batch[key], torch.Tensor) and - batch[key].dim() > 1 and - batch[key].size(1) == total_seq_len - ): - - if ring_attn_func in [ - RingAttnFunc.VARLEN_LLAMA3, - RingAttnFunc.BATCH_RING, - ]: - # Split in sequential fashion and grab this rank's chunk - batch[key] = ( - batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous() - ) - elif ring_attn_func is RingAttnFunc.BATCH_ZIGZAG: - chunks = batch[key].chunk(2 * local_world_size, dim=1) - - # Take rank's chunk and opposing chunk for zigzag pattern - selected_chunks = [ - chunks[local_rank], - chunks[2 * local_world_size - local_rank - 1], - ] - batch[key] = torch.cat(selected_chunks, dim=1).contiguous() - elif ring_attn_func is RingAttnFunc.BATCH_STRIPE: - # Split into striped data and stack - tensor = torch.stack( - batch[key].split(local_world_size, dim=1), - dim=1, - ).transpose(1, 2) - batch[key] = tensor[:, local_rank].contiguous() - - return batch - class SequenceParallelMixin: """ diff --git a/src/axolotl/utils/logging.py b/src/axolotl/utils/logging.py index 55d616712e..ab004ddc31 100644 --- a/src/axolotl/utils/logging.py +++ b/src/axolotl/utils/logging.py @@ -2,9 +2,11 @@ logging helpers to only log on main process """ +import logging import os + from axolotl.utils.distributed import is_main_process -import logging + # Adapted from Accelerate # https://github.com/huggingface/accelerate/blob/main/src/accelerate/logging.py @@ -16,7 +18,7 @@ class MultiProcessAdapter(logging.LoggerAdapter): @staticmethod def _should_log(main_process_only): return not main_process_only or ( - main_process_only and is_main_process(use_environ=True) + main_process_only and is_main_process(use_environ=False) ) def log(self, level, msg, *args, **kwargs): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 2035afc64b..c545b8bcb7 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -459,13 +459,13 @@ def load_tokenizer(cfg): ) LOG.debug( - f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}", main_process_only=True + f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}", ) LOG.debug( - f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}", main_process_only=True + f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}", ) LOG.debug( - f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}", main_process_only=True + f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}", ) if cfg.chat_template: @@ -795,7 +795,9 @@ def patch_llama_derived_model(self): hijack_llama_attention, ) - LOG.info("patching with xformers attention", main_process_only=True) + LOG.info( + "patching with xformers attention", + ) hijack_llama_attention() elif self.cfg.sample_packing: from axolotl.monkeypatch.llama_patch_multipack import ( @@ -1098,11 +1100,15 @@ def _configure_zero3_memory_efficient_loading(): ) if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): - LOG.info("patching with SwiGLU", main_process_only=True) + LOG.info( + "patching with SwiGLU", + ) replace_llama_mlp_with_swiglu(self.model) if self.cfg.flash_attn_fuse_qkv: - LOG.info("patching with fused QKV", main_process_only=True) + LOG.info( + "patching with fused QKV", + ) replace_llama_qkv_with_fused(self.model) elif self.model_type == "MambaLMHeadModel": # FIXME this is janky at best and hacked together to make it work @@ -1303,7 +1309,9 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: skip_move_to_device = self.build_model(qlora_fsdp) PLUGIN_MANAGER.post_model_build(self.cfg, self.model) except Exception as err: # pylint: disable=broad-exception-caught - LOG.exception(err, main_process_only=True) + LOG.exception( + err, + ) raise err if isinstance(self.model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp: @@ -1380,7 +1388,8 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: if should_convert: LOG.info( - "Converting modules to %s", self.cfg.torch_dtype, main_process_only=True + "Converting modules to %s", + self.cfg.torch_dtype, ) self.convert_embedding_modules_dtype( embedding_modules=embedding_modules, @@ -1510,7 +1519,9 @@ def load_llama_adapter(model, cfg): ) if cfg.lora_model_dir: - LOG.debug("Loading pretrained PEFT - llama_adapter", main_process_only=True) + LOG.debug( + "Loading pretrained PEFT - llama_adapter", + ) model = PeftModel.from_pretrained( model, cfg.lora_model_dir, @@ -1632,7 +1643,9 @@ def load_lora(model, cfg, inference=False, config_only=False): setup_quantized_meta_for_peft(model) if cfg.lora_model_dir: - LOG.debug("Loading pretrained PEFT - LoRA", main_process_only=True) + LOG.debug( + "Loading pretrained PEFT - LoRA", + ) model_kwargs: Any = {} if cfg.lora_on_cpu: model_kwargs["max_memory"] = {"cpu": "256GiB"} From 589367d56ff54e6c6fded225dba7352b693c48da Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 22 May 2025 11:23:45 +0100 Subject: [PATCH 11/29] comments --- src/axolotl/cli/checks.py | 3 +- src/axolotl/cli/config.py | 2 +- src/axolotl/cli/evaluate.py | 2 +- src/axolotl/cli/inference.py | 2 +- src/axolotl/cli/merge_lora.py | 2 +- src/axolotl/cli/merge_sharded_fsdp_weights.py | 2 +- src/axolotl/cli/preprocess.py | 16 +- src/axolotl/cli/utils.py | 2 +- src/axolotl/common/datasets.py | 18 +- src/axolotl/core/chat/messages.py | 5 - src/axolotl/core/trainer_builder.py | 44 +-- src/axolotl/core/trainers/base.py | 16 +- src/axolotl/core/trainers/grpo/__init__.py | 2 +- src/axolotl/core/trainers/mixins/optimizer.py | 19 +- .../core/trainers/mixins/rng_state_loader.py | 3 +- src/axolotl/core/trainers/mixins/scheduler.py | 11 +- src/axolotl/datasets.py | 7 +- src/axolotl/integrations/base.py | 1 - .../cut_cross_entropy/__init__.py | 2 +- .../integrations/cut_cross_entropy/args.py | 3 +- src/axolotl/integrations/grokfast/__init__.py | 4 +- src/axolotl/integrations/liger/args.py | 3 +- .../integrations/llm_compressor/plugin.py | 2 +- src/axolotl/monkeypatch/accelerate/fsdp2.py | 3 +- .../monkeypatch/btlm_attn_hijack_flash.py | 3 +- .../monkeypatch/llama_attn_hijack_flash.py | 20 +- .../monkeypatch/mistral_attn_hijack_flash.py | 28 +- src/axolotl/monkeypatch/relora.py | 14 +- src/axolotl/monkeypatch/trainer/lr.py | 4 +- .../monkeypatch/trainer_accelerator_args.py | 8 +- src/axolotl/monkeypatch/trainer_eval_guard.py | 8 +- src/axolotl/monkeypatch/trainer_fsdp_optim.py | 8 +- .../monkeypatch/transformers_fa_utils.py | 3 +- src/axolotl/prompt_strategies/__init__.py | 2 +- src/axolotl/prompt_strategies/base.py | 1 + .../bradley_terry/__init__.py | 2 +- .../prompt_strategies/messages/__init__.py | 1 + src/axolotl/prompt_strategies/metharme.py | 2 +- src/axolotl/prompt_strategies/pygmalion.py | 4 +- src/axolotl/prompt_tokenizers.py | 20 +- src/axolotl/prompters.py | 11 +- src/axolotl/utils/callbacks/__init__.py | 29 +- src/axolotl/utils/callbacks/comet_.py | 2 +- src/axolotl/utils/callbacks/lisa.py | 3 +- src/axolotl/utils/callbacks/mlflow_.py | 2 +- src/axolotl/utils/chat_templates.py | 5 +- src/axolotl/utils/comet_.py | 2 +- src/axolotl/utils/config/__init__.py | 64 ++--- src/axolotl/utils/data/pretraining.py | 2 +- src/axolotl/utils/data/rl.py | 36 +-- src/axolotl/utils/data/sft.py | 2 - src/axolotl/utils/schemas/config.py | 263 +++++++++--------- src/axolotl/utils/schemas/deprecated.py | 3 +- src/axolotl/utils/schemas/integrations.py | 3 +- src/axolotl/utils/schemas/model.py | 4 +- src/axolotl/utils/schemas/training.py | 2 +- src/axolotl/utils/schemas/utils.py | 10 +- src/axolotl/utils/tokenization.py | 4 +- tests/prompt_strategies/messages/test_chat.py | 18 +- .../prompt_strategies/test_chat_templates.py | 1 - .../test_chat_templates_thinking.py | 2 +- .../test_jinja_template_analyzer.py | 21 +- tests/test_prompt_tokenizers.py | 41 +-- update_logging.py | 106 ------- 64 files changed, 420 insertions(+), 518 deletions(-) delete mode 100644 update_logging.py diff --git a/src/axolotl/cli/checks.py b/src/axolotl/cli/checks.py index ddec6e761d..10086c2a4c 100644 --- a/src/axolotl/cli/checks.py +++ b/src/axolotl/cli/checks.py @@ -1,6 +1,5 @@ """Various checks for Axolotl CLI.""" -from axolotl.utils.logging import get_logger import os from pathlib import Path @@ -8,6 +7,8 @@ from huggingface_hub import HfApi from huggingface_hub.utils import LocalTokenNotFoundError +from axolotl.utils.logging import get_logger + LOG = get_logger(__name__) diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index 0cc72cc91f..9718d52541 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -1,7 +1,6 @@ """Configuration loading and processing.""" import json -from axolotl.utils.logging import get_logger import os import tempfile from pathlib import Path @@ -22,6 +21,7 @@ validate_config, ) from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from axolotl.utils.mlflow_ import setup_mlflow_env_vars from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env from axolotl.utils.wandb_ import setup_wandb_env_vars diff --git a/src/axolotl/cli/evaluate.py b/src/axolotl/cli/evaluate.py index 2df6417919..f131f70830 100644 --- a/src/axolotl/cli/evaluate.py +++ b/src/axolotl/cli/evaluate.py @@ -1,6 +1,5 @@ """CLI to run evaluation on a model.""" -from axolotl.utils.logging import get_logger import os from pathlib import Path from typing import Union @@ -17,6 +16,7 @@ from axolotl.evaluate import evaluate from axolotl.utils import patch_optimized_env from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger LOG = get_logger(__name__) diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index 4ac0779106..b5bc158fa1 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -1,7 +1,6 @@ """CLI to run inference on a trained model.""" import importlib -from axolotl.utils.logging import get_logger import sys from pathlib import Path from threading import Thread @@ -22,6 +21,7 @@ get_chat_template_from_config, ) from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger LOG = get_logger(__name__) diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index a5da077d7b..2e59d25374 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -1,6 +1,5 @@ """CLI to merge a trained LoRA into a base model.""" -from axolotl.utils.logging import get_logger from pathlib import Path from typing import Union @@ -13,6 +12,7 @@ from axolotl.cli.config import load_cfg from axolotl.cli.utils import load_model_and_tokenizer from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger LOG = get_logger(__name__) diff --git a/src/axolotl/cli/merge_sharded_fsdp_weights.py b/src/axolotl/cli/merge_sharded_fsdp_weights.py index 490a8fe163..297d7946e4 100644 --- a/src/axolotl/cli/merge_sharded_fsdp_weights.py +++ b/src/axolotl/cli/merge_sharded_fsdp_weights.py @@ -1,7 +1,6 @@ """CLI to merge sharded FSDP model checkpoints into a single combined checkpoint.""" import json -from axolotl.utils.logging import get_logger import os import shutil from pathlib import Path @@ -27,6 +26,7 @@ from axolotl.cli.args import TrainerCliArgs from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.config import load_cfg +from axolotl.utils.logging import get_logger LOG = get_logger(__name__) diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index 8440718841..9f96f5cc17 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -1,6 +1,5 @@ """CLI to run preprocessing of a dataset.""" -from axolotl.utils.logging import get_logger import warnings from pathlib import Path from typing import Union @@ -20,6 +19,7 @@ from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.integrations.base import PluginManager from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from axolotl.utils.trainer import disable_datasets_caching LOG = get_logger(__name__) @@ -39,10 +39,10 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None: if not cfg.dataset_prepared_path: msg = ( - Fore.RED + - "preprocess CLI called without dataset_prepared_path set, " + - f"using default path: {DEFAULT_DATASET_PREPARED_PATH}" + - Fore.RESET + Fore.RED + + "preprocess CLI called without dataset_prepared_path set, " + + f"using default path: {DEFAULT_DATASET_PREPARED_PATH}" + + Fore.RESET ) LOG.warning(msg) cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH @@ -73,9 +73,9 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None: # fmt: on LOG.info( - Fore.GREEN + - f"Success! Preprocessed data path: `dataset_prepared_path: {cfg.dataset_prepared_path}`" + - Fore.RESET + Fore.GREEN + + f"Success! Preprocessed data path: `dataset_prepared_path: {cfg.dataset_prepared_path}`" + + Fore.RESET ) diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py index 9871dd827f..8ad188433c 100644 --- a/src/axolotl/cli/utils.py +++ b/src/axolotl/cli/utils.py @@ -4,7 +4,6 @@ import dataclasses import hashlib import json -from axolotl.utils.logging import get_logger from functools import wraps from pathlib import Path from types import NoneType @@ -21,6 +20,7 @@ ) from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from axolotl.utils.models import load_model, load_processor, load_tokenizer LOG = get_logger(__name__) diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index c6933aaf18..e7929831e1 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -1,6 +1,5 @@ """Dataset loading utilities.""" -from axolotl.utils.logging import get_logger import math import random from dataclasses import dataclass @@ -13,6 +12,7 @@ from axolotl.utils.data import prepare_dataset from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from axolotl.utils.models import load_processor, load_tokenizer from axolotl.utils.schemas.enums import RLType from axolotl.utils.tokenization import check_dataset_labels @@ -67,10 +67,10 @@ def load_datasets( tokenizer = load_tokenizer(cfg) processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None preprocess_iterable = ( - cli_args and - hasattr(cli_args, "iterable") and - cli_args.iterable is not None and - cli_args.iterable + cli_args + and hasattr(cli_args, "iterable") + and cli_args.iterable is not None + and cli_args.iterable ) train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset( @@ -81,10 +81,10 @@ def load_datasets( ) if cli_args and ( - cli_args.debug or - cfg.debug or - cli_args.debug_text_only or - int(cli_args.debug_num_examples) > 0 + cli_args.debug + or cfg.debug + or cli_args.debug_text_only + or int(cli_args.debug_num_examples) > 0 ): LOG.info("check_dataset_labels...") diff --git a/src/axolotl/core/chat/messages.py b/src/axolotl/core/chat/messages.py index 655e4ce93c..923b177c1f 100644 --- a/src/axolotl/core/chat/messages.py +++ b/src/axolotl/core/chat/messages.py @@ -9,10 +9,6 @@ from pydantic import BaseModel from transformers import PreTrainedTokenizer -from axolotl.utils.logging import get_logger - -LOG = get_logger(__name__) - class MessageRoles(str, Enum): """ @@ -160,7 +156,6 @@ def tokenized( len(input_ids) : len(input_ids) + len(pending_input_ids) ] if new_pending_inputs != pending_input_ids: - # LOG.warning("tokenization mismatch from concatenation.") pending_input_ids = new_pending_inputs input_ids.extend(pending_input_ids) if pending_weight: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 2a1f1cb540..f95e3f2a4f 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -19,7 +19,6 @@ import importlib import importlib.util import inspect -from axolotl.utils.logging import get_logger import math import os import sys @@ -86,6 +85,7 @@ V2BatchSamplerDataCollatorForSeq2Seq, ) from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator +from axolotl.utils.logging import get_logger from axolotl.utils.models import ensure_dtype from axolotl.utils.schemas.enums import CustomSupportedOptimizers, RLType @@ -246,8 +246,8 @@ def get_callbacks(self): callbacks.append(ReLoRACallback(self.cfg)) if ( - hasattr(self.model, "use_bettertransformer") and - self.model.use_bettertransformer is True + hasattr(self.model, "use_bettertransformer") + and self.model.use_bettertransformer is True ): callbacks.append(SaveBetterTransformerModelCallback()) @@ -264,9 +264,9 @@ def get_post_trainer_create_callbacks(self, trainer): ) callbacks.append(LogPredictionCallback(self.cfg)) if ( - self.cfg.use_mlflow and - is_mlflow_available() and - self.cfg.eval_table_size > 0 + self.cfg.use_mlflow + and is_mlflow_available() + and self.cfg.eval_table_size > 0 ): LogPredictionCallback = log_prediction_callback_factory( trainer, self.tokenizer, "mlflow" @@ -526,16 +526,16 @@ def build(self, total_num_steps): ) training_arguments_kwargs["load_best_model_at_end"] = ( ( - self.cfg.load_best_model_at_end is not False or - self.cfg.early_stopping_patience - ) and - ( - (not self.cfg.test_datasets and self.cfg.val_set_size > 0) or - (self.cfg.test_datasets and self.cfg.val_set_size == 0) - ) and - self.cfg.save_steps and - self.cfg.eval_steps and - self.cfg.save_steps % self.cfg.eval_steps == 0 + self.cfg.load_best_model_at_end is not False + or self.cfg.early_stopping_patience + ) + and ( + (not self.cfg.test_datasets and self.cfg.val_set_size > 0) + or (self.cfg.test_datasets and self.cfg.val_set_size == 0) + ) + and self.cfg.save_steps + and self.cfg.eval_steps + and self.cfg.save_steps % self.cfg.eval_steps == 0 ) or False # handle ddp @@ -857,8 +857,8 @@ def build(self, total_num_steps): else: trainer_kwargs["tokenizer"] = self.tokenizer if ( - not (trainer_cls in [AxolotlRewardTrainer, AxolotlPRMTrainer]) and - self.cfg.datasets is not None + not (trainer_cls in [AxolotlRewardTrainer, AxolotlPRMTrainer]) + and self.cfg.datasets is not None ): trainer_kwargs["dataset_tags"] = [ d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir() @@ -888,8 +888,8 @@ def build_collator( ): if training_args.pretraining: if ( - self.cfg.pretraining_sample_concatenation is False or - self.cfg.micro_batch_size > 1 + self.cfg.pretraining_sample_concatenation is False + or self.cfg.micro_batch_size > 1 ): return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) return None @@ -923,8 +923,8 @@ def build_collator( elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES: collator = V2BatchSamplerDataCollatorForSeq2Seq elif ( - self.cfg.model_config_type in ["llama"] and - self.cfg.flash_attention is not True + self.cfg.model_config_type in ["llama"] + and self.cfg.flash_attention is not True ): collator = V2BatchSamplerDataCollatorForSeq2Seq else: diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index d0b952640e..b07626a895 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -4,7 +4,6 @@ from __future__ import annotations -from axolotl.utils.logging import get_logger import os from collections import defaultdict from functools import wraps @@ -35,6 +34,7 @@ sanitize_kwargs_for_ds_tagging, sanitize_kwargs_for_tagging, ) +from axolotl.utils.logging import get_logger from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths LOG = get_logger(__name__) @@ -231,8 +231,8 @@ def _prepare_dataloader( dataloader = DataLoader(dataset, **dataloader_params) if self.args.sample_packing and ( - (not is_eval and not self.args.pretraining) or - (is_eval and self.args.eval_sample_packing is not False) + (not is_eval and not self.args.pretraining) + or (is_eval and self.args.eval_sample_packing is not False) ): self.accelerator.even_batches = False @@ -289,9 +289,9 @@ def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader # Handle sample packing or sequence parallelism if ( - self.args.sample_packing and - self.args.eval_sample_packing is not False or - self.args.sequence_parallel_degree > 1 + self.args.sample_packing + and self.args.eval_sample_packing is not False + or self.args.sequence_parallel_degree > 1 ): # Get appropriate data collator self.data_collator = ( # pylint: disable=attribute-defined-outside-init @@ -560,8 +560,8 @@ def create_accelerator_and_postprocess(self): if self.is_fsdp_enabled: if ( - "limit_all_gathers" in self.args.fsdp_config and - self.args.fsdp_config["limit_all_gathers"] + "limit_all_gathers" in self.args.fsdp_config + and self.args.fsdp_config["limit_all_gathers"] ): self.accelerator.state.fsdp_plugin.limit_all_gathers = True diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index d5c39f8f42..f7b5004798 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -3,7 +3,6 @@ import importlib import inspect from typing import Any -from axolotl.utils.logging import get_logger from trl.trainer.grpo_trainer import RewardFunc @@ -13,6 +12,7 @@ AxolotlGRPOTrainer, ) from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from axolotl.utils.schemas.trl import TRLConfig LOG = get_logger(__name__) diff --git a/src/axolotl/core/trainers/mixins/optimizer.py b/src/axolotl/core/trainers/mixins/optimizer.py index 1a05083505..abb662706a 100644 --- a/src/axolotl/core/trainers/mixins/optimizer.py +++ b/src/axolotl/core/trainers/mixins/optimizer.py @@ -1,13 +1,12 @@ """Module for Axolotl trainer optimizer mixin""" -from axolotl.utils.logging import get_logger - from peft.optimizers import create_loraplus_optimizer from torch import nn from transformers.trainer import Trainer from transformers.utils import is_sagemaker_mp_enabled from axolotl.integrations.base import BaseOptimizerFactory +from axolotl.utils.logging import get_logger if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp @@ -107,20 +106,20 @@ def create_optimizer_grouped_parameters( def create_optimizer(self): if ( - self.args.loraplus_lr_ratio is None and - self.args.embedding_lr_scale is None and - self.args.embedding_lr is None and - self.args.lr_groups is None and - self.optimizer_cls_and_kwargs is None + self.args.loraplus_lr_ratio is None + and self.args.embedding_lr_scale is None + and self.args.embedding_lr is None + and self.args.lr_groups is None + and self.optimizer_cls_and_kwargs is None ): return super().create_optimizer() opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if ( - not self.optimizer and - self.optimizer_cls_and_kwargs is not None and - issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory) + not self.optimizer + and self.optimizer_cls_and_kwargs is not None + and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory) ): optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs self.optimizer = optimizer_factory_cls()( diff --git a/src/axolotl/core/trainers/mixins/rng_state_loader.py b/src/axolotl/core/trainers/mixins/rng_state_loader.py index cd70267a08..f248394b2e 100644 --- a/src/axolotl/core/trainers/mixins/rng_state_loader.py +++ b/src/axolotl/core/trainers/mixins/rng_state_loader.py @@ -6,7 +6,6 @@ TODO: Remove when upstream added PR to release """ -from axolotl.utils.logging import get_logger import os import random @@ -17,6 +16,8 @@ from transformers.trainer_pt_utils import set_rng_state_for_device from transformers.training_args import ParallelMode +from axolotl.utils.logging import get_logger + LOG = get_logger(__name__) diff --git a/src/axolotl/core/trainers/mixins/scheduler.py b/src/axolotl/core/trainers/mixins/scheduler.py index b79b935a49..90070ab78a 100644 --- a/src/axolotl/core/trainers/mixins/scheduler.py +++ b/src/axolotl/core/trainers/mixins/scheduler.py @@ -1,12 +1,11 @@ """Module for Axolotl trainer scheduler mixin""" -from axolotl.utils.logging import get_logger - import torch from torch.optim.lr_scheduler import LRScheduler, OneCycleLR from transformers.trainer import Trainer from axolotl.integrations.base import PluginManager +from axolotl.utils.logging import get_logger from axolotl.utils.schedulers import ( RexLR, get_cosine_schedule_with_min_lr, @@ -36,13 +35,13 @@ def create_scheduler( optimizer (torch.optim.Optimizer): The training optimizer """ use_cosine_quadratic = ( - self.args.lr_scheduler_type == "cosine" and - self.args.lr_quadratic_warmup is True + self.args.lr_scheduler_type == "cosine" + and self.args.lr_quadratic_warmup is True ) use_cosine_min_lr = ( - self.args.lr_scheduler_type == "cosine" and - self.args.cosine_min_lr_ratio is not None + self.args.lr_scheduler_type == "cosine" + and self.args.cosine_min_lr_ratio is not None ) # fmt: off diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index f6ec664d3a..9f1d9500d6 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -1,12 +1,13 @@ """Module containing Dataset functionality""" -from axolotl.utils.logging import get_logger import os from typing import List, Optional, Union import torch from datasets import Dataset, IterableDataset +from axolotl.utils.logging import get_logger + from .prompt_tokenizers import PromptTokenizingStrategy # We want this to be a wrapper for an existing dataset that we have loaded @@ -54,8 +55,8 @@ def process(self, dataset): map_kwargs["batch_size"] = 1_000 if ( - hasattr(self.prompt_tokenizer, "filter_rows") and - self.prompt_tokenizer.filter_rows + hasattr(self.prompt_tokenizer, "filter_rows") + and self.prompt_tokenizer.filter_rows ): dataset = dataset.filter( self.prompt_tokenizer.filter_rows, diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index 1427244226..c38d359197 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -24,7 +24,6 @@ import torch from torch.optim.lr_scheduler import LRScheduler -from transformers.trainer_utils import SchedulerType from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index b0414653d4..a05426044a 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -19,13 +19,13 @@ from Apple's ML team. """ import importlib -from axolotl.utils.logging import get_logger import torch from axolotl.integrations.base import BasePlugin from axolotl.utils import get_pytorch_version from axolotl.utils.distributed import is_main_process +from axolotl.utils.logging import get_logger from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401 diff --git a/src/axolotl/integrations/cut_cross_entropy/args.py b/src/axolotl/integrations/cut_cross_entropy/args.py index 81207e3cc2..2729ebe2e3 100644 --- a/src/axolotl/integrations/cut_cross_entropy/args.py +++ b/src/axolotl/integrations/cut_cross_entropy/args.py @@ -15,11 +15,12 @@ """ Module for handling Cut Cross Entropy input arguments. """ -from axolotl.utils.logging import get_logger from typing import Optional from pydantic import BaseModel, model_validator +from axolotl.utils.logging import get_logger + LOG = get_logger(__name__) diff --git a/src/axolotl/integrations/grokfast/__init__.py b/src/axolotl/integrations/grokfast/__init__.py index fe0b4ebc07..234d27226a 100644 --- a/src/axolotl/integrations/grokfast/__init__.py +++ b/src/axolotl/integrations/grokfast/__init__.py @@ -2,10 +2,10 @@ Grokfast plugin for Axolotl """ -from axolotl.utils.logging import get_logger - from transformers.trainer_callback import TrainerCallback +from axolotl.utils.logging import get_logger + from ..base import BasePlugin from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401 from .optimizer import gradfilter_ema diff --git a/src/axolotl/integrations/liger/args.py b/src/axolotl/integrations/liger/args.py index d85c651f23..7c9eb23d56 100644 --- a/src/axolotl/integrations/liger/args.py +++ b/src/axolotl/integrations/liger/args.py @@ -15,11 +15,12 @@ """ Module for handling LIGER input arguments. """ -from axolotl.utils.logging import get_logger from typing import Optional from pydantic import BaseModel, model_validator +from axolotl.utils.logging import get_logger + LOG = get_logger(__name__) diff --git a/src/axolotl/integrations/llm_compressor/plugin.py b/src/axolotl/integrations/llm_compressor/plugin.py index 595062436b..57d506a573 100644 --- a/src/axolotl/integrations/llm_compressor/plugin.py +++ b/src/axolotl/integrations/llm_compressor/plugin.py @@ -3,7 +3,6 @@ by maintaining masks for zero weights during training. """ -from axolotl.utils.logging import get_logger from functools import wraps from typing import Any, Callable, Concatenate, ParamSpec, TypeVar @@ -16,6 +15,7 @@ from transformers.training_args import TrainingArguments from axolotl.integrations.base import BasePlugin +from axolotl.utils.logging import get_logger P = ParamSpec("P") # Params for generic function signatures R = TypeVar("R") # Return type for generic function signatures diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index 2f12fc1575..d7b769a86b 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -2,11 +2,12 @@ monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation """ -from axolotl.utils.logging import get_logger import sys import torch +from axolotl.utils.logging import get_logger + LOG = get_logger(__name__) diff --git a/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py b/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py index 6d0451f16b..589980c8b9 100644 --- a/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py @@ -3,7 +3,6 @@ """ import importlib -from axolotl.utils.logging import get_logger from typing import Optional, Tuple import torch @@ -11,6 +10,8 @@ from flash_attn.flash_attn_interface import flash_attn_func from transformers import AutoConfig, AutoModelForCausalLM +from axolotl.utils.logging import get_logger + LOG = get_logger(__name__) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 15d9628c9a..70e36714c8 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -2,7 +2,6 @@ # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py -from axolotl.utils.logging import get_logger import warnings from typing import List, Optional, Tuple, Union @@ -25,6 +24,7 @@ ) from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name +from axolotl.utils.logging import get_logger try: from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports @@ -493,7 +493,7 @@ def flashattn_forward( # the attention_mask should be the same as the key_padding_mask key_padding_mask=attention_mask, query_padding_mask=( - attention_mask[:, -query_states.size(1):] + attention_mask[:, -query_states.size(1) :] if attention_mask is not None else None ), @@ -536,7 +536,7 @@ def flashattn_forward( kvpacked=True, key_padding_mask=attention_mask, query_padding_mask=( - attention_mask[:, -query_states.size(1):] + attention_mask[:, -query_states.size(1) :] if attention_mask is not None else None ), @@ -612,9 +612,10 @@ def generate_qkv( q, query_padding_mask ) - def output_pad_fn(output_unpad): return pad_input( # noqa: E731 - output_unpad, indices_q, batch_size, seqlen_q - ) + def output_pad_fn(output_unpad): + return pad_input( # noqa: E731 + output_unpad, indices_q, batch_size, seqlen_q + ) else: q_unpad = rearrange(q, "b s h d -> (b s) h d") @@ -627,9 +628,10 @@ def output_pad_fn(output_unpad): return pad_input( # noqa: E731 ) max_seqlen_q = seqlen_q - def output_pad_fn(output_unpad): return rearrange( # noqa: E731 - output_unpad, "(b s) h d -> b s h d", b=batch_size - ) + def output_pad_fn(output_unpad): + return rearrange( # noqa: E731 + output_unpad, "(b s) h d -> b s h d", b=batch_size + ) if key_padding_mask is not None: k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) diff --git a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py index bcb5802f99..3fc22917fb 100644 --- a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py @@ -2,7 +2,6 @@ # pylint: disable=duplicate-code -from axolotl.utils.logging import get_logger from functools import partial from typing import List, Optional, Tuple, Union @@ -28,6 +27,7 @@ ) from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids +from axolotl.utils.logging import get_logger LOG = get_logger(__name__) @@ -165,8 +165,8 @@ def flashattn_forward( ) use_sliding_windows = ( - getattr(self.config, "sliding_window") is not None and - kv_seq_len > self.config.sliding_window + getattr(self.config, "sliding_window") is not None + and kv_seq_len > self.config.sliding_window ) if use_sliding_windows: @@ -177,8 +177,8 @@ def flashattn_forward( if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute if ( - hasattr(self.config, "sliding_window") and - kv_seq_len > self.config.sliding_window + hasattr(self.config, "sliding_window") + and kv_seq_len > self.config.sliding_window ): slicing_tokens = kv_seq_len - self.config.sliding_window @@ -248,7 +248,7 @@ def flashattn_forward( # the attention_mask should be the same as the key_padding_mask key_padding_mask=attention_mask, query_padding_mask=( - attention_mask[:, -query_states.size(1):] + attention_mask[:, -query_states.size(1) :] if attention_mask is not None else None ), @@ -293,7 +293,7 @@ def flashattn_forward( kvpacked=True, key_padding_mask=attention_mask, query_padding_mask=( - attention_mask[:, -query_states.size(1):] + attention_mask[:, -query_states.size(1) :] if attention_mask is not None else None ), @@ -359,9 +359,10 @@ def generate_qkv( q, query_padding_mask ) - def output_pad_fn(output_unpad): return pad_input( # noqa: E731 - output_unpad, indices_q, batch_size, seqlen_q - ) + def output_pad_fn(output_unpad): + return pad_input( # noqa: E731 + output_unpad, indices_q, batch_size, seqlen_q + ) else: q_unpad = rearrange(q, "b s h d -> (b s) h d") @@ -374,9 +375,10 @@ def output_pad_fn(output_unpad): return pad_input( # noqa: E731 ) max_seqlen_q = seqlen_q - def output_pad_fn(output_unpad): return rearrange( # noqa: E731 - output_unpad, "(b s) h d -> b s h d", b=batch_size - ) + def output_pad_fn(output_unpad): + return rearrange( # noqa: E731 + output_unpad, "(b s) h d -> b s h d", b=batch_size + ) if key_padding_mask is not None: k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) diff --git a/src/axolotl/monkeypatch/relora.py b/src/axolotl/monkeypatch/relora.py index 22913532b4..5b7418e39d 100644 --- a/src/axolotl/monkeypatch/relora.py +++ b/src/axolotl/monkeypatch/relora.py @@ -2,7 +2,6 @@ import glob import json -from axolotl.utils.logging import get_logger import os.path import shutil from functools import partial @@ -27,6 +26,7 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import barrier, is_main_process +from axolotl.utils.logging import get_logger LOG = get_logger(__name__) @@ -194,8 +194,8 @@ def on_save( args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", "relora" ) if ( - state.global_step >= self.relora_steps and - state.global_step % self.relora_steps != 0 + state.global_step >= self.relora_steps + and state.global_step % self.relora_steps != 0 ): if self.quantized: if is_main_process() and self.last_full_model != checkpoint_folder: @@ -327,8 +327,8 @@ def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor layer.lora_B[adapter].weight.detach().to(device) @ layer.lora_A[adapter].weight.detach().to(device), getattr(layer, "fan_in_fan_out", False), - ) * - layer.scaling[adapter] + ) + * layer.scaling[adapter] ) raise ValueError("unhandled lora layer type") @@ -441,8 +441,8 @@ def merge_and_save( out_shard_name = shard_path if out_shard_name.startswith("pytorch_model"): out_shard_name = ( - out_shard_name.replace("pytorch_model", "model").rstrip(".bin") + - ".safetensors" + out_shard_name.replace("pytorch_model", "model").rstrip(".bin") + + ".safetensors" ) for module_name in in_tensors: diff --git a/src/axolotl/monkeypatch/trainer/lr.py b/src/axolotl/monkeypatch/trainer/lr.py index 52f934e51d..9afc23c466 100644 --- a/src/axolotl/monkeypatch/trainer/lr.py +++ b/src/axolotl/monkeypatch/trainer/lr.py @@ -2,10 +2,10 @@ monkeypatch for Trainer _get_learning_rate method """ -from axolotl.utils.logging import get_logger - import torch +from axolotl.utils.logging import get_logger + LOG = get_logger(__name__) diff --git a/src/axolotl/monkeypatch/trainer_accelerator_args.py b/src/axolotl/monkeypatch/trainer_accelerator_args.py index 11c46be38c..0a5b27c13e 100644 --- a/src/axolotl/monkeypatch/trainer_accelerator_args.py +++ b/src/axolotl/monkeypatch/trainer_accelerator_args.py @@ -3,11 +3,11 @@ """ import inspect -from axolotl.utils.logging import get_logger from transformers import Trainer from axolotl.monkeypatch.utils import detab_code +from axolotl.utils.logging import get_logger LOG = get_logger(__name__) @@ -70,9 +70,9 @@ def patch_create_accelerate_code_for_fp8(): items_to_import.append(item) exec( # pylint: disable=exec-used # nosec B102 - "from transformers.trainer import (" + - ", ".join(x for x in items_to_import) + - ")", + "from transformers.trainer import (" + + ", ".join(x for x in items_to_import) + + ")", globals(), ) exec(create_code, globals()) # pylint: disable=exec-used # nosec B102 diff --git a/src/axolotl/monkeypatch/trainer_eval_guard.py b/src/axolotl/monkeypatch/trainer_eval_guard.py index 9e750a6b06..8488a16df9 100644 --- a/src/axolotl/monkeypatch/trainer_eval_guard.py +++ b/src/axolotl/monkeypatch/trainer_eval_guard.py @@ -3,11 +3,11 @@ """ import inspect -from axolotl.utils.logging import get_logger from transformers import Trainer from axolotl.monkeypatch.utils import detab_code +from axolotl.utils.logging import get_logger LOG = get_logger(__name__) @@ -66,9 +66,9 @@ def patch_evaluation_loop_for_fsdp2(): items_to_import.append(item) exec( # pylint: disable=exec-used # nosec B102 - "from transformers.trainer import (" + - ", ".join(x for x in items_to_import) + - ")", + "from transformers.trainer import (" + + ", ".join(x for x in items_to_import) + + ")", globals(), ) exec(evaluation_loop, globals()) # pylint: disable=exec-used # nosec B102 diff --git a/src/axolotl/monkeypatch/trainer_fsdp_optim.py b/src/axolotl/monkeypatch/trainer_fsdp_optim.py index 219ceff669..4ce5b8ecd3 100644 --- a/src/axolotl/monkeypatch/trainer_fsdp_optim.py +++ b/src/axolotl/monkeypatch/trainer_fsdp_optim.py @@ -3,11 +3,11 @@ """ import inspect -from axolotl.utils.logging import get_logger from transformers import Trainer from axolotl.monkeypatch.utils import detab_code +from axolotl.utils.logging import get_logger LOG = get_logger(__name__) @@ -69,9 +69,9 @@ def patch_training_loop_for_fsdp(): items_to_import.append(item) exec( # pylint: disable=exec-used # nosec B102 - "from transformers.trainer import (" + - ", ".join(x for x in items_to_import) + - ")", + "from transformers.trainer import (" + + ", ".join(x for x in items_to_import) + + ")", globals(), ) exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102 diff --git a/src/axolotl/monkeypatch/transformers_fa_utils.py b/src/axolotl/monkeypatch/transformers_fa_utils.py index ed2439412e..e372dc3f85 100644 --- a/src/axolotl/monkeypatch/transformers_fa_utils.py +++ b/src/axolotl/monkeypatch/transformers_fa_utils.py @@ -2,12 +2,13 @@ see https://github.com/huggingface/transformers/pull/35834 """ -from axolotl.utils.logging import get_logger from functools import partial from typing import Optional import torch +from axolotl.utils.logging import get_logger + logger = get_logger(__name__) diff --git a/src/axolotl/prompt_strategies/__init__.py b/src/axolotl/prompt_strategies/__init__.py index 28d09e5c6a..3cdbbb6f33 100644 --- a/src/axolotl/prompt_strategies/__init__.py +++ b/src/axolotl/prompt_strategies/__init__.py @@ -2,9 +2,9 @@ import importlib import inspect -from axolotl.utils.logging import get_logger from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig +from axolotl.utils.logging import get_logger LOG = get_logger(__name__) diff --git a/src/axolotl/prompt_strategies/base.py b/src/axolotl/prompt_strategies/base.py index 3a936bde63..370a51a95a 100644 --- a/src/axolotl/prompt_strategies/base.py +++ b/src/axolotl/prompt_strategies/base.py @@ -3,6 +3,7 @@ """ import importlib + from axolotl.utils.logging import get_logger LOG = get_logger(__name__) diff --git a/src/axolotl/prompt_strategies/bradley_terry/__init__.py b/src/axolotl/prompt_strategies/bradley_terry/__init__.py index 4f8a144488..7530aee192 100644 --- a/src/axolotl/prompt_strategies/bradley_terry/__init__.py +++ b/src/axolotl/prompt_strategies/bradley_terry/__init__.py @@ -2,9 +2,9 @@ import importlib import inspect -from axolotl.utils.logging import get_logger from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig +from axolotl.utils.logging import get_logger LOG = get_logger(__name__) diff --git a/src/axolotl/prompt_strategies/messages/__init__.py b/src/axolotl/prompt_strategies/messages/__init__.py index eb51623e8b..cc7b84da18 100644 --- a/src/axolotl/prompt_strategies/messages/__init__.py +++ b/src/axolotl/prompt_strategies/messages/__init__.py @@ -2,6 +2,7 @@ import importlib import inspect + from axolotl.utils.logging import get_logger LOG = get_logger(__name__) diff --git a/src/axolotl/prompt_strategies/metharme.py b/src/axolotl/prompt_strategies/metharme.py index 7914e346bf..66da723893 100644 --- a/src/axolotl/prompt_strategies/metharme.py +++ b/src/axolotl/prompt_strategies/metharme.py @@ -1,10 +1,10 @@ """Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class""" -from axolotl.utils.logging import get_logger from typing import Tuple from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy from axolotl.prompters import AlpacaPrompter +from axolotl.utils.logging import get_logger LOG = get_logger(__name__) diff --git a/src/axolotl/prompt_strategies/pygmalion.py b/src/axolotl/prompt_strategies/pygmalion.py index ea8413e22e..51f92f3970 100644 --- a/src/axolotl/prompt_strategies/pygmalion.py +++ b/src/axolotl/prompt_strategies/pygmalion.py @@ -1,7 +1,6 @@ """Module containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class""" import copy -from axolotl.utils.logging import get_logger from collections import defaultdict from typing import Generator, List, Tuple @@ -10,6 +9,7 @@ parse_tokenized_to_result, tokenize_prompt_default, ) +from axolotl.utils.logging import get_logger LOG = get_logger(__name__) @@ -64,7 +64,7 @@ def tokenize_prompt(self, prompt): # make sure we create the labels first, otherwise we get incorrect lengths labels = [IGNORE_TOKEN_ID] * len(self.bot_prefix_token_ids) + [ *copy.deepcopy(res["input_ids"]) - ][len(self.bot_prefix_token_ids):] + ][len(self.bot_prefix_token_ids) :] else: LOG.warning(f"unknown role in conversation: {role}") res = defaultdict(lambda: []) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 6ffba1fca4..cb1a1ba4ed 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -1,12 +1,12 @@ """Module containing PromptTokenizingStrategy and Prompter classes""" import abc -from axolotl.utils.logging import get_logger from typing import Callable, Dict, List, Optional, Tuple, Union from transformers import BatchEncoding, PreTrainedTokenizer from axolotl.prompters import Prompter +from axolotl.utils.logging import get_logger LOG = get_logger(__name__) @@ -79,9 +79,9 @@ def _tokenize( return empty if ( - result["input_ids"][-1] != self.tokenizer.eos_token_id and - len(result["input_ids"]) < self.max_length and - add_eos_token + result["input_ids"][-1] != self.tokenizer.eos_token_id + and len(result["input_ids"]) < self.max_length + and add_eos_token ): result["input_ids"].append(self.tokenizer.eos_token_id) result["attention_mask"].append(1) @@ -300,9 +300,9 @@ def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): return_tensors=None, ) if ( - result["input_ids"][-1] != self.tokenizer.eos_token_id and - len(result["input_ids"]) < self.sequence_len and - add_eos_token + result["input_ids"][-1] != self.tokenizer.eos_token_id + and len(result["input_ids"]) < self.sequence_len + and add_eos_token ): result["input_ids"].append(self.tokenizer.eos_token_id) result["attention_mask"].append(1) @@ -353,11 +353,11 @@ def parse_tokenized_to_result( input_ids = res["input_ids"] input_len = len(input_ids) - result["input_ids"][current_len: current_len + input_len] = input_ids - result["attention_mask"][current_len: current_len + input_len] = [ + result["input_ids"][current_len : current_len + input_len] = input_ids + result["attention_mask"][current_len : current_len + input_len] = [ 1 if x != pad_token_id else 0 for x in input_ids ] - result["labels"][current_len: current_len + input_len] = labels + result["labels"][current_len : current_len + input_len] = labels current_len += input_len return result, current_len diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 4bd288db7f..d29da075e0 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -1,11 +1,12 @@ """Module containing prompters""" -from axolotl.utils.logging import get_logger from enum import Enum from typing import Generator, Optional, Union from colorama import Fore +from axolotl.utils.logging import get_logger + LOG = get_logger(__name__) IGNORE_TOKEN_ID = -100 REPR_TEMPLATE = "\n\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n\n" @@ -193,12 +194,12 @@ def __init__(self, prompt_style="instruct"): def match_prompt_style(self): if self.prompt_style == PromptStyle.INSTRUCT.value: self.prompt_input = ( - self.system_prompt + - "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" + self.system_prompt + + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" ) self.prompt_no_input = ( - self.system_no_input_prompt + - "### Instruction:\n{instruction}\n\n### Response:\n" + self.system_no_input_prompt + + "### Instruction:\n{instruction}\n\n### Response:\n" ) self.agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}" self.response_split = "### Final Response:" diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index c0d26fb550..d94f4be74d 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -4,7 +4,6 @@ import gc import json -from axolotl.utils.logging import get_logger import os import traceback from shutil import copyfile @@ -43,6 +42,7 @@ is_main_process, zero_first, ) +from axolotl.utils.logging import get_logger from axolotl.utils.schemas.config import AxolotlInputConfig if TYPE_CHECKING: @@ -86,9 +86,9 @@ def on_step_end( ): # Save if ( - args.save_strategy == IntervalStrategy.STEPS and - args.save_steps > 0 and - state.global_step % args.save_steps == 0 + args.save_strategy == IntervalStrategy.STEPS + and args.save_steps > 0 + and state.global_step % args.save_steps == 0 ): control.should_save = True @@ -508,8 +508,8 @@ def predict_with_generate(): if start == end: continue - input_ids = input_ids_all[start: end + 1] - labels = labels_all[start: end + 1] + input_ids = input_ids_all[start : end + 1] + labels = labels_all[start : end + 1] tokens_without_loss = labels == IGNORE_INDEX tokens_with_loss = labels != IGNORE_INDEX @@ -550,7 +550,7 @@ def predict_with_generate(): prompt_token_ids_list, prediction_all_tokens ): prediction_without_prompt_tokens = prediction_tokens[ - len(prompt_token_ids): + len(prompt_token_ids) : ] prediction_without_prompt_tokens_list.append( prediction_without_prompt_tokens @@ -679,8 +679,8 @@ def log_table_from_dataloader(name: str, table_dataloader): if start == end: continue - input_ids = input_ids_all[start: end + 1] - labels = labels_all[start: end + 1] + input_ids = input_ids_all[start : end + 1] + labels = labels_all[start : end + 1] tokens_without_loss = labels == IGNORE_INDEX tokens_with_loss = labels != IGNORE_INDEX @@ -696,7 +696,7 @@ def log_table_from_dataloader(name: str, table_dataloader): completion_token_ids_list.append(completion_token_ids) pred_step_token_ids = logits_to_tokens( - logits[start: end + 1] + logits[start : end + 1] )[tokens_with_loss] pred_step_token_ids_list.append(pred_step_token_ids) @@ -724,7 +724,7 @@ def log_table_from_dataloader(name: str, table_dataloader): prompt_token_ids_list, prediction_all_tokens ): prediction_without_prompt_tokens = prediction_tokens[ - len(prompt_token_ids): + len(prompt_token_ids) : ] prediction_without_prompt_tokens_list.append( prediction_without_prompt_tokens @@ -755,7 +755,12 @@ def log_table_from_dataloader(name: str, table_dataloader): if logger == "wandb": # type: ignore[attr-defined] wandb.run.log( - {f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) + { + f"{name} - Predictions vs Ground Truth": pd.DataFrame( + table_data + ) + } + ) elif logger == "mlflow" and is_mlflow_available(): import mlflow diff --git a/src/axolotl/utils/callbacks/comet_.py b/src/axolotl/utils/callbacks/comet_.py index af3a8b592b..b7e9034b0e 100644 --- a/src/axolotl/utils/callbacks/comet_.py +++ b/src/axolotl/utils/callbacks/comet_.py @@ -1,12 +1,12 @@ """Comet module for trainer callbacks""" -from axolotl.utils.logging import get_logger from typing import TYPE_CHECKING import comet_ml from transformers import TrainerCallback, TrainerControl, TrainerState from axolotl.utils.distributed import is_main_process +from axolotl.utils.logging import get_logger if TYPE_CHECKING: from axolotl.core.trainer_builder import AxolotlTrainingArguments diff --git a/src/axolotl/utils/callbacks/lisa.py b/src/axolotl/utils/callbacks/lisa.py index dbac72691e..ad7e23144a 100644 --- a/src/axolotl/utils/callbacks/lisa.py +++ b/src/axolotl/utils/callbacks/lisa.py @@ -6,13 +6,14 @@ License: Apache 2.0 """ -from axolotl.utils.logging import get_logger from functools import reduce from typing import TYPE_CHECKING import numpy as np from transformers import TrainerCallback +from axolotl.utils.logging import get_logger + if TYPE_CHECKING: from axolotl.core.trainer_builder import AxolotlTrainer diff --git a/src/axolotl/utils/callbacks/mlflow_.py b/src/axolotl/utils/callbacks/mlflow_.py index c4c08b7d5f..15f8ef0697 100644 --- a/src/axolotl/utils/callbacks/mlflow_.py +++ b/src/axolotl/utils/callbacks/mlflow_.py @@ -1,6 +1,5 @@ """MLFlow module for trainer callbacks""" -from axolotl.utils.logging import get_logger import os from shutil import copyfile from tempfile import NamedTemporaryFile @@ -10,6 +9,7 @@ from transformers import TrainerCallback, TrainerControl, TrainerState from axolotl.utils.distributed import is_main_process +from axolotl.utils.logging import get_logger if TYPE_CHECKING: from axolotl.core.trainer_builder import AxolotlTrainingArguments diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 7a958ea32c..c84ade4422 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -3,9 +3,10 @@ These templates are used for formatting messages in a conversation. """ -from axolotl.utils.logging import get_logger from typing import TYPE_CHECKING, Any, Dict, Optional +from axolotl.utils.logging import get_logger + if TYPE_CHECKING: from transformers import PreTrainedTokenizerBase @@ -93,7 +94,7 @@ def get_chat_template( return tokenizer.chat_template # type: ignore user_choice = user_choice[ - len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX): + len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) : ] LOG.warning( f"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template." diff --git a/src/axolotl/utils/comet_.py b/src/axolotl/utils/comet_.py index 8f26645f95..9eeb6a2801 100644 --- a/src/axolotl/utils/comet_.py +++ b/src/axolotl/utils/comet_.py @@ -1,9 +1,9 @@ """Module for wandb utilities""" -from axolotl.utils.logging import get_logger import os from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger LOG = get_logger(__name__) diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 8183e4ec1e..bde499fd4b 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -1,7 +1,6 @@ """Module for working with config dicts""" import json -from axolotl.utils.logging import get_logger import os from typing import Optional @@ -13,6 +12,7 @@ from axolotl.integrations.config import merge_input_args from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from axolotl.utils.models import MULTIMODAL_AUTO_MODEL_MAPPING, load_model_config from axolotl.utils.schemas.config import ( AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, @@ -158,15 +158,15 @@ def normalize_config(cfg): ) cfg.is_multimodal = ( - hasattr(model_config, "model_type") and - model_config.model_type in MULTIMODAL_AUTO_MODEL_MAPPING or - any( + hasattr(model_config, "model_type") + and model_config.model_type in MULTIMODAL_AUTO_MODEL_MAPPING + or any( multimodal_name in cfg.base_model.lower() for multimodal_name in [ "pixtral", ] - ) or - cfg.is_multimodal + ) + or cfg.is_multimodal ) if cfg.is_multimodal: cfg.processor_config = ( @@ -178,46 +178,46 @@ def normalize_config(cfg): # figure out if the model is llama cfg.is_llama_derived_model = ( ( - hasattr(model_config, "model_type") and - model_config.model_type in ["llama", "mllama_text_model"] - ) or - cfg.is_llama_derived_model or - "llama" in cfg.base_model.lower() or - (cfg.type_of_model and "llama" in cfg.type_of_model.lower()) + hasattr(model_config, "model_type") + and model_config.model_type in ["llama", "mllama_text_model"] + ) + or cfg.is_llama_derived_model + or "llama" in cfg.base_model.lower() + or (cfg.type_of_model and "llama" in cfg.type_of_model.lower()) ) # figure out if the model is falcon cfg.is_falcon_derived_model = ( ( - hasattr(model_config, "model_type") and - model_config.model_type + hasattr(model_config, "model_type") + and model_config.model_type in [ "falcon", "RefinedWebModel", "RefinedWeb", ] - ) or - cfg.is_falcon_derived_model or - "falcon" in cfg.base_model.lower() or - (cfg.type_of_model and "rwforcausallm" in cfg.type_of_model.lower()) + ) + or cfg.is_falcon_derived_model + or "falcon" in cfg.base_model.lower() + or (cfg.type_of_model and "rwforcausallm" in cfg.type_of_model.lower()) ) cfg.is_mistral_derived_model = ( ( - hasattr(model_config, "model_type") and - model_config.model_type + hasattr(model_config, "model_type") + and model_config.model_type in [ "mistral", ] - ) or - cfg.is_mistral_derived_model or - "mistral" in cfg.base_model.lower().split("/")[-1] or - (cfg.type_of_model and "mistral" in cfg.type_of_model.lower()) + ) + or cfg.is_mistral_derived_model + or "mistral" in cfg.base_model.lower().split("/")[-1] + or (cfg.type_of_model and "mistral" in cfg.type_of_model.lower()) ) cfg.is_qwen_derived_model = ( - hasattr(model_config, "model_type") and - model_config.model_type + hasattr(model_config, "model_type") + and model_config.model_type in [ "qwen", ] @@ -227,10 +227,10 @@ def normalize_config(cfg): cfg.pretraining_dataset = [cfg.pretraining_dataset] if ( - cfg.gradient_checkpointing and - cfg.unfrozen_parameters is None and - cfg.gradient_checkpointing_kwargs is None and - cfg.rl is None + cfg.gradient_checkpointing + and cfg.unfrozen_parameters is None + and cfg.gradient_checkpointing_kwargs is None + and cfg.rl is None ): cfg.gradient_checkpointing_kwargs = {"use_reentrant": True} @@ -246,8 +246,8 @@ def normalize_cfg_datasets(cfg): if cfg.datasets: for idx, ds_cfg in enumerate(cfg.datasets): if ( - ds_cfg.type in ["orpo.chat_template", "chat_template"] and - not ds_cfg.chat_template + ds_cfg.type in ["orpo.chat_template", "chat_template"] + and not ds_cfg.chat_template ): LOG.info( f"updating dataset {ds_cfg.path} with `chat_template: {cfg.chat_template}` to match your chat_template" diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py index bdf562759c..44d8d6fed0 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/pretraining.py @@ -1,7 +1,6 @@ """data handling specific to pretraining""" import functools -from axolotl.utils.logging import get_logger from collections import defaultdict from typing import Callable, Dict, List, Optional @@ -11,6 +10,7 @@ from transformers import PreTrainedTokenizerBase from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq +from axolotl.utils.logging import get_logger from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.trainer import process_pretraining_datasets_for_packing diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index 30b214f922..160d8d5346 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -1,7 +1,6 @@ """data handling specific to DPO""" import inspect -from axolotl.utils.logging import get_logger from functools import partial from pathlib import Path from typing import Any, List, Union @@ -17,6 +16,7 @@ from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5 from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first +from axolotl.utils.logging import get_logger from axolotl.utils.models import load_tokenizer from axolotl.utils.schemas.enums import RLType @@ -40,9 +40,9 @@ def _load_preprocessed_ds(cfg, sub_cfg): # pylint: disable=duplicate-code if ( - cfg.dataset_prepared_path and - any(prepared_ds_path.glob("*")) and - not cfg.is_preprocess + cfg.dataset_prepared_path + and any(prepared_ds_path.glob("*")) + and not cfg.is_preprocess ): LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") dataset = load_from_disk(str(prepared_ds_path)) @@ -211,22 +211,22 @@ def load_split(dataset_cfgs, _cfg): # ensure we end up with the same fingerprint by doing rank0 first and being able to cache to_hash_train = ( - train_dataset._fingerprint + # pylint: disable=protected-access - "|" + - str(cfg.val_set_size) + - "|" + - "train" + - "|" + - str(cfg.seed or 42) + train_dataset._fingerprint # pylint: disable=protected-access + + "|" + + str(cfg.val_set_size) + + "|" + + "train" + + "|" + + str(cfg.seed or 42) ) to_hash_test = ( - train_dataset._fingerprint + # pylint: disable=protected-access - "|" + - str(cfg.val_set_size) + - "|" + - "test" + - "|" + - str(cfg.seed or 42) + train_dataset._fingerprint # pylint: disable=protected-access + + "|" + + str(cfg.val_set_size) + + "|" + + "test" + + "|" + + str(cfg.seed or 42) ) train_fingerprint = md5(to_hash_train) test_fingerprint = md5(to_hash_test) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 6cbe6a3ed1..ebe9dba40c 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -508,8 +508,6 @@ def get_dataset_wrapper( LOG.info( f"Loading dataset: {config_dataset['path']} with base_type: {d_base_type} and prompt_style: {d_prompt_style}" - - ) if ( diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index b1002dd6dd..7ed5b45316 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -2,7 +2,6 @@ # pylint: disable=too-many-lines -from axolotl.utils.logging import get_logger import os from typing import Annotated, Any, Literal @@ -18,6 +17,7 @@ ) from transformers.utils.import_utils import is_torch_npu_available +from axolotl.utils.logging import get_logger from axolotl.utils.schemas.datasets import ( DatasetConfig, DPODataset, @@ -103,16 +103,16 @@ class AxolotlInputConfig( Annotated[ list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset], MinLen(1), - ] | - None + ] + | None ) = None test_datasets: ( Annotated[ list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset], MinLen(1), - ] | - None + ] + | None ) = None shuffle_merged_datasets: bool | None = True dataset_prepared_path: str | None = None @@ -126,8 +126,9 @@ class AxolotlInputConfig( default=None, json_schema_extra={"description": "streaming dataset to use for pretraining"}, ) - dataset_processes: int | None = Field(default=min( - 32, os.cpu_count())) # type: ignore[type-var] + dataset_processes: int | None = Field( + default=min(32, os.cpu_count()) + ) # type: ignore[type-var] dataset_exact_deduplication: bool | None = None dataset_keep_in_memory: bool | None = None dataloader_pin_memory: bool | None = None @@ -307,8 +308,8 @@ class AxolotlInputConfig( low_cpu_mem_usage: bool | None = None chat_template: ( - ChatTemplate | - Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")] + ChatTemplate + | Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")] ) | None = None chat_template_jinja: str | None = None eot_tokens: list[str] | None = None @@ -431,9 +432,9 @@ def check_pretraining_split_batches_accelerate(cls, data): def check_gptq_w_revision(cls, data): if data.get("gptq") and data.get("revision_of_model"): raise ValueError( - "revision_of_model is not supported for GPTQ models. " + - "Please download the model from HuggingFace Hub manually for correct branch, " + - "point to its path, and remove revision_of_model from the config." + "revision_of_model is not supported for GPTQ models. " + + "Please download the model from HuggingFace Hub manually for correct branch, " + + "point to its path, and remove revision_of_model from the config." ) return data @@ -459,10 +460,10 @@ def check_chat_template_config(cls, data): @classmethod def check_sample_packing_wo_flash(cls, data): if ( - data.get("sample_packing") and - not data.get("flash_attention") and - not data.get("sdp_attention") and - not data.get("flex_attention") + data.get("sample_packing") + and not data.get("flash_attention") + and not data.get("sdp_attention") + and not data.get("flex_attention") ): LOG.warning( "sample_packing without flash, sdp, xformers or flex attention does not handle cross sample decontamination." @@ -483,10 +484,10 @@ def check_batch_flattening_fa(cls, data): LOG.warning("batch_flattening has no effect with micro_batch_size == 1") if ( - batch_flattening_auto and - data.get("flash_attention") and - not data.get("sample_packing") and - data.get("micro_batch_size") > 1 + batch_flattening_auto + and data.get("flash_attention") + and not data.get("sample_packing") + and data.get("micro_batch_size") > 1 ): data["batch_flattening"] = True elif batch_flattening_auto: @@ -541,9 +542,9 @@ def check_gas_bsz(cls, data): @classmethod def hint_eval_train_mbsz(cls, data): if ( - data.get("eval_batch_size") and - data.get("micro_batch_size") and - data.get("eval_batch_size") != data.get("micro_batch_size") + data.get("eval_batch_size") + and data.get("micro_batch_size") + and data.get("eval_batch_size") != data.get("micro_batch_size") ): LOG.warning( "eval_batch_size != micro_batch_size. This can lead to VRAM instability." @@ -554,8 +555,8 @@ def hint_eval_train_mbsz(cls, data): @classmethod def check_push_ds_auth(cls, data): if ( - data.get("push_dataset_to_hub") and - data.get("hf_use_auth_token") is not True + data.get("push_dataset_to_hub") + and data.get("hf_use_auth_token") is not True ): raise ValueError( "Require cfg.hf_use_auth_token to be True for push_dataset_to_hub" @@ -620,9 +621,9 @@ def check_lr_groups(cls, data): @classmethod def check_saves(cls, data): if ( - data.get("save_strategy") and - data.get("save_steps") and - data.get("save_strategy") != "steps" + data.get("save_strategy") + and data.get("save_steps") + and data.get("save_strategy") != "steps" ): raise ValueError( "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps." @@ -648,19 +649,19 @@ def check_push_save(cls, data): @classmethod def check_evals(cls, data): if ( - data.get("eval_strategy") and - data.get("eval_steps") and - data.get("eval_strategy") != "steps" + data.get("eval_strategy") + and data.get("eval_steps") + and data.get("eval_strategy") != "steps" ): raise ValueError( "eval_strategy and eval_steps mismatch. Please set eval_strategy to 'steps' or remove eval_steps." ) if ( - data.get("val_set_size") == 0 and - (data.get("eval_steps") or data.get("eval_strategy")) and - not data.get("test_datasets") and - data.get("eval_strategy") != "no" + data.get("val_set_size") == 0 + and (data.get("eval_steps") or data.get("eval_strategy")) + and not data.get("test_datasets") + and data.get("eval_strategy") != "no" ): raise ValueError( "eval_steps and eval_strategy are not supported with val_set_size == 0" @@ -670,9 +671,9 @@ def check_evals(cls, data): "eval_steps and evals_per_epoch are mutually exclusive and cannot be used together." ) if ( - data.get("evals_per_epoch") and - data.get("eval_strategy") and - data.get("eval_strategy") != "steps" + data.get("evals_per_epoch") + and data.get("eval_strategy") + and data.get("eval_strategy") != "steps" ): raise ValueError( "eval_strategy must be empty or set to `steps` when used with evals_per_epoch." @@ -690,9 +691,9 @@ def check_evals(cls, data): @classmethod def check_test_datasets_bench(cls, data): if ( - data.get("do_bench_eval") and - not data.get("test_datasets") and - not data.get("val_set_size") + data.get("do_bench_eval") + and not data.get("test_datasets") + and not data.get("val_set_size") ): LOG.warning( "`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset." @@ -706,17 +707,17 @@ def check_eval_packing(cls, data): # TODO also should check test_datasets and val_set_size as we can skip # if there are no eval datasets/splits if ( - data.get("sample_packing") and - data.get("eval_table_size") and - data.get("eval_sample_packing") is not False + data.get("sample_packing") + and data.get("eval_table_size") + and data.get("eval_sample_packing") is not False ): raise ValueError( "eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false." ) if ( - data.get("sample_packing") and - data.get("eval_sample_packing") is None and - not data.get("eval_table_size") + data.get("sample_packing") + and data.get("eval_sample_packing") is None + and not data.get("eval_table_size") ): LOG.info( "explicitly setting `eval_sample_packing` to match `sample_packing`" @@ -724,9 +725,9 @@ def check_eval_packing(cls, data): data["eval_sample_packing"] = True if ( - data.get("sample_packing") and - data.get("eval_sample_packing") is False and - data.get("remove_unused_columns") is None + data.get("sample_packing") + and data.get("eval_sample_packing") is False + and data.get("remove_unused_columns") is None ): LOG.info( "setting `remove_unused_columns: false` for when sample_packing and eval_sample_packing don't match" @@ -792,9 +793,9 @@ def check_simpo_warmup(self): @classmethod def check_frozen(cls, data): if ( - data.get("adapter") and - data.get("peft_layers_to_transform") and - data.get("unfrozen_parameters") + data.get("adapter") + and data.get("peft_layers_to_transform") + and data.get("unfrozen_parameters") ): raise ValueError( "`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior." @@ -815,11 +816,11 @@ def check_peft_layers_pattern(cls, data): def check_fft_possible_bad_config(self): if ( # pylint: disable=too-many-boolean-expressions - not (self.bf16 or self.bfloat16) and - (self.fp16 or self.float16) and - not self.adapter and - not self.flash_attention and - self.sample_packing + not (self.bf16 or self.bfloat16) + and (self.fp16 or self.float16) + and not self.adapter + and not self.flash_attention + and self.sample_packing ): LOG.warning( "Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA." @@ -884,8 +885,8 @@ def check_relora(self): @classmethod def check_mem_mismatch(cls, data): if ( - data.get("max_memory") is not None and - data.get("gpu_memory_limit") is not None + data.get("max_memory") is not None + and data.get("gpu_memory_limit") is not None ): raise ValueError( "max_memory and gpu_memory_limit are mutually exclusive and cannot be used together." @@ -896,9 +897,9 @@ def check_mem_mismatch(cls, data): @classmethod def check_use_reentrant_mismatch(cls, data): if ( - data.get("unfrozen_parameters") and - data.get("gradient_checkpointing_kwargs") and - data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") + data.get("unfrozen_parameters") + and data.get("gradient_checkpointing_kwargs") + and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") is True ): # https://github.com/huggingface/transformers/issues/21381 @@ -911,12 +912,12 @@ def check_use_reentrant_mismatch(cls, data): @classmethod def warn_qlora_zero3_w_use_reentrant(cls, data): if ( - data.get("adapter") == "qlora" and - data.get("gradient_checkpointing_kwargs", {}) and - data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") - is False and - data.get("deepspeed", "") is not None and - "zero3" in data.get("deepspeed", "") + data.get("adapter") == "qlora" + and data.get("gradient_checkpointing_kwargs", {}) + and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") + is False + and data.get("deepspeed", "") is not None + and "zero3" in data.get("deepspeed", "") ): # may result in: # torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: @@ -940,8 +941,8 @@ def check_val_w_test_datasets(cls, data): @classmethod def check_eval_strategy(cls, data): if ( - data.get("evaluation_strategy") is not None and - data.get("eval_strategy") is None + data.get("evaluation_strategy") is not None + and data.get("eval_strategy") is None ): LOG.info( "explicitly setting `eval_strategy` from the `evaluation_strategy`" @@ -953,20 +954,20 @@ def check_eval_strategy(cls, data): @classmethod def check_fsdp_offload_w_8bit_optimizer(cls, data): if ( - data.get("fsdp") and - "8bit" in data.get("optimizer", "") and - data.get("fsdp_config") and - data["fsdp_config"].get("fsdp_offload_params") and - str(data["fsdp_config"].get("fsdp_version")) != "2" + data.get("fsdp") + and "8bit" in data.get("optimizer", "") + and data.get("fsdp_config") + and data["fsdp_config"].get("fsdp_offload_params") + and str(data["fsdp_config"].get("fsdp_version")) != "2" ): raise ValueError( f"FSDP Offload not compatible with {data.get('optimizer')}" ) if ( - data.get("fsdp") and - "8bit" in data.get("optimizer", "") and - data.get("fsdp_config") and - str(data["fsdp_config"].get("fsdp_version")) == "2" + data.get("fsdp") + and "8bit" in data.get("optimizer", "") + and data.get("fsdp_config") + and str(data["fsdp_config"].get("fsdp_version")) == "2" ): if data.get("optimizer", "") in ["adamw_8bit", "adamw_bnb_8bit"]: # CUDA ops errors with bnb 8bit optimizer + FSDP2 @@ -980,10 +981,10 @@ def check_fsdp_offload_w_8bit_optimizer(cls, data): @classmethod def check_fsdp_sharded_state_dict_w_safetensors(cls, data): if ( - data.get("fsdp") and - data.get("save_safetensors") and - data.get("fsdp_config") and - data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT" + data.get("fsdp") + and data.get("save_safetensors") + and data.get("fsdp_config") + and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT" ): raise ValueError( "FSDP SHARDED_STATE_DICT not compatible with save_safetensors" @@ -1030,9 +1031,9 @@ def check_xentropy_patch_conflicts(cls, data): @classmethod def check_qlora_unsloth(cls, data): if ( - data.get("unsloth_lora_mlp") or - data.get("unsloth_lora_qkv") or - data.get("unsloth_lora_o") + data.get("unsloth_lora_mlp") + or data.get("unsloth_lora_qkv") + or data.get("unsloth_lora_o") ): if data.get("adapter") == "lora" and data.get("load_in_8bit"): raise ValueError( @@ -1044,9 +1045,9 @@ def check_qlora_unsloth(cls, data): @classmethod def check_lora_8bit(cls, data): if ( - data.get("lora_mlp_kernel") or - data.get("lora_qkv_kernel") or - data.get("lora_o_kernel") + data.get("lora_mlp_kernel") + or data.get("lora_qkv_kernel") + or data.get("lora_o_kernel") ): if data.get("adapter") == "lora" and data.get("load_in_8bit"): raise ValueError( @@ -1121,14 +1122,14 @@ def check_rl_config_gradient_checkpointing(cls, data): # and use_reentrant = True is broken upstream in TRL # pylint: disable=too-many-boolean-expressions if ( - data.get("rl") and - data.get("gradient_checkpointing") and - data.get("gradient_checkpointing_kwargs") and - data.get("gradient_checkpointing_kwargs").get("use_reentrant") and - data.get("load_in_4bit") and - data.get("adapter") == "qlora" and - data.get("capabilities") and - data.get("capabilities").get("n_gpu", 1) > 1 + data.get("rl") + and data.get("gradient_checkpointing") + and data.get("gradient_checkpointing_kwargs") + and data.get("gradient_checkpointing_kwargs").get("use_reentrant") + and data.get("load_in_4bit") + and data.get("adapter") == "qlora" + and data.get("capabilities") + and data.get("capabilities").get("n_gpu", 1) > 1 ): raise ValueError( "The `use_reentrant: True` implementation of gradient checkpointing " @@ -1165,10 +1166,10 @@ def check_kto_config(cls, data): @classmethod def check_grpo_liger_sequence_parallel(cls, data): if ( - data.get("rl") == "grpo" and - data.get("trl", {}) and - data.get("trl").get("use_liger_loss") and - data.get("adapter") + data.get("rl") == "grpo" + and data.get("trl", {}) + and data.get("trl").get("use_liger_loss") + and data.get("adapter") ): raise ValueError("GRPO + SP + Liger not currently supported") return data @@ -1257,9 +1258,9 @@ def check_bf16(self): ) else: if ( - not self.merge_lora and - not self.is_preprocess and - (self.bf16 is True or self.bfloat16 is True) + not self.merge_lora + and not self.is_preprocess + and (self.bf16 is True or self.bfloat16 is True) ): raise ValueError( "bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above." @@ -1270,14 +1271,14 @@ def check_bf16(self): @classmethod def check_sample_packing_w_sdpa_bf16(cls, data): is_sm_90: bool = ( - data["capabilities"] and - data["capabilities"].get("compute_capability") == "sm_90" + data["capabilities"] + and data["capabilities"].get("compute_capability") == "sm_90" ) if ( - data.get("sample_packing") and - data.get("sdp_attention") and - (data.get("bfloat16") or data.get("bf16")) and - not is_sm_90 + data.get("sample_packing") + and data.get("sdp_attention") + and (data.get("bfloat16") or data.get("bf16")) + and not is_sm_90 ): # https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450 LOG.warning( @@ -1298,9 +1299,9 @@ def check_fsdp_deepspeed(cls, data): @classmethod def check_multigpu_unsloth(cls, data): if ( - data.get("unsloth_lora_mlp") or - data.get("unsloth_lora_qkv") or - data.get("unsloth_lora_o") + data.get("unsloth_lora_mlp") + or data.get("unsloth_lora_qkv") + or data.get("unsloth_lora_o") ): capabilities = data.get("capabilities") if capabilities and capabilities.get("n_gpu", 0) > 1: @@ -1313,15 +1314,15 @@ def check_multigpu_unsloth(cls, data): @classmethod def check_multigpu_lora_kernels(cls, data): if ( - data.get("lora_mlp_kernel") or - data.get("lora_qkv_kernel") or - data.get("lora_o_kernel") + data.get("lora_mlp_kernel") + or data.get("lora_qkv_kernel") + or data.get("lora_o_kernel") ): capabilities = data.get("capabilities") is_fsdp = data.get("fsdp") is not None is_fsdp2 = ( - data.get("fsdp_config") is not None and - str(data.get("fsdp_config").get("fsdp_version")) == "2" + data.get("fsdp_config") is not None + and str(data.get("fsdp_config").get("fsdp_version")) == "2" ) if capabilities and capabilities.get("n_gpu", 0) > 1 and not is_fsdp2: if is_fsdp: @@ -1342,10 +1343,10 @@ def check_auto_enable_lora_kernels(cls, data): unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"] kernel_fields = ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"] if ( - any(data.get(k) is not None for k in kernel_fields) or - any(data.get(k) for k in unsloth_fields) or - data.get("adapter") == "lora" and - data.get("load_in_8bit") + any(data.get(k) is not None for k in kernel_fields) + or any(data.get(k) for k in unsloth_fields) + or data.get("adapter") == "lora" + and data.get("load_in_8bit") ): return data @@ -1358,14 +1359,14 @@ def check_auto_enable_lora_kernels(cls, data): is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1 is_fsdp = data.get("fsdp") is not None is_fsdp2 = ( - data.get("fsdp_config") is not None and - str(data.get("fsdp_config").get("fsdp_version")) == "2" + data.get("fsdp_config") is not None + and str(data.get("fsdp_config").get("fsdp_version")) == "2" ) if ( - not is_multi_gpu or - (is_multi_gpu and not is_fsdp) or - (is_multi_gpu and is_fsdp2) + not is_multi_gpu + or (is_multi_gpu and not is_fsdp) + or (is_multi_gpu and is_fsdp2) ): # Auto-enable kernels if not explicitly set by user if data.get("lora_mlp_kernel") is None: @@ -1378,9 +1379,9 @@ def check_auto_enable_lora_kernels(cls, data): data["lora_o_kernel"] = True LOG.warning( - "Auto-enabling LoRA kernel optimizations for faster training. " + - "Please explicitly set `lora_*_kernel` config values to `false` to disable. " + - "See https://docs.axolotl.ai/docs/lora_optims.html for more info." + "Auto-enabling LoRA kernel optimizations for faster training. " + + "Please explicitly set `lora_*_kernel` config values to `false` to disable. " + + "See https://docs.axolotl.ai/docs/lora_optims.html for more info." ) return data diff --git a/src/axolotl/utils/schemas/deprecated.py b/src/axolotl/utils/schemas/deprecated.py index 48b9fd5fce..b8904136e4 100644 --- a/src/axolotl/utils/schemas/deprecated.py +++ b/src/axolotl/utils/schemas/deprecated.py @@ -1,10 +1,11 @@ """Pydantic models for deprecated and remapped configuration parameters""" -from axolotl.utils.logging import get_logger from typing import Any from pydantic import BaseModel, Field, field_validator +from axolotl.utils.logging import get_logger + LOG = get_logger(__name__) diff --git a/src/axolotl/utils/schemas/integrations.py b/src/axolotl/utils/schemas/integrations.py index c505113fb8..4843e3592d 100644 --- a/src/axolotl/utils/schemas/integrations.py +++ b/src/axolotl/utils/schemas/integrations.py @@ -1,10 +1,11 @@ """Pydantic models for Axolotl integrations""" -from axolotl.utils.logging import get_logger from typing import Any from pydantic import BaseModel, Field, model_validator +from axolotl.utils.logging import get_logger + LOG = get_logger(__name__) diff --git a/src/axolotl/utils/schemas/model.py b/src/axolotl/utils/schemas/model.py index b143803349..25a6ffd4f9 100644 --- a/src/axolotl/utils/schemas/model.py +++ b/src/axolotl/utils/schemas/model.py @@ -1,9 +1,9 @@ """Pydantic models for model input / output, etc. configuration""" -from axolotl.utils.logging import get_logger - from pydantic import BaseModel, Field, field_validator +from axolotl.utils.logging import get_logger + LOG = get_logger(__name__) diff --git a/src/axolotl/utils/schemas/training.py b/src/axolotl/utils/schemas/training.py index aa505738c9..ad7f899aca 100644 --- a/src/axolotl/utils/schemas/training.py +++ b/src/axolotl/utils/schemas/training.py @@ -1,12 +1,12 @@ """Pydantic models for training hyperparameters""" -from axolotl.utils.logging import get_logger from typing import Any, Literal from pydantic import BaseModel, Field, field_validator from transformers import SchedulerType from transformers.training_args import OptimizerNames +from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import CustomSupportedOptimizers LOG = get_logger(__name__) diff --git a/src/axolotl/utils/schemas/utils.py b/src/axolotl/utils/schemas/utils.py index f1e8d1f47b..b46c8f8475 100644 --- a/src/axolotl/utils/schemas/utils.py +++ b/src/axolotl/utils/schemas/utils.py @@ -40,8 +40,8 @@ def handle_legacy_message_fields_logic(data: dict) -> dict: f"Example: message_property_mappings: {{role: {data['message_field_role']}}}" ) if ( - "role" in data["message_property_mappings"] and - data["message_property_mappings"]["role"] != data["message_field_role"] + "role" in data["message_property_mappings"] + and data["message_property_mappings"]["role"] != data["message_field_role"] ): raise ValueError( f"Conflicting message role fields: message_field_role='{data['message_field_role']}' " @@ -60,9 +60,9 @@ def handle_legacy_message_fields_logic(data: dict) -> dict: f"Example: message_property_mappings: {{content: {data['message_field_content']}}}" ) if ( - "content" in data["message_property_mappings"] and - data["message_property_mappings"]["content"] != - data["message_field_content"] + "content" in data["message_property_mappings"] + and data["message_property_mappings"]["content"] + != data["message_field_content"] ): raise ValueError( f"Conflicting message content fields: message_field_content='{data['message_field_content']}' " diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index 9f53f63856..3526bd5b58 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -1,9 +1,9 @@ """Module for tokenization utilities""" -from axolotl.utils.logging import get_logger - from termcolor import colored +from axolotl.utils.logging import get_logger + LOG = get_logger(__name__) diff --git a/tests/prompt_strategies/messages/test_chat.py b/tests/prompt_strategies/messages/test_chat.py index 5d7a4f18b6..a4c2ae67fd 100644 --- a/tests/prompt_strategies/messages/test_chat.py +++ b/tests/prompt_strategies/messages/test_chat.py @@ -1,23 +1,15 @@ -"""Module for testing chat message internals.""" +""" +tests for chat_template prompt strategy +""" -import os +# pylint: disable=duplicate-code import unittest -from transformers import AutoTokenizer - -from axolotl.core.chat.messages import ( - ChatFormattedChats, - Chats, - MessageContents, - MessageContentTypes, - MessageRoles, - Messages, -) from axolotl.prompt_strategies.messages.chat import load from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger -LOG = get_logger("axolotl") +LOG = get_logger(__name__, log_level="DEBUG") class TestMessagesChatLlama3: diff --git a/tests/prompt_strategies/test_chat_templates.py b/tests/prompt_strategies/test_chat_templates.py index 667320b5fa..371ccf6161 100644 --- a/tests/prompt_strategies/test_chat_templates.py +++ b/tests/prompt_strategies/test_chat_templates.py @@ -12,7 +12,6 @@ from axolotl.prompters import IGNORE_TOKEN_ID from axolotl.utils.chat_templates import get_chat_template from axolotl.utils.dict import DictDefault - from axolotl.utils.logging import get_logger LOG = get_logger(__name__) diff --git a/tests/prompt_strategies/test_chat_templates_thinking.py b/tests/prompt_strategies/test_chat_templates_thinking.py index 51495bdb1f..21d8c4d5ea 100644 --- a/tests/prompt_strategies/test_chat_templates_thinking.py +++ b/tests/prompt_strategies/test_chat_templates_thinking.py @@ -10,9 +10,9 @@ load, ) from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger from tests.hf_offline_utils import enable_hf_offline -from axolotl.utils.logging import get_logger LOG = get_logger(__name__) diff --git a/tests/prompt_strategies/test_jinja_template_analyzer.py b/tests/prompt_strategies/test_jinja_template_analyzer.py index 497e1f390a..41b9a0203a 100644 --- a/tests/prompt_strategies/test_jinja_template_analyzer.py +++ b/tests/prompt_strategies/test_jinja_template_analyzer.py @@ -1,16 +1,13 @@ -"""Module for testing jinja template analyzer.""" - -import os +""" +tests for jinja_template_analyzer +""" import pytest -from axolotl.prompt_strategies.jinja_template_analyzer import ( - PromptComponentStatus, - PromptTemplateAnalyzer, -) +from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer from axolotl.utils.logging import get_logger -LOG = get_logger("axolotl") +LOG = get_logger(__name__, log_level="DEBUG") class TestJinjaTemplateAnalyzer: @@ -81,7 +78,7 @@ def test_nested_property_access(self): LOG.info("Testing nested property access") template = """{{ user.profile.name }}{{ user.settings['preference'] }}""" - analyzer = PromptTemplateAnalyzer(template) + analyzer = JinjaTemplateAnalyzer(template) variables = analyzer.get_template_variables() assert "user" in variables @@ -100,7 +97,7 @@ def test_loop_variable_handling(self): {% endfor %} {% endfor %} """ - analyzer = PromptTemplateAnalyzer(template) + analyzer = JinjaTemplateAnalyzer(template) analysis = analyzer.analyze_template() assert analysis["items"]["is_iterated"] @@ -116,7 +113,7 @@ def test_conditional_variable_usage(self): {{ debug_info }} {% endif %} """ - analyzer = PromptTemplateAnalyzer(template) + analyzer = JinjaTemplateAnalyzer(template) analysis = analyzer.analyze_template() assert analysis["user"]["is_conditional"] @@ -133,7 +130,7 @@ def test_complex_expressions(self): {{ messages | length > 0 and messages[0].content }} {{ data['key'].nested['value'] }} """ - analyzer = PromptTemplateAnalyzer(template) + analyzer = JinjaTemplateAnalyzer(template) variables = analyzer.get_template_variables() assert "user" in variables diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index be2d734d09..d34b774b36 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -1,21 +1,26 @@ -"""Testing for prompt_tokenizers.py""" +"""Module for testing prompt tokenizers.""" -import unittest +import json +from pathlib import Path -import pytest -from transformers import AutoTokenizer - -from axolotl.prompt_strategies.alpaca import AlpacaPrompter -from axolotl.prompt_tokenizers import ( - AlpacaPromptTokenizingStrategy, - InstructionPromptTokenizingStrategy, - PromptTokenizingStrategy, - ShareGPTPromptTokenizingStrategy, +from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter +from axolotl.prompt_strategies.alpaca_w_system import ( + InstructionWSystemPromptTokenizingStrategy, + SystemDataPrompter, +) +from axolotl.prompt_strategies.llama2_chat import ( + Llama2ChatPrompter, + LLama2ChatTokenizingStrategy, ) -from axolotl.prompters import AlpacaInstructionPrompter, PromptStyle, ShareGPTPrompter +from axolotl.prompt_strategies.orpo.chat_template import load +from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy +from axolotl.prompters import AlpacaPrompter, PromptStyle +from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger -LOG = get_logger("axolotl") +from tests.hf_offline_utils import enable_hf_offline + +LOG = get_logger(__name__) test_data = { "multi_turn_sys": { @@ -56,7 +61,7 @@ class TestPromptTokenizationStrategies: Test class for prompt tokenization strategies. """ - @pytest.mark.enable_hf_offline + @enable_hf_offline def test_no_sys_prompt(self, tokenizer_huggyllama_w_special_tokens): """ tests the interface between the user and assistant parts @@ -78,7 +83,7 @@ def test_no_sys_prompt(self, tokenizer_huggyllama_w_special_tokens): assert example["labels"][world_idx] == 3186 assert example["labels"][world_idx - 1] == -100 - @pytest.mark.enable_hf_offline + @enable_hf_offline def test_alpaca(self, tokenizer_huggyllama_w_special_tokens): """ tests the interface between the user and assistant parts @@ -103,7 +108,7 @@ class TestInstructionWSystemPromptTokenizingStrategy: Test class for prompt tokenization strategies with sys prompt from the dataset """ - @pytest.mark.enable_hf_offline + @enable_hf_offline def test_system_alpaca(self, tokenizer_huggyllama_w_special_tokens): prompter = SystemDataPrompter(PromptStyle.CHAT.value) strat = InstructionWSystemPromptTokenizingStrategy( @@ -134,7 +139,7 @@ class Llama2ChatTokenizationTest: Test class for prompt tokenization strategies with sys prompt from the dataset """ - @pytest.mark.enable_hf_offline + @enable_hf_offline def test_llama2_chat_integration(self, tokenizer_llama2_7b): with open( Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8" @@ -208,7 +213,7 @@ def compare_with_transformers_integration(self, tokenizer_llama2_7b): class OrpoTokenizationTest: """test case for the ORPO tokenization""" - @pytest.mark.enable_hf_offline + @enable_hf_offline def test_orpo_integration( self, tokenizer_mistral_7b_instruct_chatml, diff --git a/update_logging.py b/update_logging.py deleted file mode 100644 index 461ab9a849..0000000000 --- a/update_logging.py +++ /dev/null @@ -1,106 +0,0 @@ -#!/usr/bin/env python -""" -Script to update all test files to use the standardized logging approach. -""" - -import os -import re -import sys -from pathlib import Path - - -def update_file(file_path, dry_run=False): - """Update a file to use the standardized logging approach.""" - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - - # Keep track of changes - changes_made = False - - # Replace the import if it's a standalone import - import_pattern = r"import\s+logging\s*\n" - if re.search(import_pattern, content): - new_content = re.sub( - import_pattern, "from axolotl.utils.logging import get_logger\n", content - ) - changes_made = new_content != content - content = new_content - - # Replace the logger initialization - logger_pattern = r'LOG\s*=\s*logging\.getLogger\([\'"]([^\'"]+)[\'"]\)' - if re.search(logger_pattern, content): - new_content = re.sub(logger_pattern, r'LOG = get_logger("\1")', content) - changes_made = changes_made or (new_content != content) - content = new_content - - # Remove logging.basicConfig if present - basicconfig_pattern = r"logging\.basicConfig\([^\)]+\)\s*\n" - if re.search(basicconfig_pattern, content): - new_content = re.sub(basicconfig_pattern, "", content) - changes_made = changes_made or (new_content != content) - content = new_content - - if changes_made and not dry_run: - with open(file_path, "w", encoding="utf-8") as f: - f.write(content) - - return changes_made - - -def find_and_update_files(base_dir, dry_run=False): - """Find and update all test files that use logging.""" - updated_files = [] - skipped_files = [] - - for root, _, files in os.walk(base_dir): - for file in files: - if file.endswith(".py"): - file_path = os.path.join(root, file) - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - - if ( - "import logging" in content - or "logging.getLogger" in content - or "logging.basicConfig" in content - ): - if "from axolotl.utils.logging import get_logger" in content: - if "import logging" in content: - # Both imports present, probably needs manual inspection - skipped_files.append(file_path) - else: - # Already using the standardized logger - pass - else: - if update_file(file_path, dry_run): - updated_files.append(file_path) - - return updated_files, skipped_files - - -if __name__ == "__main__": - dry_run = "--dry-run" in sys.argv - if dry_run: - sys.argv.remove("--dry-run") - - if len(sys.argv) > 1: - base_dir = sys.argv[1] - else: - base_dir = "tests" - - updated_files, skipped_files = find_and_update_files(base_dir, dry_run) - - if dry_run: - print(f"DRY RUN: Would update {len(updated_files)} files:") - else: - print(f"Updated {len(updated_files)} files:") - - for file in updated_files: - rel_path = os.path.relpath(file, os.getcwd()) - print(f" - {rel_path}") - - if skipped_files: - print(f"\nSkipped {len(skipped_files)} files (need manual inspection):") - for file in skipped_files: - rel_path = os.path.relpath(file, os.getcwd()) - print(f" - {rel_path}") From f734d51a700f6cad93da76b87263b8fc418d3278 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 22 May 2025 12:11:39 +0100 Subject: [PATCH 12/29] linting --- .../prompt_strategies/chat_template.py | 5 +-- src/axolotl/utils/models.py | 5 +-- src/axolotl/utils/samplers/multipack.py | 43 ++++++++++++++----- 3 files changed, 36 insertions(+), 17 deletions(-) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 59f1c759b3..7b804084c0 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -543,9 +543,8 @@ def find_turn(self, turns: list[dict], turn_idx: int): and turns[0].get("role") == "system" and ( "mistral" in self.tokenizer.name_or_path.lower() - or - # gemma3 uses gemma tokenizer - "gemma" in self.tokenizer.name_or_path.lower() + or "gemma" + in self.tokenizer.name_or_path.lower() # gemma3 uses gemma tokenizer ) ): return -1, -1 diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c545b8bcb7..4c44a3c9e5 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1377,13 +1377,12 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: should_convert = ( # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # convert them back to fp16/bf16 for flash-attn compatibility. + # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass ( (needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) and not qlora_fsdp ) - or - # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass - self.cfg.cut_cross_entropy + or self.cfg.cut_cross_entropy ) if should_convert: diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index feccb19806..222dfacb21 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -3,7 +3,6 @@ into fixed-capacity batches to optimize memory usage and training throughput. """ -import logging import math from concurrent.futures import ProcessPoolExecutor from multiprocessing import cpu_count, get_context @@ -17,7 +16,6 @@ from axolotl.utils.logging import get_logger LOG = get_logger(__name__) -LOG.setLevel(logging.INFO) @numba.njit @@ -79,15 +77,11 @@ def pack_group( Returns: List of bins, where each bin contains indices of sequences assigned to it """ - # Get sorting indices and sort lengths in descending order - indices = np.argsort(sequence_lengths)[::-1] - sorted_lengths = sequence_lengths[indices] - bins_remaining_space: list = [] # Tracks remaining capacity in each bin bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin - for seq_id, size in enumerate(sorted_lengths): - global_idx = indices[seq_id] + group_offset + for seq_id, size in enumerate(sequence_lengths): + global_idx = seq_id + group_offset # Try to place sequence in existing bins add_new_bin = True @@ -131,6 +125,7 @@ def pack_parallel( bin_size: int, num_processes: int | None = None, safe_mode: bool = True, + mp_start_method: str | None = "spawn", ): """ Pack sequences into bins using parallel processing @@ -142,7 +137,9 @@ def pack_parallel( bin_size: Maximum number of bins to use num_processes: Number of parallel processes to use safe_mode: If True, use a more conservative packing approach - + mp_start_method: Multiprocessing start method ('fork', 'spawn', 'forkserver'). + 'spawn' is often safer with Numba/PyTorch. + Set to None to use system default. Returns: List of bins, where each bin contains indices of sequences assigned to it """ @@ -159,9 +156,33 @@ def pack_parallel( # Process groups in parallel all_bins = [] - with ProcessPoolExecutor(max_workers=num_processes) as executor: - for group_bins in executor.map(_process_group, tasks): + + mp_ctx = None + if mp_start_method: + try: + mp_ctx = get_context(mp_start_method) + except ValueError: + LOG.warning( + f"Failed to get multiprocessing context '{mp_start_method}'. " + f"Falling back to default. Available: {get_context().get_all_start_methods()}" + ) + mp_ctx = ( + None # Fallback to default context if specified one is not available + ) + + if num_processes == 1: + LOG.debug("Using single process for pack_parallel, running sequentially.") + for task_args in tasks: + group_bins = _process_group(task_args) all_bins.extend(group_bins) + else: + # Use ProcessPoolExecutor only if num_processes > 1 + # Pass mp_context if available + with ProcessPoolExecutor( + max_workers=num_processes, mp_context=mp_ctx + ) as executor: + for group_bins in executor.map(_process_group, tasks): + all_bins.extend(group_bins) return all_bins From ff0857c3e93b9e48a77ea98c89c406ab8b395784 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 22 May 2025 12:23:10 +0100 Subject: [PATCH 13/29] linting --- src/axolotl/common/datasets.py | 15 +++++++++------ src/axolotl/monkeypatch/lora_kernels.py | 3 ++- src/axolotl/utils/logging.py | 15 ++++++++++++++- 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index e7929831e1..ac10b22a25 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -80,12 +80,15 @@ def load_datasets( preprocess_iterable=preprocess_iterable, ) - if cli_args and ( - cli_args.debug - or cfg.debug - or cli_args.debug_text_only - or int(cli_args.debug_num_examples) > 0 - ): + if ( # pylint: disable=too-many-boolean-expressions + cli_args + and ( + cli_args.debug + or cfg.debug + or cli_args.debug_text_only + or int(cli_args.debug_num_examples) > 0 + ) + ) or debug: LOG.info("check_dataset_labels...") num_examples = cli_args.debug_num_examples if cli_args else 1 diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index 34797a759f..11e0989cf5 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -2,6 +2,7 @@ import importlib import inspect +import logging import types from typing import Generator, Tuple, Type @@ -317,7 +318,7 @@ def apply_lora_kernel_patches( # This needs to be reset after patching original_level = LOG.getEffectiveLevel() - LOG.setLevel("INFO") + LOG.setLevel(logging.INFO) # Choose activation based on model type activation = None diff --git a/src/axolotl/utils/logging.py b/src/axolotl/utils/logging.py index ab004ddc31..9bae7c4401 100644 --- a/src/axolotl/utils/logging.py +++ b/src/axolotl/utils/logging.py @@ -2,14 +2,16 @@ logging helpers to only log on main process """ +import functools import logging import os from axolotl.utils.distributed import is_main_process - # Adapted from Accelerate # https://github.com/huggingface/accelerate/blob/main/src/accelerate/logging.py + + class MultiProcessAdapter(logging.LoggerAdapter): """ logger adapter for distributed logging, specifically to only log on main process @@ -30,6 +32,17 @@ def log(self, level, msg, *args, **kwargs): msg, kwargs = self.process(msg, kwargs) self.logger.log(level, msg, *args, **kwargs) + @functools.lru_cache(None) + def warning_once(self, *args, **kwargs): + """ + This method is identical to `logger.warning()`, but will emit the warning with the same message only once + + Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the + cache. The assumption here is that all warning messages are unique across the code. If they aren't then need to + switch to another type of cache that includes the caller frame information in the hashing function. + """ + self.warning(*args, **kwargs) + def get_logger(name: str, log_level: str | None = None): if log_level is None: From f01174259708b7528fa1727b50dbc69a6189b9f7 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 22 May 2025 12:35:33 +0100 Subject: [PATCH 14/29] linting --- src/axolotl/core/trainers/grpo/__init__.py | 6 ++---- src/axolotl/utils/logging.py | 2 +- src/axolotl/utils/schemas/config.py | 7 +++---- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index f7b5004798..196cdb56a5 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -44,10 +44,8 @@ def set_training_args_kwargs(cls, cfg: DictDefault) -> dict[str, Any]: if trl.use_vllm: grpo_args_kwargs["use_vllm"] = trl.use_vllm - # type: ignore[attr-defined] - grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host - # type: ignore[attr-defined] - grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port + grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host # type: ignore[attr-defined] + grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port # type: ignore[attr-defined] if trl.vllm_server_timeout: grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout if trl.vllm_guided_decoding_regex: diff --git a/src/axolotl/utils/logging.py b/src/axolotl/utils/logging.py index 9bae7c4401..16ba857ce3 100644 --- a/src/axolotl/utils/logging.py +++ b/src/axolotl/utils/logging.py @@ -32,7 +32,7 @@ def log(self, level, msg, *args, **kwargs): msg, kwargs = self.process(msg, kwargs) self.logger.log(level, msg, *args, **kwargs) - @functools.lru_cache(None) + @functools.lru_cache(maxsize=10) def warning_once(self, *args, **kwargs): """ This method is identical to `logger.warning()`, but will emit the warning with the same message only once diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 7ed5b45316..9e1fb48f32 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -126,9 +126,7 @@ class AxolotlInputConfig( default=None, json_schema_extra={"description": "streaming dataset to use for pretraining"}, ) - dataset_processes: int | None = Field( - default=min(32, os.cpu_count()) - ) # type: ignore[type-var] + dataset_processes: int | None = Field(default=min(32, os.cpu_count())) # type: ignore[type-var] dataset_exact_deduplication: bool | None = None dataset_keep_in_memory: bool | None = None dataloader_pin_memory: bool | None = None @@ -464,6 +462,7 @@ def check_sample_packing_wo_flash(cls, data): and not data.get("flash_attention") and not data.get("sdp_attention") and not data.get("flex_attention") + and not data.get("xformers_attention") ): LOG.warning( "sample_packing without flash, sdp, xformers or flex attention does not handle cross sample decontamination." @@ -1169,7 +1168,7 @@ def check_grpo_liger_sequence_parallel(cls, data): data.get("rl") == "grpo" and data.get("trl", {}) and data.get("trl").get("use_liger_loss") - and data.get("adapter") + and data.get("sequence_parallel_degree", 1) > 1 ): raise ValueError("GRPO + SP + Liger not currently supported") return data From 3dbe0f4ca682f3c6cd47b8fa2f84960426a6757a Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 22 May 2025 12:59:48 +0100 Subject: [PATCH 15/29] linting --- src/axolotl/utils/logging.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/logging.py b/src/axolotl/utils/logging.py index 16ba857ce3..91b53d9dc3 100644 --- a/src/axolotl/utils/logging.py +++ b/src/axolotl/utils/logging.py @@ -27,10 +27,9 @@ def log(self, level, msg, *args, **kwargs): main_process_only = kwargs.pop("main_process_only", True) kwargs.setdefault("stacklevel", 2) - if self.isEnabledFor(level): - if self._should_log(main_process_only): - msg, kwargs = self.process(msg, kwargs) - self.logger.log(level, msg, *args, **kwargs) + if self.isEnabledFor(level) and self._should_log(main_process_only): + msg, kwargs = self.process(msg, kwargs) + self.logger.log(level, msg, *args, **kwargs) @functools.lru_cache(maxsize=10) def warning_once(self, *args, **kwargs): From faa1e032ddc9e02c5e57d57242a1725ebcaf4d5d Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 22 May 2025 14:20:55 +0100 Subject: [PATCH 16/29] linting --- src/axolotl/integrations/cut_cross_entropy/__init__.py | 9 ++++----- src/axolotl/integrations/liger/__init__.py | 9 ++++----- src/axolotl/utils/logging.py | 9 ++++++--- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index a05426044a..43c156ad2b 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -24,7 +24,6 @@ from axolotl.integrations.base import BasePlugin from axolotl.utils import get_pytorch_version -from axolotl.utils.distributed import is_main_process from axolotl.utils.logging import get_logger from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401 @@ -76,10 +75,10 @@ def pre_model_load(self, cfg): cce_patch, ) - if is_main_process(use_environ=True): - LOG.info( - f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}" - ) + LOG.info( + f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}", + use_environ=True, + ) # The patch checks model_type internally cce_patch(cfg.model_config_type) diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index a5c4588693..bf477cbee4 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -22,7 +22,6 @@ import sys from axolotl.integrations.base import BasePlugin -from axolotl.utils.distributed import is_main_process from axolotl.utils.logging import get_logger from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 @@ -85,10 +84,10 @@ def pre_model_load(self, cfg): kwargs["geglu"] = cfg.liger_glu_activation elif "swiglu" in liger_fn_sig.parameters: kwargs["swiglu"] = cfg.liger_glu_activation - if is_main_process(use_environ=True): - LOG.info( - f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}" - ) + LOG.info( + f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}", + use_environ=True, + ) apply_liger_fn(**kwargs) elif cfg.model_config_type == "jamba": from transformers.models.jamba import modeling_jamba diff --git a/src/axolotl/utils/logging.py b/src/axolotl/utils/logging.py index 91b53d9dc3..97cc861404 100644 --- a/src/axolotl/utils/logging.py +++ b/src/axolotl/utils/logging.py @@ -18,16 +18,19 @@ class MultiProcessAdapter(logging.LoggerAdapter): """ @staticmethod - def _should_log(main_process_only): + def _should_log(main_process_only, use_environ): return not main_process_only or ( - main_process_only and is_main_process(use_environ=False) + main_process_only and is_main_process(use_environ=use_environ) ) def log(self, level, msg, *args, **kwargs): + use_environ = kwargs.pop("use_environ", False) main_process_only = kwargs.pop("main_process_only", True) kwargs.setdefault("stacklevel", 2) - if self.isEnabledFor(level) and self._should_log(main_process_only): + if self.isEnabledFor(level) and self._should_log( + main_process_only, use_environ + ): msg, kwargs = self.process(msg, kwargs) self.logger.log(level, msg, *args, **kwargs) From 598fcf65b0aa86300b37983b1fab946c121387b7 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 22 May 2025 16:16:19 +0100 Subject: [PATCH 17/29] fixing logging --- src/axolotl/cli/config.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index 9718d52541..579775c1c3 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -67,7 +67,8 @@ def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: # Log a warning but do not raise an error; JSON is technically valid YAML. # This can happen when you forget to point to a raw GitHub link. LOG.warning( - f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended." + f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended.", + use_environ=True, ) except json.JSONDecodeError: # If it's not valid JSON, verify it's valid YAML @@ -75,7 +76,8 @@ def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: yaml.safe_load(content) except yaml.YAMLError as err: raise ValueError( - f"Failed to parse the content at {config} as YAML: {err}" + f"Failed to parse the content at {config} as YAML: {err}", + use_environ=True, ) from err # Write the content to a file if it's valid YAML (or JSON treated as YAML) @@ -83,7 +85,8 @@ def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: with open(output_path, "wb") as file: file.write(content) LOG.info( - f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n" + f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n", + use_environ=True, ) return output_path @@ -119,12 +122,12 @@ def choose_config(path: Path) -> str: ) if len(yaml_files) == 1: - print(f"Using default YAML file '{yaml_files[0]}'") + LOG.info(f"Using default YAML file '{yaml_files[0]}'", use_environ=True) return str(yaml_files[0]) - print("Choose a YAML file:") + LOG.info("Choose a YAML file:") for idx, file in enumerate(yaml_files): - print(f"{idx + 1}. {file}") + LOG.info(f"{idx + 1}. {file}", use_environ=True) chosen_file = None while chosen_file is None: @@ -133,9 +136,12 @@ def choose_config(path: Path) -> str: if 1 <= choice <= len(yaml_files): chosen_file = str(yaml_files[choice - 1]) else: - print("Invalid choice. Please choose a number from the list.") + LOG.info( + "Invalid choice. Please choose a number from the list.", + use_environ=True, + ) except ValueError: - print("Invalid input. Please enter a number.") + LOG.info("Invalid input. Please enter a number.", use_environ=True) return chosen_file From 9dcb21df912eb957035fc567ea912cc22abd264d Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 22 May 2025 17:03:18 +0100 Subject: [PATCH 18/29] fixing logging --- src/axolotl/cli/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index 579775c1c3..c9789a419d 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -125,7 +125,7 @@ def choose_config(path: Path) -> str: LOG.info(f"Using default YAML file '{yaml_files[0]}'", use_environ=True) return str(yaml_files[0]) - LOG.info("Choose a YAML file:") + LOG.info("Choose a YAML file:", use_environ=True) for idx, file in enumerate(yaml_files): LOG.info(f"{idx + 1}. {file}", use_environ=True) From a4f00f2f706c8195040be1f21f10fb0c2f4ae4ba Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 22 May 2025 17:06:02 +0100 Subject: [PATCH 19/29] debugging --- src/axolotl/utils/logging.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/axolotl/utils/logging.py b/src/axolotl/utils/logging.py index 97cc861404..dd779ab527 100644 --- a/src/axolotl/utils/logging.py +++ b/src/axolotl/utils/logging.py @@ -32,6 +32,7 @@ def log(self, level, msg, *args, **kwargs): main_process_only, use_environ ): msg, kwargs = self.process(msg, kwargs) + self.logger.log(level, "BOOOO", *args, **kwargs) self.logger.log(level, msg, *args, **kwargs) @functools.lru_cache(maxsize=10) From e37ff32cc3ca2566ccc3c175beeba6e93b968676 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 22 May 2025 17:08:40 +0100 Subject: [PATCH 20/29] debugging --- src/axolotl/utils/schemas/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/schemas/model.py b/src/axolotl/utils/schemas/model.py index 25a6ffd4f9..6b52cbebde 100644 --- a/src/axolotl/utils/schemas/model.py +++ b/src/axolotl/utils/schemas/model.py @@ -31,7 +31,8 @@ class ModelInputConfig(BaseModel): def hint_trust_remote_code(cls, trust_remote_code): if trust_remote_code: LOG.warning( - "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model." + "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model.", + use_environ=True, ) return trust_remote_code From 7ce9baab61bb2a754010b8a48206f8b7dd9b508d Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 22 May 2025 17:10:15 +0100 Subject: [PATCH 21/29] debugging --- src/axolotl/utils/logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/logging.py b/src/axolotl/utils/logging.py index dd779ab527..08cf1d6427 100644 --- a/src/axolotl/utils/logging.py +++ b/src/axolotl/utils/logging.py @@ -32,7 +32,7 @@ def log(self, level, msg, *args, **kwargs): main_process_only, use_environ ): msg, kwargs = self.process(msg, kwargs) - self.logger.log(level, "BOOOO", *args, **kwargs) + self.logger.log(level, f"use_environ: {use_environ}", *args, **kwargs) self.logger.log(level, msg, *args, **kwargs) @functools.lru_cache(maxsize=10) From 7f12bde31f3a8dbc2e465756b7a4629e9d0ae884 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 22 May 2025 17:11:32 +0100 Subject: [PATCH 22/29] debugging --- src/axolotl/utils/distributed.py | 5 ++--- src/axolotl/utils/logging.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index adf6fa33e1..8c52102c89 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -80,11 +80,10 @@ def is_main_process(use_environ=False): Returns: - bool: `True` if the current process is the main process, `False` otherwise. """ - if not is_distributed(): - return True if use_environ: return os.environ.get("LOCAL_RANK", "0") == "0" - + if not is_distributed(): + return True return dist.get_rank() == 0 diff --git a/src/axolotl/utils/logging.py b/src/axolotl/utils/logging.py index 08cf1d6427..93e7e56a97 100644 --- a/src/axolotl/utils/logging.py +++ b/src/axolotl/utils/logging.py @@ -18,7 +18,7 @@ class MultiProcessAdapter(logging.LoggerAdapter): """ @staticmethod - def _should_log(main_process_only, use_environ): + def _should_log(main_process_only, use_environ=False): return not main_process_only or ( main_process_only and is_main_process(use_environ=use_environ) ) @@ -29,7 +29,7 @@ def log(self, level, msg, *args, **kwargs): kwargs.setdefault("stacklevel", 2) if self.isEnabledFor(level) and self._should_log( - main_process_only, use_environ + main_process_only, use_environ=use_environ ): msg, kwargs = self.process(msg, kwargs) self.logger.log(level, f"use_environ: {use_environ}", *args, **kwargs) From 42921f6c118e9df2e66e375db5c0e6952c2a6199 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 22 May 2025 17:17:28 +0100 Subject: [PATCH 23/29] fixed logging --- src/axolotl/cli/config.py | 3 +-- src/axolotl/utils/logging.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index c9789a419d..1698aeaebd 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -76,8 +76,7 @@ def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: yaml.safe_load(content) except yaml.YAMLError as err: raise ValueError( - f"Failed to parse the content at {config} as YAML: {err}", - use_environ=True, + f"Failed to parse the content at {config} as YAML: {err}" ) from err # Write the content to a file if it's valid YAML (or JSON treated as YAML) diff --git a/src/axolotl/utils/logging.py b/src/axolotl/utils/logging.py index 93e7e56a97..94703354d2 100644 --- a/src/axolotl/utils/logging.py +++ b/src/axolotl/utils/logging.py @@ -32,7 +32,6 @@ def log(self, level, msg, *args, **kwargs): main_process_only, use_environ=use_environ ): msg, kwargs = self.process(msg, kwargs) - self.logger.log(level, f"use_environ: {use_environ}", *args, **kwargs) self.logger.log(level, msg, *args, **kwargs) @functools.lru_cache(maxsize=10) From 9c8403dd530b5eacd55d0094f47f909920512176 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Fri, 23 May 2025 13:20:07 +0100 Subject: [PATCH 24/29] configuring use_environ with get_logger --- src/axolotl/cli/config.py | 13 +++++-------- .../integrations/cut_cross_entropy/__init__.py | 3 +-- src/axolotl/integrations/liger/__init__.py | 3 +-- src/axolotl/utils/logging.py | 10 +++++++--- src/axolotl/utils/schemas/config.py | 12 ++++++------ src/axolotl/utils/schemas/model.py | 3 +-- 6 files changed, 21 insertions(+), 23 deletions(-) diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index 1698aeaebd..29bbba2b6a 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -26,7 +26,7 @@ from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env from axolotl.utils.wandb_ import setup_wandb_env_vars -LOG = get_logger(__name__) +LOG = get_logger(__name__, use_environ=True) def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: @@ -68,7 +68,6 @@ def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: # This can happen when you forget to point to a raw GitHub link. LOG.warning( f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended.", - use_environ=True, ) except json.JSONDecodeError: # If it's not valid JSON, verify it's valid YAML @@ -85,7 +84,6 @@ def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: file.write(content) LOG.info( f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n", - use_environ=True, ) return output_path @@ -121,12 +119,12 @@ def choose_config(path: Path) -> str: ) if len(yaml_files) == 1: - LOG.info(f"Using default YAML file '{yaml_files[0]}'", use_environ=True) + LOG.info(f"Using default YAML file '{yaml_files[0]}'") return str(yaml_files[0]) - LOG.info("Choose a YAML file:", use_environ=True) + LOG.info("Choose a YAML file:") for idx, file in enumerate(yaml_files): - LOG.info(f"{idx + 1}. {file}", use_environ=True) + LOG.info(f"{idx + 1}. {file}") chosen_file = None while chosen_file is None: @@ -137,10 +135,9 @@ def choose_config(path: Path) -> str: else: LOG.info( "Invalid choice. Please choose a number from the list.", - use_environ=True, ) except ValueError: - LOG.info("Invalid input. Please enter a number.", use_environ=True) + LOG.info("Invalid input. Please enter a number.") return chosen_file diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index 43c156ad2b..62e56f197c 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -28,7 +28,7 @@ from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401 -LOG = get_logger(__name__) +LOG = get_logger(__name__, use_environ=True) _CCE_INSTALL_MESSAGE = ( "Please install cut_cross_entropy with transformers support using " @@ -77,7 +77,6 @@ def pre_model_load(self, cfg): LOG.info( f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}", - use_environ=True, ) # The patch checks model_type internally diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index bf477cbee4..974b1de27a 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -27,7 +27,7 @@ from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 from .utils import patch_with_compile_disable -LOG = get_logger(__name__) +LOG = get_logger(__name__, use_environ=True) class LigerPlugin(BasePlugin): @@ -86,7 +86,6 @@ def pre_model_load(self, cfg): kwargs["swiglu"] = cfg.liger_glu_activation LOG.info( f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}", - use_environ=True, ) apply_liger_fn(**kwargs) elif cfg.model_config_type == "jamba": diff --git a/src/axolotl/utils/logging.py b/src/axolotl/utils/logging.py index 94703354d2..260cce83ac 100644 --- a/src/axolotl/utils/logging.py +++ b/src/axolotl/utils/logging.py @@ -17,6 +17,10 @@ class MultiProcessAdapter(logging.LoggerAdapter): logger adapter for distributed logging, specifically to only log on main process """ + def __init__(self, logger, use_environ=False, extra=None): + super().__init__(logger, extra) + self.use_environ = use_environ + @staticmethod def _should_log(main_process_only, use_environ=False): return not main_process_only or ( @@ -24,7 +28,7 @@ def _should_log(main_process_only, use_environ=False): ) def log(self, level, msg, *args, **kwargs): - use_environ = kwargs.pop("use_environ", False) + use_environ = kwargs.pop("use_environ", self.use_environ) main_process_only = kwargs.pop("main_process_only", True) kwargs.setdefault("stacklevel", 2) @@ -46,11 +50,11 @@ def warning_once(self, *args, **kwargs): self.warning(*args, **kwargs) -def get_logger(name: str, log_level: str | None = None): +def get_logger(name: str, log_level: str | None = None, use_environ: bool = False): if log_level is None: log_level = os.environ.get("AXOLOTL_LOG_LEVEL", None) logger = logging.getLogger(name) if log_level is not None: logger.setLevel(log_level.upper()) logger.root.setLevel(log_level.upper()) - return MultiProcessAdapter(logger, {}) + return MultiProcessAdapter(logger, use_environ=use_environ, extra={}) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 9e1fb48f32..3952f149d3 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -48,7 +48,7 @@ from axolotl.utils.schemas.trl import TRLConfig from axolotl.utils.schemas.vllm import VllmConfig -LOG = get_logger(__name__) +LOG = get_logger(__name__, use_environ=True) SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"} @@ -923,7 +923,7 @@ def warn_qlora_zero3_w_use_reentrant(cls, data): # Recomputed values for the following tensors have different metadata # than during the forward pass. LOG.warning( - "qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values" + "qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values", ) return data @@ -944,7 +944,7 @@ def check_eval_strategy(cls, data): and data.get("eval_strategy") is None ): LOG.info( - "explicitly setting `eval_strategy` from the `evaluation_strategy`" + "explicitly setting `eval_strategy` from the `evaluation_strategy`", ) data["eval_strategy"] = data.get("evaluation_strategy") return data @@ -1207,7 +1207,7 @@ def check_sequence_parallel_degree(self): "Please note that logged losses may differ slightly to the non-SP " "losses due to transformers Trainer implementation details. " "Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " - "for more details." + "for more details.", ) return self @@ -1253,7 +1253,7 @@ def check_bf16(self): if self.capabilities.bf16: if not self.bf16 and not self.bfloat16: LOG.info( - "bf16 support detected, but not enabled for this configuration." + "bf16 support detected, but not enabled for this configuration.", ) else: if ( @@ -1282,7 +1282,7 @@ def check_sample_packing_w_sdpa_bf16(cls, data): # https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450 LOG.warning( "sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. " - "This may work on H100s." + "This may work on H100s.", ) return data diff --git a/src/axolotl/utils/schemas/model.py b/src/axolotl/utils/schemas/model.py index 6b52cbebde..0603fac076 100644 --- a/src/axolotl/utils/schemas/model.py +++ b/src/axolotl/utils/schemas/model.py @@ -4,7 +4,7 @@ from axolotl.utils.logging import get_logger -LOG = get_logger(__name__) +LOG = get_logger(__name__, use_environ=True) class ModelInputConfig(BaseModel): @@ -32,7 +32,6 @@ def hint_trust_remote_code(cls, trust_remote_code): if trust_remote_code: LOG.warning( "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model.", - use_environ=True, ) return trust_remote_code From 9dfd461a9d6c39a332497dabf0d3345fb90f2a6b Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Fri, 23 May 2025 17:31:27 +0100 Subject: [PATCH 25/29] comments --- examples/llama-3/qlora-1b-kto.yaml | 3 +- src/axolotl/cli/config.py | 4 +-- src/axolotl/monkeypatch/unsloth_.py | 8 ++--- src/axolotl/utils/data/sft.py | 44 +++++++------------------ src/axolotl/utils/logging.py | 4 ++- src/axolotl/utils/models.py | 51 ++++++++--------------------- src/axolotl/utils/schemas/config.py | 6 ++-- src/axolotl/utils/schemas/model.py | 2 +- 8 files changed, 36 insertions(+), 86 deletions(-) diff --git a/examples/llama-3/qlora-1b-kto.yaml b/examples/llama-3/qlora-1b-kto.yaml index aa52a62ef2..89a51ea68f 100644 --- a/examples/llama-3/qlora-1b-kto.yaml +++ b/examples/llama-3/qlora-1b-kto.yaml @@ -40,8 +40,7 @@ wandb_log_model: gradient_accumulation_steps: 1 micro_batch_size: 2 -# num_epochs: 1 -max_steps: 2 +num_epochs: 1 optimizer: adamw_8bit lr_scheduler: cosine learning_rate: 0.0002 diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index 29bbba2b6a..421dca5dd5 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -133,9 +133,7 @@ def choose_config(path: Path) -> str: if 1 <= choice <= len(yaml_files): chosen_file = str(yaml_files[choice - 1]) else: - LOG.info( - "Invalid choice. Please choose a number from the list.", - ) + LOG.info("Invalid choice. Please choose a number from the list.") except ValueError: LOG.info("Invalid input. Please enter a number.") diff --git a/src/axolotl/monkeypatch/unsloth_.py b/src/axolotl/monkeypatch/unsloth_.py index 3566be7253..bed780e7d6 100644 --- a/src/axolotl/monkeypatch/unsloth_.py +++ b/src/axolotl/monkeypatch/unsloth_.py @@ -133,9 +133,7 @@ def patch_self_attn_lora(): ) exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102 self_attn_lora_patched = True - LOG.info( - "patching unsloth attn lora", - ) + LOG.info("patching unsloth attn lora") LlamaFlashAttention2.forward = ( unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821 ) @@ -155,9 +153,7 @@ def apply_rotary_pos_emb( # pylint: disable=unused-argument ): return fast_rope_embedding(q, k, cos, sin) - LOG.info( - "patching unsloth RoPE embeddings", - ) + LOG.info("patching unsloth RoPE embeddings") transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index ebe9dba40c..8aa9654ecb 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -167,9 +167,7 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None): ) if cfg.dataset_exact_deduplication: - LOG.info( - "Deduplication not available for pretrained datasets", - ) + LOG.info("Deduplication not available for pretrained datasets") return train_dataset, eval_dataset, cfg.max_steps, prompters @@ -264,36 +262,26 @@ def load_tokenized_prepared_datasets( f"Loading prepared dataset from disk at {prepared_ds_path}...", ) dataset = load_from_disk(str(prepared_ds_path)) - LOG.info( - "Prepared dataset loaded from disk...", - ) + LOG.info("Prepared dataset loaded from disk...") else: if cfg.push_dataset_to_hub: - LOG.info( - "Unable to find prepared dataset in Huggingface hub", - ) + LOG.info("Unable to find prepared dataset in Huggingface hub") if cfg.is_preprocess: LOG.info( - f"Skipping prepared dataset in {prepared_ds_path} for pre-processing...", + f"Skipping prepared dataset in {prepared_ds_path} for pre-processing..." ) else: - LOG.info( - f"Unable to find prepared dataset in {prepared_ds_path}", - ) - LOG.info( - "Loading raw datasets...", - ) + LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}") + LOG.info("Loading raw datasets...") if not cfg.is_preprocess: LOG.warning( - "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset.", + "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset." ) if cfg.seed: seed = cfg.seed else: - LOG.info( - "No seed provided, using default seed of 42", - ) + LOG.info("No seed provided, using default seed of 42") seed = 42 datasets = [] @@ -346,21 +334,15 @@ def load_tokenized_prepared_datasets( if len(datasets) == 1: dataset = datasets[0] else: - LOG.info( - "Merging datasets...", - ) + LOG.info("Merging datasets...") dataset = concatenate_datasets(datasets) if len(datasets) > 1: if cfg.shuffle_merged_datasets: - LOG.debug( - "Shuffling merged datasets...", - ) + LOG.debug("Shuffling merged datasets...") dataset = dataset.shuffle(seed=seed) else: - LOG.debug( - "NOT shuffling merged datasets", - ) + LOG.debug("NOT shuffling merged datasets") if not cfg.skip_prepare_dataset: dataset = drop_long_seq_in_dataset(dataset, cfg) @@ -369,9 +351,7 @@ def load_tokenized_prepared_datasets( dataset, _ = process_datasets_for_packing(cfg, dataset, None) if cfg.local_rank == 0 and not cfg.skip_prepare_dataset: - LOG.info( - f"Saving merged prepared dataset to disk... {prepared_ds_path}", - ) + LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") if isinstance(dataset, IterableDataset): num_workers = cfg.dataset_processes diff --git a/src/axolotl/utils/logging.py b/src/axolotl/utils/logging.py index 260cce83ac..80daab4eaa 100644 --- a/src/axolotl/utils/logging.py +++ b/src/axolotl/utils/logging.py @@ -50,7 +50,9 @@ def warning_once(self, *args, **kwargs): self.warning(*args, **kwargs) -def get_logger(name: str, log_level: str | None = None, use_environ: bool = False): +def get_logger( + name: str, log_level: str | None = None, use_environ: bool = False +) -> MultiProcessAdapter: if log_level is None: log_level = os.environ.get("AXOLOTL_LOG_LEVEL", None) logger = logging.getLogger(name) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index fc99371c50..bd345ab704 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -776,18 +776,14 @@ def patch_llama_derived_model(self): if self.cfg.sample_packing: if self.cfg.device not in ["mps", "cpu"] and not self.inference: - LOG.info( - "patching with flash attention for sample packing", - ) + LOG.info("patching with flash attention for sample packing") replace_llama_attn_with_flash_attn( packed=True, cross_entropy=self.cfg.flash_attn_cross_entropy, rms_norm=self.cfg.flash_attn_rms_norm, ) elif self.cfg.s2_attention: - LOG.info( - "patching w/ flash-enabled, shifted-sparse attention", - ) + LOG.info("patching w/ flash-enabled, shifted-sparse attention") replace_llama_attn_with_flash_attn( packed=False, cross_entropy=self.cfg.flash_attn_cross_entropy, @@ -805,18 +801,14 @@ def patch_llama_derived_model(self): hijack_llama_attention, ) - LOG.info( - "patching with xformers attention", - ) + LOG.info("patching with xformers attention") hijack_llama_attention() elif self.cfg.sample_packing: from axolotl.monkeypatch.llama_patch_multipack import ( hijack_llama_prepare_4d_mask, ) - LOG.info( - "patching llama _prepare_4d_causal_attention_mask*", - ) + LOG.info("patching llama _prepare_4d_causal_attention_mask*") hijack_llama_prepare_4d_mask() elif self.cfg.s2_attention: raise NotImplementedError( @@ -898,7 +890,7 @@ def set_quantization_config(self) -> None: if self.cfg.gptq: if not hasattr(self.model_config, "quantization_config"): LOG.warning( - "model config does not contain quantization_config information", + "model config does not contain quantization_config information" ) else: if self.cfg.gptq_disable_exllama is not None: @@ -1110,15 +1102,11 @@ def _configure_zero3_memory_efficient_loading(): ) if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): - LOG.info( - "patching with SwiGLU", - ) + LOG.info("patching with SwiGLU") replace_llama_mlp_with_swiglu(self.model) if self.cfg.flash_attn_fuse_qkv: - LOG.info( - "patching with fused QKV", - ) + LOG.info("patching with fused QKV") replace_llama_qkv_with_fused(self.model) elif self.model_type == "MambaLMHeadModel": # FIXME this is janky at best and hacked together to make it work @@ -1255,9 +1243,7 @@ def prepare_model(self, qlora_fsdp: bool) -> None: and self.cfg.adapter in ["lora", "qlora"] and (self.cfg.load_in_8bit or self.cfg.load_in_4bit) ): - LOG.info( - "converting PEFT model w/ prepare_model_for_kbit_training", - ) + LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") self.model = prepare_model_for_kbit_training( self.model, use_gradient_checkpointing=self.cfg.gradient_checkpointing ) @@ -1396,10 +1382,7 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: ) if should_convert: - LOG.info( - "Converting modules to %s", - self.cfg.torch_dtype, - ) + LOG.info(f"Converting modules to {self.cfg.torch_dtype}") self.convert_embedding_modules_dtype( embedding_modules=embedding_modules, dist_dtype=self.cfg.torch_dtype, @@ -1452,9 +1435,7 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: if param.requires_grad: requires_grad.append(f"{name}: {param.requires_grad}") if len(requires_grad) == 0: - LOG.warning( - "there are no parameters that require gradient updates", - ) + LOG.warning("there are no parameters that require gradient updates") if self.cfg.flash_optimum: from optimum.bettertransformer import BetterTransformer @@ -1528,9 +1509,7 @@ def load_llama_adapter(model, cfg): ) if cfg.lora_model_dir: - LOG.debug( - "Loading pretrained PEFT - llama_adapter", - ) + LOG.debug("Loading pretrained PEFT - llama_adapter") model = PeftModel.from_pretrained( model, cfg.lora_model_dir, @@ -1616,9 +1595,7 @@ def load_lora(model, cfg, inference=False, config_only=False): lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights if cfg.peft_use_dora: lora_config_kwargs["use_dora"] = cfg.peft_use_dora - LOG.info( - "Initializing LoRA weights using dora. This might take longer.", - ) + LOG.info("Initializing LoRA weights using dora. This might take longer.") if cfg.peft_use_rslora: lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora if cfg.peft_layer_replication: @@ -1652,9 +1629,7 @@ def load_lora(model, cfg, inference=False, config_only=False): setup_quantized_meta_for_peft(model) if cfg.lora_model_dir: - LOG.debug( - "Loading pretrained PEFT - LoRA", - ) + LOG.debug("Loading pretrained PEFT - LoRA") model_kwargs: Any = {} if cfg.lora_on_cpu: model_kwargs["max_memory"] = {"cpu": "256GiB"} diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 3952f149d3..91337dcc68 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -923,7 +923,7 @@ def warn_qlora_zero3_w_use_reentrant(cls, data): # Recomputed values for the following tensors have different metadata # than during the forward pass. LOG.warning( - "qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values", + "qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values" ) return data @@ -944,7 +944,7 @@ def check_eval_strategy(cls, data): and data.get("eval_strategy") is None ): LOG.info( - "explicitly setting `eval_strategy` from the `evaluation_strategy`", + "explicitly setting `eval_strategy` from the `evaluation_strategy`" ) data["eval_strategy"] = data.get("evaluation_strategy") return data @@ -1253,7 +1253,7 @@ def check_bf16(self): if self.capabilities.bf16: if not self.bf16 and not self.bfloat16: LOG.info( - "bf16 support detected, but not enabled for this configuration.", + "bf16 support detected, but not enabled for this configuration." ) else: if ( diff --git a/src/axolotl/utils/schemas/model.py b/src/axolotl/utils/schemas/model.py index 0603fac076..57f5ae309c 100644 --- a/src/axolotl/utils/schemas/model.py +++ b/src/axolotl/utils/schemas/model.py @@ -31,7 +31,7 @@ class ModelInputConfig(BaseModel): def hint_trust_remote_code(cls, trust_remote_code): if trust_remote_code: LOG.warning( - "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model.", + "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model." ) return trust_remote_code From dd76a13e4f888eb2fe90b244ef68e803b08abd74 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Tue, 27 May 2025 18:36:43 +0100 Subject: [PATCH 26/29] comments-fixing test --- src/axolotl/cli/config.py | 2 +- .../cut_cross_entropy/__init__.py | 2 +- src/axolotl/integrations/liger/__init__.py | 4 +-- src/axolotl/monkeypatch/unsloth_.py | 8 +++--- src/axolotl/train.py | 4 +-- src/axolotl/utils/config/__init__.py | 2 +- src/axolotl/utils/data/sft.py | 10 +++---- src/axolotl/utils/data/utils.py | 12 +++------ src/axolotl/utils/models.py | 26 +++++-------------- src/axolotl/utils/trainer.py | 22 +++++----------- 10 files changed, 27 insertions(+), 65 deletions(-) diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index 421dca5dd5..58e7f06aa2 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -67,7 +67,7 @@ def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: # Log a warning but do not raise an error; JSON is technically valid YAML. # This can happen when you forget to point to a raw GitHub link. LOG.warning( - f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended.", + f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended." ) except json.JSONDecodeError: # If it's not valid JSON, verify it's valid YAML diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index 62e56f197c..a7e94e3637 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -76,7 +76,7 @@ def pre_model_load(self, cfg): ) LOG.info( - f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}", + f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}" ) # The patch checks model_type internally diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index 974b1de27a..4e67c322dc 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -84,9 +84,7 @@ def pre_model_load(self, cfg): kwargs["geglu"] = cfg.liger_glu_activation elif "swiglu" in liger_fn_sig.parameters: kwargs["swiglu"] = cfg.liger_glu_activation - LOG.info( - f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}", - ) + LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}") apply_liger_fn(**kwargs) elif cfg.model_config_type == "jamba": from transformers.models.jamba import modeling_jamba diff --git a/src/axolotl/monkeypatch/unsloth_.py b/src/axolotl/monkeypatch/unsloth_.py index bed780e7d6..61f4eeea03 100644 --- a/src/axolotl/monkeypatch/unsloth_.py +++ b/src/axolotl/monkeypatch/unsloth_.py @@ -189,7 +189,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM): if is_mlp_lora and mlp_no_bias and mlp_not_dora: layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp) else: - LOG.warning("unable to apply unsloth lora mlp patch to layer %d", idx) + LOG.warning(f"unable to apply unsloth lora mlp patch to layer {idx}") def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg): @@ -215,7 +215,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg): layer.self_attn.apply_qkv = apply_lora_qkv else: layer.self_attn.apply_qkv = original_apply_qkv - LOG.warning("unable to apply unsloth lora qkv patch to layer %d", idx) + LOG.warning(f"unable to apply unsloth lora qkv patch to layer {idx}") if cfg.unsloth_lora_o: layer_modules = [ getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"] @@ -234,9 +234,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg): layer.self_attn.apply_o = apply_lora_o else: layer.self_attn.apply_o = original_apply_o - LOG.warning( - "unable to apply unsloth lora o_proj patch to layer %d", idx - ) + LOG.warning(f"unable to apply unsloth lora o_proj patch to layer {idx}") def patch_unsloth_layernorm(): diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 81d7596c07..da08fe1a64 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -60,9 +60,7 @@ def setup_model_and_tokenizer( `None`), and processor (if multimodal, else `None`). """ # Load tokenizer - LOG.debug( - f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", - ) + LOG.debug(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") tokenizer = load_tokenizer(cfg) # Load processor for multimodal models if needed diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index bde499fd4b..1cc10f842a 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -20,7 +20,7 @@ from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset -LOG = get_logger(__name__) +LOG = get_logger(__name__, use_environ=True) def choose_device(cfg): diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 8aa9654ecb..88c78174bc 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -184,9 +184,7 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None): ) else: total_num_steps = calculate_total_num_steps(cfg, train_dataset) - LOG.info( - f"Maximum number of steps set at {total_num_steps}", - ) + LOG.info(f"Maximum number of steps set at {total_num_steps}") return train_dataset, eval_dataset, total_num_steps, prompters @@ -237,7 +235,7 @@ def load_tokenized_prepared_datasets( try: if cfg.push_dataset_to_hub: LOG.info( - f"Attempting to load prepared dataset from Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})...", + f"Attempting to load prepared dataset from Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..." ) dataset = load_dataset( cfg.push_dataset_to_hub, @@ -258,9 +256,7 @@ def load_tokenized_prepared_datasets( and not cfg.is_preprocess and not cfg.skip_prepare_dataset ): - LOG.info( - f"Loading prepared dataset from disk at {prepared_ds_path}...", - ) + LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") dataset = load_from_disk(str(prepared_ds_path)) LOG.info("Prepared dataset loaded from disk...") else: diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index 6202d336d6..b22f3bcbba 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -174,13 +174,9 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): try: ds_lengths = get_dataset_lengths(dataset, from_arrow=True) min_input_len = np.min(ds_lengths) - LOG.info( - f"min_input_len: {min_input_len}", - ) + LOG.info(f"min_input_len: {min_input_len}") max_input_len = np.max(ds_lengths) - LOG.info( - f"max_input_len: {max_input_len}", - ) + LOG.info(f"max_input_len: {max_input_len}") except AttributeError: pass @@ -208,8 +204,6 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): if prior_len: dropped = prior_len - len(dataset) if dropped: - LOG.warning( - f"Dropped {dropped} long samples from dataset", - ) + LOG.warning(f"Dropped {dropped} long samples from dataset") return dataset diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index bd345ab704..e206f2d675 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -139,9 +139,7 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig): and hasattr(model_config.vision_config, "image_size") ): cfg.image_size = model_config.vision_config.image_size - LOG.debug( - f"Loaded image size: {cfg.image_size} from model config", - ) + LOG.debug(f"Loaded image size: {cfg.image_size} from model config") quant_config_exists = ( hasattr(model_config, "quantization_config") @@ -459,15 +457,9 @@ def load_tokenizer(cfg): {"additional_special_tokens": additional_special_tokens} ) - LOG.debug( - f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}", - ) - LOG.debug( - f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}", - ) - LOG.debug( - f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}", - ) + LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") + LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") + LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") if cfg.chat_template: chat_template_string = get_chat_template_from_config( @@ -524,9 +516,7 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): elif im_height is not None: cfg.image_size = im_height - LOG.debug( - f"Loaded image size: {cfg.image_size} from processor", - ) + LOG.debug(f"Loaded image size: {cfg.image_size} from processor") return processor @@ -1179,7 +1169,7 @@ def adjust_model_config(self) -> None: and self.cfg.sequence_len > self.model.config.max_position_embeddings ): LOG.warning( - f"increasing model.config.max_position_embeddings from {self.model.config.max_position_embeddings} to {self.cfg.sequence_len}", + f"increasing model.config.max_position_embeddings from {self.model.config.max_position_embeddings} to {self.cfg.sequence_len}" ) self.model.config.max_position_embeddings = self.cfg.sequence_len @@ -1576,9 +1566,7 @@ def load_lora(model, cfg, inference=False, config_only=False): if cfg.lora_target_linear: linear_names = find_all_linear_names(model) - LOG.info( - f"found linear modules: {repr(sorted(linear_names))}", - ) + LOG.info(f"found linear modules: {repr(sorted(linear_names))}") lora_target_modules_as_list = ( lora_target_modules if isinstance(lora_target_modules, list) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 6cfefe97f0..c08504d73c 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -402,9 +402,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): .apply(len) .values ) - LOG.debug( - f"total_num_tokens: {total_num_tokens:_}", - ) + LOG.debug(f"total_num_tokens: {total_num_tokens:_}") if update: cfg.total_num_tokens = total_num_tokens @@ -422,9 +420,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): .apply(lambda x: np.sum(np.array(x) != -100)) .sum() ) - LOG.debug( - f"`total_supervised_tokens: {total_supervised_tokens:_}`", - ) + LOG.debug(f"`total_supervised_tokens: {total_supervised_tokens:_}`") if update: cfg.total_supervised_tokens = total_supervised_tokens @@ -449,7 +445,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): * cfg.sequence_parallel_degree ) LOG.debug( - f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}", + f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}" ) else: if cfg.flash_attention and not cfg.multipack_real_batches: @@ -478,9 +474,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): batch_sampler=sampler, ) data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size - LOG.debug( - f"data_loader_len: {data_loader_len}", - ) + LOG.debug(f"data_loader_len: {data_loader_len}") # FIXME: is there a bug here somewhere? the total num steps depends # on the agreed on value for sample_packing_eff_est total_num_steps = int( @@ -502,9 +496,7 @@ def calc_sample_packing_eff_est(estimates: List[float]): ) if update: cfg.sample_packing_eff_est = sample_packing_eff_est - LOG.debug( - f"sample_packing_eff_est: {cfg.sample_packing_eff_est}", - ) + LOG.debug(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}") else: total_num_steps = int( math.ceil( @@ -514,9 +506,7 @@ def calc_sample_packing_eff_est(estimates: List[float]): / cfg.batch_size ) ) - LOG.debug( - f"total_num_steps: {total_num_steps}", - ) + LOG.debug(f"total_num_steps: {total_num_steps}") return total_num_steps From 1df853f8cd88ff6dfc4d44a2a93201b447a79e06 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 28 May 2025 11:04:10 +0100 Subject: [PATCH 27/29] CI --- tests/e2e/multigpu/solo/test_grpo.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/e2e/multigpu/solo/test_grpo.py b/tests/e2e/multigpu/solo/test_grpo.py index 575b7a620a..efea52867f 100644 --- a/tests/e2e/multigpu/solo/test_grpo.py +++ b/tests/e2e/multigpu/solo/test_grpo.py @@ -4,7 +4,6 @@ import os import random -import shutil import subprocess # nosec B404 import sys import tempfile @@ -118,7 +117,11 @@ def start_vllm( recursive_kill(process) with open("/tmp/vllm.log", "r", encoding="utf-8") as log_file: print(log_file.read()) - shutil.rmtree("/tmp/vllm.log") + + try: + os.remove("/tmp/vllm.log") + except FileNotFoundError: + pass raise RuntimeError(f"VLLM server process did not start within {wait} seconds.") # return the process From fb5e1d1d48d62140118601fd05272df58150bbc6 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 28 May 2025 13:19:40 +0100 Subject: [PATCH 28/29] fixing trailing commas --- src/axolotl/cli/config.py | 2 +- src/axolotl/prompt_strategies/bradley_terry/chat_template.py | 4 ++-- src/axolotl/utils/data/utils.py | 2 +- src/axolotl/utils/schemas/config.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index 58e7f06aa2..d55448da4d 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -83,7 +83,7 @@ def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: with open(output_path, "wb") as file: file.write(content) LOG.info( - f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n", + f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n" ) return output_path diff --git a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py index 1f693f1b4c..e655f85a1f 100644 --- a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py +++ b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py @@ -44,7 +44,7 @@ def _tokenize_single_prompt(self, prompt): if len(chosen_tokenized["input_ids"]) > max_length: LOG.warning( - f"To-be-trimmed chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}", + f"To-be-trimmed chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}" ) chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length] @@ -62,7 +62,7 @@ def _tokenize_single_prompt(self, prompt): if len(rejected_tokenized["input_ids"]) > max_length: LOG.warning( - f"To-be-trimmed rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}", + f"To-be-trimmed rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}" ) rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][ diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index b22f3bcbba..5f3b8d3cc6 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -161,7 +161,7 @@ def deduplicate_and_log_datasets( def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): if "input_ids" not in dataset.column_names: LOG.warning( - "Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling.", + "Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling." ) return dataset diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 568f7cbc83..aafb433cb1 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1308,7 +1308,7 @@ def check_sample_packing_w_sdpa_bf16(cls, data): # https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450 LOG.warning( "sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. " - "This may work on H100s.", + "This may work on H100s." ) return data From cef31bc0f14fa4a8878dcec242e6333ad3ca39f5 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 28 May 2025 13:22:05 +0100 Subject: [PATCH 29/29] fixing trailing commas --- src/axolotl/utils/schemas/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index aafb433cb1..75551085b6 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1233,7 +1233,7 @@ def check_sequence_parallel_degree(self): "Please note that logged losses may differ slightly to the non-SP " "losses due to transformers Trainer implementation details. " "Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " - "for more details.", + "for more details." ) return self