Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions examples/devstral/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Finetune Devstral with Axolotl

Devstral Small is a 24B parameter opensource model from MistralAI found on HuggingFace [Devstral-Small-2505](https://huggingface.co/mistralai/Devstral-Small-2505). This guide shows how to fine-tune it with Axolotl with multi-turn conversations with proper masking.

The model was fine-tuned ontop of [Mistral-Small-3.1](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503) without the vision layer and has a context of upto 128k tokens.

## Getting started

1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Devstral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).

Here is an example of how to install from main for pip:

```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0+)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl

pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'

# Install the latest mistral-common from source
pip3 uninstall mistral-common
pip3 install git+https://github.com/mistralai/mistral-common.git@039465d

```

2. Run the finetuning example:

```bash
axolotl train examples/devstral/devstral-small-qlora.yml
```

This config uses about 21GB VRAM.

Let us know how it goes. Happy finetuning! 🚀

### TIPS

- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).

## Optimization Guides

- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
- [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy)
- [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels)

## Limitations

We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.

In addition, we do not support overriding tokens yet.

## Related Resources

- [MistralAI Devstral Blog](https://mistral.ai/news/devstral)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)


## Future Work

- Add parity to Preference Tuning, RL, Multi-modal, etc.
- Add parity to other tokenizer configs like overriding tokens.
64 changes: 64 additions & 0 deletions examples/devstral/devstral-small-qlora.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
base_model: mistralai/Devstral-Small-2505

# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

# Enable to use mistral-common tokenizer
tokenizer_use_mistral_common: true

load_in_8bit: false
load_in_4bit: true

plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin

datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template

dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/qlora-out

adapter: qlora
lora_model_dir:

sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true

lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_linear: true

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002

bf16: auto
tf32: false

gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true

loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

warmup_ratio: 0.05
evals_per_epoch: 4
saves_per_epoch: 1

weight_decay: 0.0
special_tokens:
14 changes: 4 additions & 10 deletions examples/magistral/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,10 @@ git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl

pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,mistral]'
pip3 install --no-build-isolation -e '.[flash-attn]'
```

2. Download the example config:

```bash
axolotl fetch examples
```

3. Run the finetuning example:
2. Run the finetuning example:

```bash
axolotl train examples/magistral/magistral-small-qlora.yaml
Expand All @@ -42,7 +36,7 @@ Let us know how it goes. Happy finetuning! 🚀
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format is the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).

## Optimization Guides

Expand All @@ -54,7 +48,7 @@ Let us know how it goes. Happy finetuning! 🚀

We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.

The tokenizer does not work with `dataset.map` with multiprocessing, so we had to disable it. In addition, we do not support overriding tokens yet.
In addition, we do not support overriding tokens yet.

## Related Resources

Expand Down
7 changes: 0 additions & 7 deletions src/axolotl/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,6 @@ def process(self, dataset):
features = dataset.features.keys()
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())

# Disable multiprocessing if the tokenizer doesn't support it (e.g., mistral_common)
if not getattr(self.prompt_tokenizer, "supports_multiprocessing", True):
LOG.info(
"Disabling multiprocessing for tokenizer as it doesn't support it (e.g., mistral_common)"
)
num_proc = 1

map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
Expand Down
24 changes: 8 additions & 16 deletions src/axolotl/prompt_strategies/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,13 +681,14 @@ def get_conversation_thread(self, prompt):
for message in messages:
transformed_message = self.transform_message(message)

turn = {
**transformed_message,
"training": message.get(self.prompter.message_field_training),
"training_detail": message.get(
self.prompter.message_field_training_detail
),
}
turn = transformed_message

training = message.get(self.prompter.message_field_training)
training_detail = message.get(self.prompter.message_field_training_detail)
if training is not None:
turn["training"] = training
if training_detail is not None:
turn["training_detail"] = training_detail

turns.append(turn)

Expand Down Expand Up @@ -859,15 +860,6 @@ def __init__(
# TODO: address this in the future with mistral-specific checks
# self._validate_eot_and_eos_tokens()

@property
def supports_multiprocessing(self) -> bool:
"""
Whether this tokenizing strategy supports multiprocessing.
mistral_common tokenizers cannot be pickled for multiprocessing.
"""

return False

def find_first_eot_token(self, input_ids, start_idx):
"""Find the first EOT token in the input_ids starting from start_idx."""
# mistral-common tokenizer does not support eot_tokens
Expand Down
8 changes: 0 additions & 8 deletions src/axolotl/prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,6 @@ def tokenize_prompt(self, prompt):
def supports_batched(self):
return False

@property
def supports_multiprocessing(self):
"""
Whether this tokenizing strategy supports multiprocessing.
Should return False if the tokenizer has unpicklable objects.
"""
return True

def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding:
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/collators/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __call__(self, features, return_tensors=None):
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=return_tensors,
)
if not has_attn_mask:
if not has_attn_mask and "attention_mask" in features:
del features["attention_mask"]

# prepare decoder_input_ids
Expand Down
Loading
Loading