Skip to content

use nanmean for loss aggregation (CP fix)#3033

Merged
winglian merged 7 commits into
mainfrom
cp-debug
Aug 8, 2025
Merged

use nanmean for loss aggregation (CP fix)#3033
winglian merged 7 commits into
mainfrom
cp-debug

Conversation

@djsaunde
Copy link
Copy Markdown
Collaborator

@djsaunde djsaunde commented Aug 7, 2025

Description

Title. If a sub-sequence (via context parallel) is fully masked, we get NaN loss on it.

Motivation and Context

Closes #3026.

Ideally this would just go upstream, but it's not needed there since HF trainer doesn't support context parallelism.

How has this been tested?

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

Summary by CodeRabbit

  • Bug Fixes
    • Improved loss calculation in model evaluation to handle fully masked input sequences without producing errors.
  • Tests
    • Added tests to verify correct application of loss calculation patches during model training and evaluation.
  • Style
    • Updated documentation formatting for improved clarity.

@djsaunde djsaunde self-assigned this Aug 7, 2025
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Aug 7, 2025

Warning

Rate limit exceeded

@djsaunde has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 4 minutes and 47 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between 2fac201 and 4bef02d.

📒 Files selected for processing (2)
  • src/axolotl/monkeypatch/trainer_eval_guard.py (0 hunks)
  • src/axolotl/monkeypatch/transformers/trainer_loss_calc.py (1 hunks)
📝 Walkthrough

Walkthrough

This change introduces two new monkeypatches for the transformers.Trainer class to use nanmean instead of mean in loss calculations, addressing NaN issues with fully masked sequence chunks. Supporting unit tests are added, and minor docstring formatting is corrected elsewhere. No public API or exported entity signatures are altered.

Changes

Cohort / File(s) Change Summary
Docstring Formatting
src/axolotl/loaders/model.py
Reformatted the module-level docstring for clarity, changing it from a split, hyphenated multiline string to a single, unhyphenated line. No functional or logic changes.
PatchManager Updates
src/axolotl/loaders/patch_manager.py
Modified the _apply_transformers_patches method to import and invoke two new patch functions (patch_evaluation_loop, patch_maybe_log_save_evaluate) from the new patch module, applying them after the existing patch. Introduced a new boolean patch_fsdp2 flag based on config to conditionally pass to patch_evaluation_loop. No changes to method signatures or error handling.
Trainer Loss Calculation Patches
src/axolotl/monkeypatch/transformers/trainer_loss_calc.py
Added a new module that monkeypatches transformers.Trainer's evaluation_loop and _maybe_log_save_evaluate methods to replace mean() with np.nanmean() in loss aggregation, preventing NaN propagation from fully masked input chunks. Uses dynamic source inspection and string replacement, includes an FSDP2 + torch.compile compatibility patch guard, and logs patch application status. No new exportable entities introduced.
Unit Tests for Trainer Loss Patches
tests/monkeypatch/test_trainer_loss_calc.py
Added a new test module with a class and test method that verify patchability of the patched methods on Trainer. The test asserts that patch checks pass, ensuring compatibility with upstream code and preventing silent breakage of the monkeypatches.
Trainer Utility Cleanup
src/axolotl/utils/trainer.py
Removed import and conditional invocation of patch_evaluation_loop_for_fsdp2 in setup_trainer function, cleaning up unused patch-related code. No other logic or control flow changes.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~15–20 minutes

Assessment against linked issues

Objective Addressed Explanation
Fix evaluation returning eval_loss = nan when using context parallel (Issue #3026) The patch replaces mean with nanmean in loss calculations in evaluation methods to prevent NaNs.

Assessment against linked issues: Out-of-scope changes

No out-of-scope changes detected.

Possibly related PRs

  • models.py -> loaders/ module refactor #2680: Refactors the monolithic axolotl.utils.models module into smaller loader submodules including model.py. Related because this PR also modifies src/axolotl/loaders/model.py docstring formatting within the new loaders structure.

Suggested reviewers

  • winglian
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch cp-debug

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🔭 Outside diff range comments (1)
src/axolotl/monkeypatch/transformers/trainer_loss_calc.py (1)

1-144: Add validation tests for the actual behavior change.

While the unit tests verify that patches are applied, they don't test the actual behavior change (handling NaN values correctly).

Consider adding integration tests that verify NaN handling:

def test_evaluation_handles_nan():
    """Test that patched evaluation_loop handles NaN values correctly."""
    import numpy as np
    import torch
    from unittest.mock import MagicMock, patch
    
    # Apply the patch
    patch_evaluation_loop()
    
    # Create a mock trainer with NaN losses
    trainer = Trainer(model=MagicMock(), args=MagicMock())
    
    # Simulate losses with NaN values
    with patch.object(trainer, 'prediction_loop') as mock_pred:
        mock_pred.return_value = (
            None,  # predictions
            None,  # label_ids  
            {'loss': torch.tensor([1.0, float('nan'), 2.0])}  # metrics with NaN
        )
        
        metrics = trainer.evaluation_loop(
            dataloader=MagicMock(),
            description="test",
            metric_key_prefix="eval"
        )
        
        # Should handle NaN and return 1.5 (mean of 1.0 and 2.0)
        assert abs(metrics['eval_loss'] - 1.5) < 1e-6
🧹 Nitpick comments (1)
tests/monkeypatch/test_trainer_loss_calc.py (1)

13-39: Consider adding cleanup in tearDown to restore original state.

While the tests clean up before running, they don't restore the original state after completion. This could affect other tests if they depend on an unpatched Trainer class.

Consider adding a tearDown method to restore the original state:

 class TestTrainerLossCalc(unittest.TestCase):
     """
     Unit test class for trainer loss calc monkeypatch
     """
+
+    def tearDown(self):
+        """Clean up patches after each test."""
+        if hasattr(Trainer, "_original_evaluation_loop"):
+            delattr(Trainer, "_original_evaluation_loop")
+        if hasattr(Trainer, "_original_maybe_log_save_evaluate"):
+            delattr(Trainer, "_original_maybe_log_save_evaluate")
 
     def test_patch_evaluation_loop_applies(self):
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ca796fb and 8cb87ea.

📒 Files selected for processing (4)
  • src/axolotl/loaders/model.py (1 hunks)
  • src/axolotl/loaders/patch_manager.py (1 hunks)
  • src/axolotl/monkeypatch/transformers/trainer_loss_calc.py (1 hunks)
  • tests/monkeypatch/test_trainer_loss_calc.py (1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
tests/monkeypatch/test_trainer_loss_calc.py (1)
src/axolotl/monkeypatch/transformers/trainer_loss_calc.py (2)
  • patch_evaluation_loop (20-88)
  • patch_maybe_log_save_evaluate (92-143)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: pre-commit
🔇 Additional comments (2)
src/axolotl/loaders/model.py (1)

1-2: LGTM! Docstring formatting improvement.

Good fix to consolidate the split text into a single line for better readability.

src/axolotl/loaders/patch_manager.py (1)

79-86: LGTM! Proper integration of trainer loss calculation patches.

The new patches are correctly imported and invoked in the appropriate sequence within the transformers patches section.

Comment thread src/axolotl/monkeypatch/transformers/trainer_loss_calc.py Outdated
Comment thread src/axolotl/monkeypatch/transformers/trainer_loss_calc.py
Comment thread src/axolotl/monkeypatch/transformers/trainer_loss_calc.py
Comment thread src/axolotl/monkeypatch/transformers/trainer_loss_calc.py Outdated
@codecov
Copy link
Copy Markdown

codecov Bot commented Aug 7, 2025

Codecov Report

❌ Patch coverage is 91.89189% with 6 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...lotl/monkeypatch/transformers/trainer_loss_calc.py 91.42% 6 Missing ⚠️

📢 Thoughts on this report? Let us know!


PATCHED_FSDP2_CODE = """
if hasattr(model, "eval") and callable(model.eval):
self.model.eval()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be model.eval() since that's what you checked?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just copying @NanoCode012's patch here... but it looks like you're right. I'll test

Copy link
Copy Markdown
Collaborator

@winglian winglian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! Let's open some upstream PRs for these changes too.

@winglian winglian merged commit 0ae06d7 into main Aug 8, 2025
15 checks passed
@winglian winglian deleted the cp-debug branch August 8, 2025 12:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Evaluation returns eval_loss = nan when using context parallel

2 participants