diff --git a/tests/build/test_launch_script.py b/tests/build/test_launch_script.py
index 824b7125c..421849b1f 100644
--- a/tests/build/test_launch_script.py
+++ b/tests/build/test_launch_script.py
@@ -61,7 +61,6 @@
"prompt_tuning_init": "RANDOM",
"num_virtual_tokens": 8,
"prompt_tuning_init_text": "hello",
- "tokenizer_name_or_path": MODEL_NAME,
"save_strategy": "epoch",
"output_dir": "tmp",
},
diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py
index eb2ca855a..26067124a 100644
--- a/tests/test_sft_trainer.py
+++ b/tests/test_sft_trainer.py
@@ -176,11 +176,9 @@ def test_run_causallm_pt_and_inference():
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
- # tokenizer_name_or_path from model arguments is passed
- # while preparing the prompt tuning config which
- # defaults to model_name_or_path if not explicitly set.
+
_validate_adapter_config(
- adapter_config, "PROMPT_TUNING", MODEL_ARGS.tokenizer_name_or_path
+ adapter_config, "PROMPT_TUNING", MODEL_ARGS.model_name_or_path
)
# Load the model
@@ -214,11 +212,8 @@ def test_run_causallm_pt_and_inference_with_formatting_data():
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
- # tokenizer_name_or_path from model arguments is passed
- # while preparing the prompt tuning config which
- # defaults to model_name_or_path if not explicitly set.
_validate_adapter_config(
- adapter_config, "PROMPT_TUNING", MODEL_ARGS.tokenizer_name_or_path
+ adapter_config, "PROMPT_TUNING", MODEL_ARGS.model_name_or_path
)
# Load the model
@@ -250,11 +245,8 @@ def test_run_causallm_pt_and_inference_JSON_file_formatter():
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
- # tokenizer_name_or_path from model arguments is passed
- # while preparing the prompt tuning config which
- # defaults to model_name_or_path if not explicitly set.
_validate_adapter_config(
- adapter_config, "PROMPT_TUNING", MODEL_ARGS.tokenizer_name_or_path
+ adapter_config, "PROMPT_TUNING", MODEL_ARGS.model_name_or_path
)
# Load the model
@@ -285,11 +277,8 @@ def test_run_causallm_pt_init_text():
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
- # tokenizer_name_or_path from model arguments is passed
- # while preparing the prompt tuning config which
- # defaults to model_name_or_path if not explicitly set.
_validate_adapter_config(
- adapter_config, "PROMPT_TUNING", MODEL_ARGS.tokenizer_name_or_path
+ adapter_config, "PROMPT_TUNING", MODEL_ARGS.model_name_or_path
)
@@ -349,6 +338,20 @@ def test_run_causallm_pt_with_validation_data_formatting():
_validate_training(tempdir, check_eval=True)
+def test_run_causallm_pt_with_custom_tokenizer():
+ """Check if we fail when custom tokenizer not having pad token is used in prompt tuning"""
+ with tempfile.TemporaryDirectory() as tempdir:
+ train_args = copy.deepcopy(TRAIN_ARGS)
+ model_args = copy.deepcopy(MODEL_ARGS)
+ model_args.tokenizer_name_or_path = model_args.model_name_or_path
+ train_args.output_dir = tempdir
+ train_args.eval_strategy = "epoch"
+ data_args = copy.deepcopy(DATA_ARGS)
+ data_args.validation_data_path = TWITTER_COMPLAINTS_DATA
+ with pytest.raises(ValueError):
+ sft_trainer.train(model_args, data_args, train_args, PEFT_PT_ARGS)
+
+
############################# Lora Tests #############################
target_modules_val_map = [
diff --git a/tuning/config/configs.py b/tuning/config/configs.py
index 92fb4f8f8..c08c90b12 100644
--- a/tuning/config/configs.py
+++ b/tuning/config/configs.py
@@ -51,15 +51,15 @@ class ModelArguments:
tokenizer_name_or_path: Optional[str] = field(
default=None,
metadata={
- "help": "Path to custom tokenizer.\
- If not provided it defaults to model_name_or_path"
+ "help": "Path to custom tokenizer. \
+ If not provided it defaults to model_name_or_path \
+ and special tokens will be added as needed for specific tokenizer classes. \
+ For prompt tuning, if tokenizer_name_or_path provided, special tokens are not added, \
+ otherwise, it defaults to model_name_or_path with special tokens for specific \
+ tokenizer classes."
},
)
- def __post_init__(self):
- if not self.tokenizer_name_or_path:
- self.tokenizer_name_or_path = self.model_name_or_path
-
@dataclass
class DataArguments:
diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py
index d889c67e7..30e768ce4 100644
--- a/tuning/sft_trainer.py
+++ b/tuning/sft_trainer.py
@@ -190,32 +190,46 @@ def train(
# TODO: Move these to a config as well
tokenizer = AutoTokenizer.from_pretrained(
- model_args.tokenizer_name_or_path, cache_dir=train_args.cache_dir, use_fast=True
+ (
+ model_args.tokenizer_name_or_path
+ if model_args.tokenizer_name_or_path
+ else model_args.model_name_or_path
+ ),
+ cache_dir=train_args.cache_dir,
+ use_fast=True,
)
# Calculate and save additional metrics to track later.
additional_metrics["model_load_time"] = time.time() - model_load_time
peft_config = get_hf_peft_config(
- task_type, peft_config, model_args.tokenizer_name_or_path
+ task_type,
+ peft_config,
+ (
+ model_args.tokenizer_name_or_path
+ if model_args.tokenizer_name_or_path
+ else model_args.model_name_or_path
+ ),
)
- # TODO: understand if we need to hardcode these here or just use defaults in model
- if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast)):
- tokenizer.add_special_tokens(
- {
- "bos_token": "",
- "eos_token": "",
- "unk_token": "",
- "pad_token": "",
- }
- )
- elif isinstance(tokenizer, (GPT2Tokenizer, GPTNeoXTokenizerFast)):
- tokenizer.add_special_tokens(
- {
- "pad_token": "",
- }
- )
+ # add special tokens only when a custom tokenizer is not passed
+ if not model_args.tokenizer_name_or_path:
+ # TODO: understand if we need to hardcode these here or just use defaults in model
+ if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast)):
+ tokenizer.add_special_tokens(
+ {
+ "bos_token": "",
+ "eos_token": "",
+ "unk_token": "",
+ "pad_token": "",
+ }
+ )
+ elif isinstance(tokenizer, (GPT2Tokenizer, GPTNeoXTokenizerFast)):
+ tokenizer.add_special_tokens(
+ {
+ "pad_token": "",
+ }
+ )
max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length)
logger.info("Max sequence length is %s", max_seq_length)
@@ -228,20 +242,22 @@ def train(
tokenizer.model_max_length,
)
- # TODO: we need to change this, perhaps follow what open instruct does?
+ # add special tokens only when a custom tokenizer is not passed
special_tokens_dict = {}
- if tokenizer.pad_token is None:
- logger.warning("PAD token set to default, missing in tokenizer")
- special_tokens_dict["pad_token"] = configs.DEFAULT_PAD_TOKEN
- if tokenizer.eos_token is None:
- logger.warning("EOS token set to default, missing in tokenizer")
- special_tokens_dict["eos_token"] = configs.DEFAULT_EOS_TOKEN
- if tokenizer.bos_token is None:
- logger.warning("BOS token set to default, missing in tokenizer")
- special_tokens_dict["bos_token"] = configs.DEFAULT_BOS_TOKEN
- if tokenizer.unk_token is None:
- logger.warning("UNK token set to default, missing in tokenizer")
- special_tokens_dict["unk_token"] = configs.DEFAULT_UNK_TOKEN
+ if not model_args.tokenizer_name_or_path:
+ # TODO: we need to change this, perhaps follow what open instruct does?
+ if tokenizer.pad_token is None:
+ logger.warning("PAD token set to default, missing in tokenizer")
+ special_tokens_dict["pad_token"] = configs.DEFAULT_PAD_TOKEN
+ if tokenizer.eos_token is None:
+ logger.warning("EOS token set to default, missing in tokenizer")
+ special_tokens_dict["eos_token"] = configs.DEFAULT_EOS_TOKEN
+ if tokenizer.bos_token is None:
+ logger.warning("BOS token set to default, missing in tokenizer")
+ special_tokens_dict["bos_token"] = configs.DEFAULT_BOS_TOKEN
+ if tokenizer.unk_token is None:
+ logger.warning("UNK token set to default, missing in tokenizer")
+ special_tokens_dict["unk_token"] = configs.DEFAULT_UNK_TOKEN
# TODO: lower priority but understand if resizing impacts inference quality and why its needed.
# It makes sense if we manipulate tokenizer that we also save it and provide it to inference.