Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/axolotl/loaders/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Model loader class implementation for loading, configuring, and patching various
models.
"""
Model loader class implementation for loading, configuring, and patching various models.
"""

import gc
Expand Down
12 changes: 12 additions & 0 deletions src/axolotl/loaders/patch_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,20 @@ def _apply_transformers_patches(self):
from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import (
patch_prepare_from_posids,
)
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
patch_evaluation_loop,
patch_maybe_log_save_evaluate,
)

patch_fsdp2 = (
self.cfg.torch_compile
and self.cfg.fsdp_config
and self.cfg.fsdp_version == 2
)

patch_prepare_from_posids()
patch_evaluation_loop(patch_fsdp2)
patch_maybe_log_save_evaluate()

def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
Expand Down
78 changes: 0 additions & 78 deletions src/axolotl/monkeypatch/trainer_eval_guard.py

This file was deleted.

165 changes: 165 additions & 0 deletions src/axolotl/monkeypatch/transformers/trainer_loss_calc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""
Module for patching transformers Trainer loss calculation to use nanmean.

This is needed for context parallelism since chunks of the input sequences may be fully
masked and return NaNs in the loss calculation.

Also includes a patch for FSDP2 + torch.compile. We need to bundle this together with
the other evaluation_loop patch because we can't patch the same code twice without
raising an OSError.
"""

import importlib
import inspect

from transformers import Trainer

from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger

LOG = get_logger(__name__)

ORIGINAL_EVAL_CODE = {
"list": 'metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()',
"array": 'metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()',
}
PATCHED_EVAL_CODE = {
"list": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(np.concatenate(all_losses)).item()',
"array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()',
}

ORIGINAL_FSDP2_CODE = """
model.eval()
"""

PATCHED_FSDP2_CODE = """
if hasattr(model, "eval") and callable(model.eval):
self.model.eval()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be model.eval() since that's what you checked?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just copying @NanoCode012's patch here... but it looks like you're right. I'll test

"""

ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()"
PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()"


def check_evaluation_loop_is_patchable() -> bool:
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
return all(value in evaluation_loop_source for value in ORIGINAL_EVAL_CODE.values())


def check_evaluation_loop_is_fsdp2_patchable() -> bool:
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
evaluation_loop_source, _ = detab_code(evaluation_loop_source)
return ORIGINAL_FSDP2_CODE in evaluation_loop_source


# pylint: disable=protected-access
def patch_evaluation_loop(patch_fsdp2: bool):
"""Patch the evaluation_loop method."""
# Check if already patched
if hasattr(Trainer, "_original_evaluation_loop"):
LOG.info("Trainer.evaluation_loop already patched")
return

# Check if the patterns exist
try:
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
except OSError:
return
Trainer.evaluation = evaluation_loop_source
evaluation_loop_source, _ = detab_code(evaluation_loop_source)

# Apply the nanmean patches
evaluation_loop_source = evaluation_loop_source.replace(
ORIGINAL_EVAL_CODE["list"], PATCHED_EVAL_CODE["list"]
)
evaluation_loop_source = evaluation_loop_source.replace(
ORIGINAL_EVAL_CODE["array"], PATCHED_EVAL_CODE["array"]
)

# Apply FSDP2 eval guard patch if needed
if patch_fsdp2 and ORIGINAL_FSDP2_CODE in evaluation_loop_source:
evaluation_loop_source = evaluation_loop_source.replace(
ORIGINAL_FSDP2_CODE, PATCHED_FSDP2_CODE
)
LOG.info("Applied FSDP2 eval guard patch to evaluation_loop")

# Rename the function to avoid conflicts
evaluation_loop_source = evaluation_loop_source.replace(
"def evaluation_loop(",
"def axolotl_evaluation_loop(",
1,
)

# Get the module for necessary imports
module_name = Trainer.__module__
module = importlib.import_module(module_name)

# Import necessary items from the module
items_to_import = []
for item in dir(module):
if item in evaluation_loop_source:
items_to_import.append(item)

Comment thread
djsaunde marked this conversation as resolved.
# Execute the imports and patched method
exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
exec(evaluation_loop_source, globals()) # pylint: disable=exec-used # nosec B102

LOG.info("Patched Trainer.evaluation_loop with nanmean loss calculation")
Trainer.evaluation_loop = (
axolotl_evaluation_loop # pylint: disable=undefined-variable # noqa: F821
)
Comment thread
djsaunde marked this conversation as resolved.


def check_maybe_log_save_evaluate_is_patchable() -> bool:
maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate)
return ORIGINAL_MAYBE_CODE in maybe_log_source


# pylint: disable=protected-access
def patch_maybe_log_save_evaluate():
"""Patch the _maybe_log_save_evaluate method."""
# Check if already patched
if hasattr(Trainer, "_original_maybe_log_save_evaluate"):
LOG.info("Trainer._maybe_log_save_evaluate already patched")
return

# Check if the patterns exist
try:
maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate)
except OSError:
return
Trainer._original_maybe_log_save_evaluate = maybe_log_source
maybe_log_source, _ = detab_code(maybe_log_source)

# Apply the patch
maybe_log_source = maybe_log_source.replace(ORIGINAL_MAYBE_CODE, PATCHED_MAYBE_CODE)

# Rename the function to avoid conflicts
maybe_log_source = maybe_log_source.replace(
"def _maybe_log_save_evaluate(",
"def axolotl_maybe_log_save_evaluate(",
1,
)

# Get the module for necessary imports
module_name = Trainer.__module__
module = importlib.import_module(module_name)

# Import necessary items from the module
items_to_import = []
for item in dir(module):
if item in maybe_log_source:
items_to_import.append(item)

# Execute the imports and patched method
exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
exec(maybe_log_source, globals()) # pylint: disable=exec-used # nosec B102

LOG.info("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation")
Trainer._maybe_log_save_evaluate = axolotl_maybe_log_save_evaluate # pylint: disable=undefined-variable # noqa: F821
3 changes: 0 additions & 3 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available

from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.logging import get_logger
Expand Down Expand Up @@ -667,8 +666,6 @@ def setup_trainer(
"""
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder

if cfg.torch_compile and cfg.fsdp_config and cfg.fsdp_version == 2:
patch_evaluation_loop_for_fsdp2()
if cfg.rl:
trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
trainer_builder.model_ref = model_ref
Expand Down
28 changes: 28 additions & 0 deletions tests/monkeypatch/test_trainer_loss_calc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Unit tests for trainer loss calc monkeypatch."""

import unittest

from axolotl.monkeypatch.transformers.trainer_loss_calc import (
check_evaluation_loop_is_fsdp2_patchable,
check_evaluation_loop_is_patchable,
check_maybe_log_save_evaluate_is_patchable,
)


class TestTrainerLossCalc(unittest.TestCase):
"""
Unit test class for trainer loss calc monkeypatch
"""

def test_trainer_loss_calc_is_patchable(self):
"""
Test that the upstream transformers code is still patchable. This will fail if
the patched code changes upstream.
"""
assert check_evaluation_loop_is_patchable()
assert check_evaluation_loop_is_fsdp2_patchable()
assert check_maybe_log_save_evaluate_is_patchable()


if __name__ == "__main__":
unittest.main()