Skip to content

feat: add mlx model and trainer#3856

Closed
JINO-ROHIT wants to merge 3212 commits into
unslothai:mainfrom
JINO-ROHIT:mlx-support
Closed

feat: add mlx model and trainer#3856
JINO-ROHIT wants to merge 3212 commits into
unslothai:mainfrom
JINO-ROHIT:mlx-support

Conversation

@JINO-ROHIT

@JINO-ROHIT JINO-ROHIT commented Jan 6, 2026

Copy link
Copy Markdown

hello everyone!

this PR aims to integrate mlx support in unsloth with minimal changes.

  1. ive tried to keep the PR as compact as possible and make use of the existing mlx utilities.
  2. ive also had to make some patches on the unsloth-zoo code files, should i raise a seperate PR for that?

im attaching below a sample alpaca training run script to get this working.

from unsloth.models.mlx_model import FastMLXModel
model, tokenizer = FastMLXModel.from_pretrained("mlx-community/Llama-3.2-3B-Instruct-4bit")


from datasets import load_dataset
dataset = load_dataset("mlabonne/FineTome-Alpaca-100k", split="train")

system_message = """You are an assistant."""
def create_conversation(sample):
  return {
    "messages": [
      {"role": "system", "content": system_message},
      {"role": "user", "content": sample["instruction"]}, # human
      {"role": "assistant", "content": sample["output"]} # model
    ]
  }

dataset = dataset.map(create_conversation, remove_columns=dataset.features,batched=False)
dataset = dataset.train_test_split(0.1)

from mlx_lm.tuner import datasets

configs = {
    "mask_prompt": False,
    "prompt_feature": "prompt",
    "text_feature": "text",
    "completion_feature": "completion",
    "chat_feature": "messages",
}

train_set = datasets.create_dataset(
    dataset["train"],
    tokenizer,
    configs
)

val_set = datasets.create_dataset(
    dataset["test"],
    tokenizer,
    configs
)


FastMLXModel.train(
    model,
    train_set,
    val_set,
    iterations = 2
)

metascroy and others added 30 commits November 14, 2025 03:08
* up

* up

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix qwen3 vl gradient accumulation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update unsloth/models/_utils.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
updates:
- [github.com/astral-sh/ruff-pre-commit: v0.14.4 → v0.14.5](astral-sh/ruff-pre-commit@v0.14.4...v0.14.5)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* README Link Fixes

* Update README.md

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
* Add an int64 path for mlp kernels

* move constant expressions to globals

* fix name
* Remove grpo requirement bs=num_generations

* Update rl.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Enable FP8 + RL training for bf16 models

**Summary:** Enable FP8 + RL training using TorchAO for 1.33x faster training and 42% less model memory usage:
- We quantize the frozen LoRA weights into fp8 and keep the LoRA adapters in bf16
- We leverage TorchAO's `Float8Tensor`, which calls into fbgemm's fp8 x fp8 rowwise matmul kernel
- For now, we need to do an offline quantization first, because vllm doesn't support on-the-fly quantization for torchao yet  (this is in progress: vllm-project/vllm#26327)

**Example usage:**
```
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-8B-Base",
    max_seq_length = 2048,
    load_in_4bit = False,
    fast_inference = True,
    max_lora_rank = 32,
    load_in_fp8 = True,  # set this to True
)

\# the rest is the same as before
model = FastLanguageModel.get_peft_model(...)
```

**Initial results:**
```
\# fp8
{'train_runtime': 1725.4337, 'train_samples_per_second': 0.232, 'train_steps_per_second': 0.058, 'train_loss': 0.00015715716748673002, 'epoch': 0.01}

\# bf16
{'train_runtime': 2297.8145, 'train_samples_per_second': 0.174, 'train_steps_per_second': 0.044, 'train_loss': 0.00016081033063528594, 'epoch': 0.01}
```

<img width="1199" height="448" alt="Screenshot 2025-11-11 at 4 10 50 PM" src="https://github.com/user-attachments/assets/b6304afd-89e9-42b1-8064-775807e17b23" />

Test script: https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423

**Requires:**
- pytorch/ao#3158 (torchao nightly or 0.15.0+)
- unslothai/unsloth-zoo#351

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* _get_inference_mode_context_manager

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* Update utils.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
* make loading gpt-oss-BF16 faster. Linked to unsloth-zoo PR unslothai#314

* fix model loading and clean merged model directory

