|
29 | 29 | sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) |
30 | 30 | import gguf |
31 | 31 | from gguf.vocab import MistralTokenizerType, MistralVocab |
32 | | -from mistral_common.tokens.tokenizers.base import TokenizerVersion |
33 | | -from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN, DATASET_STD |
34 | | -from mistral_common.tokens.tokenizers.tekken import Tekkenizer |
35 | | -from mistral_common.tokens.tokenizers.sentencepiece import ( |
36 | | - SentencePieceTokenizer, |
37 | | -) |
| 32 | + |
| 33 | +try: |
| 34 | + from mistral_common.tokens.tokenizers.base import TokenizerVersion # pyright: ignore[reportMissingImports] |
| 35 | + from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN as _MISTRAL_COMMON_DATASET_MEAN, DATASET_STD as _MISTRAL_COMMON_DATASET_STD # pyright: ignore[reportMissingImports] |
| 36 | + from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports] |
| 37 | + from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports] |
| 38 | + SentencePieceTokenizer, |
| 39 | + ) |
| 40 | + |
| 41 | + _mistral_common_installed = True |
| 42 | + _mistral_import_error_msg = "" |
| 43 | +except ImportError: |
| 44 | + _MISTRAL_COMMON_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) |
| 45 | + _MISTRAL_COMMON_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) |
| 46 | + |
| 47 | + _mistral_common_installed = False |
| 48 | + TokenizerVersion = None |
| 49 | + Tekkenizer = None |
| 50 | + SentencePieceTokenizer = None |
| 51 | + _mistral_import_error_msg = ( |
| 52 | + "Mistral format requires `mistral-common` to be installed. Please run " |
| 53 | + "`pip install mistral-common[image,audio]` to install it." |
| 54 | + ) |
38 | 55 |
|
39 | 56 |
|
40 | 57 | logger = logging.getLogger("hf-to-gguf") |
@@ -107,6 +124,9 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, |
107 | 124 | type(self) is MmprojModel: |
108 | 125 | raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") |
109 | 126 |
|
| 127 | + if self.is_mistral_format and not _mistral_common_installed: |
| 128 | + raise ImportError(_mistral_import_error_msg) |
| 129 | + |
110 | 130 | self.dir_model = dir_model |
111 | 131 | self.ftype = ftype |
112 | 132 | self.fname_out = fname_out |
@@ -1363,8 +1383,8 @@ def set_gguf_parameters(self): |
1363 | 1383 | self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"])) |
1364 | 1384 |
|
1365 | 1385 | # preprocessor config |
1366 | | - image_mean = DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"] |
1367 | | - image_std = DATASET_STD if self.is_mistral_format else self.preprocessor_config["image_std"] |
| 1386 | + image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"] |
| 1387 | + image_std = _MISTRAL_COMMON_DATASET_STD if self.is_mistral_format else self.preprocessor_config["image_std"] |
1368 | 1388 |
|
1369 | 1389 | self.gguf_writer.add_vision_image_mean(image_mean) |
1370 | 1390 | self.gguf_writer.add_vision_image_std(image_std) |
@@ -2033,6 +2053,9 @@ def __init__(self, *args, **kwargs): |
2033 | 2053 | self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32) |
2034 | 2054 |
|
2035 | 2055 | def _set_vocab_mistral(self): |
| 2056 | + if not _mistral_common_installed: |
| 2057 | + raise ImportError(_mistral_import_error_msg) |
| 2058 | + |
2036 | 2059 | vocab = MistralVocab(self.dir_model) |
2037 | 2060 | logger.info( |
2038 | 2061 | f"Converting tokenizer {vocab.tokenizer_type} of size {vocab.vocab_size}." |
@@ -9212,7 +9235,7 @@ class MistralModel(LlamaModel): |
9212 | 9235 |
|
9213 | 9236 | @staticmethod |
9214 | 9237 | def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool): |
9215 | | - assert TokenizerVersion is not None, "mistral_common is not installed" |
| 9238 | + assert TokenizerVersion is not None and Tekkenizer is not None and SentencePieceTokenizer is not None, _mistral_import_error_msg |
9216 | 9239 | assert isinstance(vocab.tokenizer, (Tekkenizer, SentencePieceTokenizer)), ( |
9217 | 9240 | f"Expected Tekkenizer or SentencePieceTokenizer, got {type(vocab.tokenizer)}" |
9218 | 9241 | ) |
@@ -9594,6 +9617,8 @@ def main() -> None: |
9594 | 9617 | fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-") |
9595 | 9618 |
|
9596 | 9619 | is_mistral_format = args.mistral_format |
| 9620 | + if is_mistral_format and not _mistral_common_installed: |
| 9621 | + raise ImportError(_mistral_import_error_msg) |
9597 | 9622 | disable_mistral_community_chat_template = args.disable_mistral_community_chat_template |
9598 | 9623 |
|
9599 | 9624 | with torch.inference_mode(): |
|
0 commit comments