Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ datasets==4.5.0
deepspeed>=0.18.3
trl==0.28.0
hf_xet==1.2.0
kernels==0.11.5
kernels==0.12.1

trackio>=0.16.1
typing-extensions>=4.15.0
Expand Down
19 changes: 8 additions & 11 deletions src/axolotl/core/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,13 +719,16 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
LOG.info(f"Saving model checkpoint to {output_dir}")

# fix for Context Parallel save
if state_dict is None:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was added recently to solve some saving issue . Do the changes below solve it?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just did a quick pass on this, .clone() may be unintentionally placing tensors on GPU.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean? This was a cleanup from changes upstream.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just added by ved a week ago to fix saving in context parallelism 97a4f28

state_dict = self.accelerator.get_state_dict(self.model)
if state_dict is not None:
state_dict = {
k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in state_dict.items()
}

supported_classes = (
(PreTrainedModel,)
if not is_peft_available()
Expand All @@ -736,6 +739,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
if not isinstance(self.model, supported_classes):
if state_dict is None:
state_dict = self.model.state_dict()

if isinstance(
self.accelerator.unwrap_model(self.model, keep_torch_compile=False),
supported_classes,
Expand All @@ -745,6 +749,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
).save_pretrained(
output_dir,
state_dict=state_dict,
is_main_process=self.accelerator.is_main_process,
)
else:
LOG.info(
Expand All @@ -756,11 +761,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
metadata={"format": "pt"},
)
else:
self.model.save_pretrained(
output_dir,
state_dict=state_dict,
is_main_process=self.accelerator.is_main_process,
)
self.model.save_pretrained(output_dir, state_dict=state_dict)

if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir)
Expand All @@ -772,11 +773,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
)
save_jinja_files = True
if self.axolotl_cfg:
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
self.data_collator.tokenizer.save_pretrained(
output_dir, save_jinja_files=save_jinja_files
)
self.data_collator.tokenizer.save_pretrained(output_dir)

# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
13 changes: 13 additions & 0 deletions src/axolotl/integrations/kernels/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,16 @@ def check_experts_implementation(cls, data):
data["experts_implementation"] = "eager"

return data

@model_validator(mode="before")
@classmethod
def disable_mlp_kernel_scattermoe(cls, data):
if data.get("use_scattermoe") is True:
if data.get("lora_mlp_kernel") is True:
LOG.warning(
"Disabling lora_mlp_kernel when using scattermoe due to compatibility issues."
)
data["lora_mlp_kernel"] = False
data["mlp_kernel"] = False
Comment on lines +39 to +46

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Warning says "Disabling lora_mlp_kernel" but the code only disables mlp_kernel.

The warning on line 43 tells the user lora_mlp_kernel is being disabled, but data["lora_mlp_kernel"] is never set to False — only data["mlp_kernel"] is. If both should be disabled, add the missing assignment; otherwise, fix the warning text.

Proposed fix (if both should be disabled)
     def disable_mlp_kernel_scattermoe(cls, data):
         if data.get("use_scattermoe") is True:
             if data.get("lora_mlp_kernel") is True:
                 LOG.warning(
                     "Disabling lora_mlp_kernel when using scattermoe due to compatibility issues."
                 )
+                data["lora_mlp_kernel"] = False
             data["mlp_kernel"] = False
🤖 Prompt for AI Agents
In `@src/axolotl/integrations/kernels/args.py` around lines 39 - 45, The warning
in disable_mlp_kernel_scattermoe refers to "Disabling lora_mlp_kernel" but the
code only sets data["mlp_kernel"] = False; either set data["lora_mlp_kernel"] =
False as well inside disable_mlp_kernel_scattermoe when
data.get("lora_mlp_kernel") is True, or change the LOG.warning text to
accurately state that mlp_kernel is being disabled; update the message and/or
assignment in the disable_mlp_kernel_scattermoe method accordingly.


return data
Empty file.
18 changes: 18 additions & 0 deletions src/axolotl/integrations/kernels/libs/scattermoe_lora/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0

from . import layers
from .lora_ops import ParallelExperts
from .parallel_experts import flatten_sort_count, parallel_linear
from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora

__all__ = [
"layers",
"ParallelExperts",
"flatten_sort_count",
"parallel_linear",
"ScatterMoELoRA",
"parallel_linear_lora",
"lora_ops",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
#
# Original work Copyright (c) Shawn Tan and ScatterMoE Contributors
# Adapted from https://github.com/shawntan/scattermoe
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
#
# Modifications and LoRA adaptation Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0

from . import lora_ops, ops

__all__ = ["ops", "lora_ops"]
Loading
Loading