* revert default quant

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert mapper.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Add 128x128 PerBlock FP8 + RL

**Summary:** Following unslothai#3440,
this PR extends torchao FP8 + RL support to also handle 128x128
PerBlock granularity (in addition to PerRow).

**Example usage:**

```
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-8B-Base",
    max_seq_length = 2048,
    load_in_4bit = False,
    fast_inference = True,
    max_lora_rank = 32,
    load_in_fp8 = "block",  # or "row" or True
)
```

**Initial results:** TBD

**Note:**
- Requires pytorch/ao#3370

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…hai#3168)

* change windows to remove windows-triton for intel xpu

* add changes for different platform

* Update pyproject.toml

* update mode windows

* Update pyproject.toml

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Update pyproject.toml

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Update pyproject.toml

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Update pyproject.toml

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Update pyproject.toml

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Update pyproject.toml

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Update pyproject.toml

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Update pyproject.toml

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
updates:
- [github.com/astral-sh/ruff-pre-commit: v0.14.5 → v0.14.6](astral-sh/ruff-pre-commit@v0.14.5...v0.14.6)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
@gemini-code-assist

Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @JINO-ROHIT, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands Unsloth's hardware compatibility by integrating MLX support, primarily benefiting users on Apple Silicon devices. The changes enable the loading and fine-tuning of models using the MLX framework, providing an optimized experience for Mac users. This integration is achieved through new dedicated MLX model and trainer classes, alongside conditional logic to ensure seamless operation across different hardware platforms.

Highlights

  • MLX Integration: Introduced support for MLX (Machine Learning eXchange) models and training, specifically targeting Apple Silicon (MPS) devices.
  • New MLX Model and Trainer: Added a new file unsloth/models/mlx_model.py which defines FastMLXModel for loading and training MLX-compatible models, along with MLXTrainer, MLXTrainingArguments, and MLXLoraConfig.
  • Conditional Imports and Device Detection: Modified core Unsloth files (__init__.py, device_type.py, models/__init__.py) to conditionally import MLX-related modules and correctly detect and handle MPS devices, ensuring MLX functionality is enabled only when appropriate.
  • Dependency Management: Updated pyproject.toml to include mlx and mlx-lm as dependencies for macOS (arm64) platforms, enabling MLX support for Apple Silicon users.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces MLX support to unsloth, enabling training on Apple Silicon. The changes are well-structured and mostly confined to conditional imports and a new mlx_model.py module. I've identified a few areas for improvement to enhance code clarity and reduce redundancy. Specifically, I've suggested removing a redundant conditional block in device_type.py, simplifying a dataclass-to-dictionary conversion, and removing a duplicate check for lora_config in the new MLX model file. Overall, this is a great addition to the library.

Comment thread unsloth/device_type.py
Comment on lines +69 to +70
elif DEVICE_TYPE_TORCH == "mps":
DEVICE_TYPE_TORCH = "mps"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This elif block is redundant. Since DEVICE_TYPE_TORCH is initialized with the value of DEVICE_TYPE, if DEVICE_TYPE_TORCH is 'mps', this block just assigns 'mps' back to it. This block can be safely removed to improve code clarity. Also, there is a trailing whitespace on line 70.

Comment on lines +31 to +40
def to_dict(self) -> Dict[str, Any]:
return {
"adapter_file": self.adapter_file,
"max_seq_length": self.max_seq_length,
"grad_checkpoint": self.grad_checkpoint,
"grad_accumulation_steps": self.grad_accumulation_steps,
"iters": self.iters,
"batch_size": self.batch_size,
"val_batches": self.val_batches,
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The to_dict method can be simplified by using dataclasses.asdict. This will make the code more concise and less prone to errors if new fields are added to the dataclass, as you won't need to update this method manually.

To use it, you'll also need to update the import on line 4 to:

from dataclasses import dataclass, asdict
    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)

Comment on lines +168 to +169
trainer.prepare_model_for_training(model, lora_config)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This check for lora_config being None is redundant. The prepare_model_for_training method already handles the case where lora_config is None by creating a default MLXLoraConfig. Removing this duplication will make the code cleaner.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c623f4cd3d

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth/__init__.py
Comment on lines +273 to +275
elif DEVICE_TYPE != "mps":
from .models import *
from .models import __version__

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge Base modules never imported on CUDA/HIP/XPU

