Skip to content

models.py -> loaders/ module refactor#2680

Merged
djsaunde merged 21 commits into
mainfrom
model-load-refactor
May 23, 2025
Merged

models.py -> loaders/ module refactor#2680
djsaunde merged 21 commits into
mainfrom
model-load-refactor

Conversation

@djsaunde
Copy link
Copy Markdown
Collaborator

@djsaunde djsaunde commented May 15, 2025

Description

Title. Breaking out monster axolotl.utils.models file into various submodules (also, moved up a level in the module hierarchy).

Motivation and Context

File / ModelLoader class at present is super large and difficult to read / maintain.

How has this been tested?

TODO

Plan:

  • Get tests passing
  • Additional manual testing

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

  • New Features

    • Introduced a modular loaders package for model, tokenizer, processor, and adapter loading with enhanced flexibility.
    • Added support for advanced adapters including LoRA, QLoRA, and llama-adapter with quantization-aware training.
    • Implemented improved validation to prevent incompatible configurations like sample packing with shifted-sparse attention.
  • Refactor

    • Replaced legacy monolithic model utilities with the new ModelLoader class and related loader modules.
    • Updated imports and test code to use the new loaders package for model and tokenizer handling.
  • Documentation

    • Enhanced type annotations and docstrings in plugin and loader classes for better clarity and maintainability.
  • Bug Fixes

    • Enforced configuration validation to disallow simultaneous use of sample packing and shifted-sparse attention.
  • Chores

    • Removed deprecated utility modules and consolidated model loading logic into the loaders package.

@djsaunde djsaunde self-assigned this May 15, 2025
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 15, 2025

Walkthrough

This change refactors the model and tokenizer loading system by moving related logic from axolotl.utils.models into a new axolotl.loaders package. The new package modularizes loader utilities, model configuration, patch management, adapters, and tokenizers. It introduces a ModelLoader class, splits out adapter and processor logic, and adds extensive type annotations and validation. Test code and imports are updated to use the new structure. The original monolithic loader file is deleted.

Changes

File(s) Change Summary
src/axolotl/loaders/adapter.py, src/axolotl/loaders/constants.py, src/axolotl/loaders/model.py, src/axolotl/loaders/patch_manager.py, src/axolotl/loaders/processor.py, src/axolotl/loaders/tokenizer.py, src/axolotl/loaders/utils.py, src/axolotl/loaders/__init__.py New loader package: Adds modules for adapters, constants, model loading, patch management, processor loading, tokenizer handling, and utility functions. Exposes key loader functions and classes at the package level.
src/axolotl/utils/models.py, src/axolotl/utils/lora_embeddings.py Deleted: Removes the previous monolithic model/tokenizer loader and LoRA embedding utility. All relevant logic is now in the axolotl.loaders package.
src/axolotl/cli/utils.py, src/axolotl/common/datasets.py, src/axolotl/core/trainer_builder.py, src/axolotl/train.py, src/axolotl/utils/config/__init__.py, src/axolotl/utils/data/rl.py, src/axolotl/monkeypatch/peft/utils.py Refactored imports: Update imports to use the new loader package and its modules. Replace direct function calls with the new ModelLoader class where appropriate.
src/axolotl/integrations/base.py Adds type annotations and improves docstrings for plugin classes and methods. Updates method signatures for clarity and type safety.
src/axolotl/utils/schemas/config.py Adds a validator to prevent enabling both sample packing and shifted-sparse attention simultaneously.
tests/core/test_trainer_builder.py, tests/e2e/patched/test_model_patches.py, tests/e2e/test_load_model.py, tests/test_lora.py, tests/test_tokenizers.py, tests/test_exact_deduplication.py, tests/patched/test_validation.py, tests/test_loaders.py Updates test imports and usages to reference new loader modules and classes. Removes or adds tests for configuration validation as needed. Adjusts test method calls to match refactored loader APIs.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant ModelLoader
    participant Tokenizer
    participant Processor
    participant Adapter

    User->>ModelLoader: instantiate(cfg, tokenizer, ...)
    User->>ModelLoader: load()
    ModelLoader->>Tokenizer: load_tokenizer(cfg)
    ModelLoader->>Processor: load_processor(cfg, tokenizer)
    ModelLoader->>ModelLoader: _build_model()
    ModelLoader->>Adapter: load_adapter(model, cfg, ...)
    ModelLoader-->>User: returns (model, peft_config)
Loading

Poem

In the warren where loaders hop anew,
Models, tokenizers, and patches too,
Split from the old, now modular delight,
Each bunny module fits just right.
With type hints and checks, we leap ahead,
The loader’s garden, freshly spread!
🐇✨
✨ Finishing Touches
  • 📝 Generate Docstrings

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 16, 2025

Comment thread src/axolotl/cli/utils.py
Comment thread src/axolotl/integrations/base.py
Comment thread src/axolotl/loaders/model.py
Comment thread src/axolotl/utils/lora_embeddings.py
@djsaunde djsaunde marked this pull request as ready for review May 16, 2025 16:44
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

🧹 Nitpick comments (9)
src/axolotl/loaders/tokenizer.py (1)

168-170: Environment variable toggling happens on every call

os.environ["TOKENIZERS_PARALLELISM"] = "false" is executed whenever a
GPTNeoXTokenizerFast is encountered—even inside worker processes. Move the
side-effect into the branch guarded by is_main_process() or guard with
os.environ.get to avoid redundant writes.

src/axolotl/loaders/utils.py (1)

190-208: Prefer contextlib.suppress + logging over bare try/except + print

ensure_dtype currently:

  1. Silences all AttributeErrors, hampering debuggability.
  2. Uses print, mixing stdout into logging streams.
-import torch
+import torch
+import contextlib
...
-        try:
-            weight_mismatch = module.weight.dtype != dtype
-        except AttributeError:
-            pass
+        with contextlib.suppress(AttributeError):
+            weight_mismatch = module.weight.dtype != dtype
...
-        if weight_mismatch:
-            print(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}")
+        if weight_mismatch:
+            LOG.debug(
+                "Converting %s.weight from %s to %s", name, module.weight.dtype, dtype
+            )

This keeps the log pipeline consistent and removes repetitive boiler-plate.

🧰 Tools
🪛 Ruff (0.11.9)

193-196: Use contextlib.suppress(AttributeError) instead of try-except-pass

Replace with contextlib.suppress(AttributeError)

(SIM105)


197-200: Use contextlib.suppress(AttributeError) instead of try-except-pass

Replace with contextlib.suppress(AttributeError)

(SIM105)

src/axolotl/loaders/patch_manager.py (1)

128-136: Minor: collapse nested if for readability

Ruff’s SIM102 – can be simplified:

-        if self.cfg.model_config_type == "llama4":
-            if self.cfg.llama4_linearized_experts:
+        if (
+            self.cfg.model_config_type == "llama4"
+            and self.cfg.llama4_linearized_experts
+        ):

No functional change, just reduces indentation.

🧰 Tools
🪛 Ruff (0.11.9)

128-129: Use a single if statement instead of nested if statements

(SIM102)

src/axolotl/loaders/adapter.py (2)

55-60: Operator precedence makes intent opaque

or precedes the long and chain, but reading order is non-obvious.
Add parentheses for clarity; avoids accidental logic regressions if the line is edited later.

-            isinstance(module, cls)
-            or "Linear" in module.__class__.__name__
-            and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
+            isinstance(module, cls)
+            or (
+                "Linear" in module.__class__.__name__
+                and module.__class__.__name__ != "LlamaLinearScalingRotaryEmbedding"
+            )

41-49: Potential memory leak: _orig_to never deleted

After restoring the original .to, you assign None but keep the attribute, preventing GC of the function object. Prefer delattr(param.quant_state, "_orig_to").

src/axolotl/loaders/model.py (2)

206-210: Prefer direct inequality over negated equality

Ruff correctly points out that the expression

not (self.model_config.model_type == "llava")

is harder to read than simply checking inequality:

-            if self.cfg.mean_resizing_embeddings is not None and not (
-                self.model_config.model_type == "llava"
-            ):
+            if self.cfg.mean_resizing_embeddings is not None and (
+                self.model_config.model_type != "llava"
+            ):

Using != conveys intent directly and avoids the double-negative.

🧰 Tools
🪛 Ruff (0.11.9)

206-208: Use self.model_config.model_type != "llava" instead of not self.model_config.model_type == "llava"

Replace with != operator

(SIM201)


339-341: Replace setattr with direct attribute assignment

setattr(self.model, "is_parallelizable", True) and the next line trigger Ruff B010 because the attribute names are hard-coded. A direct assignment is simpler and safer:

-            setattr(self.model, "is_parallelizable", True)
-            setattr(self.model, "model_parallel", True)
+            self.model.is_parallelizable = True
+            self.model.model_parallel = True

No dynamic attribute name is used, so there is no advantage in setattr.

🧰 Tools
🪛 Ruff (0.11.9)

339-339: Do not call setattr with a constant attribute value. It is not any safer than normal property access.

Replace setattr with assignment

(B010)


340-340: Do not call setattr with a constant attribute value. It is not any safer than normal property access.

Replace setattr with assignment

(B010)

src/axolotl/integrations/base.py (2)

356-362: Return-type mismatch between annotation and docstring

get_input_args is annotated to return list[str], yet the docstring says it returns “a list of Pydantic classes”. If Pydantic models are intended, update the type hint:

-    def get_input_args(self) -> list[str]:
+    from pydantic import BaseModel
+    def get_input_args(self) -> list[type[BaseModel]]:

Keeping annotations and documentation in sync improves IDE support and static analysis accuracy.


348-355: Swallowing ImportError hides critical setup failures

PluginManager.register() logs the failure but does not propagate it, so the application may proceed without an expected plugin and fail later in obscure ways.

Consider re-raising the exception (or returning a boolean) so callers can decide whether a missing plugin is fatal:

        except ImportError as exc:
            logging.error(f"Failed to load plugin: {plugin_name}")
-            pass
+            raise

Fail-fast behaviour makes configuration errors obvious in CI and production.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting

📥 Commits

Reviewing files that changed from the base of the PR and between c0a0c75 and 8c3e980.

📒 Files selected for processing (27)
  • src/axolotl/cli/utils.py (2 hunks)
  • src/axolotl/common/datasets.py (1 hunks)
  • src/axolotl/core/trainer_builder.py (1 hunks)
  • src/axolotl/integrations/base.py (14 hunks)
  • src/axolotl/loaders/__init__.py (1 hunks)
  • src/axolotl/loaders/adapter.py (1 hunks)
  • src/axolotl/loaders/constants.py (1 hunks)
  • src/axolotl/loaders/model.py (1 hunks)
  • src/axolotl/loaders/patch_manager.py (1 hunks)
  • src/axolotl/loaders/processor.py (1 hunks)
  • src/axolotl/loaders/tokenizer.py (1 hunks)
  • src/axolotl/loaders/utils.py (1 hunks)
  • src/axolotl/monkeypatch/peft/utils.py (1 hunks)
  • src/axolotl/train.py (3 hunks)
  • src/axolotl/utils/config/__init__.py (1 hunks)
  • src/axolotl/utils/data/rl.py (1 hunks)
  • src/axolotl/utils/lora_embeddings.py (0 hunks)
  • src/axolotl/utils/models.py (0 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
  • tests/core/test_trainer_builder.py (2 hunks)
  • tests/e2e/patched/test_model_patches.py (3 hunks)
  • tests/e2e/test_load_model.py (3 hunks)
  • tests/patched/test_validation.py (2 hunks)
  • tests/test_exact_deduplication.py (6 hunks)
  • tests/test_loaders.py (3 hunks)
  • tests/test_lora.py (3 hunks)
  • tests/test_tokenizers.py (1 hunks)
💤 Files with no reviewable changes (2)
  • src/axolotl/utils/lora_embeddings.py
  • src/axolotl/utils/models.py
🧰 Additional context used
🧬 Code Graph Analysis (9)
tests/test_tokenizers.py (1)
src/axolotl/loaders/tokenizer.py (1)
  • load_tokenizer (120-281)
src/axolotl/utils/data/rl.py (1)
src/axolotl/loaders/tokenizer.py (1)
  • load_tokenizer (120-281)
src/axolotl/utils/config/__init__.py (1)
src/axolotl/loaders/utils.py (1)
  • load_model_config (135-183)
tests/test_loaders.py (2)
src/axolotl/loaders/model.py (2)
  • _set_device_map_config (370-426)
  • _set_quantization_config (428-503)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
src/axolotl/core/trainer_builder.py (1)
src/axolotl/loaders/utils.py (1)
  • ensure_dtype (186-207)
tests/core/test_trainer_builder.py (1)
src/axolotl/loaders/model.py (2)
  • ModelLoader (57-781)
  • load (138-166)
tests/e2e/patched/test_model_patches.py (2)
src/axolotl/loaders/model.py (2)
  • ModelLoader (57-781)
  • load (138-166)
src/axolotl/loaders/tokenizer.py (1)
  • load_tokenizer (120-281)
src/axolotl/loaders/utils.py (4)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
tests/test_exact_deduplication.py (1)
  • cfg (221-235)
src/axolotl/models/mamba/modeling_mamba.py (1)
  • from_pretrained (122-128)
tests/test_perplexity.py (1)
  • model (20-21)
src/axolotl/loaders/adapter.py (3)
src/axolotl/utils/schemas/peft.py (3)
  • LoftQConfig (8-14)
  • LoraConfig (23-122)
  • PeftConfig (17-20)
src/axolotl/loaders/utils.py (1)
  • get_linear_embedding_layers (210-216)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
🪛 Ruff (0.11.9)
src/axolotl/loaders/utils.py

193-196: Use contextlib.suppress(AttributeError) instead of try-except-pass

Replace with contextlib.suppress(AttributeError)

(SIM105)


197-200: Use contextlib.suppress(AttributeError) instead of try-except-pass

Replace with contextlib.suppress(AttributeError)

(SIM105)

src/axolotl/loaders/patch_manager.py

128-129: Use a single if statement instead of nested if statements

(SIM102)

src/axolotl/loaders/model.py

206-208: Use self.model_config.model_type != "llava" instead of not self.model_config.model_type == "llava"

Replace with != operator

(SIM201)


339-339: Do not call setattr with a constant attribute value. It is not any safer than normal property access.

Replace setattr with assignment

(B010)


340-340: Do not call setattr with a constant attribute value. It is not any safer than normal property access.

Replace setattr with assignment

(B010)


570-572: Use False instead of ... and False

Replace with False

(SIM223)


593-593: Do not call getattr with a constant attribute value. It is not any safer than normal property access.

Replace getattr with attribute access

(B009)


779-780: Use a single if statement instead of nested if statements

(SIM102)

⏰ Context from checks skipped due to timeout of 90000ms (6)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: PyTest (3.11, 2.5.1)
🔇 Additional comments (44)
src/axolotl/core/trainer_builder.py (1)

62-62: Correct import of ensure_dtype from refactored axolotl.loaders.utils
Updating the import location for ensure_dtype aligns with the new modular loader utilities and preserves existing functionality.

src/axolotl/utils/data/rl.py (1)

13-13: Update import for load_tokenizer to new loaders package
This assumes that load_tokenizer is re-exported at the root of axolotl.loaders. Please verify that axolotl/loaders/__init__.py includes load_tokenizer; otherwise, adjust the import to from axolotl.loaders.tokenizer import load_tokenizer.

src/axolotl/common/datasets.py (1)

13-13: Update import for load_processor and load_tokenizer to new loaders package
Ensure that both load_processor and load_tokenizer are exposed in axolotl.loaders/__init__.py; if not, import directly from their modules (e.g., axolotl.loaders.processor, axolotl.loaders.tokenizer).

src/axolotl/utils/config/__init__.py (2)

14-14: Import of MULTIMODAL_AUTO_MODEL_MAPPING: verify correct export location
The constant was relocated to axolotl.loaders.constants. Confirm it’s re-exported in axolotl.loaders/__init__.py; otherwise, import it explicitly from axolotl.loaders.constants.


15-15: Correct import of load_model_config from axolotl.loaders.utils
This aligns with the refactoring of model configuration utilities into the new loaders package.

tests/test_tokenizers.py (1)

9-9: Update test import for load_tokenizer to new loaders package
The test now imports load_tokenizer from axolotl.loaders. Verify that this function is available at the package root, or update the import to from axolotl.loaders.tokenizer import load_tokenizer if needed.

tests/test_lora.py (3)

5-5: Import updated to reflect the new module structure.

The import statement has been updated to use the new axolotl.loaders package instead of axolotl.utils.models, aligning with the refactoring goal of breaking down the large models.py file into smaller submodules.


49-49: Model loading approach updated to use the new class-based interface.

The model loading has been updated from directly calling load_model(cfg, tokenizer) to instantiating the ModelLoader class and calling its load() method. This change is consistent with the refactoring goals and maintains the same functionality.


70-70: Model loading approach consistently updated in both test methods.

The second test method also follows the same pattern of using ModelLoader(cfg, tokenizer).load(), ensuring consistency across the test file.

src/axolotl/monkeypatch/peft/utils.py (1)

78-78: Monkeypatch target path updated to reflect the restructured modules.

The path for the patched function has been updated from axolotl.utils.models.prepare_model_for_kbit_training to axolotl.loaders.model.prepare_model_for_kbit_training, correctly aligning with the new module structure where model loading functionality has been moved to the axolotl.loaders package.

tests/core/test_trainer_builder.py (3)

1-1: Module docstring reformatted to a single line.

The docstring has been simplified to a single-line format, maintaining the same information while being more concise.


6-6: Import updated to reflect the new module structure.

The import statement has been updated to import ModelLoader and load_tokenizer from the new axolotl.loaders package instead of axolotl.utils.models, which aligns with the refactoring goal.


50-50: Model loading approach updated to use the new class-based interface.

The fixture_model function has been updated to use the new ModelLoader class's load() method instead of the previous direct call to load_model. This is consistent with the changes in other test files and maintains the same functionality.

tests/patched/test_validation.py (2)

12-12: Import path updated to reflect the restructured modules.

The import of check_model_config has been updated to use the new axolotl.loaders.utils module instead of the previous axolotl.utils.models, correctly aligning with the refactoring.


1218-1231: New test case added for validating incompatible configuration options.

This new test ensures that the system correctly raises a validation error when both s2_attention and sample_packing are set to True, which represents an incompatible configuration. The test is well-structured and follows the same pattern as other validation tests in this file.

src/axolotl/train.py (3)

31-35: LGTM - Updated imports align with the module refactoring.

The imports have been updated to use the new axolotl.loaders package instead of axolotl.utils.models, which reflects the code organization refactoring described in the PR objectives.


83-84: Good refactoring to use the new ModelLoader class.

The code now uses the ModelLoader class to encapsulate model loading logic instead of directly calling a function. This is a cleaner approach that better follows object-oriented principles.


121-122: LGTM - Consistent usage of ModelLoader for reference model.

The refactoring of reference model loading follows the same pattern as the main model loading, maintaining consistency across the codebase.

src/axolotl/cli/utils.py (2)

23-24: LGTM - Updated imports align with module refactoring.

The imports have been updated to use the new modular structure, with load_processor and load_tokenizer from axolotl.loaders and ModelLoader from axolotl.loaders.model.


322-323: Refactoring simplified the loading pattern.

As noted in a previous review comment, this change removes the redundant load_model function which was essentially a pass-through to ModelLoader(...).load(). The direct instantiation of ModelLoader is cleaner.

src/axolotl/utils/schemas/config.py (1)

473-482: Good addition of validation for incompatible features.

This validator prevents the simultaneous use of sample packing and shifted-sparse attention, which are incompatible features. Early validation of configuration errors provides better user experience by failing fast with a clear error message.

tests/e2e/patched/test_model_patches.py (3)

9-9: LGTM - Updated imports to use the new module structure.

The import statement has been updated to use the refactored module structure, importing ModelLoader and load_tokenizer from axolotl.loaders.


53-53: LGTM - Test updated to use ModelLoader pattern.

The test has been updated to use the new ModelLoader class directly instead of the removed load_model function, consistent with the refactoring throughout the codebase.


86-86: LGTM - Consistent refactoring in tests.

The second test method has also been updated to use the ModelLoader class, maintaining consistency throughout the test suite.

tests/test_exact_deduplication.py (3)

1-5: Docstring formatting has been improved.

The module docstring now uses backticks to properly format module and function names, improving readability and consistency with Python documentation standards.


15-15: Import path updated to reflect the refactored module structure.

The import path for load_processor and load_tokenizer has been correctly updated to use the new axolotl.loaders package instead of the old axolotl.utils.models module, aligning with the refactoring goal.


249-249: Patch decorators updated to use new module path.

The patch paths have been correctly updated to reference axolotl.loaders.load_tokenizer instead of the old path, ensuring that mocking works correctly after the refactoring.

Also applies to: 273-273

src/axolotl/loaders/__init__.py (1)

1-11: Well-structured package initialization.

The __init__.py file is well organized, exposing key components from the submodules at the package level. The use of linter disabling directives is appropriate here to allow for unused imports, which is a common practice for package initializers.

This structure aligns perfectly with the refactoring goal of breaking down the large axolotl.utils.models file into multiple smaller, more manageable submodules.

tests/e2e/test_load_model.py (3)

9-9: Import path updated to use the new module structure.

The import has been correctly updated to use ModelLoader and load_tokenizer from the new axolotl.loaders package.


61-63: Added explicit parameters for ModelLoader instantiation.

The instantiation of ModelLoader now includes explicit parameters inference=False and reference_model=True, which makes the intent clearer and aligns with the new class structure.


76-78: Updated to use the ModelLoader methods.

The test now correctly uses the ModelLoader.load() method and the internal _convert_embedding_modules_dtype method according to the new architecture, replacing the previous standalone function approach.

tests/test_loaders.py (3)

1-1: Updated docstrings to reflect the new module structure.

Both the module and class docstrings have been correctly updated to reference axolotl.loaders as the module under test.

Also applies to: 15-15


10-10: Import path updated to use the new module structure.

The import has been correctly updated to use ModelLoader directly from axolotl.loaders.


53-54: Updated method calls to respect protected method visibility.

The tests now call the protected methods _set_device_map_config() and _set_quantization_config() with appropriate pylint disabling comments, indicating that these methods are now intended for internal use within the ModelLoader class.

This change reflects a more principled approach to method visibility in the refactored code.

Also applies to: 80-81

src/axolotl/loaders/constants.py (1)

13-21: Use typing.Final and lowercase keys for consistency

  1. Annotate the mapping with Final to communicate immutability.
  2. Convert keys to lowercase once (either here or on lookup) so that downstream look-ups do not need to remember the exact case.
-from typing import Final
-
-MULTIMODAL_AUTO_MODEL_MAPPING: Final[dict[str, type]] = {
-    "mllama": MllamaForConditionalGeneration,
-    ...
-}

[ suggest_nitpick ]

src/axolotl/loaders/processor.py (1)

32-43: Image-size extraction can silently mis-set cfg.image_size

When only one dimension is present the function stores an int; when two are
present it stores a tuple[int, int]. Consumers now need to handle a
union. Consider always normalising to a tuple for consistency:

-        if im_width is not None and im_height is not None:
-            cfg.image_size = (im_width, im_height)
-        elif im_width is not None:
-            cfg.image_size = im_width
-        elif im_height is not None:
-            cfg.image_size = im_height
+        if im_width is not None or im_height is not None:
+            cfg.image_size = (im_width or im_height, im_height or im_width)

[ suggest_optional_refactor ]

src/axolotl/loaders/tokenizer.py (3)

87-92: Variable shadowing hides the original tokenizer_path argument

tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json") overwrites the
function parameter of the same name, making subsequent debugging harder.
Rename the local variable to avoid confusion:

-        tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
+        tokenizer_json_path = os.path.join(tokenizer_dir, "tokenizer.json")
 ...
-            with open(tokenizer_path, "r", encoding="utf-8") as f:
+            with open(tokenizer_json_path, "r", encoding="utf-8") as f:

[ suggest_nitpick ]


155-165: Prefer isinstance over string comparisons for tokenizer class checks

Comparing tokenizer.__class__.__name__ is brittle and fails with subclassing.
The transformers library exports the actual classes—use them directly:

-from ... in ["LlamaTokenizer", "LlamaTokenizerFast", ...]
+from transformers import LlamaTokenizer, LlamaTokenizerFast, CodeLlamaTokenizer, CodeLlamaTokenizerFast
+
+if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast, CodeLlamaTokenizer, CodeLlamaTokenizerFast)):
     ...

[ suggest_optional_refactor ]


188-216: ⚠️ Potential issue

LoRA/special-token validation raises on perfectly valid scenarios

The current guard enforces that all linear embedding layers are listed in
lora_modules_to_save just because special tokens change, even when the
user is not doing adapter training (cfg.adapter == False).
Consider restructuring the predicate:

-            if (
-                (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
-                ...
-                and cfg.adapter
+            if (
+                cfg.adapter                                  # only enforce for adapter training
+                and (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
                 ...
             ):

This prevents false-positive ValueErrors when training full-fine-tune models.

Likely an incorrect or invalid review comment.

src/axolotl/loaders/utils.py (1)

62-75: Early exit hides subsequent validation when using compressed-tensors

Returning immediately for compressed-tensors skips the LoRA/token-addition guard (lines 117-132). Users adding new tokens to a compressed-tensors model will not receive the helpful validation error. Consider moving the LoRA check above the return or explicitly re-running it.

src/axolotl/loaders/patch_manager.py (1)

92-99: Side-effect: silently flips flash_attention flag

_apply_flash_attention_patches sets self.cfg.flash_attention = True when xformers_attention & sample_packing are enabled. This mutates user-supplied config, potentially leading to surprising behaviour later in the run (e.g., different kernels selected).

Prefer deriving a local variable or logging a warning so the user knows their config was overridden.

src/axolotl/loaders/adapter.py (2)

89-103: use_dora / use_rslora keys may not exist in upstream PEFT

LoraConfig from PEFT < 0.9 doesn’t accept use_dora or use_rslora. Passing them via **lora_config_kwargs raises TypeError.
Guard by checking inspect.signature(LoraConfig).parameters before injecting.


124-131: ⚠️ Potential issue

setup_quantized_meta_for_peft relies on LOCAL_RANK env

If LOCAL_RANK is unset (common in single-GPU runs), int(None) raises TypeError. Use os.getenv(..., "0") default.

-    rank = int(os.environ.get("LOCAL_RANK", 0))
+    rank = int(os.environ.get("LOCAL_RANK", "0"))

Likely an incorrect or invalid review comment.

src/axolotl/loaders/model.py (1)

656-660: Hard-coding CUDA device breaks non-CUDA environments

self.model_kwargs["device"] = torch.cuda.current_device()

If the process runs on CPU, MPS or NPU, torch.cuda may be unavailable or cuda.is_available() may be False, leading to an exception. Consider deriving the device from get_device_type() or guarding the assignment:

device_type = get_device_type()
if device_type == "cuda":
    self.model_kwargs["device"] = torch.cuda.current_device()

This keeps the path functional on all accelerators.

Comment thread src/axolotl/loaders/constants.py
Comment thread src/axolotl/loaders/processor.py
Comment thread src/axolotl/loaders/processor.py
Comment thread src/axolotl/loaders/tokenizer.py
Comment thread src/axolotl/loaders/utils.py
Comment thread src/axolotl/loaders/patch_manager.py Outdated
Comment thread src/axolotl/loaders/model.py
Comment thread src/axolotl/loaders/model.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
src/axolotl/loaders/model.py (1)

151-157: Consider deprecating post_model_build in favor of pre_lora_load.

As noted in previous comments, using two separate hooks (post_model_build and pre_lora_load) that execute at nearly the same point in the lifecycle creates potential confusion.

-        PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
-
-        # Post-build model configuration
-        self._apply_post_model_load_setup()
-
-        # Load adapters (LoRA, etc.)
-        PLUGIN_MANAGER.pre_lora_load(self.cfg, self.model)
+        # Post-build model configuration before loading adapters
+        self._apply_post_model_load_setup()
+        
+        # Load adapters (LoRA, etc.)
+        PLUGIN_MANAGER.pre_lora_load(self.cfg, self.model)
+        # For backward compatibility
+        PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
🧹 Nitpick comments (8)
src/axolotl/loaders/model.py (8)

206-209: Simplify boolean expression.

Use != instead of not ... == for better readability.

-            if self.cfg.mean_resizing_embeddings is not None and not (
-                self.model_config.model_type == "llava"
-            ):
+            if self.cfg.mean_resizing_embeddings is not None and self.model_config.model_type != "llava":
🧰 Tools
🪛 Ruff (0.11.9)

206-208: Use self.model_config.model_type != "llava" instead of not self.model_config.model_type == "llava"

Replace with != operator

(SIM201)


335-336: Remove or clarify TODO comment.

The TODO comment doesn't provide sufficient context about what needs validation.

-            # TODO: validate this conditional
             self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}")

Consider either removing the comment if the conditional is already validated, or add more specific details about what aspect needs validation.


339-340: Use direct attribute assignment instead of setattr.

setattr with constant attribute names doesn't provide any advantages over direct attribute assignment and makes the code less readable.

-            setattr(self.model, "is_parallelizable", True)
-            setattr(self.model, "model_parallel", True)
+            self.model.is_parallelizable = True
+            self.model.model_parallel = True
🧰 Tools
🪛 Ruff (0.11.9)

339-339: Do not call setattr with a constant attribute value. It is not any safer than normal property access.

Replace setattr with assignment

(B010)


340-340: Do not call setattr with a constant attribute value. It is not any safer than normal property access.

Replace setattr with assignment

(B010)


419-427: Dead code: consider removing commented-out code block.

This commented-out code block related to reference model GPU assignment serves no functional purpose and could confuse maintainers. Either remove it entirely or add a clear comment explaining why it's preserved.


566-703: Complex model building method could benefit from further modularization.

The _build_model method is quite lengthy (137 lines) with multiple conditional branches for different model types and loading strategies. Consider splitting it into smaller, more focused methods to improve maintainability.

For example:

  1. _build_sharded_quantized_model
  2. _build_llama_model
  3. _build_mamba_model
  4. _build_generic_model
🧰 Tools
🪛 Ruff (0.11.9)

580-580: Do not call getattr with a constant attribute value. It is not any safer than normal property access.

Replace getattr with attribute access

(B009)


580-583: Use direct attribute access instead of getattr.

getattr with constant attribute names doesn't provide any advantages over direct attribute access and makes the code less readable.

-            quantization_config = hasattr(
-                self.model_config, "quantization_config"
-            ) and getattr(self.model_config, "quantization_config")
+            quantization_config = hasattr(
+                self.model_config, "quantization_config"
+            ) and self.model_config.quantization_config
🧰 Tools
🪛 Ruff (0.11.9)

580-580: Do not call getattr with a constant attribute value. It is not any safer than normal property access.

Replace getattr with attribute access

(B009)


643-643: Improve variable naming to avoid pylint suppression.

The variable MambaLMHeadModel uses camel case, which violates Python naming conventions and requires a pylint suppression. Consider using snake_case instead:

-            MambaLMHeadModel = fix_mamba_attn_for_loss()  # pylint: disable=invalid-name
+            mamba_lm_head_model = fix_mamba_attn_for_loss()

And update the usage below:

-            self.model = MambaLMHeadModel.from_pretrained(
+            self.model = mamba_lm_head_model.from_pretrained(

766-769: Simplify nested if statements.

The nested if statements can be combined into a single condition for better readability.

-            if any(m in name for m in embedding_modules):
-                if hasattr(module, "weight"):
-                    module.to(dist_dtype)
+            if any(m in name for m in embedding_modules) and hasattr(module, "weight"):
+                module.to(dist_dtype)
🧰 Tools
🪛 Ruff (0.11.9)

766-767: Use a single if statement instead of nested if statements

(SIM102)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting

📥 Commits

Reviewing files that changed from the base of the PR and between 8c3e980 and 0b8440d.

📒 Files selected for processing (1)
  • src/axolotl/loaders/model.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.11.9)
src/axolotl/loaders/model.py

206-208: Use self.model_config.model_type != "llava" instead of not self.model_config.model_type == "llava"

Replace with != operator

(SIM201)


339-339: Do not call setattr with a constant attribute value. It is not any safer than normal property access.

Replace setattr with assignment

(B010)


340-340: Do not call setattr with a constant attribute value. It is not any safer than normal property access.

Replace setattr with assignment

(B010)


580-580: Do not call getattr with a constant attribute value. It is not any safer than normal property access.

Replace getattr with attribute access

(B009)


766-767: Use a single if statement instead of nested if statements

(SIM102)

⏰ Context from checks skipped due to timeout of 90000ms (7)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.5.1)
  • GitHub Check: pre-commit
🔇 Additional comments (1)
src/axolotl/loaders/model.py (1)

1-83: Well-organized model loader module with comprehensive documentation.

The ModelLoader class has excellent docstrings that clearly explain its purpose, lifecycle, and capabilities. This is especially important for a complex component that handles multiple model types, quantization schemes, and adapter integrations.

@djsaunde djsaunde force-pushed the model-load-refactor branch from 0b8440d to fe504b7 Compare May 16, 2025 17:46
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

♻️ Duplicate comments (2)
src/axolotl/loaders/utils.py (1)

90-97: ⚠️ Potential issue

quantization_config may not be a dict – AttributeError risk remains

model_config.quantization_config.get("quant_method") assumes the field is a mapping.
HF frequently stores this as a PretrainedConfig, dataclass, or custom object, and several users ran into AttributeError: 'PretrainedConfig' object has no attribute 'get' in earlier commits (see previous review). The same runtime failure is still possible here.

-    is_compressed_tensors_config = (
-        quant_config_exists
-        and model_config.quantization_config.get("quant_method") == "compressed-tensors"
-    )
+    qc = model_config.quantization_config
+    if quant_config_exists and not isinstance(qc, dict):
+        qc = qc.to_dict() if hasattr(qc, "to_dict") else vars(qc)
+
+    is_compressed_tensors_config = (
+        quant_config_exists and qc.get("quant_method") == "compressed-tensors"
+    )

Same normalisation should be applied a few lines below for the GPTQ check.

src/axolotl/loaders/model.py (1)

151-157: Consider deprecating post_model_build in favor of pre_lora_load.

As noted in a previous comment, it might be worth deprecating post_model_build and using only pre_lora_load, but this would need to be evaluated against potential impacts on third-party plugins.

🧹 Nitpick comments (6)
src/axolotl/loaders/utils.py (1)

193-200: Use contextlib.suppress for cleaner, faster attribute checks

The twin try/except AttributeError: pass blocks slightly inflate byte-code and hide other potential errors. Ruff already flagged this (SIM105). A concise alternative:

-import torch
+import torch, contextlib
 ...
-        try:
-            weight_mismatch = module.weight.dtype != dtype
-        except AttributeError:
-            pass
-        try:
-            bias_mismatch = module.bias.dtype != dtype
-        except AttributeError:
-            pass
+        with contextlib.suppress(AttributeError):
+            weight_mismatch = module.weight.dtype != dtype
+        with contextlib.suppress(AttributeError):
+            bias_mismatch = module.bias.dtype != dtype

Behaviour is identical, code is clearer.

🧰 Tools
🪛 Ruff (0.11.9)

193-196: Use contextlib.suppress(AttributeError) instead of try-except-pass

Replace with contextlib.suppress(AttributeError)

(SIM105)


197-200: Use contextlib.suppress(AttributeError) instead of try-except-pass

Replace with contextlib.suppress(AttributeError)

(SIM105)

src/axolotl/loaders/patch_manager.py (1)

128-136: Nested if can be flattened for clarity (SIM102)

Combining the llama4 model-type branch makes intent clearer and removes one indentation level:

-        if self.cfg.model_config_type == "llama4":
-            if self.cfg.llama4_linearized_experts:
+        if (
+            self.cfg.model_config_type == "llama4"
+            and self.cfg.llama4_linearized_experts
+        ):

Pure readability—behaviour unchanged.

🧰 Tools
🪛 Ruff (0.11.9)

128-129: Use a single if statement instead of nested if statements

(SIM102)

src/axolotl/loaders/model.py (4)

214-241: Consider refactoring repeated attribute checks.

The _adjust_model_config method has multiple similar conditional blocks checking for attributes. This could be refactored into a helper function to reduce repetition.

-        if (
-            hasattr(self.model, "config")
-            and hasattr(self.model.config, "max_position_embeddings")
-            and self.model.config.max_position_embeddings
-            and self.cfg.sequence_len > self.model.config.max_position_embeddings
-        ):
-            LOG.warning(
-                f"increasing model.config.max_position_embeddings from {self.model.config.max_position_embeddings} to {self.cfg.sequence_len}"
-            )
-            self.model.config.max_position_embeddings = self.cfg.sequence_len
-
-        if (
-            hasattr(self.model, "config")
-            and hasattr(self.model.config, "bos_token_id")
-            and self.model.config.bos_token_id
-            and self.model.config.bos_token_id != self.tokenizer.bos_token_id
-        ):
-            self.model.config.bos_token_id = self.tokenizer.bos_token_id
-
-        if (
-            hasattr(self.model, "config")
-            and hasattr(self.model.config, "eos_token_id")
-            and self.model.config.eos_token_id
-            and self.model.config.eos_token_id != self.tokenizer.eos_token_id
-        ):
-            self.model.config.eos_token_id = self.tokenizer.eos_token_id
+        def update_config_if_exists(attr_name, new_value, condition=None, log_msg=None):
+            if (
+                hasattr(self.model, "config")
+                and hasattr(self.model.config, attr_name)
+                and getattr(self.model.config, attr_name)
+                and (condition is None or condition(getattr(self.model.config, attr_name)))
+            ):
+                if log_msg:
+                    LOG.warning(log_msg)
+                setattr(self.model.config, attr_name, new_value)
+
+        update_config_if_exists(
+            "max_position_embeddings",
+            self.cfg.sequence_len,
+            lambda x: self.cfg.sequence_len > x,
+            f"increasing model.config.max_position_embeddings from {self.model.config.max_position_embeddings} to {self.cfg.sequence_len}"
+        )
+
+        update_config_if_exists(
+            "bos_token_id",
+            self.tokenizer.bos_token_id,
+            lambda x: x != self.tokenizer.bos_token_id
+        )
+
+        update_config_if_exists(
+            "eos_token_id", 
+            self.tokenizer.eos_token_id,
+            lambda x: x != self.tokenizer.eos_token_id
+        )

419-427: Consider implementing or removing the commented code.

There's commented-out code about placing reference models on separate GPUs. Either implement this feature (after careful testing) or remove the comments if they're no longer relevant.


505-506: Clarify or remove the comment about sample packing.

The comment "Sample packing uses custom FA2 patch" seems out of place or incomplete, as it doesn't directly relate to what the method does. Consider clarifying this comment or removing it.


566-703: Consider refactoring the complex model building method.

The _build_model method is quite complex with multiple nested conditionals for different model types and configurations. Consider breaking it down into smaller, more focused methods for each model type or configuration pattern.

For example:

def _build_model(self) -> bool:
    """Load model, with load strategy depending on config."""
    skip_move_to_device = False
    
    # Dispatch to specialized builders based on model type/config
    if self._should_use_sharded_qlora_loading():
        skip_move_to_device = self._build_sharded_qlora_model()
    elif self._is_llama_model_without_special_requirements():
        skip_move_to_device = self._build_llama_model()
    elif self._is_mamba_model():
        skip_move_to_device = self._build_mamba_model()
    elif self._is_custom_model_type():
        skip_move_to_device = self._build_custom_model_type()
    else:
        skip_move_to_device = self._build_standard_model()
    
    if is_deepspeed_zero3_enabled():
        skip_move_to_device = True
        
    return skip_move_to_device
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting

📥 Commits

Reviewing files that changed from the base of the PR and between 0b8440d and fe504b7.

📒 Files selected for processing (27)
  • src/axolotl/cli/utils.py (2 hunks)
  • src/axolotl/common/datasets.py (1 hunks)
  • src/axolotl/core/trainer_builder.py (1 hunks)
  • src/axolotl/integrations/base.py (14 hunks)
  • src/axolotl/loaders/__init__.py (1 hunks)
  • src/axolotl/loaders/adapter.py (1 hunks)
  • src/axolotl/loaders/constants.py (1 hunks)
  • src/axolotl/loaders/model.py (1 hunks)
  • src/axolotl/loaders/patch_manager.py (1 hunks)
  • src/axolotl/loaders/processor.py (1 hunks)
  • src/axolotl/loaders/tokenizer.py (1 hunks)
  • src/axolotl/loaders/utils.py (1 hunks)
  • src/axolotl/monkeypatch/peft/utils.py (1 hunks)
  • src/axolotl/train.py (3 hunks)
  • src/axolotl/utils/config/__init__.py (1 hunks)
  • src/axolotl/utils/data/rl.py (1 hunks)
  • src/axolotl/utils/lora_embeddings.py (0 hunks)
  • src/axolotl/utils/models.py (0 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
  • tests/core/test_trainer_builder.py (2 hunks)
  • tests/e2e/patched/test_model_patches.py (3 hunks)
  • tests/e2e/test_load_model.py (3 hunks)
  • tests/patched/test_validation.py (2 hunks)
  • tests/test_exact_deduplication.py (6 hunks)
  • tests/test_loaders.py (3 hunks)
  • tests/test_lora.py (3 hunks)
  • tests/test_tokenizers.py (1 hunks)
💤 Files with no reviewable changes (2)
  • src/axolotl/utils/lora_embeddings.py
  • src/axolotl/utils/models.py
✅ Files skipped from review due to trivial changes (2)
  • src/axolotl/utils/config/init.py
  • src/axolotl/common/datasets.py
🚧 Files skipped from review as they are similar to previous changes (19)
  • src/axolotl/utils/data/rl.py
  • src/axolotl/core/trainer_builder.py
  • tests/core/test_trainer_builder.py
  • tests/test_lora.py
  • tests/test_tokenizers.py
  • src/axolotl/loaders/constants.py
  • tests/e2e/patched/test_model_patches.py
  • src/axolotl/cli/utils.py
  • tests/test_loaders.py
  • tests/test_exact_deduplication.py
  • src/axolotl/monkeypatch/peft/utils.py
  • tests/e2e/test_load_model.py
  • src/axolotl/utils/schemas/config.py
  • tests/patched/test_validation.py
  • src/axolotl/loaders/init.py
  • src/axolotl/train.py
  • src/axolotl/loaders/processor.py
  • src/axolotl/loaders/tokenizer.py
  • src/axolotl/integrations/base.py
🧰 Additional context used
🧬 Code Graph Analysis (3)
src/axolotl/loaders/model.py (9)
src/axolotl/integrations/base.py (13)
  • PluginManager (290-576)
  • cfg (331-332)
  • cfg (335-336)
  • pre_model_load (92-98)
  • pre_model_load (393-401)
  • post_model_build (100-108)
  • post_model_build (403-413)
  • pre_lora_load (121-130)
  • pre_lora_load (427-436)
  • post_lora_load (132-141)
  • post_lora_load (438-447)
  • post_model_load (110-119)
  • post_model_load (415-425)
src/axolotl/loaders/adapter.py (2)
  • load_adapter (164-181)
  • load_lora (71-161)
src/axolotl/loaders/patch_manager.py (4)
  • PatchManager (27-353)
  • has_flash_attn (45-47)
  • apply_pre_model_load_patches (49-64)
  • apply_post_model_load_patches (66-90)
src/axolotl/loaders/utils.py (3)
  • get_linear_embedding_layers (210-216)
  • get_module_class_from_name (16-41)
  • load_model_config (135-183)
src/axolotl/models/mamba/__init__.py (1)
  • fix_mamba_attn_for_loss (16-24)
src/axolotl/utils/bench.py (1)
  • log_gpu_memory_usage (94-111)
src/axolotl/utils/distributed.py (2)
  • get_device_count (33-39)
  • get_device_type (22-30)
src/axolotl/utils/schemas/enums.py (1)
  • RLType (6-14)
src/axolotl/models/mamba/modeling_mamba.py (3)
  • tie_weights (61-62)
  • from_pretrained (122-128)
  • MambaLMHeadModel (17-128)
src/axolotl/loaders/adapter.py (3)
src/axolotl/utils/schemas/peft.py (3)
  • LoftQConfig (8-14)
  • LoraConfig (23-122)
  • PeftConfig (17-20)
src/axolotl/loaders/utils.py (1)
  • get_linear_embedding_layers (210-216)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
src/axolotl/loaders/utils.py (4)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
tests/test_exact_deduplication.py (1)
  • cfg (221-235)
src/axolotl/models/mamba/modeling_mamba.py (1)
  • from_pretrained (122-128)
tests/test_perplexity.py (1)
  • model (20-21)
🪛 Ruff (0.11.9)
src/axolotl/loaders/patch_manager.py

128-129: Use a single if statement instead of nested if statements

(SIM102)

src/axolotl/loaders/utils.py

193-196: Use contextlib.suppress(AttributeError) instead of try-except-pass

Replace with contextlib.suppress(AttributeError)

(SIM105)


197-200: Use contextlib.suppress(AttributeError) instead of try-except-pass

Replace with contextlib.suppress(AttributeError)

(SIM105)

⏰ Context from checks skipped due to timeout of 90000ms (6)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.5.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
🔇 Additional comments (4)
src/axolotl/loaders/patch_manager.py (1)

295-301:

Details

❌ Incorrect review comment

self.cfg.device may be missing – guard before access

_patch_llama_flash_attention accesses self.cfg.device, but that field is not guaranteed by the schema shown in other modules. If absent, an AttributeError is raised during model-load.

-            if self.cfg.device not in ["mps", "cpu"] and not self.inference:
+            if getattr(self.cfg, "device", None) not in ["mps", "cpu"] and not self.inference:

Please verify that all config paths indeed define device; otherwise add a default ("cuda"/"cpu").


🏁 Script executed:

#!/bin/bash
# Find definitions of config classes to inspect schema for 'device'
rg -n "class .*Config" -n src/axolotl

# Search for occurrences of 'device' in configuration-related files
rg -n "device" -n src/axolotl

Length of output: 43317


No missing device attribute on self.cfg

The device field is declared in AxolotlInputConfig (see utils/schemas/config.py line 144) with a default of None, so accessing self.cfg.device cannot raise AttributeError. If you need to treat None differently (e.g. default to "cuda" or skip patching when unset), add an explicit None check rather than guarding against a missing attribute.

Likely an incorrect or invalid review comment.

src/axolotl/loaders/model.py (3)

122-123: Good job with default initialization.

Setting self.auto_model_loader = AutoModelForCausalLM as the default is a sensible choice, as it's later overridden for multimodal models in _set_auto_model_loader.


138-167: Well-structured orchestration method.

The load() method nicely orchestrates the entire loading process with clear stages, properly integrating with the plugin system at each stage. The method signature correctly indicates that the returned LoRA config can be None.


647-648: Good job fixing the potential KeyError issue.

The previous implementation used del which could raise KeyError if the keys didn't exist. The current implementation uses .pop(key, None) which safely handles missing keys.

Comment thread src/axolotl/loaders/adapter.py
Comment thread src/axolotl/loaders/adapter.py
Comment thread src/axolotl/loaders/model.py Outdated
Comment thread src/axolotl/loaders/model.py
@djsaunde djsaunde force-pushed the model-load-refactor branch from 75a1284 to a2c5720 Compare May 21, 2025 22:54
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (6)
src/axolotl/loaders/tokenizer.py (1)

48-51: 🛠️ Refactor suggestion

Guard against cfg.output_dir being None before os.path.join

output_dir originates from the run-time configuration.
If the user omits --output_dir, os.path.join(None, "tokenizer")
will raise a TypeError, breaking training/inference when
added_tokens_overrides is set.

-    tokenizer_dir = os.path.join(output_dir, "tokenizer")
+    if output_dir is None:
+        import tempfile
+        output_dir = tempfile.mkdtemp(prefix="axolotl_tokenizer_")
+        LOG.warning(
+            "`cfg.output_dir` is None – writing modified tokenizer to %s", output_dir
+        )
+    tokenizer_dir = os.path.join(output_dir, "tokenizer")
src/axolotl/loaders/utils.py (1)

90-97: quantization_config may not be a plain dict – avoid .get() directly
Duplicate of a previously-raised issue.

model_config.quantization_config can be a nested PretrainedConfig
or a dataclass. Calling .get() will raise AttributeError.

Refactor to normalise the object first:

-    is_compressed_tensors_config = (
-        quant_config_exists
-        and model_config.quantization_config.get("quant_method") == "compressed-tensors"
-    )
+    qc = model_config.quantization_config
+    if quant_config_exists and not isinstance(qc, dict):
+        qc = qc.to_dict() if hasattr(qc, "to_dict") else vars(qc)
+
+    is_compressed_tensors_config = (
+        quant_config_exists and qc.get("quant_method") == "compressed-tensors"
+    )
src/axolotl/loaders/adapter.py (2)

51-68: Prefer full module paths and explicit output-head filtering in find_all_linear_names

The current implementation has two issues:

  1. It keeps only the last token of the module name which might lead to missing intended targets
  2. It blindly removes output_embedding which might hide an intentional request to wrap that head

The solution is to return the full dotted path rather than only the last token and let the caller decide how to shorten it. Also, consider making the output embedding filtering optional.

-def find_all_linear_names(model):
+def find_all_linear_names(
+    model,
+    filter_output: bool = True,
+    warn_on_drop: bool = False,
+) -> list[str]:
     cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
     lora_module_names = set()
     for name, module in model.named_modules():
         if (
             isinstance(module, cls)
             or "Linear" in module.__class__.__name__
             and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
         ):
-            names = name.split(".")
-            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
+            # keep full dotted path
+            lora_module_names.add(name)
 
     embedding_modules = get_linear_embedding_layers(model.config.model_type)
     output_embedding = embedding_modules[1]
-    if output_embedding in lora_module_names:  # needed for 16-bit
+    # Optionally filter out the model's output-head
+    if output_embedding in lora_module_names and filter_output:
+        if warn_on_drop:
+            LOG.warning(f"Dropping output-embedding layer '{output_embedding}' from LoRA targets")
         lora_module_names.remove(output_embedding)
 
     return list(lora_module_names)

104-116: Ensure r and lora_alpha have valid defaults

PEFT's LoraConfig requires positive integers for r and lora_alpha; passing None will trigger a ValueError. Guard against unset YAML values by providing sensible defaults or omitting the arguments when they're None.

     lora_config = LoraConfig(
-        r=cfg.lora_r,
-        lora_alpha=cfg.lora_alpha,
+        r=cfg.lora_r or 8,            # fallback to 8 if unset
+        lora_alpha=cfg.lora_alpha or 16,  # fallback to 16 if unset
         target_modules=lora_target_modules,
         layers_to_transform=cfg.peft_layers_to_transform,
         layers_pattern=cfg.peft_layers_pattern,
         lora_dropout=cfg.lora_dropout,
         fan_in_fan_out=cfg.lora_fan_in_fan_out,
         modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
         bias="none",
         task_type="CAUSAL_LM",
         **lora_config_kwargs,
     )
src/axolotl/loaders/model.py (2)

340-341: Resolve device-placement conditional and remove TODO

The current code unconditionally appends :{self.cfg.local_rank} to the device string—even when not in DDP or on backends that don't support indexing (e.g. MPS).

Replace the TODO with a clear conditional that properly handles different device types:

-            # TODO: validate this conditional
-            self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}")
+            device = get_device_type()
+            # only append ":local_rank" for CUDA multi-GPU setups
+            if getattr(self.cfg, "world_size", 1) > 1 and device.type == "cuda":
+                target = f"{device}:{self.cfg.local_rank}"
+            else:
+                target = str(device)
+            self.model.to(target)

631-643: Improve the Mamba model loading implementation

The code contains a FIXME comment acknowledging that the Mamba model loading is "janky at best and hacked together".

This implementation uses a workaround with fix_mamba_attn_for_loss() and has special handling for model kwargs, which could lead to unexpected behavior if the underlying model implementation changes. Consider refactoring this in a future PR for better maintainability and reliability.

🧹 Nitpick comments (6)
src/axolotl/loaders/tokenizer.py (1)

87-88: Variable shadowing reduces readability

Re-using the name tokenizer_path for the JSON file path
(tokenizer.json) shadows the original tokenizer_path argument.
This is harmless but confusing when reading or debugging.

-        tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")
+        tokenizer_json_path = os.path.join(tokenizer_dir, "tokenizer.json")
src/axolotl/loaders/utils.py (2)

186-205: Use contextlib.suppress for cleaner attribute-checks

The nested try/except AttributeError blocks can be simplified:

+import contextlib
 ...
-        try:
-            weight_mismatch = module.weight.dtype != dtype
-        except AttributeError:
-            pass
-        try:
-            bias_mismatch = module.bias.dtype != dtype
-        except AttributeError:
-            pass
+        with contextlib.suppress(AttributeError):
+            weight_mismatch = module.weight.dtype != dtype
+        with contextlib.suppress(AttributeError):
+            bias_mismatch = module.bias.dtype != dtype

The intent remains clear while reducing boilerplate.

🧰 Tools
🪛 Ruff (0.11.9)

190-193: Use contextlib.suppress(AttributeError) instead of try-except-pass

Replace with contextlib.suppress(AttributeError)

(SIM105)


194-197: Use contextlib.suppress(AttributeError) instead of try-except-pass

Replace with contextlib.suppress(AttributeError)

(SIM105)


161-175: Catching only ValueError misses common HF config errors

AutoConfig.from_pretrained also raises OSError (e.g., missing files)
and EnvironmentError. Consider broadening the exception list so
special-case handling for Mamba isn’t skipped on those errors.

-    except ValueError as err:
+    except (ValueError, OSError) as err:
src/axolotl/loaders/patch_manager.py (2)

112-118: Simplify nested if statements

The nested if statements can be combined for better readability and to address the static analysis warning.

-        if self.cfg.model_config_type == "llama4":
-            if self.cfg.llama4_linearized_experts:
-                from axolotl.monkeypatch.models.llama4.modeling import (
-                    patch_llama4_linearized_modeling,
-                )
-
-                patch_llama4_linearized_modeling()
+        if self.cfg.model_config_type == "llama4" and self.cfg.llama4_linearized_experts:
+            from axolotl.monkeypatch.models.llama4.modeling import (
+                patch_llama4_linearized_modeling,
+            )
+
+            patch_llama4_linearized_modeling()
🧰 Tools
🪛 Ruff (0.11.9)

112-113: Use a single if statement instead of nested if statements

(SIM102)


351-371: Simplify nested if statements

Similar to the previous suggestion, consider combining these nested if statements to address the static analysis warning.

-        if (
-            self.model_config.model_type in ["llama", "llama4"]
-            and not self.cfg.trust_remote_code
-            and not self.cfg.gptq
-        ):
-            # TODO(MengqingCao): split these patches seperately
-            if self.cfg.flash_attention and not self.inference:
+        if (
+            self.model_config.model_type in ["llama", "llama4"]
+            and not self.cfg.trust_remote_code
+            and not self.cfg.gptq
+            and self.cfg.flash_attention 
+            and not self.inference
+        ):
+            # TODO(MengqingCao): split these patches seperately
🧰 Tools
🪛 Ruff (0.11.9)

351-357: Use a single if statement instead of nested if statements

(SIM102)

src/axolotl/loaders/model.py (1)

658-662: Use safer key removal for model kwargs

The current implementation unconditionally deletes keys from self.model_kwargs, which could raise KeyError if the keys don't exist.

-                    del self.model_kwargs["torch_dtype"]
-                    del self.model_kwargs["device_map"]
+                    self.model_kwargs.pop("torch_dtype", None)
+                    self.model_kwargs.pop("device_map", None)

Using pop with a default value of None makes the code more robust by avoiding KeyError if the key is already removed.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 75a1284 and a2c5720.

📒 Files selected for processing (29)
  • src/axolotl/cli/utils.py (2 hunks)
  • src/axolotl/common/datasets.py (1 hunks)
  • src/axolotl/core/trainer_builder.py (1 hunks)
  • src/axolotl/core/trainers/grpo/trainer.py (1 hunks)
  • src/axolotl/integrations/base.py (14 hunks)
  • src/axolotl/loaders/__init__.py (1 hunks)
  • src/axolotl/loaders/adapter.py (1 hunks)
  • src/axolotl/loaders/constants.py (1 hunks)
  • src/axolotl/loaders/model.py (1 hunks)
  • src/axolotl/loaders/patch_manager.py (1 hunks)
  • src/axolotl/loaders/processor.py (1 hunks)
  • src/axolotl/loaders/tokenizer.py (1 hunks)
  • src/axolotl/loaders/utils.py (1 hunks)
  • src/axolotl/monkeypatch/peft/utils.py (1 hunks)
  • src/axolotl/train.py (3 hunks)
  • src/axolotl/utils/config/__init__.py (1 hunks)
  • src/axolotl/utils/ctx_managers/sequence_parallel.py (1 hunks)
  • src/axolotl/utils/data/rl.py (1 hunks)
  • src/axolotl/utils/lora_embeddings.py (0 hunks)
  • src/axolotl/utils/models.py (0 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
  • tests/core/test_trainer_builder.py (2 hunks)
  • tests/e2e/patched/test_model_patches.py (3 hunks)
  • tests/e2e/test_load_model.py (3 hunks)
  • tests/patched/test_validation.py (2 hunks)
  • tests/test_exact_deduplication.py (6 hunks)
  • tests/test_loaders.py (3 hunks)
  • tests/test_lora.py (3 hunks)
  • tests/test_tokenizers.py (1 hunks)
💤 Files with no reviewable changes (2)
  • src/axolotl/utils/lora_embeddings.py
  • src/axolotl/utils/models.py
✅ Files skipped from review due to trivial changes (4)
  • src/axolotl/utils/data/rl.py
  • tests/test_tokenizers.py
  • src/axolotl/utils/ctx_managers/sequence_parallel.py
  • tests/test_exact_deduplication.py
🚧 Files skipped from review as they are similar to previous changes (17)
  • src/axolotl/common/datasets.py
  • src/axolotl/utils/config/init.py
  • src/axolotl/core/trainer_builder.py
  • tests/core/test_trainer_builder.py
  • tests/e2e/patched/test_model_patches.py
  • src/axolotl/train.py
  • tests/test_lora.py
  • tests/patched/test_validation.py
  • src/axolotl/cli/utils.py
  • src/axolotl/loaders/constants.py
  • src/axolotl/monkeypatch/peft/utils.py
  • tests/e2e/test_load_model.py
  • tests/test_loaders.py
  • src/axolotl/loaders/init.py
  • src/axolotl/loaders/processor.py
  • src/axolotl/utils/schemas/config.py
  • src/axolotl/integrations/base.py
🧰 Additional context used
🧬 Code Graph Analysis (2)
src/axolotl/loaders/adapter.py (3)
src/axolotl/utils/schemas/peft.py (3)
  • LoftQConfig (8-14)
  • LoraConfig (23-122)
  • PeftConfig (17-20)
src/axolotl/loaders/utils.py (1)
  • get_linear_embedding_layers (207-213)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
src/axolotl/loaders/utils.py (4)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
tests/test_exact_deduplication.py (1)
  • cfg (221-235)
src/axolotl/models/mamba/modeling_mamba.py (1)
  • from_pretrained (122-128)
tests/test_perplexity.py (1)
  • model (20-21)
🪛 Ruff (0.11.9)
src/axolotl/loaders/patch_manager.py

112-113: Use a single if statement instead of nested if statements

(SIM102)


351-357: Use a single if statement instead of nested if statements

(SIM102)

src/axolotl/loaders/utils.py

190-193: Use contextlib.suppress(AttributeError) instead of try-except-pass

Replace with contextlib.suppress(AttributeError)

(SIM105)


194-197: Use contextlib.suppress(AttributeError) instead of try-except-pass

Replace with contextlib.suppress(AttributeError)

(SIM105)

⏰ Context from checks skipped due to timeout of 90000ms (6)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
  • GitHub Check: PyTest (3.11, 2.5.1)
🔇 Additional comments (10)
src/axolotl/core/trainers/grpo/trainer.py (1)

46-46: Confirm ring-attention import path exists in refactored package

The get_ring_attn_group symbol has moved from axolotl.monkeypatch.ring_attn.patch
to axolotl.monkeypatch.ring_attn. Please double-check that

  1. src/axolotl/monkeypatch/ring_attn/__init__.py re-exports get_ring_attn_group
    (or defines it directly), and
  2. All other call-sites have been updated in the PR.

A quick rg get_ring_attn_group scan in the repo branch will reveal any stale
imports.

src/axolotl/loaders/patch_manager.py (3)

356-356: Address the TODO comment in a future PR

The TODO comment indicates that these patches should be split separately. Consider creating a follow-up task to refactor this method for better organization and maintainability.

This would involve extracting the Flash Attention-specific patches into their own dedicated methods in the PatchManager class to improve code organization. Has this been captured as a future task?


345-347: Good error handling for unsupported configuration

The implementation correctly raises a NotImplementedError when shifted-sparse attention is requested without flash attention, making it clear to users that this combination is not supported.


206-216: Well-handled auto_map configuration

The implementation now correctly handles auto_map access for both dictionary-like and object configs, addressing the previous review comment. This robust implementation prevents potential AttributeError issues.

src/axolotl/loaders/adapter.py (2)

146-152: Robust error handling when printing trainable parameters

The code properly catches and logs exceptions that might occur during model.print_trainable_parameters(), ensuring the process continues even if this non-critical operation fails.


123-142: Good support for CPU-based LoRA loading

The implementation supports loading LoRA weights onto CPU with proper memory settings when cfg.lora_on_cpu is set, which is important for large models that don't fit in GPU memory.

src/axolotl/loaders/model.py (4)

154-156: Consider deprecating post_model_build in favor of pre_lora_load

According to a past review comment by the developer, post_model_build could potentially be deprecated in favor of pre_lora_load.

Before making this change, consider the implications for any third-party plugins that might rely on the current hook structure. What is the plan for deprecating post_model_build?


167-168: Well-structured plugin hooks

The code provides well-defined plugin hooks at different stages of the model loading process (pre_model_load, post_model_build, pre_lora_load, post_lora_load, post_model_load), allowing for extensible customization.


25-52: Comprehensive imports and organization

The file has well-organized imports, grouping standard library, third-party, and internal imports appropriately. This makes the code more readable and maintainable.


126-130: Good integration with PatchManager

The ModelLoader correctly initializes and uses the PatchManager for applying patches, which centralizes patch management and improves code organization.

Comment thread src/axolotl/loaders/patch_manager.py
Comment thread src/axolotl/loaders/patch_manager.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (4)
src/axolotl/loaders/utils.py (2)

90-94: ⚠️ Potential issue

Fix potential AttributeError with non-dict quantization_config

The code assumes that model_config.quantization_config is a dict-like object with a .get() method, but it could be a nested PretrainedConfig or a custom dataclass which would cause an AttributeError.

-is_compressed_tensors_config = (
-    quant_config_exists
-    and model_config.quantization_config.get("quant_method") == "compressed-tensors"
-)
+qc = model_config.quantization_config
+if quant_config_exists and not isinstance(qc, dict):
+    qc = qc.to_dict() if hasattr(qc, "to_dict") else vars(qc)
+
+is_compressed_tensors_config = (
+    quant_config_exists and qc.get("quant_method") == "compressed-tensors"
+)

105-109: ⚠️ Potential issue

Same issue with quant_config type checking

Similar to the previous issue, this code also assumes that model_config.quantization_config is a dict-like object. The same fix should be applied here to avoid AttributeError.

-quant_config_method_is_gptq = (
-    quant_config_exists
-    and "quant_method" in model_config.quantization_config
-    and model_config.quantization_config["quant_method"] == "gptq"
-)
+quant_config_method_is_gptq = False
+if quant_config_exists:
+    qc = model_config.quantization_config
+    if not isinstance(qc, dict):
+        qc = qc.to_dict() if hasattr(qc, "to_dict") else vars(qc)
+    
+    quant_config_method_is_gptq = (
+        "quant_method" in qc and qc["quant_method"] == "gptq"
+    )
src/axolotl/loaders/patch_manager.py (2)

298-305: ⚠️ Potential issue

Missing flash attention check for s2_attention

The s2_attention path doesn't check if flash attention is available via self.has_flash_attn before applying flash attention patches, unlike other code paths. This could lead to runtime errors if flash attention is not installed.

        elif self.cfg.s2_attention:
+            if not self.has_flash_attn:
+                raise ValueError("Shifted-sparse attention requires flash attention to be installed")
            LOG.info("patching w/ flash-enabled, shifted-sparse attention")
            replace_llama_attn_with_flash_attn(
                packed=False,
                cross_entropy=self.cfg.flash_attn_cross_entropy,
                rms_norm=self.cfg.flash_attn_rms_norm,
                use_shifted_sparse_attn=True,
            )

258-269: ⚠️ Potential issue

Potential import/usage mismatch for flash attention functions

The function patch_fa_llama_cross_entropy is imported conditionally when self.has_flash_attn is True, but then used with a separate condition check. If self.has_flash_attn is False but self.cfg.flash_attn_cross_entropy is True, this would cause a runtime error as the function wouldn't be defined.

To fix this, either adjust the import condition to match the usage, or adjust the usage condition to match the import:

-        if self.cfg.flash_attn_cross_entropy and self.has_flash_attn:
+        if self.has_flash_attn and self.cfg.flash_attn_cross_entropy:
            from axolotl.monkeypatch.llama_attn_hijack_flash import (
                patch_fa_llama_cross_entropy,
            )

            patch_fa_llama_cross_entropy()

Or alternatively, move the imports outside the conditional blocks:

-        if self.cfg.flash_attn_cross_entropy and self.has_flash_attn:
-            from axolotl.monkeypatch.llama_attn_hijack_flash import (
-                patch_fa_llama_cross_entropy,
-            )
+        # Always import these functions regardless of flash_attn availability
+        # to avoid runtime errors when referencing them
+        from axolotl.monkeypatch.llama_attn_hijack_flash import (
+            patch_fa_llama_cross_entropy,
+            patch_llama_rms_norm,
+        )
+        
+        if self.cfg.flash_attn_cross_entropy and self.has_flash_attn:
            patch_fa_llama_cross_entropy()
🧹 Nitpick comments (1)
src/axolotl/loaders/utils.py (1)

186-202: Consider using logging instead of print statements

The function uses print statements for logging dtype conversions. Consider using the logger (LOG) that's already defined in this module for consistency.

-            print(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}")
+            LOG.debug(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}")
-            print(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}")
+            LOG.debug(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}")
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a2c5720 and 74a2a74.

📒 Files selected for processing (2)
  • src/axolotl/loaders/patch_manager.py (1 hunks)
  • src/axolotl/loaders/utils.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (6)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.5.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
🔇 Additional comments (8)
src/axolotl/loaders/utils.py (2)

16-41: Well-implemented recursive search function

The recursive module search function is cleanly implemented with proper type hints and docstrings.


135-183: Good error handling and config overrides implementation

This function properly handles special cases like Mamba models and provides clear error reporting. The config overrides mechanism is also well-implemented.

src/axolotl/loaders/patch_manager.py (6)

1-26: Good module structure and docstring

The file has a clear purpose, proper imports, and good organization. The docstring explains the module's role well.


28-75: Well-designed class structure with clear responsibility separation

The PatchManager class has a clean design with proper initialization, and clear separation between pre-model and post-model patching responsibilities.


200-232: Improved auto_map handling for PretrainedConfig

The code now correctly handles both dictionary-style and attribute-style access to auto_map with proper type checking. This is a good improvement.


179-199: Good error handling in sequence parallelism setup

The code checks if parallelism is already configured before applying patches, which prevents redundant setup. This is a good defensive programming practice.


351-391: Well-structured model patching with clear logging

The post-model patching methods are well-organized with proper logging of actions, making the code behavior transparent to users.


159-178:

Details

❌ Incorrect review comment

Potential inconsistency in Unsloth patching

The _apply_unsloth_self_attention_patch method is called during pre-model load patching, but the same functionality (patch_self_attn_lora) is also called in _patch_llama_derived_model at line 280. This could lead to the same patch being applied twice.

Please verify that this doesn't cause any issues, or refactor to ensure the patch is only applied once:


🏁 Script executed:

#!/bin/bash
# Check for duplicate patch applications
echo "Checking for duplicate patch_self_attn_lora calls:"
rg -p "patch_self_attn_lora" --type python

Length of output: 184


Here’s a more reliable check for both occurrences and context:

#!/bin/bash
set -e

echo "1) All occurrences of patch_self_attn_lora:"
rg -n "patch_self_attn_lora"

echo -e "\n2) Context around _patch_llama_derived_model in patch_manager.py:"
rg -n "def _patch_llama_derived_model" -A20 src/axolotl/loaders/patch_manager.py || true

Let’s get a complete picture of where and how many times patch_self_attn_lora is applied, and the surrounding context in _patch_llama_derived_model:

#!/bin/bash
set -e

echo "1) All occurrences of patch_self_attn_lora:"
rg -n "patch_self_attn_lora" .

echo -e "\n2) Context around _patch_llama_derived_model in patch_manager.py:"
rg -n "def _patch_llama_derived_model" -A20 src/axolotl/loaders/patch_manager.py

Let’s list all patch_self_attn_lora calls and see how they’re used in both methods:

#!/bin/bash
set -e

echo "1) All occurrences of patch_self_attn_lora:"
rg -n "patch_self_attn_lora" -g "*.py"

echo -e "\n2) Context around _apply_unsloth_self_attention_patch:"
rg -n -C5 "_apply_unsloth_self_attention_patch" src/axolotl/loaders/patch_manager.py

echo -e "\n3) Context around _patch_llama_derived_model:"
rg -n -C5 "_patch_llama_derived_model" src/axolotl/loaders/patch_manager.py

🏁 Script executed:

#!/bin/bash
set -e

echo "1) All occurrences of patch_self_attn_lora (with surrounding context):"
rg -n -C3 "patch_self_attn_lora" -g "*.py"

echo -e "\n2) Context around _apply_unsloth_self_attention_patch in patch_manager.py:"
rg -n -C5 "_apply_unsloth_self_attention_patch" src/axolotl/loaders/patch_manager.py

echo -e "\n3) Context around _patch_llama_derived_model in patch_manager.py:"
rg -n -C20 "_patch_llama_derived_model" src/axolotl/loaders/patch_manager.py

Length of output: 9016


Incorrect duplication concern: distinct patch functions

The two patch_self_attn_lora calls import entirely different functions and apply separate patches:

  • In _apply_unsloth_self_attention_patch, it imports
    axolotl.monkeypatch.lora_kernels.patch_self_attn_lora(cfg)
    to inject optimized LoRA kernel logic based on the config.
  • In _patch_llama_derived_model, it imports
    axolotl.monkeypatch.unsloth_.patch_self_attn_lora()
    to apply Unsloth-specific attention patches, guarded by a global flag to ensure idempotence.

Because these are different modules and serve different patching responsibilities, there’s no double-application risk.

Likely an incorrect or invalid review comment.

Copy link
Copy Markdown
Collaborator

@NanoCode012 NanoCode012 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First review pass

Do we need to add loaders to quarto and remove any modules?

Comment thread src/axolotl/integrations/base.py Outdated
Comment thread src/axolotl/integrations/base.py Outdated
@djsaunde djsaunde force-pushed the model-load-refactor branch from 5c24934 to 8fd7325 Compare May 22, 2025 13:37
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Inline review comments failed to post. This is likely due to GitHub's limits when posting large numbers of comments. If you are seeing this consistently it is likely a permissions issue. Please check "Moderation" -> "Code review limits" under your organization settings.

Actionable comments posted: 1

♻️ Duplicate comments (3)
src/axolotl/loaders/model.py (3)

155-161: Consider deprecating post_model_build in favor of pre_lora_load

The plugin hooks post_model_build and pre_lora_load are called sequentially with no intervening code. As noted in a previous comment, consider deprecating post_model_build and standardizing on pre_lora_load.

If there are concerns about breaking third-party plugins, consider adding a deprecation warning when post_model_build is used, directing users to migrate to pre_lora_load.


332-341: 🛠️ Refactor suggestion

Validate and improve the device mapping conditional

The current implementation unconditionally appends :{self.cfg.local_rank} to the device string, which may not be appropriate for all device types (like MPS) or non-DDP setups.

- # TODO: validate this conditional
- self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}")
+ device = get_device_type()
+ # Only append device index for CUDA multi-GPU setups
+ if device.type == "cuda" and getattr(self.cfg, "world_size", 1) > 1:
+     target = f"{device}:{self.cfg.local_rank}"
+ else:
+     target = str(device)
+ self.model.to(target)

568-574: ⚠️ Potential issue

Remove dead code branch with and False condition

This conditional will never evaluate to True due to the and False clause, making this entire FSDP-sharded path unreachable.

- if (
-     (self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading)
-     and not self.qlora_fsdp
-     and False
- ):
+ if (
+     self.cfg.fsdp 
+     and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
+     and not self.qlora_fsdp
+ ):

If this branch is no longer needed, delete the entire block instead to improve code clarity.

🧹 Nitpick comments (2)
src/axolotl/loaders/model.py (2)

516-534: Simplify attention configuration with a mapping dictionary

The current approach for setting attention implementation uses multiple if-elif branches with repeated code patterns. This could be simplified for better maintainability.

- if self.cfg.flex_attention:
-     self.model_kwargs["attn_implementation"] = "flex_attention"
-     self.model_config._attn_implementation = (  # pylint: disable=protected-access
-         "flex_attention"
-     )
- 
- elif self.cfg.flash_attention:
-     if not self.cfg.sample_packing and self.cfg.s2_attention:
-         pass
-     self.model_kwargs["attn_implementation"] = "flash_attention_2"
-     self.model_config._attn_implementation = (  # pylint: disable=protected-access
-         "flash_attention_2"
-     )
- elif self.cfg.sdp_attention:
-     self.model_kwargs["attn_implementation"] = "sdpa"
-     self.model_config._attn_implementation = (  # pylint: disable=protected-access
-         "sdpa"
-     )
- elif self.cfg.eager_attention:
-     self.model_kwargs["attn_implementation"] = "eager"
-     self.model_config._attn_implementation = (  # pylint: disable=protected-access
-         "eager"
-     )
+ # Map config flags to attention implementations
+ attention_mapping = {
+     "flex_attention": self.cfg.flex_attention,
+     "flash_attention_2": self.cfg.flash_attention and not (self.cfg.sample_packing and not self.cfg.s2_attention),
+     "sdpa": self.cfg.sdp_attention,
+     "eager": self.cfg.eager_attention
+ }
+ 
+ # Find the first matching attention implementation
+ for impl, condition in attention_mapping.items():
+     if condition:
+         self.model_kwargs["attn_implementation"] = impl
+         self.model_config._attn_implementation = impl  # pylint: disable=protected-access
+         break

This approach is more maintainable and makes it easier to add new attention implementations in the future.


361-364: Replace multiple GC calls with a cleaner approach

The current implementation calls gc.collect() three times in a row, which is unnecessary and doesn't follow best practices.

- for _ in range(3):
-     gc.collect()
-     torch.cuda.empty_cache()
+ # Clean up memory once
+ gc.collect()
+ torch.cuda.empty_cache()

Multiple GC calls generally don't provide additional benefit. A single call followed by cache clearing is sufficient.

🛑 Comments failed to post (1)
src/axolotl/loaders/model.py (1)

629-641: 🛠️ Refactor suggestion

Improve the Mamba model loading implementation

The current implementation for loading Mamba models is explicitly marked as "janky" and uses direct dictionary manipulation which could lead to KeyError exceptions.

- # FIXME this is janky at best and hacked together to make it work
- MambaLMHeadModel = fix_mamba_attn_for_loss()  # pylint: disable=invalid-name
- 
- self.model_kwargs["dtype"] = self.model_kwargs["torch_dtype"]
- self.model_kwargs["device"] = torch.cuda.current_device()
- del self.model_kwargs["torch_dtype"]
- del self.model_kwargs["device_map"]
+ # Load Mamba model with proper configuration
+ MambaLMHeadModel = fix_mamba_attn_for_loss()  # pylint: disable=invalid-name
+ 
+ # Safely transfer and remove kwargs to prevent KeyError
+ dtype = self.model_kwargs.pop("torch_dtype", None)
+ self.model_kwargs["dtype"] = dtype
+ self.model_kwargs["device"] = torch.cuda.current_device()
+ self.model_kwargs.pop("device_map", None)

Consider creating a dedicated method for Mamba model loading in a future PR to better encapsulate this special case.

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/axolotl/loaders/model.py around lines 629 to 641, the Mamba model loading
code directly manipulates the model_kwargs dictionary and is marked as "janky,"
risking KeyError exceptions. Refactor this by extracting the Mamba model loading
logic into a dedicated method that safely handles dictionary keys and prepares
the arguments cleanly before calling from_pretrained. This will encapsulate the
special case and improve code clarity and robustness.

@djsaunde djsaunde merged commit b5f1e53 into main May 23, 2025
16 checks passed
@djsaunde djsaunde deleted the model-load-refactor branch May 23, 2025 19:51
This was referenced May 27, 2025
This was referenced Aug 18, 2025
@coderabbitai coderabbitai Bot mentioned this pull request Dec 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants