@@ -13,6 +13,7 @@ class SpecialVocab:
1313 merges : list [str ]
1414 add_special_token : dict [str , bool ]
1515 special_token_ids : dict [str , int ]
16+ chat_template : str | None
1617
1718 def __init__ (
1819 self , path : str | os .PathLike [str ], load_merges : bool = False ,
@@ -24,6 +25,7 @@ def __init__(
2425 self .n_vocab = n_vocab
2526 self .load_merges = load_merges
2627 self .merges = []
28+ self .chat_template = None
2729 if special_token_types is not None :
2830 self .special_token_types = special_token_types
2931 else :
@@ -67,6 +69,10 @@ def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
6769 if not quiet :
6870 print (f'gguf: Setting add_{ typ } _token to { value } ' )
6971 add_handler (value )
72+ if self .chat_template is not None :
73+ if not quiet :
74+ print (f'gguf: Setting chat_template to { self .chat_template } ' )
75+ gw .add_chat_template (self .chat_template )
7076
7177 def _load (self , path : Path ) -> None :
7278 self ._try_load_from_tokenizer_json (path )
@@ -132,6 +138,14 @@ def _try_load_from_tokenizer_json(self, path: Path) -> bool:
132138 return True
133139 with open (tokenizer_config_file , encoding = 'utf-8' ) as f :
134140 tokenizer_config = json .load (f )
141+ chat_template = tokenizer_config .get ('chat_template' )
142+ if chat_template is None or isinstance (chat_template , str ):
143+ self .chat_template = chat_template
144+ else :
145+ print (
146+ f'gguf: WARNING: Bad type for chat_template field in { tokenizer_config_file !r} - ignoring' ,
147+ file = sys .stderr
148+ )
135149 for typ in self .special_token_types :
136150 add_entry = tokenizer_config .get (f'add_{ typ } _token' )
137151 if isinstance (add_entry , bool ):
0 commit comments