Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/axolotl/loaders/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
if cfg.processor_type:
processor_cls = getattr(transformers, cfg.processor_type)

# Build common kwargs for processor loading
processor_kwargs = {}
if cfg.revision_of_model:
processor_kwargs["revision"] = cfg.revision_of_model

if cfg.tokenizer_use_mistral_common:

def _patch_mistralcommontokenizer():
Expand All @@ -40,6 +45,7 @@ def _patch_mistralcommontokenizer():
if processor_cls == VoxtralProcessor:
return VoxtralProcessor.from_pretrained(
cfg.processor_config,
**processor_kwargs,
)

from axolotl.utils.mistral import Mistral3Processor
Expand All @@ -48,10 +54,12 @@ def _patch_mistralcommontokenizer():
tokenizer=tokenizer,
)

processor_kwargs["trust_remote_code"] = cfg.trust_remote_code or False
processor_kwargs["tokenizer"] = tokenizer

processor = processor_cls.from_pretrained(
cfg.processor_config,
trust_remote_code=cfg.trust_remote_code or False,
tokenizer=tokenizer,
**processor_kwargs,
)

# Attempt to load image size from processor if available
Expand Down
22 changes: 18 additions & 4 deletions src/axolotl/loaders/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@


def modify_tokenizer_files(
tokenizer_path: str, token_mappings: dict[int, str], output_dir: str
tokenizer_path: str,
token_mappings: dict[int, str],
output_dir: str,
revision: str = "main",
) -> str:
"""
Modify tokenizer files to replace added_tokens strings, save to output directory,
Expand All @@ -41,6 +44,7 @@ def modify_tokenizer_files(
tokenizer_path: Path or name of the original tokenizer
token_mappings: Dict mapping {token_id (int): new_token_string}
output_dir: Directory to save the modified tokenizer
revision: Model revision/branch/tag/commit to load from (HF Hub)

Returns:
Path to the modified tokenizer directory
Expand All @@ -53,7 +57,9 @@ def modify_tokenizer_files(

if is_local_main_process():
# Load the tokenizer
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
temp_tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, use_fast=True, revision=revision
)

# Save the tokenizer to the output directory
temp_tokenizer.save_pretrained(tokenizer_dir)
Expand Down Expand Up @@ -134,7 +140,10 @@ def _load_mistral_common_tokenizer(cfg: DictDefault):
from axolotl.utils.mistral import HFMistralTokenizer

# Load the HF-compatible wrapper around MistralTokenizer
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config)
kwargs = {}
if cfg.revision_of_model:
kwargs["revision"] = cfg.revision_of_model
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config, **kwargs)

return tokenizer

Expand All @@ -150,6 +159,8 @@ def _load_mistral_common_tokenizer(cfg: DictDefault):
if cfg.tokenizer_legacy is not None:
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
if cfg.revision_of_model:
tokenizer_kwargs["revision"] = cfg.revision_of_model

tokenizer_cls = AutoTokenizer
if cfg.tokenizer_type:
Expand All @@ -161,8 +172,11 @@ def _load_mistral_common_tokenizer(cfg: DictDefault):
# Apply token string overrides if specified
if cfg.added_tokens_overrides:
# Modify tokenizer files and get path to modified tokenizer
modify_kwargs = {"output_dir": cfg.output_dir}
if cfg.revision_of_model:
modify_kwargs["revision"] = cfg.revision_of_model
tokenizer_path = modify_tokenizer_files(
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
tokenizer_path, cfg.added_tokens_overrides, **modify_kwargs
)

tokenizer = tokenizer_cls.from_pretrained(
Expand Down
135 changes: 135 additions & 0 deletions tests/test_revision_parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""Tests for revision_of_model being passed to tokenizer and processor loaders."""

from unittest.mock import MagicMock, patch

from transformers import PreTrainedTokenizerBase

from axolotl.utils.dict import DictDefault


class TestRevisionParameter:
"""Tests for revision_of_model being passed to tokenizer and processor loaders."""

