diff --git a/docs/guides/sft.md b/docs/guides/sft.md index 8a67da85e8..54ebe4d341 100644 --- a/docs/guides/sft.md +++ b/docs/guides/sft.md @@ -27,7 +27,7 @@ uv run examples/run_sft.py \ SFT datasets in Reinforcer are encapsulated using classes. Each SFT data class is expected to have the following attributes: 1. `formatted_ds`: The dictionary of formatted datasets. This dictionary should contain `train` and `validation` splits, and each split should conform to the format described below. - 2. `task_spec`: The `TaskDataSpec` for this dataset. This should specify the name you choose for this dataset as well as the `custom_template` for this dataset. More on custom templates below. + 2. `task_spec`: The `TaskDataSpec` for this dataset. This should specify the name you choose for this dataset. SFT datasets are expected to follow the HuggingFace chat format. Refer to the [chat dataset document](../design-docs/chat-datasets.md) for details. If your data is not in the correct format, simply write a preprocessing script to convert the data into this format. [data/hf_datasets/squad.py](../../nemo_reinforcer/data/hf_datasets/squad.py) has an example: @@ -51,17 +51,21 @@ def format_squad(data): } ``` -Reinforcer SFT uses HuggingFace chat templates to format the individual examples. If you would like to use a custom template, create a string template in [jinja format](https://huggingface.co/docs/transformers/v4.34.0/en/chat_templating#how-do-i-create-a-chat-template) and pass it to the dataset's `TaskDataSpec`. For example, +Reinforcer SFT uses HuggingFace chat templates to format the individual examples. Three types of chat templates are supported, which can be configured via `tokenizer.chat_template` in your yaml config (see [sft.yaml](../../examples/configs/sft.yaml) for an example): + +1. Apply the tokenizer's default chat template. To use the tokenizer's default, either omit `tokenizer.chat_template` from the config altogether, or set `tokenizer.chat_template="default"`. +2. Use a "passthrough" template which simply concatenates all messages. This is desirable if the chat template has been applied to your dataset as an offline preprocessing step. In this case, you should set `tokenizer.chat_template` to None as follows: + ```yaml + tokenizer: + chat_template: NULL + ``` +3. Use a custom template: If you would like to use a custom template, create a string template in [jinja format](https://huggingface.co/docs/transformers/v4.34.0/en/chat_templating#how-do-i-create-a-chat-template), and add that string to the config. For example, + + ```yaml + tokenizer: + custom_template: "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer: '}}{%- elif message['role'] == 'assistant' %}{{message['content'].strip()}}{%- endif %}{% endfor %}" + ``` -```python -custom_template = ( - "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer: '}}{%- elif message['role'] == 'assistant' %}{{message['content'].strip()}}{%- endif %}{% endfor %}" -) -task_spec = TaskDataSpec( - task_name="squad", - custom_template=custom_template, -) -``` By default, NeMo-Reinforcer has support for `Squad` and `OpenAssistant` datasets. Both of these datasets are downloaded from HuggingFace and preprocessed on-the-fly, so there's no need to provide a path to any datasets on disk. diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 7a256621e4..ac37b0f0ac 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -25,7 +25,8 @@ checkpointing: policy: model_name: "meta-llama/Llama-3.2-1B-Instruct" - tokenizer_name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default + tokenizer: + name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default train_global_batch_size: 512 train_micro_batch_size: 4 generation_batch_size: 32 # Only used when generating using HF backend diff --git a/examples/configs/grpo_math_8B.yaml b/examples/configs/grpo_math_8B.yaml index 46c8855c6e..56d65a465c 100644 --- a/examples/configs/grpo_math_8B.yaml +++ b/examples/configs/grpo_math_8B.yaml @@ -7,7 +7,8 @@ grpo: policy: model_name: "meta-llama/Llama-3.1-8B-Instruct" - tokenizer_name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default + tokenizer: + name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default train_global_batch_size: 512 train_micro_batch_size: 1 generation_batch_size: 32 # Only used when generating using HF backend diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index bb8467165f..dad99b2479 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -18,7 +18,9 @@ checkpointing: policy: model_name: "meta-llama/Llama-3.2-1B" - tokenizer_name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default + tokenizer: + name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default + chat_template: "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer:'}}{%- elif message['role'] == 'assistant' %}{{' ' + message['content'].strip()}}{%- endif %}{% endfor %}" train_global_batch_size: 32 train_micro_batch_size: 1 max_total_sequence_length: 1024 @@ -35,6 +37,8 @@ policy: data: max_input_seq_length: ${policy.max_total_sequence_length} dataset_name: "squad" + add_bos: true + add_eos: true logger: log_dir: "logs" # Base directory for all logs diff --git a/examples/run_grpo_math.py b/examples/run_grpo_math.py index c6dea5609c..e147d167f9 100644 --- a/examples/run_grpo_math.py +++ b/examples/run_grpo_math.py @@ -64,7 +64,6 @@ def openinstructmath2_data_processor( problem = user_message[0]["content"] extra_env_info = {"ground_truth": user_message[1]["content"]} - template = task_data_spec.custom_template message_log: LLMMessageLogType = [] user_message = { "role": "user", @@ -72,7 +71,6 @@ def openinstructmath2_data_processor( } message = tokenizer.apply_chat_template( [user_message], - chat_template=template, tokenize=False, add_generation_prompt=True, add_special_tokens=False, @@ -254,7 +252,7 @@ def main(): init_ray() # setup tokenizer - tokenizer = get_tokenizer(config["policy"]["model_name"]) + tokenizer = get_tokenizer(config["policy"]["tokenizer"]) config["policy"]["generation"] = configure_generation_config( config["policy"]["generation"], tokenizer ) diff --git a/examples/run_sft.py b/examples/run_sft.py index 00d2a93900..875aa9a000 100644 --- a/examples/run_sft.py +++ b/examples/run_sft.py @@ -15,6 +15,7 @@ import argparse import os import pprint +from functools import partial from typing import Dict, Any from omegaconf import OmegaConf @@ -56,10 +57,16 @@ def sft_preprocessor( tokenizer, max_seq_length: int, idx: int, + add_bos: bool = True, + add_eos: bool = True, ) -> DatumSpec: """Process a datum dictionary for SFT training.""" message_log = get_formatted_message_log( - datum_dict["messages"], tokenizer, task_data_spec + datum_dict["messages"], + tokenizer, + task_data_spec, + add_bos_token=add_bos, + add_eos_token=add_eos, ) length = sum(len(m["token_ids"]) for m in message_log) @@ -90,6 +97,13 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig): data = hf_datasets.OasstDataset(output_dir="/tmp/open_assistant") elif data_cls == "squad": data = hf_datasets.SquadDataset() + elif data_cls == "prompt_response_dataset": + data = hf_datasets.PromptResponseDataset( + data_config["train_data_path"], + data_config["val_data_path"], + data_config["input_key"], + data_config["output_key"], + ) else: raise ValueError(f"Unknown dataset class: {data_cls}") print( @@ -104,7 +118,11 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig): train_dataset, tokenizer, sft_task_spec, - sft_preprocessor, + partial( + sft_preprocessor, + add_bos=data_config["add_bos"], + add_eos=data_config["add_eos"], + ), max_seq_length=data_config["max_input_seq_length"], ) @@ -112,7 +130,11 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig): val_dataset, tokenizer, sft_task_spec, - sft_preprocessor, + partial( + sft_preprocessor, + add_bos=data_config.get("add_bos", True), + add_eos=data_config.get("add_eos", True), + ), max_seq_length=data_config["max_input_seq_length"], ) @@ -151,7 +173,7 @@ def main(): init_ray() # setup tokenizer - tokenizer = get_tokenizer(config["policy"]["model_name"]) + tokenizer = get_tokenizer(config["policy"]["tokenizer"]) # setup data ( @@ -170,8 +192,7 @@ def main(): checkpointer, sft_save_state, master_config, - ) = setup(config, dataset, val_dataset) - + ) = setup(config, tokenizer, dataset, val_dataset) sft_train( policy, train_dataloader, diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index 0eda853375..eb0947011f 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -236,6 +236,7 @@ def setup( policy = HfPolicy( cluster=cluster, config=policy_config, + tokenizer=tokenizer, weights_path=Path(last_checkpoint_path) / "policy" / "weights" if last_checkpoint_path else None, @@ -628,6 +629,9 @@ def grpo_train( optimizer_path=os.path.join( checkpoint_path, "policy", "optimizer" ), + tokenizer_path=os.path.join( + checkpoint_path, "policy", "tokenizer" + ), save_hf=is_last_checkpoint, ) torch.save( diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index b5bb41aec5..45f4f08575 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from transformers import AutoTokenizer from pathlib import Path from typing import Optional, Tuple, TypedDict @@ -76,6 +77,7 @@ class MasterConfig(TypedDict): # ======================================================= def setup( master_config: MasterConfig, + tokenizer: AutoTokenizer, train_dataset: AllTaskProcessedDataset, val_dataset: AllTaskProcessedDataset, ) -> Tuple[ @@ -175,6 +177,7 @@ def setup( policy = HfPolicy( cluster=cluster, config=policy_config, + tokenizer=tokenizer, weights_path=Path(last_checkpoint_path) / "policy" / "weights" if last_checkpoint_path else None, @@ -416,6 +419,9 @@ def sft_train( optimizer_path=os.path.join( checkpoint_path, "policy", "optimizer" ), + tokenizer_path=os.path.join( + checkpoint_path, "policy", "tokenizer" + ), save_hf=is_last_checkpoint, ) torch.save( diff --git a/nemo_reinforcer/algorithms/utils.py b/nemo_reinforcer/algorithms/utils.py index a3c42e2a19..3cfd8569a9 100644 --- a/nemo_reinforcer/algorithms/utils.py +++ b/nemo_reinforcer/algorithms/utils.py @@ -20,6 +20,9 @@ from torch.masked import as_masked_tensor from transformers import AutoTokenizer +from nemo_reinforcer.data import hf_datasets +from nemo_reinforcer.models.policy import TokenizerConfig + def calculate_kl_penalty_joschu2020( logprobs_policy: torch.Tensor, logprobs_reference: torch.Tensor @@ -133,9 +136,77 @@ def set_seed(seed: int): torch.cuda.manual_seed_all(seed) -def get_tokenizer(model_name: str) -> AutoTokenizer: - """Get the tokenizer and set pad token to eos token if it is not already set.""" - tokenizer = AutoTokenizer.from_pretrained(model_name) +def get_tokenizer(tokenizer_config: TokenizerConfig) -> AutoTokenizer: + """Get the tokenizer and set pad token to eos token if it is not already set. + + This function initializes a tokenizer from the Hugging Face transformers library + and configures it with appropriate chat templates and padding tokens. + + Args: + tokenizer_config: A dictionary containing tokenizer configuration. + Required keys: + - name: The name or path of the pretrained tokenizer + Optional keys: + - chat_template: The chat template to use. Can be: + - None: Uses a passthrough template that just returns message content + - "default": Uses the tokenizer's default template + - A custom jinja2 template string + If not specified, the tokenizer's default template will be used. + + Returns: + AutoTokenizer: The configured tokenizer instance + + Examples: + ```{doctest} + >>> from transformers import AutoTokenizer + >>> from nemo_reinforcer.algorithms.utils import get_tokenizer + >>> # not specifying a chat template uses the tokenizer's default + >>> config = {"name": "meta-llama/Llama-3.2-1B-Instruct"} + >>> tokenizer = get_tokenizer(config) + No chat template provided, using tokenizer's default + >>> messages = [ + ... {"role": "system", "content": "You are a helpful AI assistant."}, + ... {"role": "user", "content": "Hello!"} + ... ] + >>> formatted = tokenizer.apply_chat_template(messages, tokenize=False) + >>> assert formatted == AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct").apply_chat_template(messages, tokenize=False) + + >>> # Using a passthrough template + >>> config = { + ... "name": "meta-llama/Llama-3.2-1B-Instruct", + ... "chat_template": None + ... } + >>> tokenizer = get_tokenizer(config) + Using passthrough chat template + >>> formatted = tokenizer.apply_chat_template(messages, tokenize=False) + >>> assert formatted == "".join(msg["content"] for msg in messages) + + >>> # Using a custom template + >>> config = { + ... "name": "meta-llama/Llama-3.2-1B-Instruct", + ... "chat_template": "{% for message in messages %}{{ ' START: ' + message['content'] + ' END.' }}{% endfor %}" + ... } + >>> tokenizer = get_tokenizer(config) + Using custom chat template + >>> formatted = tokenizer.apply_chat_template(messages, tokenize=False) + >>> assert formatted == " START: You are a helpful AI assistant. END. START: Hello! END." + ``` + """ + tokenizer = AutoTokenizer.from_pretrained(tokenizer_config["name"]) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token + if "chat_template" in tokenizer_config: + if tokenizer_config["chat_template"] is None: + print("Using passthrough chat template") + tokenizer.chat_template = ( + hf_datasets.COMMON_CHAT_TEMPLATES.passthrough_prompt_response + ) + elif tokenizer_config["chat_template"].lower() == "default": + print("Using tokenizer's default chat template") + else: + print("Using custom chat template") + tokenizer.chat_template = tokenizer_config["chat_template"] + else: + print("No chat template provided, using tokenizer's default") + return tokenizer diff --git a/nemo_reinforcer/data/__init__.py b/nemo_reinforcer/data/__init__.py index 09eaf35fb5..eba1eaf7ce 100644 --- a/nemo_reinforcer/data/__init__.py +++ b/nemo_reinforcer/data/__init__.py @@ -21,6 +21,8 @@ class DataConfig(TypedDict): system_prompt_file: Optional[str] dataset_name: str val_dataset_name: Optional[str] + add_bos: Optional[bool] + add_eos: Optional[bool] class MathDataConfig(DataConfig): diff --git a/nemo_reinforcer/data/hf_datasets/__init__.py b/nemo_reinforcer/data/hf_datasets/__init__.py index df1227140d..919f1a494e 100644 --- a/nemo_reinforcer/data/hf_datasets/__init__.py +++ b/nemo_reinforcer/data/hf_datasets/__init__.py @@ -12,7 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from nemo_reinforcer.data.hf_datasets.prompt_response_dataset import ( + PromptResponseDataset, +) from nemo_reinforcer.data.hf_datasets.oasst import OasstDataset from nemo_reinforcer.data.hf_datasets.squad import SquadDataset +from nemo_reinforcer.data.hf_datasets.chat_templates import COMMON_CHAT_TEMPLATES -__all__ = ["OasstDataset", "SquadDataset"] +__all__ = [ + "OasstDataset", + "PromptResponseDataset", + "SquadDataset", + "COMMON_CHAT_TEMPLATES", +] diff --git a/nemo_reinforcer/data/hf_datasets/interfaces.py b/nemo_reinforcer/data/hf_datasets/chat_templates.py similarity index 67% rename from nemo_reinforcer/data/hf_datasets/interfaces.py rename to nemo_reinforcer/data/hf_datasets/chat_templates.py index 63be96c5a7..8e7ad957a6 100644 --- a/nemo_reinforcer/data/hf_datasets/interfaces.py +++ b/nemo_reinforcer/data/hf_datasets/chat_templates.py @@ -12,28 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Any, Optional -from nemo_reinforcer.data.interfaces import TaskDataSpec - - +## a reference to frequently used chat templates for convenience class COMMON_CHAT_TEMPLATES: ### simple template which prepends a role header to the content simple_role_header = "{% for message in messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" - -class HfDataset: - """Interface for HuggingFace datasets.""" - - formatted_ds: Dict[str, Any] - - def __init__( - self, - dataset_name: str, - custom_template: Optional[ - str - ] = None, ## "None" means use HuggingFace's tokenizer's template - ): - self.task_spec = TaskDataSpec( - task_name=dataset_name, - custom_template=custom_template, - ) + ### passthrough template which just concatenates the content of the messages with no special tokens + passthrough_prompt_response = ( + "{% for message in messages %}{{ message['content'] }}{% endfor %}" + ) diff --git a/nemo_reinforcer/data/hf_datasets/oasst.py b/nemo_reinforcer/data/hf_datasets/oasst.py index 4525c0fa92..3d10c46e37 100644 --- a/nemo_reinforcer/data/hf_datasets/oasst.py +++ b/nemo_reinforcer/data/hf_datasets/oasst.py @@ -20,7 +20,8 @@ import copy from dataclasses import dataclass from typing import Optional -from nemo_reinforcer.data.hf_datasets.interfaces import HfDataset, COMMON_CHAT_TEMPLATES + +from nemo_reinforcer.data.interfaces import TaskDataSpec SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\n" @@ -118,11 +119,9 @@ def download_and_process_oasst(output_directory=".", seed=42, split_ratio=0.95): return formatted_ds -@dataclass -class OasstDataset(HfDataset): +class OasstDataset: def __init__(self, output_dir: str = "."): self.formatted_ds = download_and_process_oasst(output_dir) - super().__init__( - dataset_name="oasst", - custom_template=COMMON_CHAT_TEMPLATES.simple_role_header, + self.task_spec = TaskDataSpec( + task_name="OASST", ) diff --git a/nemo_reinforcer/data/hf_datasets/openmathinstruct2.py b/nemo_reinforcer/data/hf_datasets/openmathinstruct2.py index c3c5126263..b5a0bfa1bc 100644 --- a/nemo_reinforcer/data/hf_datasets/openmathinstruct2.py +++ b/nemo_reinforcer/data/hf_datasets/openmathinstruct2.py @@ -14,9 +14,10 @@ from typing import Optional from datasets import load_dataset -from nemo_reinforcer.data.hf_datasets.interfaces import HfDataset from dataclasses import dataclass +from nemo_reinforcer.data.interfaces import TaskDataSpec + def format_math(data): return { @@ -61,8 +62,7 @@ def prepare_openinstructmath2_dataset(split: str = "train_1M", seed=42, test_siz } -@dataclass -class OpenMathInstruct2Dataset(HfDataset): +class OpenMathInstruct2Dataset: def __init__( self, split: str = "train_1M", seed: int = 42, test_size: float = 0.05 ): @@ -82,6 +82,6 @@ def __init__( split=split, seed=seed, test_size=test_size ) - super().__init__( - dataset_name="OpenMathInstruct-2", + self.task_spec = TaskDataSpec( + task_name="OpenMathInstruct-2", ) diff --git a/nemo_reinforcer/data/hf_datasets/prompt_response_dataset.py b/nemo_reinforcer/data/hf_datasets/prompt_response_dataset.py new file mode 100644 index 0000000000..928af4fdff --- /dev/null +++ b/nemo_reinforcer/data/hf_datasets/prompt_response_dataset.py @@ -0,0 +1,51 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datasets import load_dataset +from nemo_reinforcer.data.interfaces import TaskDataSpec + + +class PromptResponseDataset: + def __init__( + self, + train_ds_path: str, + val_ds_path: str, + input_key: str = "input", + output_key: str = "output", + ): + train_original_dataset = load_dataset("json", data_files=train_ds_path)["train"] + val_original_dataset = load_dataset("json", data_files=val_ds_path)["train"] + + self.input_key = input_key + self.output_key = output_key + + formatted_train_dataset = train_original_dataset.map(self.add_messages_key) + formatted_val_dataset = val_original_dataset.map(self.add_messages_key) + + self.formatted_ds = { + "train": formatted_train_dataset, + "validation": formatted_val_dataset, + } + + self.task_spec = TaskDataSpec( + "json_dataset", + ) + + def add_messages_key(self, example): + return { + "messages": [ + {"role": "user", "content": example[self.input_key]}, + {"role": "assistant", "content": example[self.output_key]}, + ] + } diff --git a/nemo_reinforcer/data/hf_datasets/squad.py b/nemo_reinforcer/data/hf_datasets/squad.py index 4ba3e87949..3bc257c88a 100644 --- a/nemo_reinforcer/data/hf_datasets/squad.py +++ b/nemo_reinforcer/data/hf_datasets/squad.py @@ -14,7 +14,8 @@ from typing import Optional from datasets import load_dataset -from nemo_reinforcer.data.hf_datasets.interfaces import HfDataset + +from nemo_reinforcer.data.interfaces import TaskDataSpec def format_squad(data): @@ -36,14 +37,10 @@ def format_squad(data): } -class SquadDataset(HfDataset): +class SquadDataset: def __init__(self): original_ds = load_dataset("rajpurkar/squad") self.formatted_ds = original_ds.map(format_squad) - - custom_template = "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer:'}}{%- elif message['role'] == 'assistant' %}{{' ' + message['content'].strip()}}{%- endif %}{% endfor %}" - - super().__init__( - dataset_name="squad", - custom_template=custom_template, + self.task_spec = TaskDataSpec( + task_name="SQuAD", ) diff --git a/nemo_reinforcer/data/interfaces.py b/nemo_reinforcer/data/interfaces.py index ade6f145e8..6ae1152be7 100644 --- a/nemo_reinforcer/data/interfaces.py +++ b/nemo_reinforcer/data/interfaces.py @@ -42,7 +42,6 @@ class TaskDataSpec: prompt_file: Optional[os.PathLike] = None system_prompt_file: Optional[Union[str, os.PathLike]] = None - custom_template: Optional[Union[str, os.PathLike]] = None def __post_init__(self): def load_prompt_file( @@ -66,7 +65,6 @@ def copy_defaults(self, from_spec: "TaskDataSpec"): default_attrs = { "system_prompt": from_spec.system_prompt, "prompt": from_spec.prompt, - "custom_template": from_spec.custom_template, } for attr_name, default_value in default_attrs.items(): diff --git a/nemo_reinforcer/data/llm_message_utils.py b/nemo_reinforcer/data/llm_message_utils.py index 93f6ddc50b..13fc7dbd30 100644 --- a/nemo_reinforcer/data/llm_message_utils.py +++ b/nemo_reinforcer/data/llm_message_utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import Dict, List import torch @@ -342,6 +343,8 @@ def get_formatted_message_log( message_log: LLMMessageLogType, tokenizer, task_data_spec: TaskDataSpec, + add_bos_token: bool = True, + add_eos_token: bool = True, ) -> LLMMessageLogType: """Format and tokenize chat messages using the specified template. @@ -355,12 +358,10 @@ def get_formatted_message_log( """ new_message_log = [] prev_formatted_message = "" - template = task_data_spec.custom_template for i, message in enumerate(message_log): formatted_message = tokenizer.apply_chat_template( message_log[: i + 1], - chat_template=template, add_generation_prompt=False, tokenize=False, add_special_tokens=False, @@ -375,16 +376,24 @@ def get_formatted_message_log( ## pull out the chunk corresponding to the current message message_chunk = formatted_message[prev_message_len_no_eos:] - if tokenizer.bos_token is not None: - if i == 0 and not message_chunk.startswith(tokenizer.bos_token): - message_chunk = tokenizer.bos_token + message_chunk + if i == 0: + if add_bos_token: + if tokenizer.bos_token is None: + warnings.warn( + "add_bos_token is True but the tokenizer does not have a BOS token. Skipping BOS token addition." + ) + elif not message_chunk.startswith(tokenizer.bos_token): + message_chunk = tokenizer.bos_token + message_chunk if i == len(message_log) - 1: message_chunk = message_chunk.rstrip("\n") - if tokenizer.eos_token is not None and not message_chunk.endswith( - tokenizer.eos_token - ): - message_chunk += tokenizer.eos_token + if add_eos_token: + if tokenizer.eos_token is None: + warnings.warn( + "add_eos_token is True but the tokenizer does not have an EOS token. Skipping EOS token addition." + ) + elif not message_chunk.endswith(tokenizer.eos_token): + message_chunk += tokenizer.eos_token new_message = message.copy() new_message["token_ids"] = tokenizer( diff --git a/nemo_reinforcer/models/policy/__init__.py b/nemo_reinforcer/models/policy/__init__.py index 24390b9670..43c9422e2a 100644 --- a/nemo_reinforcer/models/policy/__init__.py +++ b/nemo_reinforcer/models/policy/__init__.py @@ -17,9 +17,14 @@ from nemo_reinforcer.models.generation.interfaces import GenerationConfig +class TokenizerConfig(TypedDict): + name: str + chat_template: str + + class PolicyConfig(TypedDict): model_name: str - tokenizer_name: str + tokenizer: TokenizerConfig train_global_batch_size: int train_micro_batch_size: int learning_rate: float diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 9ba65aca38..0141a382ef 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -26,7 +26,7 @@ MixedPrecision, ) from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, AutoTokenizer from nemo_reinforcer.algorithms.interfaces import LossFunction from nemo_reinforcer.algorithms.utils import get_tokenizer @@ -68,6 +68,7 @@ def __repr__(self): def __init__( self, config: PolicyConfig, + tokenizer: AutoTokenizer, weights_path: Optional[str] = None, optimizer_path: Optional[str] = None, init_optimizer: bool = True, @@ -79,7 +80,6 @@ def __init__( rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() model_name = self.cfg["model_name"] - tokenizer_name = self.cfg["tokenizer_name"] if self.cfg["precision"] == "float32": self.dtype = torch.float32 elif self.cfg["precision"] == "bfloat16": @@ -101,7 +101,8 @@ def __init__( ) else: self.reference_model = None - self.tokenizer = get_tokenizer(tokenizer_name) + + self.tokenizer = tokenizer # ------------------------------------------------ # 3) Move to GPU + Composable FSDP @@ -825,6 +826,7 @@ def save_checkpoint( self, weights_path: str, optimizer_path: Optional[str] = None, + tokenizer_path: Optional[str] = None, save_torch_dist: bool = True, save_hf: bool = False, ): @@ -856,6 +858,8 @@ def save_checkpoint( optimizer=self.optimizer if optimizer_path else None, scheduler=self.scheduler if optimizer_path else None, optimizer_path=optimizer_path, + tokenizer=self.tokenizer if tokenizer_path else None, + tokenizer_path=tokenizer_path, save_torch_dist=save_torch_dist, save_hf=save_hf, ) @@ -881,6 +885,7 @@ def __init__( self, cluster: RayVirtualCluster, config: PolicyConfig, + tokenizer: AutoTokenizer, name_prefix: str = "hf_policy", workers_per_node: Optional[Union[int, List[int]]] = None, init_optimizer: bool = True, @@ -896,6 +901,7 @@ def __init__( worker_builder = RayWorkerBuilder( HfPolicyWorker, config, + tokenizer=tokenizer, init_optimizer=init_optimizer, weights_path=weights_path, optimizer_path=optimizer_path, @@ -1087,6 +1093,7 @@ def save_checkpoint( self, weights_path: str, optimizer_path: Optional[str] = None, + tokenizer_path: Optional[str] = None, save_torch_dist: bool = True, save_hf: bool = False, ): @@ -1095,6 +1102,7 @@ def save_checkpoint( "save_checkpoint", weights_path, optimizer_path, + tokenizer_path, save_torch_dist, save_hf, respect_tied_workers=True, diff --git a/nemo_reinforcer/utils/native_checkpoint.py b/nemo_reinforcer/utils/native_checkpoint.py index aa6c4f9f01..d163b2afd0 100644 --- a/nemo_reinforcer/utils/native_checkpoint.py +++ b/nemo_reinforcer/utils/native_checkpoint.py @@ -137,6 +137,8 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[Any] = None, optimizer_path: Optional[str] = None, + tokenizer: Optional[Any] = None, + tokenizer_path: Optional[str] = None, save_torch_dist: bool = True, save_hf: bool = False, ) -> None: @@ -175,6 +177,13 @@ def save_checkpoint( optimizer_state = {"optim": OptimizerState(model, optimizer, scheduler)} dcp.save(optimizer_state, checkpoint_id=optimizer_path) + if tokenizer is not None: + if tokenizer_path is None: + raise ValueError( + "tokenizer_path must be provided when saving tokenizer state" + ) + tokenizer.save_pretrained(tokenizer_path) + def load_checkpoint( model, diff --git a/tests/unit/algorithms/test_utils.py b/tests/unit/algorithms/test_utils.py new file mode 100755 index 0000000000..c3cc381fe9 --- /dev/null +++ b/tests/unit/algorithms/test_utils.py @@ -0,0 +1,127 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from datetime import datetime +from transformers import AutoTokenizer +from nemo_reinforcer.algorithms.utils import get_tokenizer +from nemo_reinforcer.data.hf_datasets.chat_templates import COMMON_CHAT_TEMPLATES + + +@pytest.fixture +def conversation_messages(): + """Fixture providing a multi-turn conversation for testing chat templates""" + return [ + {"role": "system", "content": "You are a helpful AI assistant."}, + {"role": "user", "content": "What's the weather like today?"}, + { + "role": "assistant", + "content": "I don't have access to real-time weather data.", + }, + {"role": "user", "content": "Can you help me with something else then?"}, + {"role": "assistant", "content": "Of course! What would you like help with?"}, + ] + + +def get_expected_llama_format(messages): + """Generate the expected output format for Llama's chat template""" + # Extract the date from the formatted output + # Get current date + current_date = datetime.now() + formatted_date = current_date.strftime("%d %b %Y") + + # Extract system message if present + if messages[0]["role"] == "system": + system_message = messages[0]["content"].strip() + messages = messages[1:] + else: + system_message = "" + + # Start with BOS token and system header + expected = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n" + expected += "Cutting Knowledge Date: December 2023\n" + expected += f"Today Date: {formatted_date}\n\n" + expected += f"{system_message}<|eot_id|>" + + # Add each message + for message in messages: + if message["role"] not in ["ipython", "tool"]: + expected += f"<|start_header_id|>{message['role']}<|end_header_id|>\n\n" + expected += f"{message['content'].strip()}<|eot_id|>" + + return expected + + +def get_format_with_simple_role_header(messages): + message = "<|begin_of_text|>" + for msg in messages: + message += ( + "<|start_header_id|>" + + msg["role"] + + "<|end_header_id|>\n\n" + + msg["content"].strip() + + "<|eot_id|>" + ) + return message + + +def test_get_tokenizer_no_chat_template(conversation_messages): + """Test get_tokenizer when no chat template is specified in config""" + config = {"name": "meta-llama/Llama-3.2-1B-Instruct"} + tokenizer = get_tokenizer(config) + + # Verify that the tokenizer's default template is used + formatted = tokenizer.apply_chat_template(conversation_messages, tokenize=False) + + expected = get_expected_llama_format(conversation_messages) + assert formatted == expected + + +def test_get_tokenizer_default_chat_template(conversation_messages): + """Test get_tokenizer when chat_template is 'default' in config""" + config = {"name": "meta-llama/Llama-3.2-1B-Instruct", "chat_template": "default"} + tokenizer = get_tokenizer(config) + + # Verify that the tokenizer's default template is used + formatted = tokenizer.apply_chat_template(conversation_messages, tokenize=False) + expected = get_expected_llama_format(conversation_messages) + assert formatted == expected + + +def test_get_tokenizer_null_chat_template(conversation_messages): + """Test get_tokenizer when chat_template is None in config""" + config = {"name": "meta-llama/Llama-3.2-1B-Instruct", "chat_template": None} + tokenizer = get_tokenizer(config) + + # Verify that the passthrough template is used + formatted = tokenizer.apply_chat_template(conversation_messages, tokenize=False) + + expected = "".join(msg["content"] for msg in conversation_messages) + + assert formatted == expected + + +def test_get_tokenizer_custom_jinja_template(conversation_messages): + """Test get_tokenizer when a custom jinja template is specified""" + custom_template = COMMON_CHAT_TEMPLATES.simple_role_header + config = { + "name": "meta-llama/Llama-3.2-1B-Instruct", + "chat_template": custom_template, + } + tokenizer = get_tokenizer(config) + + # Verify that the custom template is used + formatted = tokenizer.apply_chat_template(conversation_messages, tokenize=False) + expected = get_format_with_simple_role_header(conversation_messages) + assert formatted == expected diff --git a/tests/unit/data/hf_datasets/test_prompt_response.py b/tests/unit/data/hf_datasets/test_prompt_response.py new file mode 100644 index 0000000000..12b98d7fb2 --- /dev/null +++ b/tests/unit/data/hf_datasets/test_prompt_response.py @@ -0,0 +1,104 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import tempfile +import json +from nemo_reinforcer.data.hf_datasets.chat_templates import COMMON_CHAT_TEMPLATES +from nemo_reinforcer.data.hf_datasets.prompt_response_dataset import ( + PromptResponseDataset, +) +from transformers import AutoTokenizer + + +@pytest.fixture +def sample_data(request): + input_key = request.param[0] + output_key = request.param[1] + + train_data = [ + {input_key: "Hello", output_key: "Hi there!"}, + {input_key: "How are you?", output_key: "I'm good, thanks!"}, + ] + val_data = [ + {input_key: "What's up?", output_key: "Not much!"}, + {input_key: "Bye", output_key: "Goodbye!"}, + ] + + # Create temporary files for train and validation data + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as train_file: + json.dump(train_data, train_file) + train_path = train_file.name + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as val_file: + json.dump(val_data, val_file) + val_path = val_file.name + + return train_path, val_path + + +@pytest.mark.parametrize("sample_data", [("input", "output")], indirect=True) +def test_dataset_initialization(sample_data): + train_path, val_path = sample_data + dataset = PromptResponseDataset(train_path, val_path) + + assert dataset.input_key == "input" + assert dataset.output_key == "output" + assert "train" in dataset.formatted_ds + assert "validation" in dataset.formatted_ds + + +@pytest.mark.parametrize("sample_data", [("question", "answer")], indirect=True) +def test_custom_keys(sample_data): + train_path, val_path = sample_data + dataset = PromptResponseDataset( + train_path, val_path, input_key="question", output_key="answer" + ) + + assert dataset.input_key == "question" + assert dataset.output_key == "answer" + + +@pytest.mark.parametrize("sample_data", [("question", "answer")], indirect=True) +def test_message_formatting(sample_data): + train_path, val_path = sample_data + dataset = PromptResponseDataset( + train_path, val_path, input_key="question", output_key="answer" + ) + + first_example = dataset.formatted_ds["train"][0] + + assert first_example["messages"][0]["role"] == "user" + assert first_example["messages"][0]["content"] == "Hello" + assert first_example["messages"][1]["role"] == "assistant" + assert first_example["messages"][1]["content"] == "Hi there!" + + chat_template = COMMON_CHAT_TEMPLATES.passthrough_prompt_response + tokenizer = AutoTokenizer.from_pretrained("Meta-Llama/Meta-Llama-3-8B-Instruct") + + combined_message = tokenizer.apply_chat_template( + first_example["messages"], + chat_template=chat_template, + tokenize=False, + add_generation_prompt=False, + add_special_tokens=False, + ) + + assert combined_message == "".join( + message["content"] for message in first_example["messages"] + ) diff --git a/tests/unit/data/test_hf_datasets.py b/tests/unit/data/hf_datasets/test_squad.py similarity index 80% rename from tests/unit/data/test_hf_datasets.py rename to tests/unit/data/hf_datasets/test_squad.py index 5fe634a2a7..47511363b5 100644 --- a/tests/unit/data/test_hf_datasets.py +++ b/tests/unit/data/hf_datasets/test_squad.py @@ -14,6 +14,7 @@ import pytest from transformers import AutoTokenizer +from nemo_reinforcer.data.hf_datasets.chat_templates import COMMON_CHAT_TEMPLATES from nemo_reinforcer.data.hf_datasets.squad import SquadDataset @@ -31,10 +32,12 @@ def test_squad_dataset(): assert example["messages"][1]["role"] == "user" assert example["messages"][2]["role"] == "assistant" + template = "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer:'}}{%- elif message['role'] == 'assistant' %}{{' ' + message['content'].strip()}}{%- endif %}{% endfor %}" + ## check that applying chat template works as expected default_templated = tokenizer.apply_chat_template( example["messages"], - chat_template=squad_dataset.task_spec.custom_template, + chat_template=template, tokenize=False, add_generation_prompt=False, add_special_tokens=False, diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 946a6bf3b2..d86f4983cc 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -30,7 +30,9 @@ basic_vllm_test_config: VllmConfig = { "backend": "vllm", "model_name": "meta-llama/Llama-3.2-1B", # Small model for testing - "tokenizer_name": "meta-llama/Llama-3.2-1B", + "tokenizer": { + "name": "meta-llama/Llama-3.2-1B", + }, "dtype": "bfloat16", "max_new_tokens": 10, "temperature": 1.0, @@ -48,7 +50,9 @@ # Create HF-specific config with required parameters basic_hf_test_config: PolicyConfig = { "model_name": basic_vllm_test_config["model_name"], - "tokenizer_name": basic_vllm_test_config["tokenizer_name"], + "tokenizer": { + "name": basic_vllm_test_config["tokenizer"]["name"], + }, # Required training parameters "train_global_batch_size": 1, "train_micro_batch_size": 1, @@ -87,8 +91,7 @@ def cluster(): @pytest.fixture(scope="function") def tokenizer(): """Initialize tokenizer for the test model.""" - model_name = basic_vllm_test_config["model_name"] - tokenizer = get_tokenizer(model_name) + tokenizer = get_tokenizer(basic_vllm_test_config["tokenizer"]) return tokenizer @@ -251,7 +254,7 @@ def test_vllm_worker_seed_behavior(cluster, tokenizer): from nemo_reinforcer.models.policy.hf_policy import HfPolicy hf_config = basic_hf_test_config.copy() - hf_policy = HfPolicy(cluster, hf_config) + hf_policy = HfPolicy(cluster, hf_config, tokenizer) print(f"refitting vllm policy...") ipc_handles = hf_policy.get_weights_ipc_handles() @@ -415,7 +418,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): vllm_policy.finish_generation() print("Creating HF policy...") - hf_policy = HfPolicy(cluster, hf_config) + hf_policy = HfPolicy(cluster, hf_config, tokenizer) print(f"refitting vllm policy...") ipc_handles = hf_policy.get_weights_ipc_handles() @@ -653,7 +656,7 @@ def test_vllm_weight_update_and_prefix_cache_reset( hf_policy = None try: print(f"Creating HF policy for TP={tensor_parallel_size}...") - hf_policy = HfPolicy(cluster, hf_config) + hf_policy = HfPolicy(cluster, hf_config, tokenizer) print(f"Creating vLLM policy for TP={tensor_parallel_size}...") vllm_policy = VllmGeneration(cluster, vllm_config) @@ -761,7 +764,7 @@ def test_vllm_generation_with_stop(cluster, test_input_data, tokenizer, is_eval) print("Creating HF policy...") hf_config = basic_hf_test_config.copy() - hf_policy = HfPolicy(cluster, hf_config) + hf_policy = HfPolicy(cluster, hf_config, tokenizer) print(f"refitting vllm policy...") ipc_handles = hf_policy.get_weights_ipc_handles() diff --git a/tests/unit/models/policy/test_hf_ray_policy.py b/tests/unit/models/policy/test_hf_ray_policy.py index 449ed016d5..f56d3b8df3 100644 --- a/tests/unit/models/policy/test_hf_ray_policy.py +++ b/tests/unit/models/policy/test_hf_ray_policy.py @@ -29,7 +29,7 @@ basic_llama_test_config: PolicyConfig = { "model_name": "meta-llama/Llama-3.2-1B", - "tokenizer_name": "meta-llama/Llama-3.2-1B", + "tokenizer": {"name": "meta-llama/Llama-3.2-1B"}, "generation_batch_size": 1, # Small batch size for testing "train_global_batch_size": 4, "train_micro_batch_size": 1, @@ -74,8 +74,7 @@ def gc_collect(): @pytest.fixture(scope="function") def tokenizer(): """Initialize tokenizer for the test model.""" - model_name = basic_llama_test_config["model_name"] - tokenizer = get_tokenizer(model_name) + tokenizer = get_tokenizer(basic_llama_test_config["tokenizer"]) return tokenizer @@ -148,7 +147,7 @@ def policy_setup(tokenizer): config["generation"] = configure_generation_config(config["generation"], tokenizer) print("Creating HfPolicy...") - policy = HfPolicy(cluster=cluster, config=config) + policy = HfPolicy(cluster=cluster, config=config, tokenizer=tokenizer) yield policy, cluster @@ -240,7 +239,7 @@ def test_hf_policy_init(policy_setup): @pytest.fixture -def training_setup(): +def training_setup(tokenizer): """Setup and teardown specifically for training tests.""" policy = None cluster = None @@ -263,7 +262,12 @@ def training_setup(): config = basic_llama_test_config print("Creating training HfPolicy...") - policy = HfPolicy(cluster=cluster, config=config, init_reference_model=False) + policy = HfPolicy( + cluster=cluster, + config=config, + init_reference_model=False, + tokenizer=tokenizer, + ) # Create a test batch print("Creating test batch...") @@ -366,7 +370,10 @@ def generation_setup(request, test_input_data, tokenizer): print("Creating generation HfPolicy...") policy = HfPolicy( - cluster=cluster, config=config, init_reference_model=request.param + cluster=cluster, + config=config, + tokenizer=tokenizer, + init_reference_model=request.param, ) # Create a test batch @@ -584,7 +591,7 @@ def test_hf_policy_generation_with_stop(test_input_data, tokenizer): ) # Create policy - policy = HfPolicy(cluster=cluster, config=config) + policy = HfPolicy(cluster=cluster, config=config, tokenizer=tokenizer) # Call prepare_for_generation if available print("Preparing for generation...") diff --git a/tests/unit/utils/test_native_checkpoint.py b/tests/unit/utils/test_native_checkpoint.py index a5b9c9d637..7da9ee6947 100755 --- a/tests/unit/utils/test_native_checkpoint.py +++ b/tests/unit/utils/test_native_checkpoint.py @@ -17,6 +17,7 @@ import torch from tempfile import TemporaryDirectory +from nemo_reinforcer.algorithms.utils import get_tokenizer from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster from nemo_reinforcer.models.policy.hf_policy import HfPolicy @@ -33,7 +34,9 @@ # Define basic test config simple_policy_config = { "model_name": "meta-llama/Llama-3.2-1B", # "hf-internal-testing/tiny-random-Gemma3ForCausalLM", - "tokenizer_name": "meta-llama/Llama-3.2-1B", # "hf-internal-testing/tiny-random-Gemma3ForCausalLM", + "tokenizer": { + "name": "meta-llama/Llama-3.2-1B", + }, "train_global_batch_size": 4, "train_micro_batch_size": 1, "logprob_batch_size": 1, @@ -86,10 +89,7 @@ def cluster(): @pytest.fixture(scope="function") def tokenizer(): """Initialize tokenizer for the test model.""" - tokenizer_name = simple_policy_config["tokenizer_name"] - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + tokenizer = get_tokenizer(simple_policy_config["tokenizer"]) return tokenizer @@ -98,6 +98,7 @@ def policy(cluster, tokenizer): """Initialize the policy.""" policy = HfPolicy( cluster=cluster, + tokenizer=tokenizer, config=simple_policy_config, init_optimizer=True, init_reference_model=False, @@ -280,6 +281,7 @@ def test_save_and_load_hf_checkpoint(policy): os.path.join(tmp_dir, "test_hf_and_dcp"), save_hf=True, save_torch_dist=True, + tokenizer_path=os.path.join(tmp_dir, "test_hf_and_dcp_tokenizer"), ) ## make sure we save both HF and DCP checkpoints @@ -297,6 +299,12 @@ def test_save_and_load_hf_checkpoint(policy): "model.safetensors.index.json", } + assert set(os.listdir(os.path.join(tmp_dir, "test_hf_and_dcp_tokenizer"))) == { + "tokenizer_config.json", + "tokenizer.json", + "special_tokens_map.json", + } + converted_model = AutoModelForCausalLM.from_pretrained( os.path.join(tmp_dir, "test_hf_and_dcp-hf") )