Skip to content

feat: add mlx model and trainer#4258

Open
danielhanchen wants to merge 2 commits into
mainfrom
dh/recover-3856-mlx
Open

feat: add mlx model and trainer#4258
danielhanchen wants to merge 2 commits into
mainfrom
dh/recover-3856-mlx

Conversation

@danielhanchen
Copy link
Copy Markdown
Member

Replacement for #3856 due to Studio rebasing

hello everyone!

this PR aims to integrate mlx support in unsloth with minimal changes.

  1. ive tried to keep the PR as compact as possible and make use of the existing mlx utilities.
  2. ive also had to make some patches on the unsloth-zoo code files, should i raise a seperate PR for that?

im attaching below a sample alpaca training run script to get this working.

from unsloth.models.mlx_model import FastMLXModel
model, tokenizer = FastMLXModel.from_pretrained("mlx-community/Llama-3.2-3B-Instruct-4bit")


from datasets import load_dataset
dataset = load_dataset("mlabonne/FineTome-Alpaca-100k", split="train")

system_message = """You are an assistant."""
def create_conversation(sample):
  return {
    "messages": [
      {"role": "system", "content": system_message},
      {"role": "user", "content": sample["instruction"]}, # human
      {"role": "assistant", "content": sample["output"]} # model
    ]
  }

dataset = dataset.map(create_conversation, remove_columns=dataset.features,batched=False)
dataset = dataset.train_test_split(0.1)

from mlx_lm.tuner import datasets

configs = {
    "mask_prompt": False,
    "prompt_feature": "prompt",
    "text_feature": "text",
    "completion_feature": "completion",
    "chat_feature": "messages",
}

train_set = datasets.create_dataset(
    dataset["train"],
    tokenizer,
    configs
)

val_set = datasets.create_dataset(
    dataset["test"],
    tokenizer,
    configs
)


FastMLXModel.train(
    model,
    train_set,
    val_set,
    iterations = 2
)

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 1bf8b6c685

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth/__init__.py
Comment on lines +311 to +313
elif DEVICE_TYPE != "mps":
from .models import *
from .models import __version__
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Move top-level exports out of device setup elif chain

This elif DEVICE_TYPE != "mps" block is attached to the earlier if/elif chain that handles CUDA/HIP/XPU initialization, so it is skipped on all those supported devices; as a result, the package no longer imports FastLanguageModel/other top-level symbols (breaking common usage like from unsloth import FastLanguageModel) and also skips _patch_trl_trainer() on those environments. This is a regression for standard GPU users introduced by the new conditional structure.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 76054a562a

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

) -> Tuple[Any, Any]:
print(f"Unsloth: Loading model with MLX: {model_name}")

model, tokenizer = load(model_name)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Forward from_pretrained options into MLX loader

from_pretrained accepts arbitrary keyword arguments but drops them when calling load, so caller-supplied options (for example adapter_path when reopening a LoRA checkpoint) are silently ignored and the base model is loaded instead. This can produce incorrect inference/training continuation while appearing to succeed, because the API signature suggests those parameters are supported.

Useful? React with 👍 / 👎.

@nidhishgajjar

This comment was marked as low quality.

1 similar comment
@nidhishgajjar

This comment was marked as low quality.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants