models.py -> loaders/ module refactor#2680
Conversation
WalkthroughThis change refactors the model and tokenizer loading system by moving related logic from Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant ModelLoader
participant Tokenizer
participant Processor
participant Adapter
User->>ModelLoader: instantiate(cfg, tokenizer, ...)
User->>ModelLoader: load()
ModelLoader->>Tokenizer: load_tokenizer(cfg)
ModelLoader->>Processor: load_processor(cfg, tokenizer)
ModelLoader->>ModelLoader: _build_model()
ModelLoader->>Adapter: load_adapter(model, cfg, ...)
ModelLoader-->>User: returns (model, peft_config)
Poem
✨ Finishing Touches
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
Codecov ReportAttention: Patch coverage is 📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Actionable comments posted: 8
🧹 Nitpick comments (9)
src/axolotl/loaders/tokenizer.py (1)
168-170: Environment variable toggling happens on every call
os.environ["TOKENIZERS_PARALLELISM"] = "false"is executed whenever a
GPTNeoXTokenizerFastis encountered—even inside worker processes. Move the
side-effect into the branch guarded byis_main_process()or guard with
os.environ.getto avoid redundant writes.src/axolotl/loaders/utils.py (1)
190-208: Prefercontextlib.suppress+ logging over baretry/except+
ensure_dtypecurrently:
- Silences all
AttributeErrors, hampering debuggability.- Uses
-import torch +import torch +import contextlib ... - try: - weight_mismatch = module.weight.dtype != dtype - except AttributeError: - pass + with contextlib.suppress(AttributeError): + weight_mismatch = module.weight.dtype != dtype ... - if weight_mismatch: - print(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}") + if weight_mismatch: + LOG.debug( + "Converting %s.weight from %s to %s", name, module.weight.dtype, dtype + )This keeps the log pipeline consistent and removes repetitive boiler-plate.
🧰 Tools
🪛 Ruff (0.11.9)
193-196: Use
contextlib.suppress(AttributeError)instead oftry-except-passReplace with
contextlib.suppress(AttributeError)(SIM105)
197-200: Use
contextlib.suppress(AttributeError)instead oftry-except-passReplace with
contextlib.suppress(AttributeError)(SIM105)
src/axolotl/loaders/patch_manager.py (1)
128-136: Minor: collapse nestediffor readabilityRuff’s
SIM102– can be simplified:- if self.cfg.model_config_type == "llama4": - if self.cfg.llama4_linearized_experts: + if ( + self.cfg.model_config_type == "llama4" + and self.cfg.llama4_linearized_experts + ):No functional change, just reduces indentation.
🧰 Tools
🪛 Ruff (0.11.9)
128-129: Use a single
ifstatement instead of nestedifstatements(SIM102)
src/axolotl/loaders/adapter.py (2)
55-60: Operator precedence makes intent opaque
orprecedes the longandchain, but reading order is non-obvious.
Add parentheses for clarity; avoids accidental logic regressions if the line is edited later.- isinstance(module, cls) - or "Linear" in module.__class__.__name__ - and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",) + isinstance(module, cls) + or ( + "Linear" in module.__class__.__name__ + and module.__class__.__name__ != "LlamaLinearScalingRotaryEmbedding" + )
41-49: Potential memory leak:_orig_tonever deletedAfter restoring the original
.to, you assignNonebut keep the attribute, preventing GC of the function object. Preferdelattr(param.quant_state, "_orig_to").src/axolotl/loaders/model.py (2)
206-210: Prefer direct inequality over negated equalityRuff correctly points out that the expression
not (self.model_config.model_type == "llava")is harder to read than simply checking inequality:
- if self.cfg.mean_resizing_embeddings is not None and not ( - self.model_config.model_type == "llava" - ): + if self.cfg.mean_resizing_embeddings is not None and ( + self.model_config.model_type != "llava" + ):Using
!=conveys intent directly and avoids the double-negative.🧰 Tools
🪛 Ruff (0.11.9)
206-208: Use
self.model_config.model_type != "llava"instead ofnot self.model_config.model_type == "llava"Replace with
!=operator(SIM201)
339-341: Replacesetattrwith direct attribute assignment
setattr(self.model, "is_parallelizable", True)and the next line trigger Ruff B010 because the attribute names are hard-coded. A direct assignment is simpler and safer:- setattr(self.model, "is_parallelizable", True) - setattr(self.model, "model_parallel", True) + self.model.is_parallelizable = True + self.model.model_parallel = TrueNo dynamic attribute name is used, so there is no advantage in
setattr.🧰 Tools
🪛 Ruff (0.11.9)
339-339: Do not call
setattrwith a constant attribute value. It is not any safer than normal property access.Replace
setattrwith assignment(B010)
340-340: Do not call
setattrwith a constant attribute value. It is not any safer than normal property access.Replace
setattrwith assignment(B010)
src/axolotl/integrations/base.py (2)
356-362: Return-type mismatch between annotation and docstring
get_input_argsis annotated to returnlist[str], yet the docstring says it returns “a list of Pydantic classes”. If Pydantic models are intended, update the type hint:- def get_input_args(self) -> list[str]: + from pydantic import BaseModel + def get_input_args(self) -> list[type[BaseModel]]:Keeping annotations and documentation in sync improves IDE support and static analysis accuracy.
348-355: SwallowingImportErrorhides critical setup failures
PluginManager.register()logs the failure but does not propagate it, so the application may proceed without an expected plugin and fail later in obscure ways.Consider re-raising the exception (or returning a boolean) so callers can decide whether a missing plugin is fatal:
except ImportError as exc: logging.error(f"Failed to load plugin: {plugin_name}") - pass + raiseFail-fast behaviour makes configuration errors obvious in CI and production.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting
📒 Files selected for processing (27)
src/axolotl/cli/utils.py(2 hunks)src/axolotl/common/datasets.py(1 hunks)src/axolotl/core/trainer_builder.py(1 hunks)src/axolotl/integrations/base.py(14 hunks)src/axolotl/loaders/__init__.py(1 hunks)src/axolotl/loaders/adapter.py(1 hunks)src/axolotl/loaders/constants.py(1 hunks)src/axolotl/loaders/model.py(1 hunks)src/axolotl/loaders/patch_manager.py(1 hunks)src/axolotl/loaders/processor.py(1 hunks)src/axolotl/loaders/tokenizer.py(1 hunks)src/axolotl/loaders/utils.py(1 hunks)src/axolotl/monkeypatch/peft/utils.py(1 hunks)src/axolotl/train.py(3 hunks)src/axolotl/utils/config/__init__.py(1 hunks)src/axolotl/utils/data/rl.py(1 hunks)src/axolotl/utils/lora_embeddings.py(0 hunks)src/axolotl/utils/models.py(0 hunks)src/axolotl/utils/schemas/config.py(1 hunks)tests/core/test_trainer_builder.py(2 hunks)tests/e2e/patched/test_model_patches.py(3 hunks)tests/e2e/test_load_model.py(3 hunks)tests/patched/test_validation.py(2 hunks)tests/test_exact_deduplication.py(6 hunks)tests/test_loaders.py(3 hunks)tests/test_lora.py(3 hunks)tests/test_tokenizers.py(1 hunks)
💤 Files with no reviewable changes (2)
- src/axolotl/utils/lora_embeddings.py
- src/axolotl/utils/models.py
🧰 Additional context used
🧬 Code Graph Analysis (9)
tests/test_tokenizers.py (1)
src/axolotl/loaders/tokenizer.py (1)
load_tokenizer(120-281)
src/axolotl/utils/data/rl.py (1)
src/axolotl/loaders/tokenizer.py (1)
load_tokenizer(120-281)
src/axolotl/utils/config/__init__.py (1)
src/axolotl/loaders/utils.py (1)
load_model_config(135-183)
tests/test_loaders.py (2)
src/axolotl/loaders/model.py (2)
_set_device_map_config(370-426)_set_quantization_config(428-503)src/axolotl/utils/dict.py (1)
DictDefault(6-38)
src/axolotl/core/trainer_builder.py (1)
src/axolotl/loaders/utils.py (1)
ensure_dtype(186-207)
tests/core/test_trainer_builder.py (1)
src/axolotl/loaders/model.py (2)
ModelLoader(57-781)load(138-166)
tests/e2e/patched/test_model_patches.py (2)
src/axolotl/loaders/model.py (2)
ModelLoader(57-781)load(138-166)src/axolotl/loaders/tokenizer.py (1)
load_tokenizer(120-281)
src/axolotl/loaders/utils.py (4)
src/axolotl/utils/dict.py (1)
DictDefault(6-38)tests/test_exact_deduplication.py (1)
cfg(221-235)src/axolotl/models/mamba/modeling_mamba.py (1)
from_pretrained(122-128)tests/test_perplexity.py (1)
model(20-21)
src/axolotl/loaders/adapter.py (3)
src/axolotl/utils/schemas/peft.py (3)
LoftQConfig(8-14)LoraConfig(23-122)PeftConfig(17-20)src/axolotl/loaders/utils.py (1)
get_linear_embedding_layers(210-216)src/axolotl/utils/dict.py (1)
DictDefault(6-38)
🪛 Ruff (0.11.9)
src/axolotl/loaders/utils.py
193-196: Use contextlib.suppress(AttributeError) instead of try-except-pass
Replace with contextlib.suppress(AttributeError)
(SIM105)
197-200: Use contextlib.suppress(AttributeError) instead of try-except-pass
Replace with contextlib.suppress(AttributeError)
(SIM105)
src/axolotl/loaders/patch_manager.py
128-129: Use a single if statement instead of nested if statements
(SIM102)
src/axolotl/loaders/model.py
206-208: Use self.model_config.model_type != "llava" instead of not self.model_config.model_type == "llava"
Replace with != operator
(SIM201)
339-339: Do not call setattr with a constant attribute value. It is not any safer than normal property access.
Replace setattr with assignment
(B010)
340-340: Do not call setattr with a constant attribute value. It is not any safer than normal property access.
Replace setattr with assignment
(B010)
570-572: Use False instead of ... and False
Replace with False
(SIM223)
593-593: Do not call getattr with a constant attribute value. It is not any safer than normal property access.
Replace getattr with attribute access
(B009)
779-780: Use a single if statement instead of nested if statements
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms (6)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: PyTest (3.11, 2.5.1)
🔇 Additional comments (44)
src/axolotl/core/trainer_builder.py (1)
62-62: Correct import of ensure_dtype from refactoredaxolotl.loaders.utils
Updating the import location forensure_dtypealigns with the new modular loader utilities and preserves existing functionality.src/axolotl/utils/data/rl.py (1)
13-13: Update import for load_tokenizer to new loaders package
This assumes thatload_tokenizeris re-exported at the root ofaxolotl.loaders. Please verify thataxolotl/loaders/__init__.pyincludesload_tokenizer; otherwise, adjust the import tofrom axolotl.loaders.tokenizer import load_tokenizer.src/axolotl/common/datasets.py (1)
13-13: Update import for load_processor and load_tokenizer to new loaders package
Ensure that bothload_processorandload_tokenizerare exposed inaxolotl.loaders/__init__.py; if not, import directly from their modules (e.g.,axolotl.loaders.processor,axolotl.loaders.tokenizer).src/axolotl/utils/config/__init__.py (2)
14-14: Import of MULTIMODAL_AUTO_MODEL_MAPPING: verify correct export location
The constant was relocated toaxolotl.loaders.constants. Confirm it’s re-exported inaxolotl.loaders/__init__.py; otherwise, import it explicitly fromaxolotl.loaders.constants.
15-15: Correct import of load_model_config fromaxolotl.loaders.utils
This aligns with the refactoring of model configuration utilities into the new loaders package.tests/test_tokenizers.py (1)
9-9: Update test import for load_tokenizer to new loaders package
The test now importsload_tokenizerfromaxolotl.loaders. Verify that this function is available at the package root, or update the import tofrom axolotl.loaders.tokenizer import load_tokenizerif needed.tests/test_lora.py (3)
5-5: Import updated to reflect the new module structure.The import statement has been updated to use the new
axolotl.loaderspackage instead ofaxolotl.utils.models, aligning with the refactoring goal of breaking down the largemodels.pyfile into smaller submodules.
49-49: Model loading approach updated to use the new class-based interface.The model loading has been updated from directly calling
load_model(cfg, tokenizer)to instantiating theModelLoaderclass and calling itsload()method. This change is consistent with the refactoring goals and maintains the same functionality.
70-70: Model loading approach consistently updated in both test methods.The second test method also follows the same pattern of using
ModelLoader(cfg, tokenizer).load(), ensuring consistency across the test file.src/axolotl/monkeypatch/peft/utils.py (1)
78-78: Monkeypatch target path updated to reflect the restructured modules.The path for the patched function has been updated from
axolotl.utils.models.prepare_model_for_kbit_trainingtoaxolotl.loaders.model.prepare_model_for_kbit_training, correctly aligning with the new module structure where model loading functionality has been moved to theaxolotl.loaderspackage.tests/core/test_trainer_builder.py (3)
1-1: Module docstring reformatted to a single line.The docstring has been simplified to a single-line format, maintaining the same information while being more concise.
6-6: Import updated to reflect the new module structure.The import statement has been updated to import
ModelLoaderandload_tokenizerfrom the newaxolotl.loaderspackage instead ofaxolotl.utils.models, which aligns with the refactoring goal.
50-50: Model loading approach updated to use the new class-based interface.The
fixture_modelfunction has been updated to use the newModelLoaderclass'sload()method instead of the previous direct call toload_model. This is consistent with the changes in other test files and maintains the same functionality.tests/patched/test_validation.py (2)
12-12: Import path updated to reflect the restructured modules.The import of
check_model_confighas been updated to use the newaxolotl.loaders.utilsmodule instead of the previousaxolotl.utils.models, correctly aligning with the refactoring.
1218-1231: New test case added for validating incompatible configuration options.This new test ensures that the system correctly raises a validation error when both
s2_attentionandsample_packingare set toTrue, which represents an incompatible configuration. The test is well-structured and follows the same pattern as other validation tests in this file.src/axolotl/train.py (3)
31-35: LGTM - Updated imports align with the module refactoring.The imports have been updated to use the new
axolotl.loaderspackage instead ofaxolotl.utils.models, which reflects the code organization refactoring described in the PR objectives.
83-84: Good refactoring to use the new ModelLoader class.The code now uses the
ModelLoaderclass to encapsulate model loading logic instead of directly calling a function. This is a cleaner approach that better follows object-oriented principles.
121-122: LGTM - Consistent usage of ModelLoader for reference model.The refactoring of reference model loading follows the same pattern as the main model loading, maintaining consistency across the codebase.
src/axolotl/cli/utils.py (2)
23-24: LGTM - Updated imports align with module refactoring.The imports have been updated to use the new modular structure, with
load_processorandload_tokenizerfromaxolotl.loadersandModelLoaderfromaxolotl.loaders.model.
322-323: Refactoring simplified the loading pattern.As noted in a previous review comment, this change removes the redundant
load_modelfunction which was essentially a pass-through toModelLoader(...).load(). The direct instantiation ofModelLoaderis cleaner.src/axolotl/utils/schemas/config.py (1)
473-482: Good addition of validation for incompatible features.This validator prevents the simultaneous use of sample packing and shifted-sparse attention, which are incompatible features. Early validation of configuration errors provides better user experience by failing fast with a clear error message.
tests/e2e/patched/test_model_patches.py (3)
9-9: LGTM - Updated imports to use the new module structure.The import statement has been updated to use the refactored module structure, importing
ModelLoaderandload_tokenizerfromaxolotl.loaders.
53-53: LGTM - Test updated to use ModelLoader pattern.The test has been updated to use the new
ModelLoaderclass directly instead of the removedload_modelfunction, consistent with the refactoring throughout the codebase.
86-86: LGTM - Consistent refactoring in tests.The second test method has also been updated to use the
ModelLoaderclass, maintaining consistency throughout the test suite.tests/test_exact_deduplication.py (3)
1-5: Docstring formatting has been improved.The module docstring now uses backticks to properly format module and function names, improving readability and consistency with Python documentation standards.
15-15: Import path updated to reflect the refactored module structure.The import path for
load_processorandload_tokenizerhas been correctly updated to use the newaxolotl.loaderspackage instead of the oldaxolotl.utils.modelsmodule, aligning with the refactoring goal.
249-249: Patch decorators updated to use new module path.The patch paths have been correctly updated to reference
axolotl.loaders.load_tokenizerinstead of the old path, ensuring that mocking works correctly after the refactoring.Also applies to: 273-273
src/axolotl/loaders/__init__.py (1)
1-11: Well-structured package initialization.The
__init__.pyfile is well organized, exposing key components from the submodules at the package level. The use of linter disabling directives is appropriate here to allow for unused imports, which is a common practice for package initializers.This structure aligns perfectly with the refactoring goal of breaking down the large
axolotl.utils.modelsfile into multiple smaller, more manageable submodules.tests/e2e/test_load_model.py (3)
9-9: Import path updated to use the new module structure.The import has been correctly updated to use
ModelLoaderandload_tokenizerfrom the newaxolotl.loaderspackage.
61-63: Added explicit parameters forModelLoaderinstantiation.The instantiation of
ModelLoadernow includes explicit parametersinference=Falseandreference_model=True, which makes the intent clearer and aligns with the new class structure.
76-78: Updated to use theModelLoadermethods.The test now correctly uses the
ModelLoader.load()method and the internal_convert_embedding_modules_dtypemethod according to the new architecture, replacing the previous standalone function approach.tests/test_loaders.py (3)
1-1: Updated docstrings to reflect the new module structure.Both the module and class docstrings have been correctly updated to reference
axolotl.loadersas the module under test.Also applies to: 15-15
10-10: Import path updated to use the new module structure.The import has been correctly updated to use
ModelLoaderdirectly fromaxolotl.loaders.
53-54: Updated method calls to respect protected method visibility.The tests now call the protected methods
_set_device_map_config()and_set_quantization_config()with appropriate pylint disabling comments, indicating that these methods are now intended for internal use within theModelLoaderclass.This change reflects a more principled approach to method visibility in the refactored code.
Also applies to: 80-81
src/axolotl/loaders/constants.py (1)
13-21: Usetyping.Finaland lowercase keys for consistency
- Annotate the mapping with
Finalto communicate immutability.- Convert keys to lowercase once (either here or on lookup) so that downstream look-ups do not need to remember the exact case.
-from typing import Final - -MULTIMODAL_AUTO_MODEL_MAPPING: Final[dict[str, type]] = { - "mllama": MllamaForConditionalGeneration, - ... -}[ suggest_nitpick ]
src/axolotl/loaders/processor.py (1)
32-43: Image-size extraction can silently mis-setcfg.image_sizeWhen only one dimension is present the function stores an
int; when two are
present it stores atuple[int, int]. Consumers now need to handle a
union. Consider always normalising to a tuple for consistency:- if im_width is not None and im_height is not None: - cfg.image_size = (im_width, im_height) - elif im_width is not None: - cfg.image_size = im_width - elif im_height is not None: - cfg.image_size = im_height + if im_width is not None or im_height is not None: + cfg.image_size = (im_width or im_height, im_height or im_width)[ suggest_optional_refactor ]
src/axolotl/loaders/tokenizer.py (3)
87-92: Variable shadowing hides the originaltokenizer_pathargument
tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json")overwrites the
function parameter of the same name, making subsequent debugging harder.
Rename the local variable to avoid confusion:- tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json") + tokenizer_json_path = os.path.join(tokenizer_dir, "tokenizer.json") ... - with open(tokenizer_path, "r", encoding="utf-8") as f: + with open(tokenizer_json_path, "r", encoding="utf-8") as f:[ suggest_nitpick ]
155-165: Preferisinstanceover string comparisons for tokenizer class checksComparing
tokenizer.__class__.__name__is brittle and fails with subclassing.
Thetransformerslibrary exports the actual classes—use them directly:-from ... in ["LlamaTokenizer", "LlamaTokenizerFast", ...] +from transformers import LlamaTokenizer, LlamaTokenizerFast, CodeLlamaTokenizer, CodeLlamaTokenizerFast + +if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast, CodeLlamaTokenizer, CodeLlamaTokenizerFast)): ...[ suggest_optional_refactor ]
188-216:⚠️ Potential issueLoRA/special-token validation raises on perfectly valid scenarios
The current guard enforces that all linear embedding layers are listed in
lora_modules_to_savejust because special tokens change, even when the
user is not doing adapter training (cfg.adapter == False).
Consider restructuring the predicate:- if ( - (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val) - ... - and cfg.adapter + if ( + cfg.adapter # only enforce for adapter training + and (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val) ... ):This prevents false-positive
ValueErrors when training full-fine-tune models.Likely an incorrect or invalid review comment.
src/axolotl/loaders/utils.py (1)
62-75: Early exit hides subsequent validation when using compressed-tensorsReturning immediately for
compressed-tensorsskips the LoRA/token-addition guard (lines 117-132). Users adding new tokens to a compressed-tensors model will not receive the helpful validation error. Consider moving the LoRA check above the return or explicitly re-running it.src/axolotl/loaders/patch_manager.py (1)
92-99: Side-effect: silently flipsflash_attentionflag
_apply_flash_attention_patchessetsself.cfg.flash_attention = Truewhenxformers_attention & sample_packingare enabled. This mutates user-supplied config, potentially leading to surprising behaviour later in the run (e.g., different kernels selected).Prefer deriving a local variable or logging a warning so the user knows their config was overridden.
src/axolotl/loaders/adapter.py (2)
89-103:use_dora/use_rslorakeys may not exist in upstream PEFT
LoraConfigfrom PEFT < 0.9 doesn’t acceptuse_doraoruse_rslora. Passing them via**lora_config_kwargsraisesTypeError.
Guard by checkinginspect.signature(LoraConfig).parametersbefore injecting.
124-131:⚠️ Potential issue
setup_quantized_meta_for_peftrelies onLOCAL_RANKenvIf
LOCAL_RANKis unset (common in single-GPU runs),int(None)raisesTypeError. Useos.getenv(..., "0")default.- rank = int(os.environ.get("LOCAL_RANK", 0)) + rank = int(os.environ.get("LOCAL_RANK", "0"))Likely an incorrect or invalid review comment.
src/axolotl/loaders/model.py (1)
656-660: Hard-coding CUDA device breaks non-CUDA environmentsself.model_kwargs["device"] = torch.cuda.current_device()If the process runs on CPU, MPS or NPU,
torch.cudamay be unavailable orcuda.is_available()may beFalse, leading to an exception. Consider deriving the device fromget_device_type()or guarding the assignment:device_type = get_device_type() if device_type == "cuda": self.model_kwargs["device"] = torch.cuda.current_device()This keeps the path functional on all accelerators.
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (1)
src/axolotl/loaders/model.py (1)
151-157: Consider deprecatingpost_model_buildin favor ofpre_lora_load.As noted in previous comments, using two separate hooks (
post_model_buildandpre_lora_load) that execute at nearly the same point in the lifecycle creates potential confusion.- PLUGIN_MANAGER.post_model_build(self.cfg, self.model) - - # Post-build model configuration - self._apply_post_model_load_setup() - - # Load adapters (LoRA, etc.) - PLUGIN_MANAGER.pre_lora_load(self.cfg, self.model) + # Post-build model configuration before loading adapters + self._apply_post_model_load_setup() + + # Load adapters (LoRA, etc.) + PLUGIN_MANAGER.pre_lora_load(self.cfg, self.model) + # For backward compatibility + PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
🧹 Nitpick comments (8)
src/axolotl/loaders/model.py (8)
206-209: Simplify boolean expression.Use
!=instead ofnot ... ==for better readability.- if self.cfg.mean_resizing_embeddings is not None and not ( - self.model_config.model_type == "llava" - ): + if self.cfg.mean_resizing_embeddings is not None and self.model_config.model_type != "llava":🧰 Tools
🪛 Ruff (0.11.9)
206-208: Use
self.model_config.model_type != "llava"instead ofnot self.model_config.model_type == "llava"Replace with
!=operator(SIM201)
335-336: Remove or clarify TODO comment.The TODO comment doesn't provide sufficient context about what needs validation.
- # TODO: validate this conditional self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}")Consider either removing the comment if the conditional is already validated, or add more specific details about what aspect needs validation.
339-340: Use direct attribute assignment instead ofsetattr.
setattrwith constant attribute names doesn't provide any advantages over direct attribute assignment and makes the code less readable.- setattr(self.model, "is_parallelizable", True) - setattr(self.model, "model_parallel", True) + self.model.is_parallelizable = True + self.model.model_parallel = True🧰 Tools
🪛 Ruff (0.11.9)
339-339: Do not call
setattrwith a constant attribute value. It is not any safer than normal property access.Replace
setattrwith assignment(B010)
340-340: Do not call
setattrwith a constant attribute value. It is not any safer than normal property access.Replace
setattrwith assignment(B010)
419-427: Dead code: consider removing commented-out code block.This commented-out code block related to reference model GPU assignment serves no functional purpose and could confuse maintainers. Either remove it entirely or add a clear comment explaining why it's preserved.
566-703: Complex model building method could benefit from further modularization.The
_build_modelmethod is quite lengthy (137 lines) with multiple conditional branches for different model types and loading strategies. Consider splitting it into smaller, more focused methods to improve maintainability.For example:
_build_sharded_quantized_model_build_llama_model_build_mamba_model_build_generic_model🧰 Tools
🪛 Ruff (0.11.9)
580-580: Do not call
getattrwith a constant attribute value. It is not any safer than normal property access.Replace
getattrwith attribute access(B009)
580-583: Use direct attribute access instead ofgetattr.
getattrwith constant attribute names doesn't provide any advantages over direct attribute access and makes the code less readable.- quantization_config = hasattr( - self.model_config, "quantization_config" - ) and getattr(self.model_config, "quantization_config") + quantization_config = hasattr( + self.model_config, "quantization_config" + ) and self.model_config.quantization_config🧰 Tools
🪛 Ruff (0.11.9)
580-580: Do not call
getattrwith a constant attribute value. It is not any safer than normal property access.Replace
getattrwith attribute access(B009)
643-643: Improve variable naming to avoid pylint suppression.The variable
MambaLMHeadModeluses camel case, which violates Python naming conventions and requires a pylint suppression. Consider using snake_case instead:- MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name + mamba_lm_head_model = fix_mamba_attn_for_loss()And update the usage below:
- self.model = MambaLMHeadModel.from_pretrained( + self.model = mamba_lm_head_model.from_pretrained(
766-769: Simplify nested if statements.The nested if statements can be combined into a single condition for better readability.
- if any(m in name for m in embedding_modules): - if hasattr(module, "weight"): - module.to(dist_dtype) + if any(m in name for m in embedding_modules) and hasattr(module, "weight"): + module.to(dist_dtype)🧰 Tools
🪛 Ruff (0.11.9)
766-767: Use a single
ifstatement instead of nestedifstatements(SIM102)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting
📒 Files selected for processing (1)
src/axolotl/loaders/model.py(1 hunks)
🧰 Additional context used
🪛 Ruff (0.11.9)
src/axolotl/loaders/model.py
206-208: Use self.model_config.model_type != "llava" instead of not self.model_config.model_type == "llava"
Replace with != operator
(SIM201)
339-339: Do not call setattr with a constant attribute value. It is not any safer than normal property access.
Replace setattr with assignment
(B010)
340-340: Do not call setattr with a constant attribute value. It is not any safer than normal property access.
Replace setattr with assignment
(B010)
580-580: Do not call getattr with a constant attribute value. It is not any safer than normal property access.
Replace getattr with attribute access
(B009)
766-767: Use a single if statement instead of nested if statements
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms (7)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: pre-commit
🔇 Additional comments (1)
src/axolotl/loaders/model.py (1)
1-83: Well-organized model loader module with comprehensive documentation.The ModelLoader class has excellent docstrings that clearly explain its purpose, lifecycle, and capabilities. This is especially important for a complex component that handles multiple model types, quantization schemes, and adapter integrations.
0b8440d to
fe504b7
Compare
There was a problem hiding this comment.
Actionable comments posted: 5
♻️ Duplicate comments (2)
src/axolotl/loaders/utils.py (1)
90-97:⚠️ Potential issue
quantization_configmay not be a dict – AttributeError risk remains
model_config.quantization_config.get("quant_method")assumes the field is a mapping.
HF frequently stores this as aPretrainedConfig, dataclass, or custom object, and several users ran intoAttributeError: 'PretrainedConfig' object has no attribute 'get'in earlier commits (see previous review). The same runtime failure is still possible here.- is_compressed_tensors_config = ( - quant_config_exists - and model_config.quantization_config.get("quant_method") == "compressed-tensors" - ) + qc = model_config.quantization_config + if quant_config_exists and not isinstance(qc, dict): + qc = qc.to_dict() if hasattr(qc, "to_dict") else vars(qc) + + is_compressed_tensors_config = ( + quant_config_exists and qc.get("quant_method") == "compressed-tensors" + )Same normalisation should be applied a few lines below for the GPTQ check.
src/axolotl/loaders/model.py (1)
151-157: Consider deprecatingpost_model_buildin favor ofpre_lora_load.As noted in a previous comment, it might be worth deprecating
post_model_buildand using onlypre_lora_load, but this would need to be evaluated against potential impacts on third-party plugins.
🧹 Nitpick comments (6)
src/axolotl/loaders/utils.py (1)
193-200: Usecontextlib.suppressfor cleaner, faster attribute checksThe twin
try/except AttributeError: passblocks slightly inflate byte-code and hide other potential errors. Ruff already flagged this (SIM105). A concise alternative:-import torch +import torch, contextlib ... - try: - weight_mismatch = module.weight.dtype != dtype - except AttributeError: - pass - try: - bias_mismatch = module.bias.dtype != dtype - except AttributeError: - pass + with contextlib.suppress(AttributeError): + weight_mismatch = module.weight.dtype != dtype + with contextlib.suppress(AttributeError): + bias_mismatch = module.bias.dtype != dtypeBehaviour is identical, code is clearer.
🧰 Tools
🪛 Ruff (0.11.9)
193-196: Use
contextlib.suppress(AttributeError)instead oftry-except-passReplace with
contextlib.suppress(AttributeError)(SIM105)
197-200: Use
contextlib.suppress(AttributeError)instead oftry-except-passReplace with
contextlib.suppress(AttributeError)(SIM105)
src/axolotl/loaders/patch_manager.py (1)
128-136: Nestedifcan be flattened for clarity (SIM102)Combining the
llama4model-type branch makes intent clearer and removes one indentation level:- if self.cfg.model_config_type == "llama4": - if self.cfg.llama4_linearized_experts: + if ( + self.cfg.model_config_type == "llama4" + and self.cfg.llama4_linearized_experts + ):Pure readability—behaviour unchanged.
🧰 Tools
🪛 Ruff (0.11.9)
128-129: Use a single
ifstatement instead of nestedifstatements(SIM102)
src/axolotl/loaders/model.py (4)
214-241: Consider refactoring repeated attribute checks.The
_adjust_model_configmethod has multiple similar conditional blocks checking for attributes. This could be refactored into a helper function to reduce repetition.- if ( - hasattr(self.model, "config") - and hasattr(self.model.config, "max_position_embeddings") - and self.model.config.max_position_embeddings - and self.cfg.sequence_len > self.model.config.max_position_embeddings - ): - LOG.warning( - f"increasing model.config.max_position_embeddings from {self.model.config.max_position_embeddings} to {self.cfg.sequence_len}" - ) - self.model.config.max_position_embeddings = self.cfg.sequence_len - - if ( - hasattr(self.model, "config") - and hasattr(self.model.config, "bos_token_id") - and self.model.config.bos_token_id - and self.model.config.bos_token_id != self.tokenizer.bos_token_id - ): - self.model.config.bos_token_id = self.tokenizer.bos_token_id - - if ( - hasattr(self.model, "config") - and hasattr(self.model.config, "eos_token_id") - and self.model.config.eos_token_id - and self.model.config.eos_token_id != self.tokenizer.eos_token_id - ): - self.model.config.eos_token_id = self.tokenizer.eos_token_id + def update_config_if_exists(attr_name, new_value, condition=None, log_msg=None): + if ( + hasattr(self.model, "config") + and hasattr(self.model.config, attr_name) + and getattr(self.model.config, attr_name) + and (condition is None or condition(getattr(self.model.config, attr_name))) + ): + if log_msg: + LOG.warning(log_msg) + setattr(self.model.config, attr_name, new_value) + + update_config_if_exists( + "max_position_embeddings", + self.cfg.sequence_len, + lambda x: self.cfg.sequence_len > x, + f"increasing model.config.max_position_embeddings from {self.model.config.max_position_embeddings} to {self.cfg.sequence_len}" + ) + + update_config_if_exists( + "bos_token_id", + self.tokenizer.bos_token_id, + lambda x: x != self.tokenizer.bos_token_id + ) + + update_config_if_exists( + "eos_token_id", + self.tokenizer.eos_token_id, + lambda x: x != self.tokenizer.eos_token_id + )
419-427: Consider implementing or removing the commented code.There's commented-out code about placing reference models on separate GPUs. Either implement this feature (after careful testing) or remove the comments if they're no longer relevant.
505-506: Clarify or remove the comment about sample packing.The comment "Sample packing uses custom FA2 patch" seems out of place or incomplete, as it doesn't directly relate to what the method does. Consider clarifying this comment or removing it.
566-703: Consider refactoring the complex model building method.The
_build_modelmethod is quite complex with multiple nested conditionals for different model types and configurations. Consider breaking it down into smaller, more focused methods for each model type or configuration pattern.For example:
def _build_model(self) -> bool: """Load model, with load strategy depending on config.""" skip_move_to_device = False # Dispatch to specialized builders based on model type/config if self._should_use_sharded_qlora_loading(): skip_move_to_device = self._build_sharded_qlora_model() elif self._is_llama_model_without_special_requirements(): skip_move_to_device = self._build_llama_model() elif self._is_mamba_model(): skip_move_to_device = self._build_mamba_model() elif self._is_custom_model_type(): skip_move_to_device = self._build_custom_model_type() else: skip_move_to_device = self._build_standard_model() if is_deepspeed_zero3_enabled(): skip_move_to_device = True return skip_move_to_device
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge Base: Disabled due to data retention organization setting
📒 Files selected for processing (27)
src/axolotl/cli/utils.py(2 hunks)src/axolotl/common/datasets.py(1 hunks)src/axolotl/core/trainer_builder.py(1 hunks)src/axolotl/integrations/base.py(14 hunks)src/axolotl/loaders/__init__.py(1 hunks)src/axolotl/loaders/adapter.py(1 hunks)src/axolotl/loaders/constants.py(1 hunks)src/axolotl/loaders/model.py(1 hunks)src/axolotl/loaders/patch_manager.py(1 hunks)src/axolotl/loaders/processor.py(1 hunks)src/axolotl/loaders/tokenizer.py(1 hunks)src/axolotl/loaders/utils.py(1 hunks)src/axolotl/monkeypatch/peft/utils.py(1 hunks)src/axolotl/train.py(3 hunks)src/axolotl/utils/config/__init__.py(1 hunks)src/axolotl/utils/data/rl.py(1 hunks)src/axolotl/utils/lora_embeddings.py(0 hunks)src/axolotl/utils/models.py(0 hunks)src/axolotl/utils/schemas/config.py(1 hunks)tests/core/test_trainer_builder.py(2 hunks)tests/e2e/patched/test_model_patches.py(3 hunks)tests/e2e/test_load_model.py(3 hunks)tests/patched/test_validation.py(2 hunks)tests/test_exact_deduplication.py(6 hunks)tests/test_loaders.py(3 hunks)tests/test_lora.py(3 hunks)tests/test_tokenizers.py(1 hunks)
💤 Files with no reviewable changes (2)
- src/axolotl/utils/lora_embeddings.py
- src/axolotl/utils/models.py
✅ Files skipped from review due to trivial changes (2)
- src/axolotl/utils/config/init.py
- src/axolotl/common/datasets.py
🚧 Files skipped from review as they are similar to previous changes (19)
- src/axolotl/utils/data/rl.py
- src/axolotl/core/trainer_builder.py
- tests/core/test_trainer_builder.py
- tests/test_lora.py
- tests/test_tokenizers.py
- src/axolotl/loaders/constants.py
- tests/e2e/patched/test_model_patches.py
- src/axolotl/cli/utils.py
- tests/test_loaders.py
- tests/test_exact_deduplication.py
- src/axolotl/monkeypatch/peft/utils.py
- tests/e2e/test_load_model.py
- src/axolotl/utils/schemas/config.py
- tests/patched/test_validation.py
- src/axolotl/loaders/init.py
- src/axolotl/train.py
- src/axolotl/loaders/processor.py
- src/axolotl/loaders/tokenizer.py
- src/axolotl/integrations/base.py
🧰 Additional context used
🧬 Code Graph Analysis (3)
src/axolotl/loaders/model.py (9)
src/axolotl/integrations/base.py (13)
PluginManager(290-576)cfg(331-332)cfg(335-336)pre_model_load(92-98)pre_model_load(393-401)post_model_build(100-108)post_model_build(403-413)pre_lora_load(121-130)pre_lora_load(427-436)post_lora_load(132-141)post_lora_load(438-447)post_model_load(110-119)post_model_load(415-425)src/axolotl/loaders/adapter.py (2)
load_adapter(164-181)load_lora(71-161)src/axolotl/loaders/patch_manager.py (4)
PatchManager(27-353)has_flash_attn(45-47)apply_pre_model_load_patches(49-64)apply_post_model_load_patches(66-90)src/axolotl/loaders/utils.py (3)
get_linear_embedding_layers(210-216)get_module_class_from_name(16-41)load_model_config(135-183)src/axolotl/models/mamba/__init__.py (1)
fix_mamba_attn_for_loss(16-24)src/axolotl/utils/bench.py (1)
log_gpu_memory_usage(94-111)src/axolotl/utils/distributed.py (2)
get_device_count(33-39)get_device_type(22-30)src/axolotl/utils/schemas/enums.py (1)
RLType(6-14)src/axolotl/models/mamba/modeling_mamba.py (3)
tie_weights(61-62)from_pretrained(122-128)MambaLMHeadModel(17-128)
src/axolotl/loaders/adapter.py (3)
src/axolotl/utils/schemas/peft.py (3)
LoftQConfig(8-14)LoraConfig(23-122)PeftConfig(17-20)src/axolotl/loaders/utils.py (1)
get_linear_embedding_layers(210-216)src/axolotl/utils/dict.py (1)
DictDefault(6-38)
src/axolotl/loaders/utils.py (4)
src/axolotl/utils/dict.py (1)
DictDefault(6-38)tests/test_exact_deduplication.py (1)
cfg(221-235)src/axolotl/models/mamba/modeling_mamba.py (1)
from_pretrained(122-128)tests/test_perplexity.py (1)
model(20-21)
🪛 Ruff (0.11.9)
src/axolotl/loaders/patch_manager.py
128-129: Use a single if statement instead of nested if statements
(SIM102)
src/axolotl/loaders/utils.py
193-196: Use contextlib.suppress(AttributeError) instead of try-except-pass
Replace with contextlib.suppress(AttributeError)
(SIM105)
197-200: Use contextlib.suppress(AttributeError) instead of try-except-pass
Replace with contextlib.suppress(AttributeError)
(SIM105)
⏰ Context from checks skipped due to timeout of 90000ms (6)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
🔇 Additional comments (4)
src/axolotl/loaders/patch_manager.py (1)
295-301:Details
❌ Incorrect review comment
self.cfg.devicemay be missing – guard before access
_patch_llama_flash_attentionaccessesself.cfg.device, but that field is not guaranteed by the schema shown in other modules. If absent, anAttributeErroris raised during model-load.- if self.cfg.device not in ["mps", "cpu"] and not self.inference: + if getattr(self.cfg, "device", None) not in ["mps", "cpu"] and not self.inference:Please verify that all config paths indeed define
device; otherwise add a default ("cuda"/"cpu").
🏁 Script executed:
#!/bin/bash # Find definitions of config classes to inspect schema for 'device' rg -n "class .*Config" -n src/axolotl # Search for occurrences of 'device' in configuration-related files rg -n "device" -n src/axolotlLength of output: 43317
No missing
deviceattribute onself.cfgThe
devicefield is declared inAxolotlInputConfig(see utils/schemas/config.py line 144) with a default ofNone, so accessingself.cfg.devicecannot raiseAttributeError. If you need to treatNonedifferently (e.g. default to"cuda"or skip patching when unset), add an explicitNonecheck rather than guarding against a missing attribute.Likely an incorrect or invalid review comment.
src/axolotl/loaders/model.py (3)
122-123: Good job with default initialization.Setting
self.auto_model_loader = AutoModelForCausalLMas the default is a sensible choice, as it's later overridden for multimodal models in_set_auto_model_loader.
138-167: Well-structured orchestration method.The
load()method nicely orchestrates the entire loading process with clear stages, properly integrating with the plugin system at each stage. The method signature correctly indicates that the returned LoRA config can beNone.
647-648: Good job fixing the potential KeyError issue.The previous implementation used
delwhich could raise KeyError if the keys didn't exist. The current implementation uses.pop(key, None)which safely handles missing keys.
75a1284 to
a2c5720
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (6)
src/axolotl/loaders/tokenizer.py (1)
48-51: 🛠️ Refactor suggestionGuard against
cfg.output_dirbeingNonebeforeos.path.join
output_diroriginates from the run-time configuration.
If the user omits--output_dir,os.path.join(None, "tokenizer")
will raise aTypeError, breaking training/inference when
added_tokens_overridesis set.- tokenizer_dir = os.path.join(output_dir, "tokenizer") + if output_dir is None: + import tempfile + output_dir = tempfile.mkdtemp(prefix="axolotl_tokenizer_") + LOG.warning( + "`cfg.output_dir` is None – writing modified tokenizer to %s", output_dir + ) + tokenizer_dir = os.path.join(output_dir, "tokenizer")src/axolotl/loaders/utils.py (1)
90-97:quantization_configmay not be a plain dict – avoid.get()directly
Duplicate of a previously-raised issue.
model_config.quantization_configcan be a nestedPretrainedConfig
or a dataclass. Calling.get()will raiseAttributeError.Refactor to normalise the object first:
- is_compressed_tensors_config = ( - quant_config_exists - and model_config.quantization_config.get("quant_method") == "compressed-tensors" - ) + qc = model_config.quantization_config + if quant_config_exists and not isinstance(qc, dict): + qc = qc.to_dict() if hasattr(qc, "to_dict") else vars(qc) + + is_compressed_tensors_config = ( + quant_config_exists and qc.get("quant_method") == "compressed-tensors" + )src/axolotl/loaders/adapter.py (2)
51-68: Prefer full module paths and explicit output-head filtering infind_all_linear_namesThe current implementation has two issues:
- It keeps only the last token of the module name which might lead to missing intended targets
- It blindly removes
output_embeddingwhich might hide an intentional request to wrap that headThe solution is to return the full dotted path rather than only the last token and let the caller decide how to shorten it. Also, consider making the output embedding filtering optional.
-def find_all_linear_names(model): +def find_all_linear_names( + model, + filter_output: bool = True, + warn_on_drop: bool = False, +) -> list[str]: cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear) lora_module_names = set() for name, module in model.named_modules(): if ( isinstance(module, cls) or "Linear" in module.__class__.__name__ and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",) ): - names = name.split(".") - lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + # keep full dotted path + lora_module_names.add(name) embedding_modules = get_linear_embedding_layers(model.config.model_type) output_embedding = embedding_modules[1] - if output_embedding in lora_module_names: # needed for 16-bit + # Optionally filter out the model's output-head + if output_embedding in lora_module_names and filter_output: + if warn_on_drop: + LOG.warning(f"Dropping output-embedding layer '{output_embedding}' from LoRA targets") lora_module_names.remove(output_embedding) return list(lora_module_names)
104-116: Ensurerandlora_alphahave valid defaultsPEFT's
LoraConfigrequires positive integers forrandlora_alpha; passingNonewill trigger aValueError. Guard against unset YAML values by providing sensible defaults or omitting the arguments when they'reNone.lora_config = LoraConfig( - r=cfg.lora_r, - lora_alpha=cfg.lora_alpha, + r=cfg.lora_r or 8, # fallback to 8 if unset + lora_alpha=cfg.lora_alpha or 16, # fallback to 16 if unset target_modules=lora_target_modules, layers_to_transform=cfg.peft_layers_to_transform, layers_pattern=cfg.peft_layers_pattern, lora_dropout=cfg.lora_dropout, fan_in_fan_out=cfg.lora_fan_in_fan_out, modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, bias="none", task_type="CAUSAL_LM", **lora_config_kwargs, )src/axolotl/loaders/model.py (2)
340-341: Resolve device-placement conditional and remove TODOThe current code unconditionally appends
:{self.cfg.local_rank}to the device string—even when not in DDP or on backends that don't support indexing (e.g. MPS).Replace the TODO with a clear conditional that properly handles different device types:
- # TODO: validate this conditional - self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}") + device = get_device_type() + # only append ":local_rank" for CUDA multi-GPU setups + if getattr(self.cfg, "world_size", 1) > 1 and device.type == "cuda": + target = f"{device}:{self.cfg.local_rank}" + else: + target = str(device) + self.model.to(target)
631-643: Improve the Mamba model loading implementationThe code contains a FIXME comment acknowledging that the Mamba model loading is "janky at best and hacked together".
This implementation uses a workaround with
fix_mamba_attn_for_loss()and has special handling for model kwargs, which could lead to unexpected behavior if the underlying model implementation changes. Consider refactoring this in a future PR for better maintainability and reliability.
🧹 Nitpick comments (6)
src/axolotl/loaders/tokenizer.py (1)
87-88: Variable shadowing reduces readabilityRe-using the name
tokenizer_pathfor the JSON file path
(tokenizer.json) shadows the originaltokenizer_pathargument.
This is harmless but confusing when reading or debugging.- tokenizer_path = os.path.join(tokenizer_dir, "tokenizer.json") + tokenizer_json_path = os.path.join(tokenizer_dir, "tokenizer.json")src/axolotl/loaders/utils.py (2)
186-205: Usecontextlib.suppressfor cleaner attribute-checksThe nested
try/except AttributeErrorblocks can be simplified:+import contextlib ... - try: - weight_mismatch = module.weight.dtype != dtype - except AttributeError: - pass - try: - bias_mismatch = module.bias.dtype != dtype - except AttributeError: - pass + with contextlib.suppress(AttributeError): + weight_mismatch = module.weight.dtype != dtype + with contextlib.suppress(AttributeError): + bias_mismatch = module.bias.dtype != dtypeThe intent remains clear while reducing boilerplate.
🧰 Tools
🪛 Ruff (0.11.9)
190-193: Use
contextlib.suppress(AttributeError)instead oftry-except-passReplace with
contextlib.suppress(AttributeError)(SIM105)
194-197: Use
contextlib.suppress(AttributeError)instead oftry-except-passReplace with
contextlib.suppress(AttributeError)(SIM105)
161-175: Catching onlyValueErrormisses common HF config errors
AutoConfig.from_pretrainedalso raisesOSError(e.g., missing files)
andEnvironmentError. Consider broadening the exception list so
special-case handling for Mamba isn’t skipped on those errors.- except ValueError as err: + except (ValueError, OSError) as err:src/axolotl/loaders/patch_manager.py (2)
112-118: Simplify nested if statementsThe nested if statements can be combined for better readability and to address the static analysis warning.
- if self.cfg.model_config_type == "llama4": - if self.cfg.llama4_linearized_experts: - from axolotl.monkeypatch.models.llama4.modeling import ( - patch_llama4_linearized_modeling, - ) - - patch_llama4_linearized_modeling() + if self.cfg.model_config_type == "llama4" and self.cfg.llama4_linearized_experts: + from axolotl.monkeypatch.models.llama4.modeling import ( + patch_llama4_linearized_modeling, + ) + + patch_llama4_linearized_modeling()🧰 Tools
🪛 Ruff (0.11.9)
112-113: Use a single
ifstatement instead of nestedifstatements(SIM102)
351-371: Simplify nested if statementsSimilar to the previous suggestion, consider combining these nested if statements to address the static analysis warning.
- if ( - self.model_config.model_type in ["llama", "llama4"] - and not self.cfg.trust_remote_code - and not self.cfg.gptq - ): - # TODO(MengqingCao): split these patches seperately - if self.cfg.flash_attention and not self.inference: + if ( + self.model_config.model_type in ["llama", "llama4"] + and not self.cfg.trust_remote_code + and not self.cfg.gptq + and self.cfg.flash_attention + and not self.inference + ): + # TODO(MengqingCao): split these patches seperately🧰 Tools
🪛 Ruff (0.11.9)
351-357: Use a single
ifstatement instead of nestedifstatements(SIM102)
src/axolotl/loaders/model.py (1)
658-662: Use safer key removal for model kwargsThe current implementation unconditionally deletes keys from
self.model_kwargs, which could raiseKeyErrorif the keys don't exist.- del self.model_kwargs["torch_dtype"] - del self.model_kwargs["device_map"] + self.model_kwargs.pop("torch_dtype", None) + self.model_kwargs.pop("device_map", None)Using
popwith a default value ofNonemakes the code more robust by avoidingKeyErrorif the key is already removed.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (29)
src/axolotl/cli/utils.py(2 hunks)src/axolotl/common/datasets.py(1 hunks)src/axolotl/core/trainer_builder.py(1 hunks)src/axolotl/core/trainers/grpo/trainer.py(1 hunks)src/axolotl/integrations/base.py(14 hunks)src/axolotl/loaders/__init__.py(1 hunks)src/axolotl/loaders/adapter.py(1 hunks)src/axolotl/loaders/constants.py(1 hunks)src/axolotl/loaders/model.py(1 hunks)src/axolotl/loaders/patch_manager.py(1 hunks)src/axolotl/loaders/processor.py(1 hunks)src/axolotl/loaders/tokenizer.py(1 hunks)src/axolotl/loaders/utils.py(1 hunks)src/axolotl/monkeypatch/peft/utils.py(1 hunks)src/axolotl/train.py(3 hunks)src/axolotl/utils/config/__init__.py(1 hunks)src/axolotl/utils/ctx_managers/sequence_parallel.py(1 hunks)src/axolotl/utils/data/rl.py(1 hunks)src/axolotl/utils/lora_embeddings.py(0 hunks)src/axolotl/utils/models.py(0 hunks)src/axolotl/utils/schemas/config.py(1 hunks)tests/core/test_trainer_builder.py(2 hunks)tests/e2e/patched/test_model_patches.py(3 hunks)tests/e2e/test_load_model.py(3 hunks)tests/patched/test_validation.py(2 hunks)tests/test_exact_deduplication.py(6 hunks)tests/test_loaders.py(3 hunks)tests/test_lora.py(3 hunks)tests/test_tokenizers.py(1 hunks)
💤 Files with no reviewable changes (2)
- src/axolotl/utils/lora_embeddings.py
- src/axolotl/utils/models.py
✅ Files skipped from review due to trivial changes (4)
- src/axolotl/utils/data/rl.py
- tests/test_tokenizers.py
- src/axolotl/utils/ctx_managers/sequence_parallel.py
- tests/test_exact_deduplication.py
🚧 Files skipped from review as they are similar to previous changes (17)
- src/axolotl/common/datasets.py
- src/axolotl/utils/config/init.py
- src/axolotl/core/trainer_builder.py
- tests/core/test_trainer_builder.py
- tests/e2e/patched/test_model_patches.py
- src/axolotl/train.py
- tests/test_lora.py
- tests/patched/test_validation.py
- src/axolotl/cli/utils.py
- src/axolotl/loaders/constants.py
- src/axolotl/monkeypatch/peft/utils.py
- tests/e2e/test_load_model.py
- tests/test_loaders.py
- src/axolotl/loaders/init.py
- src/axolotl/loaders/processor.py
- src/axolotl/utils/schemas/config.py
- src/axolotl/integrations/base.py
🧰 Additional context used
🧬 Code Graph Analysis (2)
src/axolotl/loaders/adapter.py (3)
src/axolotl/utils/schemas/peft.py (3)
LoftQConfig(8-14)LoraConfig(23-122)PeftConfig(17-20)src/axolotl/loaders/utils.py (1)
get_linear_embedding_layers(207-213)src/axolotl/utils/dict.py (1)
DictDefault(6-38)
src/axolotl/loaders/utils.py (4)
src/axolotl/utils/dict.py (1)
DictDefault(6-38)tests/test_exact_deduplication.py (1)
cfg(221-235)src/axolotl/models/mamba/modeling_mamba.py (1)
from_pretrained(122-128)tests/test_perplexity.py (1)
model(20-21)
🪛 Ruff (0.11.9)
src/axolotl/loaders/patch_manager.py
112-113: Use a single if statement instead of nested if statements
(SIM102)
351-357: Use a single if statement instead of nested if statements
(SIM102)
src/axolotl/loaders/utils.py
190-193: Use contextlib.suppress(AttributeError) instead of try-except-pass
Replace with contextlib.suppress(AttributeError)
(SIM105)
194-197: Use contextlib.suppress(AttributeError) instead of try-except-pass
Replace with contextlib.suppress(AttributeError)
(SIM105)
⏰ Context from checks skipped due to timeout of 90000ms (6)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: PyTest (3.11, 2.5.1)
🔇 Additional comments (10)
src/axolotl/core/trainers/grpo/trainer.py (1)
46-46: Confirm ring-attention import path exists in refactored packageThe
get_ring_attn_groupsymbol has moved fromaxolotl.monkeypatch.ring_attn.patch
toaxolotl.monkeypatch.ring_attn. Please double-check that
src/axolotl/monkeypatch/ring_attn/__init__.pyre-exportsget_ring_attn_group
(or defines it directly), and- All other call-sites have been updated in the PR.
A quick
rg get_ring_attn_groupscan in the repo branch will reveal any stale
imports.src/axolotl/loaders/patch_manager.py (3)
356-356: Address the TODO comment in a future PRThe TODO comment indicates that these patches should be split separately. Consider creating a follow-up task to refactor this method for better organization and maintainability.
This would involve extracting the Flash Attention-specific patches into their own dedicated methods in the PatchManager class to improve code organization. Has this been captured as a future task?
345-347: Good error handling for unsupported configurationThe implementation correctly raises a
NotImplementedErrorwhen shifted-sparse attention is requested without flash attention, making it clear to users that this combination is not supported.
206-216: Well-handled auto_map configurationThe implementation now correctly handles
auto_mapaccess for both dictionary-like and object configs, addressing the previous review comment. This robust implementation prevents potentialAttributeErrorissues.src/axolotl/loaders/adapter.py (2)
146-152: Robust error handling when printing trainable parametersThe code properly catches and logs exceptions that might occur during
model.print_trainable_parameters(), ensuring the process continues even if this non-critical operation fails.
123-142: Good support for CPU-based LoRA loadingThe implementation supports loading LoRA weights onto CPU with proper memory settings when
cfg.lora_on_cpuis set, which is important for large models that don't fit in GPU memory.src/axolotl/loaders/model.py (4)
154-156: Consider deprecatingpost_model_buildin favor ofpre_lora_loadAccording to a past review comment by the developer,
post_model_buildcould potentially be deprecated in favor ofpre_lora_load.Before making this change, consider the implications for any third-party plugins that might rely on the current hook structure. What is the plan for deprecating
post_model_build?
167-168: Well-structured plugin hooksThe code provides well-defined plugin hooks at different stages of the model loading process (
pre_model_load,post_model_build,pre_lora_load,post_lora_load,post_model_load), allowing for extensible customization.
25-52: Comprehensive imports and organizationThe file has well-organized imports, grouping standard library, third-party, and internal imports appropriately. This makes the code more readable and maintainable.
126-130: Good integration with PatchManagerThe ModelLoader correctly initializes and uses the PatchManager for applying patches, which centralizes patch management and improves code organization.
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (4)
src/axolotl/loaders/utils.py (2)
90-94:⚠️ Potential issueFix potential AttributeError with non-dict quantization_config
The code assumes that
model_config.quantization_configis a dict-like object with a.get()method, but it could be a nestedPretrainedConfigor a custom dataclass which would cause anAttributeError.-is_compressed_tensors_config = ( - quant_config_exists - and model_config.quantization_config.get("quant_method") == "compressed-tensors" -) +qc = model_config.quantization_config +if quant_config_exists and not isinstance(qc, dict): + qc = qc.to_dict() if hasattr(qc, "to_dict") else vars(qc) + +is_compressed_tensors_config = ( + quant_config_exists and qc.get("quant_method") == "compressed-tensors" +)
105-109:⚠️ Potential issueSame issue with quant_config type checking
Similar to the previous issue, this code also assumes that
model_config.quantization_configis a dict-like object. The same fix should be applied here to avoidAttributeError.-quant_config_method_is_gptq = ( - quant_config_exists - and "quant_method" in model_config.quantization_config - and model_config.quantization_config["quant_method"] == "gptq" -) +quant_config_method_is_gptq = False +if quant_config_exists: + qc = model_config.quantization_config + if not isinstance(qc, dict): + qc = qc.to_dict() if hasattr(qc, "to_dict") else vars(qc) + + quant_config_method_is_gptq = ( + "quant_method" in qc and qc["quant_method"] == "gptq" + )src/axolotl/loaders/patch_manager.py (2)
298-305:⚠️ Potential issueMissing flash attention check for s2_attention
The
s2_attentionpath doesn't check if flash attention is available viaself.has_flash_attnbefore applying flash attention patches, unlike other code paths. This could lead to runtime errors if flash attention is not installed.elif self.cfg.s2_attention: + if not self.has_flash_attn: + raise ValueError("Shifted-sparse attention requires flash attention to be installed") LOG.info("patching w/ flash-enabled, shifted-sparse attention") replace_llama_attn_with_flash_attn( packed=False, cross_entropy=self.cfg.flash_attn_cross_entropy, rms_norm=self.cfg.flash_attn_rms_norm, use_shifted_sparse_attn=True, )
258-269:⚠️ Potential issuePotential import/usage mismatch for flash attention functions
The function
patch_fa_llama_cross_entropyis imported conditionally whenself.has_flash_attnisTrue, but then used with a separate condition check. Ifself.has_flash_attnisFalsebutself.cfg.flash_attn_cross_entropyisTrue, this would cause a runtime error as the function wouldn't be defined.To fix this, either adjust the import condition to match the usage, or adjust the usage condition to match the import:
- if self.cfg.flash_attn_cross_entropy and self.has_flash_attn: + if self.has_flash_attn and self.cfg.flash_attn_cross_entropy: from axolotl.monkeypatch.llama_attn_hijack_flash import ( patch_fa_llama_cross_entropy, ) patch_fa_llama_cross_entropy()Or alternatively, move the imports outside the conditional blocks:
- if self.cfg.flash_attn_cross_entropy and self.has_flash_attn: - from axolotl.monkeypatch.llama_attn_hijack_flash import ( - patch_fa_llama_cross_entropy, - ) + # Always import these functions regardless of flash_attn availability + # to avoid runtime errors when referencing them + from axolotl.monkeypatch.llama_attn_hijack_flash import ( + patch_fa_llama_cross_entropy, + patch_llama_rms_norm, + ) + + if self.cfg.flash_attn_cross_entropy and self.has_flash_attn: patch_fa_llama_cross_entropy()
🧹 Nitpick comments (1)
src/axolotl/loaders/utils.py (1)
186-202: Consider using logging instead of print statementsThe function uses print statements for logging dtype conversions. Consider using the logger (LOG) that's already defined in this module for consistency.
- print(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}") + LOG.debug(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}") - print(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}") + LOG.debug(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/axolotl/loaders/patch_manager.py(1 hunks)src/axolotl/loaders/utils.py(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (6)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
🔇 Additional comments (8)
src/axolotl/loaders/utils.py (2)
16-41: Well-implemented recursive search functionThe recursive module search function is cleanly implemented with proper type hints and docstrings.
135-183: Good error handling and config overrides implementationThis function properly handles special cases like Mamba models and provides clear error reporting. The config overrides mechanism is also well-implemented.
src/axolotl/loaders/patch_manager.py (6)
1-26: Good module structure and docstringThe file has a clear purpose, proper imports, and good organization. The docstring explains the module's role well.
28-75: Well-designed class structure with clear responsibility separationThe PatchManager class has a clean design with proper initialization, and clear separation between pre-model and post-model patching responsibilities.
200-232: Improved auto_map handling for PretrainedConfigThe code now correctly handles both dictionary-style and attribute-style access to
auto_mapwith proper type checking. This is a good improvement.
179-199: Good error handling in sequence parallelism setupThe code checks if parallelism is already configured before applying patches, which prevents redundant setup. This is a good defensive programming practice.
351-391: Well-structured model patching with clear loggingThe post-model patching methods are well-organized with proper logging of actions, making the code behavior transparent to users.
159-178:Details
❌ Incorrect review comment
Potential inconsistency in Unsloth patching
The
_apply_unsloth_self_attention_patchmethod is called during pre-model load patching, but the same functionality (patch_self_attn_lora) is also called in_patch_llama_derived_modelat line 280. This could lead to the same patch being applied twice.Please verify that this doesn't cause any issues, or refactor to ensure the patch is only applied once:
🏁 Script executed:
#!/bin/bash # Check for duplicate patch applications echo "Checking for duplicate patch_self_attn_lora calls:" rg -p "patch_self_attn_lora" --type pythonLength of output: 184
Here’s a more reliable check for both occurrences and context:
#!/bin/bash set -e echo "1) All occurrences of patch_self_attn_lora:" rg -n "patch_self_attn_lora" echo -e "\n2) Context around _patch_llama_derived_model in patch_manager.py:" rg -n "def _patch_llama_derived_model" -A20 src/axolotl/loaders/patch_manager.py || true
Let’s get a complete picture of where and how many times
patch_self_attn_lorais applied, and the surrounding context in_patch_llama_derived_model:#!/bin/bash set -e echo "1) All occurrences of patch_self_attn_lora:" rg -n "patch_self_attn_lora" . echo -e "\n2) Context around _patch_llama_derived_model in patch_manager.py:" rg -n "def _patch_llama_derived_model" -A20 src/axolotl/loaders/patch_manager.py
Let’s list all
patch_self_attn_loracalls and see how they’re used in both methods:#!/bin/bash set -e echo "1) All occurrences of patch_self_attn_lora:" rg -n "patch_self_attn_lora" -g "*.py" echo -e "\n2) Context around _apply_unsloth_self_attention_patch:" rg -n -C5 "_apply_unsloth_self_attention_patch" src/axolotl/loaders/patch_manager.py echo -e "\n3) Context around _patch_llama_derived_model:" rg -n -C5 "_patch_llama_derived_model" src/axolotl/loaders/patch_manager.py
🏁 Script executed:
#!/bin/bash set -e echo "1) All occurrences of patch_self_attn_lora (with surrounding context):" rg -n -C3 "patch_self_attn_lora" -g "*.py" echo -e "\n2) Context around _apply_unsloth_self_attention_patch in patch_manager.py:" rg -n -C5 "_apply_unsloth_self_attention_patch" src/axolotl/loaders/patch_manager.py echo -e "\n3) Context around _patch_llama_derived_model in patch_manager.py:" rg -n -C20 "_patch_llama_derived_model" src/axolotl/loaders/patch_manager.pyLength of output: 9016
Incorrect duplication concern: distinct patch functions
The two
patch_self_attn_loracalls import entirely different functions and apply separate patches:
- In
_apply_unsloth_self_attention_patch, it imports
axolotl.monkeypatch.lora_kernels.patch_self_attn_lora(cfg)
to inject optimized LoRA kernel logic based on the config.- In
_patch_llama_derived_model, it imports
axolotl.monkeypatch.unsloth_.patch_self_attn_lora()
to apply Unsloth-specific attention patches, guarded by a global flag to ensure idempotence.Because these are different modules and serve different patching responsibilities, there’s no double-application risk.
Likely an incorrect or invalid review comment.
5c24934 to
8fd7325
Compare
There was a problem hiding this comment.
Caution
Inline review comments failed to post. This is likely due to GitHub's limits when posting large numbers of comments. If you are seeing this consistently it is likely a permissions issue. Please check "Moderation" -> "Code review limits" under your organization settings.
Actionable comments posted: 1
♻️ Duplicate comments (3)
src/axolotl/loaders/model.py (3)
155-161: Consider deprecatingpost_model_buildin favor ofpre_lora_loadThe plugin hooks
post_model_buildandpre_lora_loadare called sequentially with no intervening code. As noted in a previous comment, consider deprecatingpost_model_buildand standardizing onpre_lora_load.If there are concerns about breaking third-party plugins, consider adding a deprecation warning when
post_model_buildis used, directing users to migrate topre_lora_load.
332-341: 🛠️ Refactor suggestionValidate and improve the device mapping conditional
The current implementation unconditionally appends
:{self.cfg.local_rank}to the device string, which may not be appropriate for all device types (like MPS) or non-DDP setups.- # TODO: validate this conditional - self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}") + device = get_device_type() + # Only append device index for CUDA multi-GPU setups + if device.type == "cuda" and getattr(self.cfg, "world_size", 1) > 1: + target = f"{device}:{self.cfg.local_rank}" + else: + target = str(device) + self.model.to(target)
568-574:⚠️ Potential issueRemove dead code branch with
and FalseconditionThis conditional will never evaluate to
Truedue to theand Falseclause, making this entire FSDP-sharded path unreachable.- if ( - (self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading) - and not self.qlora_fsdp - and False - ): + if ( + self.cfg.fsdp + and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading + and not self.qlora_fsdp + ):If this branch is no longer needed, delete the entire block instead to improve code clarity.
🧹 Nitpick comments (2)
src/axolotl/loaders/model.py (2)
516-534: Simplify attention configuration with a mapping dictionaryThe current approach for setting attention implementation uses multiple if-elif branches with repeated code patterns. This could be simplified for better maintainability.
- if self.cfg.flex_attention: - self.model_kwargs["attn_implementation"] = "flex_attention" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "flex_attention" - ) - - elif self.cfg.flash_attention: - if not self.cfg.sample_packing and self.cfg.s2_attention: - pass - self.model_kwargs["attn_implementation"] = "flash_attention_2" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "flash_attention_2" - ) - elif self.cfg.sdp_attention: - self.model_kwargs["attn_implementation"] = "sdpa" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "sdpa" - ) - elif self.cfg.eager_attention: - self.model_kwargs["attn_implementation"] = "eager" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "eager" - ) + # Map config flags to attention implementations + attention_mapping = { + "flex_attention": self.cfg.flex_attention, + "flash_attention_2": self.cfg.flash_attention and not (self.cfg.sample_packing and not self.cfg.s2_attention), + "sdpa": self.cfg.sdp_attention, + "eager": self.cfg.eager_attention + } + + # Find the first matching attention implementation + for impl, condition in attention_mapping.items(): + if condition: + self.model_kwargs["attn_implementation"] = impl + self.model_config._attn_implementation = impl # pylint: disable=protected-access + breakThis approach is more maintainable and makes it easier to add new attention implementations in the future.
361-364: Replace multiple GC calls with a cleaner approachThe current implementation calls
gc.collect()three times in a row, which is unnecessary and doesn't follow best practices.- for _ in range(3): - gc.collect() - torch.cuda.empty_cache() + # Clean up memory once + gc.collect() + torch.cuda.empty_cache()Multiple GC calls generally don't provide additional benefit. A single call followed by cache clearing is sufficient.
🛑 Comments failed to post (1)
src/axolotl/loaders/model.py (1)
629-641: 🛠️ Refactor suggestion
Improve the Mamba model loading implementation
The current implementation for loading Mamba models is explicitly marked as "janky" and uses direct dictionary manipulation which could lead to
KeyErrorexceptions.- # FIXME this is janky at best and hacked together to make it work - MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name - - self.model_kwargs["dtype"] = self.model_kwargs["torch_dtype"] - self.model_kwargs["device"] = torch.cuda.current_device() - del self.model_kwargs["torch_dtype"] - del self.model_kwargs["device_map"] + # Load Mamba model with proper configuration + MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name + + # Safely transfer and remove kwargs to prevent KeyError + dtype = self.model_kwargs.pop("torch_dtype", None) + self.model_kwargs["dtype"] = dtype + self.model_kwargs["device"] = torch.cuda.current_device() + self.model_kwargs.pop("device_map", None)Consider creating a dedicated method for Mamba model loading in a future PR to better encapsulate this special case.
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/axolotl/loaders/model.py around lines 629 to 641, the Mamba model loading code directly manipulates the model_kwargs dictionary and is marked as "janky," risking KeyError exceptions. Refactor this by extracting the Mamba model loading logic into a dedicated method that safely handles dictionary keys and prepares the arguments cleanly before calling from_pretrained. This will encapsulate the special case and improve code clarity and robustness.
Description
Title. Breaking out monster
axolotl.utils.modelsfile into various submodules (also, moved up a level in the module hierarchy).Motivation and Context
File /
ModelLoaderclass at present is super large and difficult to read / maintain.How has this been tested?
TODO
Plan:
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit
New Features
Refactor
Documentation
Bug Fixes
Chores