The core imports are now guarded by elif DEVICE_TYPE != "mps":, but that condition is part of the same if/elif chain as the preceding CUDA/HIP/XPU branches. Because one of those earlier branches always matches on supported GPUs, this block never executes, so import unsloth no longer brings in FastLanguageModel, __version__, trainer/chat helpers, or the TRL patch on any CUDA/ROCm/Intel system—the chain exits before reaching these imports—leading to AttributeError/missing functionality for all existing users.

Useful? React with 👍 / 👎.

@shimmyshimmer

Copy link
Copy Markdown
Member

Thank you so much for your PR we'll take a look asap! :)

Comment thread unsloth/device_type.py
)
raise NotImplementedError(
"Unsloth currently only works on NVIDIA, AMD and Intel GPUs."
"Unsloth currently only works on NVIDIA, AMD, Intel GPUs, MAC Silicon and MLX."

@JustinWick JustinWick Jan 7, 2026

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't that be M-Series Apple Silicon, Apple Macintosh Silicon or Apple Mac Silicon? I don't think it's ever branded "MAC".

@JINO-ROHIT

Copy link
Copy Markdown
Author

hello @shimmyshimmer sorry did you get some time to review this? id love to iterate on feedbacks and push this PR forward

@shimmyshimmer

Copy link
Copy Markdown
Member

hello @shimmyshimmer sorry did you get some time to review this? id love to iterate on feedbacks and push this PR forward

While the PR itself is fantastic (thank you!!), there arent any optimizations at the moment, were discussing whether we should proceed as is or, spend more time on optimizations

@b-straub

b-straub commented Jan 12, 2026

Copy link
Copy Markdown

Being only an interested consumer of Unsloth models I might be wrong, but to my knowledge the fine-tuning performance gains are also heavily related to the custom Triton kernels. Shouldn’t the optimization strategies behind those kernels be reimplemented for MLX when possible? (I understand Triton kernels can’t be directly ported since they compile to CUDA PTX, but the underlying approaches like fused attention and memory-efficient backward passes could potentially be implemented using MLX’s Metal primitives.)

@JINO-ROHIT

Copy link
Copy Markdown
Author

hello @shimmyshimmer sorry did you get some time to review this? id love to iterate on feedbacks and push this PR forward

While the PR itself is fantastic (thank you!!), there arent any optimizations at the moment, were discussing whether we should proceed as is or, spend more time on optimizations

of course, ill wait to hear on further updates

@JINO-ROHIT

Copy link
Copy Markdown
Author

Being only an interested consumer of Unsloth models I might be wrong, but to my knowledge the fine-tuning performance gains are also heavily related to the custom Triton kernels. Shouldn’t the optimization strategies behind those kernels be reimplemented for MLX when possible? (I understand Triton kernels can’t be directly ported since they compile to CUDA PTX, but the underlying approaches like fused attention and memory-efficient backward passes could potentially be implemented using MLX’s Metal primitives.)

i think mlx should already have its own set of pre-existing optimizations, but sure we could also possibly look into further improvements, although im not quite sure how difficult/easy itd be to make new metal kernels and integrate within.

@b-straub

Copy link
Copy Markdown

think mlx should already have its own set of pre-existing optimizations, but sure we could also possibly look into further improvements, although im not quite sure how difficult/easy itd be to make new metal kernels and integrate within.

At least should be easier now https://ml-explore.github.io/mlx/build/html/dev/custom_metal_kernels.html

@JINO-ROHIT

Copy link
Copy Markdown
Author

sure, i meant the overall complexity of metal programming + wrtiting optimized kernels on top of it

@Manan17

Manan17 commented Jan 14, 2026

Copy link
Copy Markdown
Collaborator

@JINO-ROHIT can you please upload the changes you made in unsloth_zoo in order to run this.

@JINO-ROHIT

Copy link
Copy Markdown
Author

sure ill do it

@xpatronum

Copy link
Copy Markdown

@JINO-ROHIT any updates, plz ?

@braduck

braduck commented Mar 5, 2026

Copy link
Copy Markdown

Who can Resolve the conflicts and merge?

@negative0

Copy link
Copy Markdown

Eagerly waiting for this 🥹

@danielhanchen

Copy link
Copy Markdown
Member

Sorry we had to rebase the PR - apologies!

@step21

step21 commented Mar 18, 2026

Copy link
Copy Markdown

Wouldn't it make more sense to base this just on gpu support in f.e. pytorch then on MLX? this would enable usage of the same models.

@RukshanJS

Copy link
Copy Markdown

what happened with this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.