@patch("axolotl.loaders.tokenizer.load_model_config")
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
@patch(
"axolotl.loaders.patch_manager.PatchManager.apply_pre_tokenizer_load_patches"
)
def test_load_tokenizer_passes_revision(
self, _mock_patches, mock_auto_tokenizer, _mock_load_config
):
mock_tokenizer = MagicMock()
mock_tokenizer.__class__.__name__ = "MockTokenizer"
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer

cfg = DictDefault(
{
"tokenizer_config": "some-model",
"revision_of_model": "abc123",
}
)
from axolotl.loaders.tokenizer import load_tokenizer

load_tokenizer(cfg)

call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
assert call_kwargs.kwargs.get("revision") == "abc123"

@patch("axolotl.loaders.tokenizer.load_model_config")
@patch("axolotl.loaders.tokenizer.AutoTokenizer")
@patch(
"axolotl.loaders.patch_manager.PatchManager.apply_pre_tokenizer_load_patches"
)
def test_load_tokenizer_omits_revision_when_unset(
self, _mock_patches, mock_auto_tokenizer, _mock_load_config
):
mock_tokenizer = MagicMock()
mock_tokenizer.__class__.__name__ = "MockTokenizer"
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer

cfg = DictDefault(
{
"tokenizer_config": "some-model",
}
)
from axolotl.loaders.tokenizer import load_tokenizer

load_tokenizer(cfg)

call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
assert "revision" not in call_kwargs.kwargs

@patch("axolotl.loaders.tokenizer.AutoTokenizer")
@patch("axolotl.loaders.tokenizer.is_local_main_process", return_value=True)
@patch("axolotl.loaders.tokenizer.barrier")
def test_modify_tokenizer_files_passes_revision(
self, _mock_barrier, _mock_main, mock_auto_tokenizer, temp_dir
):
mock_tokenizer = MagicMock()
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer

from axolotl.loaders.tokenizer import modify_tokenizer_files

modify_tokenizer_files("some-model", {}, output_dir=temp_dir, revision="abc123")

call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
assert call_kwargs.kwargs.get("revision") == "abc123"

@patch("axolotl.loaders.tokenizer.AutoTokenizer")
@patch("axolotl.loaders.tokenizer.is_local_main_process", return_value=True)
@patch("axolotl.loaders.tokenizer.barrier")
def test_modify_tokenizer_files_defaults_revision_to_main(
self, _mock_barrier, _mock_main, mock_auto_tokenizer, temp_dir
):
mock_tokenizer = MagicMock()
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer

from axolotl.loaders.tokenizer import modify_tokenizer_files

modify_tokenizer_files("some-model", {}, output_dir=temp_dir)

call_kwargs = mock_auto_tokenizer.from_pretrained.call_args
assert call_kwargs.kwargs.get("revision") == "main"

@patch("axolotl.loaders.processor.AutoProcessor")
def test_load_processor_passes_revision(self, mock_auto_processor):
mock_processor = MagicMock()
mock_processor.size = {}
mock_auto_processor.from_pretrained.return_value = mock_processor

cfg = DictDefault(
{
"processor_config": "some-model",
"revision_of_model": "abc123",
"trust_remote_code": False,
}
)
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)

from axolotl.loaders.processor import load_processor

load_processor(cfg, tokenizer)

call_kwargs = mock_auto_processor.from_pretrained.call_args
assert call_kwargs.kwargs.get("revision") == "abc123"

@patch("axolotl.loaders.processor.AutoProcessor")
def test_load_processor_omits_revision_when_unset(self, mock_auto_processor):
mock_processor = MagicMock()
mock_processor.size = {}
mock_auto_processor.from_pretrained.return_value = mock_processor

cfg = DictDefault(
{
"processor_config": "some-model",
"trust_remote_code": False,
}
)
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)

from axolotl.loaders.processor import load_processor

load_processor(cfg, tokenizer)

call_kwargs = mock_auto_processor.from_pretrained.call_args
assert "revision" not in call_kwargs.kwargs