Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.13.2
transformers==4.46.1
transformers==4.46.2
tokenizers>=0.20.1
bitsandbytes==0.44.1
accelerate==1.1.0
Expand Down
83 changes: 83 additions & 0 deletions src/axolotl/monkeypatch/trainer_fsdp_grad_accum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
fix for FSDP gradient accumulation
see https://github.com/huggingface/transformers/pull/34645
"""
import inspect

from accelerate.logging import get_logger
from transformers.trainer import Trainer

from axolotl.monkeypatch.unsloth_ import detab_code

LOG = get_logger("axolotl.monkeypatch.trainer_fsdp_grad_accumulation")

ORIGINAL_CONTEXT_CODE = """
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i == len(batch_samples) - 1
else contextlib.nullcontext
)
"""

PATCHED_CONTEXT_CODE = """
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1
else contextlib.nullcontext
)
"""


def get_training_loop_code() -> str:
training_loop = inspect.getsource(
Trainer._inner_training_loop # pylint: disable=protected-access
)
return training_loop


def check_training_loop_is_patchable() -> bool:
train_loop = get_training_loop_code()
train_loop, _ = detab_code(train_loop)
return ORIGINAL_CONTEXT_CODE in train_loop


def patch_training_loop_for_fsdp_grad_accum():
"""
monkeypatch for fixing the training loop for FSDP gradient accumulation
"""

train_loop = get_training_loop_code()
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
train_loop
)
train_loop, _ = detab_code(train_loop)
assert (
ORIGINAL_CONTEXT_CODE in train_loop
), "Original _inner_training_loop code not found"

train_loop = train_loop.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
train_loop = train_loop.replace(
"def _inner_training_loop(",
"def _fixed_inner_training_loop(",
1,
)

# load imports necessary
import transformers.trainer

items_to_import = []
for item in dir(transformers.trainer):
if item in train_loop:
items_to_import.append(item)

exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(train_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop", main_process_only=True)
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)
8 changes: 8 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from transformers.utils import is_torch_bf16_gpu_available

from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.monkeypatch.trainer_fsdp_grad_accum import (
patch_training_loop_for_fsdp_grad_accum,
)
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
Expand Down Expand Up @@ -493,6 +496,11 @@ def prepare_opinionated_env(cfg):
def setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
):
if cfg.fsdp:
try:
patch_training_loop_for_fsdp_grad_accum()
except AssertionError:
pass
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
trainer_builder.model_ref = model[1]
Expand Down
15 changes: 15 additions & 0 deletions tests/e2e/patched/test_trainer_fsdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
import unittest

from axolotl.monkeypatch.trainer_fsdp_grad_accum import check_training_loop_is_patchable


class TestTrainerFSDPIntegration(unittest.TestCase):
"""Unsloth monkeypatch integration tests."""

def test_train_loop_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_training_loop_is_patchable(),
"HF transformers _inner_training_loop has changed and isn't patchable",
)