Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/axolotl/loaders/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
processor_kwargs = {}
if cfg.revision_of_model:
processor_kwargs["revision"] = cfg.revision_of_model
if cfg.processor_kwargs:
processor_kwargs.update(cfg.processor_kwargs)
Comment thread
thad0ctor marked this conversation as resolved.

if cfg.tokenizer_use_mistral_common:

Expand Down
22 changes: 22 additions & 0 deletions src/axolotl/utils/schemas/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ class ModelInputConfig(BaseModel):
processor_type: str | None = Field(
default=None, json_schema_extra={"description": "transformers processor class"}
)
processor_kwargs: dict[str, Any] | None = Field(
default=None,
json_schema_extra={
"description": "kwargs forwarded to the processor's from_pretrained(), overriding processor config (e.g. image_seq_length, min_pixels, etc.)."
},
)
Comment thread
thad0ctor marked this conversation as resolved.
tokenizer_save_jinja_files: bool | None = Field(
default=True, # match the default behavior from transformers
json_schema_extra={
Expand Down Expand Up @@ -107,6 +113,22 @@ def hint_trust_remote_code(cls, trust_remote_code):
)
return trust_remote_code

@field_validator("processor_kwargs")
@classmethod
def reject_reserved_processor_kwargs(cls, processor_kwargs):
if not processor_kwargs:
return processor_kwargs
reserved = {"revision", "trust_remote_code"}
conflicts = reserved.intersection(processor_kwargs)
if conflicts:
raise ValueError(
"Do not set reserved keys "
f"{sorted(conflicts)} inside `processor_kwargs`; "
"use the top-level `revision_of_model` / `trust_remote_code` "
"config keys instead."
)
return processor_kwargs

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should also add a check that this is not compatible with cfg.tokenizer_use_mistral_common

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added via 96d8536



class ModelOutputConfig(BaseModel):
"""model save configuration subset"""
Expand Down
5 changes: 5 additions & 0 deletions src/axolotl/utils/schemas/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,11 @@ def check_mistral_common_incompatible_options(cls, data):
"Setting chat_template is not supported with mistral-common tokenizer"
)

if data.get("processor_kwargs"):
raise ValueError(
"processor_kwargs is not supported with mistral-common tokenizer"
)

return data

@model_validator(mode="before")
Expand Down
105 changes: 105 additions & 0 deletions tests/test_revision_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,108 @@ def test_load_processor_omits_revision_when_unset(self, mock_auto_processor):

call_kwargs = mock_auto_processor.from_pretrained.call_args
assert "revision" not in call_kwargs.kwargs

@patch("axolotl.loaders.processor.AutoProcessor")
def test_load_processor_forwards_processor_kwargs(self, mock_auto_processor):
mock_processor = MagicMock()
mock_processor.size = {}
mock_auto_processor.from_pretrained.return_value = mock_processor

cfg = DictDefault(
{
"processor_config": "some-model",
"trust_remote_code": False,
"processor_kwargs": {
"image_seq_length": 1120,
"max_soft_tokens": 1120,
},
}
)
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)

from axolotl.loaders.processor import load_processor

load_processor(cfg, tokenizer)

call_kwargs = mock_auto_processor.from_pretrained.call_args
assert call_kwargs.kwargs.get("image_seq_length") == 1120
assert call_kwargs.kwargs.get("max_soft_tokens") == 1120

@patch("axolotl.loaders.processor.AutoProcessor")
def test_load_processor_omits_processor_kwargs_when_unset(
self, mock_auto_processor
):
mock_processor = MagicMock()
mock_processor.size = {}
mock_auto_processor.from_pretrained.return_value = mock_processor

cfg = DictDefault(
{
"processor_config": "some-model",
"trust_remote_code": False,
}
)
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)

from axolotl.loaders.processor import load_processor

load_processor(cfg, tokenizer)

call_kwargs = mock_auto_processor.from_pretrained.call_args
assert "image_seq_length" not in call_kwargs.kwargs
assert "max_soft_tokens" not in call_kwargs.kwargs

def test_processor_kwargs_schema_rejects_revision(self):
import pytest

from axolotl.utils.schemas.model import ModelInputConfig

with pytest.raises(ValueError, match="revision"):
ModelInputConfig(
base_model="some-model",
processor_kwargs={"revision": "abc123"},
)

def test_processor_kwargs_schema_rejects_trust_remote_code(self):
import pytest

from axolotl.utils.schemas.model import ModelInputConfig

with pytest.raises(ValueError, match="trust_remote_code"):
ModelInputConfig(
base_model="some-model",
processor_kwargs={"trust_remote_code": True},
)

def test_processor_kwargs_schema_accepts_valid_keys(self):
from axolotl.utils.schemas.model import ModelInputConfig

cfg = ModelInputConfig(
base_model="some-model",
processor_kwargs={"image_seq_length": 1120, "max_soft_tokens": 1120},
)
assert cfg.processor_kwargs == {
"image_seq_length": 1120,
"max_soft_tokens": 1120,
}

def test_processor_kwargs_schema_accepts_none_and_empty(self):
from axolotl.utils.schemas.model import ModelInputConfig

assert ModelInputConfig(base_model="x").processor_kwargs is None
assert (
ModelInputConfig(base_model="x", processor_kwargs={}).processor_kwargs == {}
)

def test_processor_kwargs_incompatible_with_mistral_common(self, min_base_cfg):
import pytest

from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault

cfg = min_base_cfg | DictDefault(
tokenizer_use_mistral_common=True,
processor_kwargs={"image_seq_length": 1120},
)
with pytest.raises(ValueError, match="processor_kwargs"):
validate_config(cfg)
Loading