Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
2d68a1f
quick start
salmanmohammadi May 1, 2025
730fe0d
quick start
salmanmohammadi May 1, 2025
66162cb
refactor log rank zero funcs
winglian May 2, 2025
89d44dd
use multi process logging adapter similar to accelerate
winglian May 2, 2025
823338a
Merge branch 'dist_logging' of github.com:axolotl-ai-cloud/axolotl in…
salmanmohammadi May 7, 2025
39c2690
quick start
salmanmohammadi May 1, 2025
adb78c7
refactor log rank zero funcs
winglian May 2, 2025
aa97c92
use multi process logging adapter similar to accelerate
winglian May 2, 2025
cf7ed4e
wip replacing calls
salmanmohammadi May 15, 2025
014f499
replacing more calls - testing on dist setup
salmanmohammadi May 16, 2025
d0a30f1
updating
salmanmohammadi May 19, 2025
07ae995
seems to be working
salmanmohammadi May 19, 2025
e930ba7
Merge branch 'main' into dist_logging
salmanmohammadi May 19, 2025
589367d
comments
salmanmohammadi May 22, 2025
f734d51
linting
salmanmohammadi May 22, 2025
4199f7c
Merge branch 'main' into dist_logging
salmanmohammadi May 22, 2025
ff0857c
linting
salmanmohammadi May 22, 2025
f011742
linting
salmanmohammadi May 22, 2025
3dbe0f4
linting
salmanmohammadi May 22, 2025
faa1e03
linting
salmanmohammadi May 22, 2025
598fcf6
fixing logging
salmanmohammadi May 22, 2025
9dcb21d
fixing logging
salmanmohammadi May 22, 2025
a4f00f2
debugging
salmanmohammadi May 22, 2025
e37ff32
debugging
salmanmohammadi May 22, 2025
7ce9baa
debugging
salmanmohammadi May 22, 2025
7f12bde
debugging
salmanmohammadi May 22, 2025
42921f6
fixed logging
salmanmohammadi May 22, 2025
9c8403d
configuring use_environ with get_logger
salmanmohammadi May 23, 2025
9dfd461
comments
salmanmohammadi May 23, 2025
dd76a13
comments-fixing test
salmanmohammadi May 27, 2025
c198524
merge conflicts
salmanmohammadi May 27, 2025
1df853f
CI
salmanmohammadi May 28, 2025
20a4621
merging
salmanmohammadi May 28, 2025
ba46957
merging
salmanmohammadi May 28, 2025
fb5e1d1
fixing trailing commas
salmanmohammadi May 28, 2025
cef31bc
fixing trailing commas
salmanmohammadi May 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/llama-3/lora-1b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ base_model: NousResearch/Llama-3.2-1B
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared

val_set_size: 0.1
output_dir: ./outputs/lora-out

Expand Down Expand Up @@ -38,6 +38,7 @@ wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1

optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
Expand Down
5 changes: 3 additions & 2 deletions src/axolotl/cli/checks.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Various checks for Axolotl CLI."""

import logging
import os
from pathlib import Path

from accelerate.commands.config import config_args
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError

LOG = logging.getLogger(__name__)
from axolotl.utils.logging import get_logger

LOG = get_logger(__name__)


def check_accelerate_default_config() -> None:
Expand Down
14 changes: 7 additions & 7 deletions src/axolotl/cli/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Configuration loading and processing."""

import json
import logging
import os
import tempfile
from pathlib import Path
Expand All @@ -22,11 +21,12 @@
validate_config,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars

LOG = logging.getLogger(__name__)
LOG = get_logger(__name__, use_environ=True)


def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
Expand Down Expand Up @@ -119,12 +119,12 @@ def choose_config(path: Path) -> str:
)

if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
LOG.info(f"Using default YAML file '{yaml_files[0]}'")
return str(yaml_files[0])

print("Choose a YAML file:")
LOG.info("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}")
LOG.info(f"{idx + 1}. {file}")

chosen_file = None
while chosen_file is None:
Expand All @@ -133,9 +133,9 @@ def choose_config(path: Path) -> str:
if 1 <= choice <= len(yaml_files):
chosen_file = str(yaml_files[choice - 1])
else:
print("Invalid choice. Please choose a number from the list.")
LOG.info("Invalid choice. Please choose a number from the list.")
except ValueError:
print("Invalid input. Please enter a number.")
LOG.info("Invalid input. Please enter a number.")

return chosen_file

Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/cli/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""CLI to run evaluation on a model."""

import logging
import os
from pathlib import Path
from typing import Union
Expand All @@ -17,8 +16,9 @@
from axolotl.evaluate import evaluate
from axolotl.utils import patch_optimized_env
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger

LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)


def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/cli/inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""CLI to run inference on a trained model."""

import importlib
import logging
import sys
from pathlib import Path
from threading import Thread
Expand All @@ -22,8 +21,9 @@
get_chat_template_from_config,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger

LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)


def get_multi_line_input() -> str:
Expand Down
6 changes: 4 additions & 2 deletions src/axolotl/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

# pylint: disable=redefined-outer-name

import logging
import os
import subprocess # nosec B404
import tempfile
Expand Down Expand Up @@ -31,8 +30,11 @@
)
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import patch_optimized_env
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.config import AxolotlInputConfig

LOG = get_logger(__name__)


@click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
Expand Down Expand Up @@ -177,7 +179,7 @@ def iter_configs():

do_cli(config=cfg_file, **kwargs)
except subprocess.CalledProcessError as exc:
logging.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
if not sweep:
raise exc

Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/cli/merge_lora.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""CLI to merge a trained LoRA into a base model."""

import logging
from pathlib import Path
from typing import Union

Expand All @@ -13,8 +12,9 @@
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger

LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)


def do_merge_lora(*, cfg: DictDefault) -> None:
Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/cli/merge_sharded_fsdp_weights.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""CLI to merge sharded FSDP model checkpoints into a single combined checkpoint."""

import json
import logging
import os
import shutil
from pathlib import Path
Expand All @@ -27,8 +26,9 @@
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.utils.logging import get_logger

LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)


class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/cli/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""CLI to run preprocessing of a dataset."""

import logging
import warnings
from pathlib import Path
from typing import Union
Expand All @@ -20,9 +19,10 @@
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.trainer import disable_datasets_caching

LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)


def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/cli/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
CLI to post-training quantize a model using torchao
"""

import logging
from pathlib import Path
from typing import Union

Expand All @@ -11,9 +10,10 @@
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.loaders import load_tokenizer
from axolotl.utils.logging import get_logger
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq

LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)


def do_quantize(
Expand Down
3 changes: 0 additions & 3 deletions src/axolotl/cli/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""CLI to run training on a model."""

import gc
import logging
import os
from pathlib import Path
from typing import Union
Expand All @@ -22,8 +21,6 @@
from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault

LOG = logging.getLogger(__name__)


def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
"""
Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import dataclasses
import hashlib
import json
import logging
from functools import wraps
from pathlib import Path
from types import NoneType
Expand All @@ -23,8 +22,9 @@
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.loaders.model import ModelLoader
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger

LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)


def strip_optional_type(field_type: type | str | None):
Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/common/datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Dataset loading utilities."""

import logging
import math
import random
from dataclasses import dataclass
Expand All @@ -14,10 +13,11 @@
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
from axolotl.utils.tokenization import check_dataset_labels

LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)


@dataclass
Expand Down
1 change: 0 additions & 1 deletion src/axolotl/core/chat/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def tokenized(
len(input_ids) : len(input_ids) + len(pending_input_ids)
]
if new_pending_inputs != pending_input_ids:
# logging.warning("tokenization mismatch from concatenation.")
pending_input_ids = new_pending_inputs
input_ids.extend(pending_input_ids)
if pending_weight:
Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import importlib
import importlib.util
import inspect
import logging
import math
import os
import sys
Expand Down Expand Up @@ -88,14 +87,15 @@
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import CustomSupportedOptimizers, RLType

try:
import torch._dynamo # pylint: disable=ungrouped-imports
except ImportError:
pass

LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)


class TrainerBuilderBase(abc.ABC):
Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/core/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from __future__ import annotations

import logging
import os
from collections import defaultdict
from functools import wraps
Expand Down Expand Up @@ -34,9 +33,10 @@
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
)
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths

LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)


class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/core/trainers/grpo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import importlib
import inspect
import logging
from typing import Any

from trl.trainer.grpo_trainer import RewardFunc
Expand All @@ -13,9 +12,10 @@
AxolotlGRPOTrainer,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.trl import TRLConfig

LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)


class GRPOStrategy:
Expand Down
5 changes: 2 additions & 3 deletions src/axolotl/core/trainers/mixins/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
"""Module for Axolotl trainer optimizer mixin"""

import logging

from peft.optimizers import create_loraplus_optimizer
from torch import nn
from transformers.trainer import Trainer
from transformers.utils import is_sagemaker_mp_enabled

from axolotl.integrations.base import BaseOptimizerFactory
from axolotl.utils.logging import get_logger

if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp

LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)


class OptimizerMixin(Trainer):
Expand Down
5 changes: 3 additions & 2 deletions src/axolotl/core/trainers/mixins/rng_state_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
TODO: Remove when upstream added PR to release
"""

import logging
import os
import random

Expand All @@ -17,7 +16,9 @@
from transformers.trainer_pt_utils import set_rng_state_for_device
from transformers.training_args import ParallelMode

LOG = logging.getLogger(__name__)
from axolotl.utils.logging import get_logger

LOG = get_logger(__name__)


class RngLoaderMixin(Trainer):
Expand Down
Loading
Loading