Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ windows = [
"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0 ; (sys_platform == 'win32')",
"xformers>=0.0.22.post7 ; (sys_platform == 'win32')",
]
mac = [
"unsloth[huggingface]",
"mlx>=0.12.0 ; sys_platform == 'darwin' and platform_machine == 'arm64'",
"mlx-lm>=0.9.0 ; sys_platform == 'darwin' and platform_machine == 'arm64'",
]
base = [
"unsloth[huggingface]",
]
Expand Down
41 changes: 20 additions & 21 deletions unsloth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,9 @@ def is_bf16_supported():

# For Gradio HF Spaces?
# if "SPACE_AUTHOR_NAME" not in os.environ and "SPACE_REPO_NAME" not in os.environ:
import triton

if DEVICE_TYPE == "cuda":
import triton

libcuda_dirs = lambda: None
if Version(triton.__version__) >= Version("3.0.0"):
try:
Expand Down Expand Up @@ -308,23 +308,22 @@ def is_bf16_supported():
# TODO: check triton for intel installed properly.
pass

from .models import *
from .models import __version__
from .save import *
from .chat_templates import *
from .tokenizer_utils import *
from .trainer import *

# Export dataprep utilities for CLI and downstream users
from .dataprep.raw_text import RawTextDataLoader, TextPreprocessor
from unsloth_zoo.rl_environments import (
check_python_modules,
create_locked_down_function,
execute_with_time_limit,
Benchmarker,
is_port_open,
launch_openenv,
)
elif DEVICE_TYPE != "mps":
from .models import *
from .models import __version__
Comment on lines +311 to +313
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Move top-level exports out of device setup elif chain

This elif DEVICE_TYPE != "mps" block is attached to the earlier if/elif chain that handles CUDA/HIP/XPU initialization, so it is skipped on all those supported devices; as a result, the package no longer imports FastLanguageModel/other top-level symbols (breaking common usage like from unsloth import FastLanguageModel) and also skips _patch_trl_trainer() on those environments. This is a regression for standard GPU users introduced by the new conditional structure.

Useful? React with 👍 / 👎.

from .save import *
from .chat_templates import *
from .tokenizer_utils import *
from .trainer import *
from .dataprep.raw_text import RawTextDataLoader, TextPreprocessor
from unsloth_zoo.rl_environments import (
check_python_modules,
create_locked_down_function,
execute_with_time_limit,
Benchmarker,
is_port_open,
launch_openenv,
)

# Patch TRL trainers for backwards compatibility
_patch_trl_trainer()
# Patch TRL trainers for backwards compatibility
_patch_trl_trainer()
8 changes: 6 additions & 2 deletions unsloth/device_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,23 @@ def get_device_type():
return "cuda"
elif hasattr(torch, "xpu") and torch.xpu.is_available():
return "xpu"
elif hasattr(torch, "mps") and torch.mps.is_available():
return "mps"
# Check torch.accelerator
if hasattr(torch, "accelerator"):
if not torch.accelerator.is_available():
raise NotImplementedError(
"Unsloth cannot find any torch accelerator? You need a GPU."
)
accelerator = str(torch.accelerator.current_accelerator())
if accelerator in ("cuda", "xpu", "hip"):
if accelerator in ("cuda", "xpu", "hip", "mps"):
raise RuntimeError(
f"Unsloth: Weirdly `torch.cuda.is_available()`, `torch.xpu.is_available()` and `is_hip` all failed.\n"
f"But `torch.accelerator.current_accelerator()` works with it being = `{accelerator}`\n"
f"Please reinstall torch - it's most likely broken :("
)
raise NotImplementedError(
"Unsloth currently only works on NVIDIA, AMD and Intel GPUs."
"Unsloth currently only works on NVIDIA, AMD, Intel GPUs, MAC Silicon and MLX."
)


Expand All @@ -64,6 +66,8 @@ def get_device_type():
DEVICE_TYPE_TORCH = DEVICE_TYPE
if DEVICE_TYPE_TORCH == "hip":
DEVICE_TYPE_TORCH = "cuda"
elif DEVICE_TYPE_TORCH == "mps":
DEVICE_TYPE_TORCH = "mps"


@functools.cache
Expand Down
37 changes: 21 additions & 16 deletions unsloth/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .llama import FastLlamaModel
from .loader import FastLanguageModel, FastVisionModel, FastTextModel, FastModel
from .mistral import FastMistralModel
from .qwen2 import FastQwen2Model
from .qwen3 import FastQwen3Model
from .qwen3_moe import FastQwen3MoeModel
from .granite import FastGraniteModel
from .sentence_transformer import FastSentenceTransformer
from ..device_type import DEVICE_TYPE

try:
from .falcon_h1 import FastFalconH1Model
except:
# transformers_version < 4.53.0 does not have falcon_h1 so silently skip it for now
pass
from .dpo import PatchDPOTrainer, PatchKTOTrainer
from ._utils import is_bfloat16_supported, is_vLLM_available, __version__
from .rl import PatchFastRL, vLLMSamplingParams
if DEVICE_TYPE != "mps":
from .llama import FastLlamaModel
from .loader import FastLanguageModel, FastVisionModel, FastTextModel, FastModel
from .mistral import FastMistralModel
from .qwen2 import FastQwen2Model
from .qwen3 import FastQwen3Model
from .qwen3_moe import FastQwen3MoeModel
from .granite import FastGraniteModel
from .sentence_transformer import FastSentenceTransformer

try:
from .falcon_h1 import FastFalconH1Model
except:
# transformers_version < 4.53.0 does not have falcon_h1 so silently skip it for now
pass
from .dpo import PatchDPOTrainer, PatchKTOTrainer
from ._utils import is_bfloat16_supported, is_vLLM_available, __version__
from .rl import PatchFastRL, vLLMSamplingParams
else:
from .mlx_model import FastMLXModel
3 changes: 3 additions & 0 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,9 @@ def patch_mistral_nemo_config(config):
else:
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "xpu")
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "xpu")
else:
torch_amp_custom_fwd = None
torch_amp_custom_bwd = None
# =============================================

# =============================================
Expand Down
178 changes: 178 additions & 0 deletions unsloth/models/mlx_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import os
import json
from typing import Optional, Dict, Any, Union, Tuple
from dataclasses import dataclass

from mlx_lm import load
from mlx_lm.tuner import (
train,
TrainingArgs,
datasets,
linear_to_lora_layers,
)
import mlx.optimizers as optim
from mlx.utils import tree_flatten

from ..device_type import DEVICE_TYPE


@dataclass
class MLXTrainingArguments:
"""training arguments for MLX models."""

adapter_file: str = "adapters.safetensors"
max_seq_length: int = 2048
grad_checkpoint: bool = True
grad_accumulation_steps: int = 1
iters: int = 100
batch_size: int = 4
val_batches: int = 10

def to_dict(self) -> Dict[str, Any]:
return {
"adapter_file": self.adapter_file,
"max_seq_length": self.max_seq_length,
"grad_checkpoint": self.grad_checkpoint,
"grad_accumulation_steps": self.grad_accumulation_steps,
"iters": self.iters,
"batch_size": self.batch_size,
"val_batches": self.val_batches,
}


class MLXLoraConfig:
def __init__(
self,
rank: int = 8,
scale: float = 20.0,
dropout: float = 0.0,
num_layers: int = 8,
):
self.rank = rank
self.scale = scale
self.dropout = dropout
self.num_layers = num_layers

def to_dict(self) -> Dict[str, Any]:
return {
"num_layers": self.num_layers,
"lora_parameters": {
"rank": self.rank,
"scale": self.scale,
"dropout": self.dropout,
},
}

def save(self, adapter_path: str):
os.makedirs(adapter_path, exist_ok = True)
config_path = os.path.join(adapter_path, "adapter_config.json")
with open(config_path, "w") as f:
json.dump(self.to_dict(), f, indent = 4)


class MLXTrainer:
def prepare_model_for_training(
self,
model: Any,
lora_config: Optional[MLXLoraConfig] = None,
) -> Any:
if lora_config is None:
lora_config = MLXLoraConfig()

model.freeze()

linear_to_lora_layers(
model,
lora_config.num_layers,
lora_config.to_dict()["lora_parameters"],
)

num_train_params = sum(
v.size for _, v in tree_flatten(model.trainable_parameters())
)
print(f"number of trainable parameters: {num_train_params}")

model.train()

return model

def _train(
self,
model: Any,
training_args: Union[MLXTrainingArguments, Dict[str, Any]],
train_dataset: Any,
val_dataset: Any = None,
learning_rate: float = 1e-5,
):
if isinstance(training_args, MLXTrainingArguments):
args_dict = training_args.to_dict()
else:
args_dict = training_args

args = TrainingArgs(**args_dict)

optimizer = optim.Adam(learning_rate = learning_rate)

train_set = datasets.CacheDataset(train_dataset)
val_set = datasets.CacheDataset(val_dataset) if val_dataset else None

train(
model = model,
args = args,
optimizer = optimizer,
train_dataset = train_set,
val_dataset = val_set,
)


class FastMLXModel:
@staticmethod
def from_pretrained(
model_name: str,
**kwargs,
) -> Tuple[Any, Any]:
print(f"Unsloth: Loading model with MLX: {model_name}")

model, tokenizer = load(model_name)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Forward from_pretrained options into MLX loader

from_pretrained accepts arbitrary keyword arguments but drops them when calling load, so caller-supplied options (for example adapter_path when reopening a LoRA checkpoint) are silently ignored and the base model is loaded instead. This can produce incorrect inference/training continuation while appearing to succeed, because the API signature suggests those parameters are supported.

Useful? React with 👍 / 👎.

return model, tokenizer

@staticmethod
def for_inference(
model_name: str,
adapter_path: Optional[str] = None,
) -> Any:
if adapter_path:
model, _ = load(model_name, adapter_path = adapter_path)
else:
model, _ = load(model_name)

return model

@staticmethod
def train(
model: Any,
train_set: Any,
val_set: Any,
lora_config: Optional[MLXLoraConfig] = None,
iterations: int = 100,
learning_rate: float = 1e-5,
):
if DEVICE_TYPE != "mps":
raise RuntimeError("This function requires running on Apple Silicon")

trainer = MLXTrainer()

if lora_config is None:
lora_config = MLXLoraConfig()

trainer.prepare_model_for_training(model, lora_config)

trainer._train(
model = model,
training_args = MLXTrainingArguments(iters = iterations),
train_dataset = train_set,
val_dataset = val_set,
learning_rate = learning_rate,
)

return model
Loading