From abe6f42bf0e3298b9a94322d7d207b1e16492f73 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 8 Jul 2025 15:30:32 +0700 Subject: [PATCH 01/13] fix: do not add training and training_detail block by default --- src/axolotl/prompt_strategies/chat_template.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 6d2a048b2b..ec13acc530 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -681,13 +681,14 @@ def get_conversation_thread(self, prompt): for message in messages: transformed_message = self.transform_message(message) - turn = { - **transformed_message, - "training": message.get(self.prompter.message_field_training), - "training_detail": message.get( - self.prompter.message_field_training_detail - ), - } + turn = transformed_message + + training = message.get(self.prompter.message_field_training) + training_detail = message.get(self.prompter.message_field_training_detail) + if training is not None: + turn["training"] = training + if training_detail is not None: + turn["training_detail"] = training_detail turns.append(turn) From bb937c84499bb7d636f78cf6bec465bb086613c4 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 8 Jul 2025 15:31:12 +0700 Subject: [PATCH 02/13] fixed: magistral docs --- examples/magistral/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/magistral/README.md b/examples/magistral/README.md index a2b09ab700..9fc3adf83b 100644 --- a/examples/magistral/README.md +++ b/examples/magistral/README.md @@ -18,7 +18,7 @@ git clone https://github.com/axolotl-ai-cloud/axolotl.git cd axolotl pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja -pip3 install --no-build-isolation -e '.[flash-attn,mistral]' +pip3 install --no-build-isolation -e '.[flash-attn]' ``` 2. Download the example config: From 801cf749af17962f517dddc1b5dac2410a3326ee Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 8 Jul 2025 15:36:47 +0700 Subject: [PATCH 03/13] fix: address pad adding new fields and use built-in from_openai --- src/axolotl/utils/mistral_tokenizer.py | 183 ++++++++----------------- 1 file changed, 56 insertions(+), 127 deletions(-) diff --git a/src/axolotl/utils/mistral_tokenizer.py b/src/axolotl/utils/mistral_tokenizer.py index 1ba824938e..b1055bb7a1 100644 --- a/src/axolotl/utils/mistral_tokenizer.py +++ b/src/axolotl/utils/mistral_tokenizer.py @@ -3,10 +3,11 @@ import math import os from shutil import copyfile -from typing import TYPE_CHECKING, Optional +from typing import Optional import numpy as np from huggingface_hub import hf_hub_download +from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.mistral import MistralTokenizer from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy, Tekkenizer from torch import Tensor @@ -14,9 +15,6 @@ from axolotl.utils.collators.core import IGNORE_INDEX -if TYPE_CHECKING: - from mistral_common.protocol.instruct.request import ChatCompletionRequest - def _get_file_path(path_or_repo_id: str, filename: str) -> str: """Get the file path from local or HF Hub""" @@ -259,75 +257,6 @@ def decode( token_ids, special_token_policy=SpecialTokenPolicy.KEEP ) - def _create_mistral_chat_completion_request( - self, conversation: list[dict], tools: list[dict] | None = None - ) -> "ChatCompletionRequest": - from mistral_common.protocol.instruct.messages import ( - AssistantMessage, - SystemMessage, - ToolMessage, - UserMessage, - ) - from mistral_common.protocol.instruct.request import ChatCompletionRequest - from mistral_common.protocol.instruct.tool_calls import Function, Tool - - messages: list[UserMessage | AssistantMessage | ToolMessage | SystemMessage] = ( - [] - ) - for turn in conversation: - role = turn.get("role") - - if role == "user": - messages.append(UserMessage(content=turn["content"])) - elif role == "assistant": - messages.append( - AssistantMessage( - content=turn.get("content"), - tool_calls=turn.get("tool_calls"), - ) - ) - elif role == "tool": - messages.append( - ToolMessage( - content=turn.get("content"), - tool_call_id=turn.get("tool_call_id"), - name=turn.get("name"), - ) - ) - elif role == "system": - messages.append(SystemMessage(content=turn["content"])) - else: - raise ValueError( - f"Unknown role for use with mistral-common tokenizer: {turn['role']}" - ) - - tool_calls: list[Tool] = [] - if tools: - # convert to Tool - for tool in tools: - if tool["type"] != "function": - continue - - function = tool["function"] - - tool_calls.append( - Tool( - function=Function( - name=function["name"], - description=function["description"], - # set parameters to empty dict if not provided - parameters=function.get("parameters", {}), - ) - ) - ) - - chat_completion: ChatCompletionRequest = ChatCompletionRequest( - messages=messages, - tools=tool_calls, - ) - - return chat_completion - def apply_chat_template( self, messages: list[dict], @@ -342,8 +271,8 @@ def apply_chat_template( if add_generation_prompt: raise NotImplementedError("add_generation_prompt not supported yet") - chat_completion: ChatCompletionRequest = ( - self._create_mistral_chat_completion_request(messages, tools) + chat_completion: ChatCompletionRequest = ChatCompletionRequest.from_openai( + messages, tools ) tokens: list[int] = self._mistral.encode_chat_completion(chat_completion).tokens @@ -408,13 +337,16 @@ def pad( padding_value=IGNORE_INDEX, ) - attention_mask = torch.nn.utils.rnn.pad_sequence( - [torch.tensor(x["attention_mask"], dtype=torch.long) for x in features], - batch_first=True, - padding_value=0, - ) + attention_mask = None + if "attention_mask" in features[0]: + attention_mask = torch.nn.utils.rnn.pad_sequence( + [torch.tensor(x["attention_mask"], dtype=torch.long) for x in features], + batch_first=True, + padding_value=0, + ) # Handle position_ids - pad with sequential values for right padding, 0s for left padding + position_ids = None if "position_ids" in features[0]: if self.padding_side == "left": # Likely not needed, but keeping for now @@ -443,22 +375,15 @@ def pad( pos_seq = torch.cat([pos_seq, pad_positions]) position_ids_list.append(pos_seq) position_ids = torch.stack(position_ids_list) - else: - # Create position_ids if not present - seq_len = input_ids.size(1) - position_ids = ( - torch.arange(seq_len, dtype=torch.long) - .unsqueeze(0) - .expand(input_ids.size(0), -1) - ) # Ensure all tensors have the same sequence length - max_seq_len = max( - input_ids.size(1), - labels.size(1), - attention_mask.size(1), - position_ids.size(1), - ) + # Check attention mask and position ids if they are present + tensor_lengths = [input_ids.size(1), labels.size(1)] + if attention_mask is not None: + tensor_lengths.append(attention_mask.size(1)) + if position_ids is not None: + tensor_lengths.append(position_ids.size(1)) + max_seq_len = max(tensor_lengths) # TODO: check if trimming is needed? and correct. @@ -492,44 +417,48 @@ def pad( elif labels.size(1) > max_seq_len: labels = labels[:, :max_seq_len] - if attention_mask.size(1) < max_seq_len: - pad_len = max_seq_len - attention_mask.size(1) - if self.padding_side == "right": - attention_mask = F.pad(attention_mask, (0, pad_len), value=0) - else: - attention_mask = F.pad(attention_mask, (pad_len, 0), value=0) - elif attention_mask.size(1) > max_seq_len: - attention_mask = attention_mask[:, :max_seq_len] - - if position_ids.size(1) < max_seq_len: - pad_len = max_seq_len - position_ids.size(1) - if self.padding_side == "right": - batch_size = position_ids.size(0) - new_position_ids = [] - for i in range(batch_size): - seq = position_ids[i] - if len(seq) > 0: - # get last position and pad with sequential values - last_pos = seq[-1].item() - pad_positions = torch.arange( - last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long - ) - new_seq = torch.cat([seq, pad_positions]) - else: - new_seq = torch.arange(pad_len, dtype=torch.long) - new_position_ids.append(new_seq) - position_ids = torch.stack(new_position_ids) - else: - position_ids = F.pad(position_ids, (pad_len, 0), value=0) - elif position_ids.size(1) > max_seq_len: - position_ids = position_ids[:, :max_seq_len] + if attention_mask is not None: + if attention_mask.size(1) < max_seq_len: + pad_len = max_seq_len - attention_mask.size(1) + if self.padding_side == "right": + attention_mask = F.pad(attention_mask, (0, pad_len), value=0) + else: + attention_mask = F.pad(attention_mask, (pad_len, 0), value=0) + elif attention_mask.size(1) > max_seq_len: + attention_mask = attention_mask[:, :max_seq_len] + + if position_ids is not None: + if position_ids.size(1) < max_seq_len: + pad_len = max_seq_len - position_ids.size(1) + if self.padding_side == "right": + batch_size = position_ids.size(0) + new_position_ids = [] + for i in range(batch_size): + seq = position_ids[i] + if len(seq) > 0: + # get last position and pad with sequential values + last_pos = seq[-1].item() + pad_positions = torch.arange( + last_pos + 1, last_pos + 1 + pad_len, dtype=torch.long + ) + new_seq = torch.cat([seq, pad_positions]) + else: + new_seq = torch.arange(pad_len, dtype=torch.long) + new_position_ids.append(new_seq) + position_ids = torch.stack(new_position_ids) + else: + position_ids = F.pad(position_ids, (pad_len, 0), value=0) + elif position_ids.size(1) > max_seq_len: + position_ids = position_ids[:, :max_seq_len] final_batch = { "input_ids": input_ids, "labels": labels, - "attention_mask": attention_mask, - "position_ids": position_ids, } + if attention_mask is not None: + final_batch["attention_mask"] = attention_mask + if position_ids is not None: + final_batch["position_ids"] = position_ids # Handle non-sequence fields (raise error) sequence_fields = {"input_ids", "labels", "attention_mask", "position_ids"} From f9851d06739dcd3934585a8de24a7333f5d8ee98 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 8 Jul 2025 15:37:02 +0700 Subject: [PATCH 04/13] feat: try enable multiprocessing --- src/axolotl/prompt_strategies/chat_template.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index ec13acc530..831a6e9bd1 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -860,14 +860,14 @@ def __init__( # TODO: address this in the future with mistral-specific checks # self._validate_eot_and_eos_tokens() - @property - def supports_multiprocessing(self) -> bool: - """ - Whether this tokenizing strategy supports multiprocessing. - mistral_common tokenizers cannot be pickled for multiprocessing. - """ - - return False + # @property + # def supports_multiprocessing(self) -> bool: + # """ + # Whether this tokenizing strategy supports multiprocessing. + # mistral_common tokenizers cannot be pickled for multiprocessing. + # """ + + # return False def find_first_eot_token(self, input_ids, start_idx): """Find the first EOT token in the input_ids starting from start_idx.""" From 51cc7d5b6d8bff78ead7fb9a2d2fadc1babe693a Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 8 Jul 2025 15:51:08 +0700 Subject: [PATCH 05/13] fix: check for keys before deleting attn_mask --- src/axolotl/utils/collators/batching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index a28f360be3..25a871b2ba 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -108,7 +108,7 @@ def __call__(self, features, return_tensors=None): pad_to_multiple_of=self.pad_to_multiple_of, return_tensors=return_tensors, ) - if not has_attn_mask: + if not has_attn_mask and "attention_mask" in features: del features["attention_mask"] # prepare decoder_input_ids From 151709ddb1b99ee52f2d5ff36df946bcc7d0b223 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 8 Jul 2025 16:11:11 +0700 Subject: [PATCH 06/13] feat: add mistral pad test --- src/axolotl/utils/mistral_tokenizer.py | 2 +- .../test_chat_templates_mistral.py | 136 ++++++++++++++++++ 2 files changed, 137 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/mistral_tokenizer.py b/src/axolotl/utils/mistral_tokenizer.py index b1055bb7a1..95c87a8226 100644 --- a/src/axolotl/utils/mistral_tokenizer.py +++ b/src/axolotl/utils/mistral_tokenizer.py @@ -474,7 +474,7 @@ def pad( result = {} for k, v in final_batch.items(): if isinstance(v, torch.Tensor): - result[k] = v.numpy().astype(np.long) + result[k] = v.numpy().astype(np.int64) else: result[k] = v return result diff --git a/tests/prompt_strategies/test_chat_templates_mistral.py b/tests/prompt_strategies/test_chat_templates_mistral.py index 3c60a15c27..3567cf58d9 100644 --- a/tests/prompt_strategies/test_chat_templates_mistral.py +++ b/tests/prompt_strategies/test_chat_templates_mistral.py @@ -286,5 +286,141 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"): assert res == ["Hello", ",", " how", " are", " you", "?"] +def test_mistral_tokenizer_pad_method(magistral_tokenizer: "HFMistralTokenizer"): + """Test the pad method with various field combinations.""" + from axolotl.utils.collators.core import IGNORE_INDEX + + magistral_pad_token_id = 11 # taken from tokenizer.pad_token_id + + # Test padding with input_ids and labels only + features = [ + {"input_ids": [1, 2, 3], "labels": [4, 5, 6]}, + {"input_ids": [7, 8], "labels": [9, 10]}, + ] + + result = magistral_tokenizer.pad(features, padding=True, return_tensors="pt") + + # Check that input_ids are padded correctly + assert result["input_ids"].shape == (2, 3) + assert result["input_ids"].tolist() == [[1, 2, 3], [7, 8, magistral_pad_token_id]] + + # Check that labels are padded correctly + assert result["labels"].shape == (2, 3) + assert result["labels"].tolist() == [[4, 5, 6], [9, 10, IGNORE_INDEX]] + + # Check that attention_mask and position_ids are NOT created + assert "attention_mask" not in result + assert "position_ids" not in result + + # Test padding with attention_mask + features_with_attention = [ + {"input_ids": [1, 2, 3], "labels": [4, 5, 6], "attention_mask": [1, 1, 1]}, + {"input_ids": [7, 8], "labels": [9, 10], "attention_mask": [1, 1]}, + ] + + result = magistral_tokenizer.pad( + features_with_attention, padding=True, return_tensors="pt" + ) + + # Check that attention_mask is padded correctly + assert result["attention_mask"].shape == (2, 3) + assert result["attention_mask"].tolist() == [[1, 1, 1], [1, 1, 0]] + + # Test padding with position_ids + features_with_position = [ + {"input_ids": [1, 2, 3], "labels": [4, 5, 6], "position_ids": [0, 1, 2]}, + {"input_ids": [7, 8], "labels": [9, 10], "position_ids": [0, 1]}, + ] + + result = magistral_tokenizer.pad( + features_with_position, padding=True, return_tensors="pt" + ) + + # Check that position_ids are padded correctly (continuing sequence) + assert result["position_ids"].shape == (2, 3) + assert result["position_ids"].tolist() == [[0, 1, 2], [0, 1, 2]] + + # Test padding with all fields + features_all = [ + { + "input_ids": [1, 2, 3], + "labels": [4, 5, 6], + "attention_mask": [1, 1, 1], + "position_ids": [0, 1, 2], + }, + { + "input_ids": [7, 8], + "labels": [9, 10], + "attention_mask": [1, 1], + "position_ids": [0, 1], + }, + ] + + result = magistral_tokenizer.pad(features_all, padding=True, return_tensors="pt") + + # All fields should be present and correctly padded + assert "input_ids" in result + assert "labels" in result + assert "attention_mask" in result + assert "position_ids" in result + + # Test padding with all sequences same length + features_same_length = [ + {"input_ids": [1, 2, 3], "labels": [4, 5, 6]}, + {"input_ids": [7, 8, 9], "labels": [10, 11, 12]}, + ] + + result = magistral_tokenizer.pad( + features_same_length, padding=True, return_tensors="pt" + ) + + # Check match when no padding is needed + assert result["input_ids"][0].tolist() == features_same_length[0]["input_ids"] + assert result["labels"][0].tolist() == features_same_length[0]["labels"] + + assert result["input_ids"][1].tolist() == features_same_length[1]["input_ids"] + assert result["labels"][1].tolist() == features_same_length[1]["labels"] + + # Test padding with max_length parameter + result = magistral_tokenizer.pad( + features, padding="max_length", max_length=5, return_tensors="pt" + ) + + # Should pad to max_length + assert result["input_ids"].shape == (2, 5) + assert result["labels"].shape == (2, 5) + + # Test numpy return type + result = magistral_tokenizer.pad(features, padding=True, return_tensors="np") + + # Should return numpy arrays + import numpy as np + + assert isinstance(result["input_ids"], np.ndarray) + assert isinstance(result["labels"], np.ndarray) + + # Test unsupported field rejection + features_unsupported = [ + {"input_ids": [1, 2, 3], "labels": [4, 5, 6], "unsupported_field": [7, 8, 9]}, + ] + + try: + magistral_tokenizer.pad(features_unsupported, padding=True, return_tensors="pt") + assert False, "Should have raised NotImplementedError" + except NotImplementedError as e: + assert "unsupported_field" in str(e) + + # Test token_type_ids rejection + features_token_type = [ + {"input_ids": [1, 2, 3], "labels": [4, 5, 6], "token_type_ids": [0, 0, 0]}, + ] + + try: + magistral_tokenizer.pad(features_token_type, padding=True, return_tensors="pt") + assert False, "Should have raised ValueError" + except ValueError as e: + assert "token_type_ids is not supported" in str(e) + + if __name__ == "__main__": unittest.main() From 0e6dac41450fa928096536db78fc171fceb96f0c Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 8 Jul 2025 17:47:03 +0700 Subject: [PATCH 07/13] feat: add tool calling test --- .../test_chat_templates_mistral.py | 312 ++++++++++++++++++ 1 file changed, 312 insertions(+) diff --git a/tests/prompt_strategies/test_chat_templates_mistral.py b/tests/prompt_strategies/test_chat_templates_mistral.py index 3567cf58d9..7e01070b40 100644 --- a/tests/prompt_strategies/test_chat_templates_mistral.py +++ b/tests/prompt_strategies/test_chat_templates_mistral.py @@ -422,5 +422,317 @@ def test_mistral_tokenizer_pad_method(magistral_tokenizer: "HFMistralTokenizer") assert "token_type_ids is not supported" in str(e) +def test_mistral_tool_calling(magistral_tokenizer: "HFMistralTokenizer"): + """Test comprehensive tool calling scenarios with the Mistral tokenizer.""" + from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy + + strategy = MistralStrategy( + MistralPrompter( + magistral_tokenizer, + chat_template=None, + message_property_mappings={"role": "role", "content": "content"}, + ), + tokenizer=magistral_tokenizer, + train_on_inputs=False, + train_on_eos="turn", + sequence_len=512, + roles_to_train=["assistant"], + ) + + # Test basic tool calling with single function + basic_tool_calling = { + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + }, + "required": ["location"], + }, + }, + }, + ], + "messages": [ + { + "role": "user", + "content": "What's the weather like in San Francisco?", + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call12345", + "type": "function", + "function": { + "name": "get_weather", + "arguments": { + "location": "San Francisco, CA", + }, + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call12345", + "name": "get_weather", + "content": "Sunny, 72°F", + }, + { + "role": "assistant", + "content": "The weather in San Francisco is sunny and 72°F.", + }, + ], + } + + res = strategy.tokenize_prompt(basic_tool_calling) + + # Basic validation + assert "input_ids" in res + assert "labels" in res + assert len(res["input_ids"]) > 0 + assert len(res["labels"]) == len(res["input_ids"]) + + # Decode and verify structure + decoded = magistral_tokenizer.decode(res["input_ids"]) + assert ( + '[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}}, "required": ["location"]}}}][/AVAILABLE_TOOLS]' + in decoded + ) + assert ( + '[TOOL_CALLS]get_weather[CALL_ID]call12345[ARGS]{"location": "San Francisco, CA"}' + in decoded + ) + assert "[TOOL_RESULTS]call12345[TOOL_CONTENT]Sunny, 72°F[/TOOL_RESULTS]" in decoded + assert "The weather in San Francisco is sunny and 72°F." in decoded + + # Test multiple tool calls in sequence + multi_tool_calling = { + "tools": [ + { + "type": "function", + "function": { + "name": "add_numbers", + "description": "Add two numbers together", + "parameters": { + "type": "object", + "properties": { + "a": {"type": "number", "description": "First number"}, + "b": {"type": "number", "description": "Second number"}, + }, + "required": ["a", "b"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "multiply_numbers", + "description": "Multiply two numbers", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "number", "description": "First number"}, + "y": {"type": "number", "description": "Second number"}, + }, + "required": ["x", "y"], + }, + }, + }, + ], + "messages": [ + { + "role": "user", + "content": "Add 5 and 3, then multiply the result by 2", + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call12345", + "type": "function", + "function": { + "name": "add_numbers", + "arguments": {"a": 5, "b": 3}, + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call12345", + "name": "add_numbers", + "content": "8", + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call23456", + "type": "function", + "function": { + "name": "multiply_numbers", + "arguments": {"x": 8, "y": 2}, + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call23456", + "name": "multiply_numbers", + "content": "16", + }, + { + "role": "assistant", + "content": "The result is 16. I first added 5 and 3 to get 8, then multiplied 8 by 2 to get 16.", + }, + ], + } + + res = strategy.tokenize_prompt(multi_tool_calling) + + # Validation + assert len(res["input_ids"]) > 0 + assert len(res["labels"]) == len(res["input_ids"]) + + decoded = magistral_tokenizer.decode(res["input_ids"]) + assert ( + '[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "add_numbers", "description": "Add two numbers together", "parameters": {"type": "object", "properties": {"a": {"type": "number", "description": "First number"}, "b": {"type": "number", "description": "Second number"}}, "required": ["a", "b"]}}}, {"type": "function", "function": {"name": "multiply_numbers", "description": "Multiply two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "First number"}, "y": {"type": "number", "description": "Second number"}}, "required": ["x", "y"]}}}][/AVAILABLE_TOOLS]' + in decoded + ) + assert ( + '[TOOL_CALLS]add_numbers[CALL_ID]call12345[ARGS]{"a": 5, "b": 3}' in decoded + ) + assert "[TOOL_RESULTS]call12345[TOOL_CONTENT]8[/TOOL_RESULTS]" in decoded + assert ( + '[TOOL_CALLS]multiply_numbers[CALL_ID]call23456[ARGS]{"x": 8, "y": 2}' + in decoded + ) + assert "[TOOL_RESULTS]call23456[TOOL_CONTENT]16[/TOOL_RESULTS]" in decoded + assert ( + "The result is 16. I first added 5 and 3 to get 8, then multiplied 8 by 2 to get 16." + in decoded + ) + + # Test tool calling with system message + system_tool_calling = { + "tools": [ + { + "type": "function", + "function": { + "name": "search_database", + "description": "Search for information in database", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"}, + }, + "required": ["query"], + }, + }, + }, + ], + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant with access to a database.", + }, + { + "role": "user", + "content": "Find information about Python programming", + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "search123", + "type": "function", + "function": { + "name": "search_database", + "arguments": {"query": "Python programming"}, + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "search123", + "name": "search_database", + "content": "Python is a high-level programming language known for its simplicity.", + }, + { + "role": "assistant", + "content": "Based on the database search, Python is a high-level programming language known for its simplicity and readability.", + }, + ], + } + + res = strategy.tokenize_prompt(system_tool_calling) + + # Validation + assert len(res["input_ids"]) > 0 + assert len(res["labels"]) == len(res["input_ids"]) + + decoded = magistral_tokenizer.decode(res["input_ids"]) + + assert ( + '[SYSTEM_PROMPT]You are a helpful assistant with access to a database.[/SYSTEM_PROMPT][AVAILABLE_TOOLS][{"type": "function", "function": {"name": "search_database", "description": "Search for information in database", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Search query"}}, "required": ["query"]}}}][/AVAILABLE_TOOLS]' + in decoded + ) + + # Test error handling - missing tool response + incomplete_tool_calling = { + "tools": [ + { + "type": "function", + "function": { + "name": "get_time", + "description": "Get current time", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ], + "messages": [ + { + "role": "user", + "content": "What time is it?", + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "time12345", + "type": "function", + "function": { + "name": "get_time", + "arguments": {}, + }, + } + ], + }, + { + "role": "assistant", + "content": "The current time is 12:00 PM.", + }, + ], + } + + from mistral_common.exceptions import InvalidMessageStructureException + + try: + strategy.tokenize_prompt(incomplete_tool_calling) + except InvalidMessageStructureException as e: + assert "Not the same number of function calls and responses" in str(e) + + if __name__ == "__main__": unittest.main() From 1a3c96a8c0b0f93b12a88d260d270f0fbf689354 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 8 Jul 2025 18:48:49 +0700 Subject: [PATCH 08/13] feat: add devstral tokenizer tests --- tests/prompt_strategies/conftest.py | 8 +++ .../test_chat_templates_mistral.py | 64 ++++++++++++------- 2 files changed, 49 insertions(+), 23 deletions(-) diff --git a/tests/prompt_strategies/conftest.py b/tests/prompt_strategies/conftest.py index d440565d2e..60b14d6523 100644 --- a/tests/prompt_strategies/conftest.py +++ b/tests/prompt_strategies/conftest.py @@ -164,6 +164,14 @@ def fixture_magistral_tokenizer(): return tokenizer +@pytest.fixture(name="devstral_tokenizer") +def fixture_devstral_tokenizer(): + from axolotl.utils.mistral_tokenizer import HFMistralTokenizer + + tokenizer = HFMistralTokenizer.from_pretrained("mistralai/Devstral-Small-2505") + return tokenizer + + @pytest.fixture(name="mistralv03_tokenizer_chat_template_jinja") def fixture_mistralv03_chat_template_jinja_w_system() -> str: return '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message["role"] == "user") != (ns.index % 2 == 0) %}\n {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message["role"] == "user" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- "[AVAILABLE_TOOLS] [" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- \'{"type": "function", "function": {\' }}\n {%- for key, val in tool.items() if key != "return" %}\n {%- if val is string %}\n {{- \'"\' + key + \'": "\' + val + \'"\' }}\n {%- else %}\n {{- \'"\' + key + \'": \' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- "}}" }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" }}\n {%- endif %}\n {%- endfor %}\n {{- "[/AVAILABLE_TOOLS]" }}\n {%- endif %}\n {%- if loop.first and system_message is defined %}\n {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}\n {%- else %}\n {{- "[INST] " + message["content"] + "[/INST]" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- "[TOOL_CALLS] [" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \', "id": "\' + tool_call.id + \'"}\' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message["role"] == "assistant" %}\n {{- " " + message["content"]|trim + eos_token}}\n {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- \'[TOOL_RESULTS] {"content": \' + content|string + ", " }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \'"call_id": "\' + message.tool_call_id + \'"}[/TOOL_RESULTS]\' }}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}\n' diff --git a/tests/prompt_strategies/test_chat_templates_mistral.py b/tests/prompt_strategies/test_chat_templates_mistral.py index 7e01070b40..31d81dd8d7 100644 --- a/tests/prompt_strategies/test_chat_templates_mistral.py +++ b/tests/prompt_strategies/test_chat_templates_mistral.py @@ -3,32 +3,50 @@ import unittest from typing import TYPE_CHECKING +import pytest + if TYPE_CHECKING: from axolotl.utils.mistral_tokenizer import HFMistralTokenizer -def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"): +# fmt: off +@pytest.mark.parametrize( + ("tokenizer_str", "assistant_toolcall_ids"), + ( + ("magistral_tokenizer", (9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2)), + ("devstral_tokenizer", (9, 1091, 19227, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 61906, 2811, 16753, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 4179, 1429, 1327, 2811, 1429, 19881, 1049, 1050, 1051, 1052, 1053, 1034, 27028, 2)) + ) +) +# fmt: on +def test_mistral_chat_template( + tokenizer_str: str, + assistant_toolcall_ids: tuple[int, ...], + request: pytest.FixtureRequest, +): + """Test chat template with the Magistral/Devstral tokenizer""" # pylint: disable=duplicate-code from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy + tokenizer: HFMistralTokenizer = request.getfixturevalue(tokenizer_str) + # check bos, eos, pad, unk are accessible properties - assert magistral_tokenizer.bos_token_id == 1 - assert magistral_tokenizer.eos_token_id == 2 - assert magistral_tokenizer.pad_token_id == 11 - assert magistral_tokenizer.unk_token_id == 0 + assert tokenizer.bos_token_id == 1 + assert tokenizer.eos_token_id == 2 + assert tokenizer.pad_token_id == 11 + assert tokenizer.unk_token_id == 0 - assert magistral_tokenizer.pad_token == "" - assert magistral_tokenizer.eos_token == "" - assert magistral_tokenizer.bos_token == "" - assert magistral_tokenizer.unk_token == "" + assert tokenizer.pad_token == "" + assert tokenizer.eos_token == "" + assert tokenizer.bos_token == "" + assert tokenizer.unk_token == "" strategy = MistralStrategy( MistralPrompter( - magistral_tokenizer, + tokenizer, chat_template=None, message_property_mappings={"role": "role", "content": "content"}, ), - tokenizer=magistral_tokenizer, + tokenizer=tokenizer, train_on_inputs=False, train_on_eos="turn", sequence_len=512, @@ -219,7 +237,7 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"): 1, # bos 5, 1091, 19227, 4994, 2811, 1429, 5165, 1897, 1429, 5165, 2811, 16753, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 14653, 2811, 1429, 10639, 2130, 1261, 2951, 1307, 1747, 1278, 60092, 1307, 1261, 2782, 1455, 1584, 4289, 2224, 1261, 4265, 6139, 39249, 1429, 26204, 2811, 16753, 4994, 2811, 1429, 6371, 1897, 1429, 48649, 2811, 16753, 12856, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 2782, 1317, 3081, 60092, 1307, 2613, 4179, 1429, 33319, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 9229, 6139, 1394, 1278, 60092, 2613, 47579, 1429, 15760, 2811, 12161, 12856, 1897, 1429, 33319, 4964, 2821, 27028, 6, # tool prompt 3, 46634, 1044, 1710, 1636, 5628, 1639, 1261, 44433, 1307, 2606, 1317, 5388, 1420, 54191, 2424, 1286, 8967, 1063, 15621, 1044, 2549, 30305, 2196, 3560, 1044, 1321, 2606, 1710, 1362, 2016, 8605, 2015, 1317, 5524, 118931, 2036, 32951, 1063, 1362, 2933, 2269, 12106, 1408, 101987, 1044, 6939, 1044, 1321, 9216, 1455, 2084, 3180, 1278, 8967, 119141, 1689, 5935, 1033, 4, # user - 9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2, # assistant tool calling + *assistant_toolcall_ids, # assistant tool calling 7, 19881, 1049, 1050, 1051, 1052, 1053, 19, 1049, 1044, 1050, 8, # tool result 1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant 2 # eos @@ -229,7 +247,7 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"): -100, # bos -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool prompt -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # user prompt - 9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2, # assistant tool calling + *assistant_toolcall_ids, # assistant tool calling -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool result 1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant 2 # eos @@ -237,7 +255,7 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"): # fmt: on # test chat template with tokenize=False - res = magistral_tokenizer.apply_chat_template( + res = tokenizer.apply_chat_template( [ {"role": "user", "content": "Hello, how are you?"}, {"role": "assistant", "content": "I'm doing great, thank you!"}, @@ -248,7 +266,7 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"): assert res == "[INST]Hello, how are you?[/INST]I'm doing great, thank you!" # test encode - res = magistral_tokenizer.encode("Hello, how are you?", add_special_tokens=True) + res = tokenizer.encode("Hello, how are you?", add_special_tokens=True) assert res == [ 1, # bos 22177, # Hello @@ -261,16 +279,16 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"): ] # test decode no skip special tokens - decoded_res = magistral_tokenizer.decode(res, skip_special_tokens=False) + decoded_res = tokenizer.decode(res, skip_special_tokens=False) assert decoded_res == "Hello, how are you?" # test decode skip special tokens - decoded_res = magistral_tokenizer.decode(res, skip_special_tokens=True) + decoded_res = tokenizer.decode(res, skip_special_tokens=True) assert decoded_res == "Hello, how are you?" # test encode no special tokens - res = magistral_tokenizer.encode("Hello, how are you?", add_special_tokens=False) + res = tokenizer.encode("Hello, how are you?", add_special_tokens=False) assert res == [ 22177, # Hello 1044, # , @@ -281,13 +299,13 @@ def test_magistral_chat_template(magistral_tokenizer: "HFMistralTokenizer"): ] # test convert ids to tokens - res = magistral_tokenizer.convert_ids_to_tokens(res) + res = tokenizer.convert_ids_to_tokens(res) # spacing are needed as we are converting without decoding assert res == ["Hello", ",", " how", " are", " you", "?"] -def test_mistral_tokenizer_pad_method(magistral_tokenizer: "HFMistralTokenizer"): - """Test the pad method with various field combinations.""" +def test_magistral_tokenizer_pad_method(magistral_tokenizer: "HFMistralTokenizer"): + """Test the MistralTokenizer pad method""" from axolotl.utils.collators.core import IGNORE_INDEX magistral_pad_token_id = 11 # taken from tokenizer.pad_token_id @@ -422,8 +440,8 @@ def test_mistral_tokenizer_pad_method(magistral_tokenizer: "HFMistralTokenizer") assert "token_type_ids is not supported" in str(e) -def test_mistral_tool_calling(magistral_tokenizer: "HFMistralTokenizer"): - """Test comprehensive tool calling scenarios with the Mistral tokenizer.""" +def test_magistral_tool_calling(magistral_tokenizer: "HFMistralTokenizer"): + """Test tool calling with the Magistral tokenizer""" from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy strategy = MistralStrategy( From b5c0cbfd684a00aff552aad82f24fe371e12e241 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 8 Jul 2025 18:51:56 +0700 Subject: [PATCH 09/13] fix: comma format --- tests/prompt_strategies/test_chat_templates_mistral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/prompt_strategies/test_chat_templates_mistral.py b/tests/prompt_strategies/test_chat_templates_mistral.py index 31d81dd8d7..21848e6ea2 100644 --- a/tests/prompt_strategies/test_chat_templates_mistral.py +++ b/tests/prompt_strategies/test_chat_templates_mistral.py @@ -14,7 +14,7 @@ ("tokenizer_str", "assistant_toolcall_ids"), ( ("magistral_tokenizer", (9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2)), - ("devstral_tokenizer", (9, 1091, 19227, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 61906, 2811, 16753, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 4179, 1429, 1327, 2811, 1429, 19881, 1049, 1050, 1051, 1052, 1053, 1034, 27028, 2)) + ("devstral_tokenizer", (9, 1091, 19227, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 61906, 2811, 16753, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 4179, 1429, 1327, 2811, 1429, 19881, 1049, 1050, 1051, 1052, 1053, 1034, 27028, 2)), ) ) # fmt: on From c817e8409475310d6a07d529b6bf4148a86e14f2 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 8 Jul 2025 20:45:42 +0700 Subject: [PATCH 10/13] chore: remove unused support_preprocessing as tokenizer is pickable now --- src/axolotl/datasets.py | 7 ------- src/axolotl/prompt_strategies/chat_template.py | 9 --------- src/axolotl/prompt_tokenizers.py | 8 -------- 3 files changed, 24 deletions(-) diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 28182b16f5..7c112c59e7 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -48,13 +48,6 @@ def process(self, dataset): features = dataset.features.keys() num_proc = min(64, self.process_count if self.process_count else os.cpu_count()) - # Disable multiprocessing if the tokenizer doesn't support it (e.g., mistral_common) - if not getattr(self.prompt_tokenizer, "supports_multiprocessing", True): - LOG.info( - "Disabling multiprocessing for tokenizer as it doesn't support it (e.g., mistral_common)" - ) - num_proc = 1 - map_kwargs = {} if self.prompt_tokenizer.supports_batched: map_kwargs["batched"] = True diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 831a6e9bd1..a9d26a650e 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -860,15 +860,6 @@ def __init__( # TODO: address this in the future with mistral-specific checks # self._validate_eot_and_eos_tokens() - # @property - # def supports_multiprocessing(self) -> bool: - # """ - # Whether this tokenizing strategy supports multiprocessing. - # mistral_common tokenizers cannot be pickled for multiprocessing. - # """ - - # return False - def find_first_eot_token(self, input_ids, start_idx): """Find the first EOT token in the input_ids starting from start_idx.""" # mistral-common tokenizer does not support eot_tokens diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index aae778ae8a..9ca645de3c 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -70,14 +70,6 @@ def tokenize_prompt(self, prompt): def supports_batched(self): return False - @property - def supports_multiprocessing(self): - """ - Whether this tokenizing strategy supports multiprocessing. - Should return False if the tokenizer has unpicklable objects. - """ - return True - def _tokenize( self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False ) -> BatchEncoding: From d04cafb94b23c1b83b8b85012dc606ad0a6fdcc1 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 8 Jul 2025 20:50:52 +0700 Subject: [PATCH 11/13] chore: update magistral doc --- examples/magistral/README.md | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/examples/magistral/README.md b/examples/magistral/README.md index 9fc3adf83b..0c39c061b2 100644 --- a/examples/magistral/README.md +++ b/examples/magistral/README.md @@ -21,13 +21,7 @@ pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja pip3 install --no-build-isolation -e '.[flash-attn]' ``` -2. Download the example config: - -```bash -axolotl fetch examples -``` - -3. Run the finetuning example: +2. Run the finetuning example: ```bash axolotl train examples/magistral/magistral-small-qlora.yaml @@ -42,7 +36,7 @@ Let us know how it goes. Happy finetuning! 🚀 - For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`. - You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. - Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). -- The dataset format is the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). +- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). ## Optimization Guides @@ -54,7 +48,7 @@ Let us know how it goes. Happy finetuning! 🚀 We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only. -The tokenizer does not work with `dataset.map` with multiprocessing, so we had to disable it. In addition, we do not support overriding tokens yet. +In addition, we do not support overriding tokens yet. ## Related Resources From d88afa860d300431087f4ef662dca65ec4a5525b Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 8 Jul 2025 20:52:09 +0700 Subject: [PATCH 12/13] feat: add devstral readme and example --- examples/devstral/README.md | 69 ++++++++++++++++++++++ examples/devstral/devstral-small-qlora.yml | 64 ++++++++++++++++++++ 2 files changed, 133 insertions(+) create mode 100644 examples/devstral/README.md create mode 100644 examples/devstral/devstral-small-qlora.yml diff --git a/examples/devstral/README.md b/examples/devstral/README.md new file mode 100644 index 0000000000..9dc5377bc6 --- /dev/null +++ b/examples/devstral/README.md @@ -0,0 +1,69 @@ +# Finetune Devstral with Axolotl + +Devstral Small is a 24B parameter opensource model from MistralAI found on HuggingFace [Devstral-Small-2505](https://huggingface.co/mistralai/Devstral-Small-2505). This guide shows how to fine-tune it with Axolotl with multi-turn conversations with proper masking. + +The model was fine-tuned ontop of [Mistral-Small-3.1](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503) without the vision layer and has a context of upto 128k tokens. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Devstral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html). + + Here is an example of how to install from main for pip: + +```bash +# Ensure you have Pytorch installed (Pytorch 2.6.0+) +git clone https://github.com/axolotl-ai-cloud/axolotl.git +cd axolotl + +pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja +pip3 install --no-build-isolation -e '.[flash-attn]' + +# Install the latest mistral-common from source +pip3 uninstall mistral-common +pip3 install git+https://github.com/mistralai/mistral-common.git@039465d + +``` + +2. Run the finetuning example: + +```bash +axolotl train examples/devstral/devstral-small-qlora.yml +``` + +This config uses about 21GB VRAM. + +Let us know how it goes. Happy finetuning! 🚀 + +### TIPS + +- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + +## Optimization Guides + +- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) +- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) +- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) +- [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) +- [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels) + +## Limitations + +We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only. + +In addition, we do not support overriding tokens yet. + +## Related Resources + +- [MistralAI Devstral Blog](https://mistral.ai/news/devstral) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) + + +## Future Work + +- Add parity to Preference Tuning, RL, Multi-modal, etc. +- Add parity to other tokenizer configs like overriding tokens. diff --git a/examples/devstral/devstral-small-qlora.yml b/examples/devstral/devstral-small-qlora.yml new file mode 100644 index 0000000000..d2c5930e3c --- /dev/null +++ b/examples/devstral/devstral-small-qlora.yml @@ -0,0 +1,64 @@ +base_model: mistralai/Devstral-Small-2505 + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +# Enable to use mistral-common tokenizer +tokenizer_use_mistral_common: true + +load_in_8bit: false +load_in_4bit: true + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/qlora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true +pad_to_sequence_len: true + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0 +lora_target_linear: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_torch +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +loss_watchdog_threshold: 5.0 +loss_watchdog_patience: 3 + +warmup_ratio: 0.05 +evals_per_epoch: 4 +saves_per_epoch: 1 + +weight_decay: 0.0 +special_tokens: From 87e99d60509756bb04fd9e5c75bdc849115d3df8 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 8 Jul 2025 21:08:25 +0700 Subject: [PATCH 13/13] chore: refactor error handling --- tests/prompt_strategies/test_chat_templates_mistral.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/prompt_strategies/test_chat_templates_mistral.py b/tests/prompt_strategies/test_chat_templates_mistral.py index 21848e6ea2..dcf5138d38 100644 --- a/tests/prompt_strategies/test_chat_templates_mistral.py +++ b/tests/prompt_strategies/test_chat_templates_mistral.py @@ -422,22 +422,16 @@ def test_magistral_tokenizer_pad_method(magistral_tokenizer: "HFMistralTokenizer {"input_ids": [1, 2, 3], "labels": [4, 5, 6], "unsupported_field": [7, 8, 9]}, ] - try: + with pytest.raises(NotImplementedError, match="unsupported_field"): magistral_tokenizer.pad(features_unsupported, padding=True, return_tensors="pt") - assert False, "Should have raised NotImplementedError" - except NotImplementedError as e: - assert "unsupported_field" in str(e) # Test token_type_ids rejection features_token_type = [ {"input_ids": [1, 2, 3], "labels": [4, 5, 6], "token_type_ids": [0, 0, 0]}, ] - try: + with pytest.raises(ValueError, match="token_type_ids is not supported"): magistral_tokenizer.pad(features_token_type, padding=True, return_tensors="pt") - assert False, "Should have raised ValueError" - except ValueError as e: - assert "token_type_ids is not supported" in str(e) def test_magistral_tool_calling(magistral_tokenizer: "HFMistralTokenizer"):