diff --git a/examples/language-modeling/run_lora_clm.py b/examples/language-modeling/run_lora_clm.py index 0b16be0725..1c6e29da25 100644 --- a/examples/language-modeling/run_lora_clm.py +++ b/examples/language-modeling/run_lora_clm.py @@ -700,10 +700,6 @@ def main(): raise ValueError("Must provide model_name_or_path to load a pretrained CausalLM model.") if model.config.model_type == "llama": - # unwind broken decapoda-research config - model.generation_config.pad_token_id = 0 - model.generation_config.bos_token_id = 1 - model.generation_config.eos_token_id = 2 if model_args.attn_softmax_bf16: model.generation_config.attn_softmax_bf16 = True if model_args.use_flash_attention: diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 2b0fe0d328..cb734071b0 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -525,16 +525,22 @@ def setup_tokenizer(args, model, assistant_model): tokenizer.padding_side = "left" if model.config.model_type == "llama": - # unwind broken decapoda-research config - model.generation_config.pad_token_id = 0 - model.generation_config.bos_token_id = 1 - model.generation_config.eos_token_id = 2 + if model.generation_config.pad_token_id is None: + if isinstance(model.generation_config.eos_token_id, int): + model.generation_config.pad_token_id = model.generation_config.eos_token_id + elif isinstance(model.generation_config.eos_token_id, list): + model.generation_config.pad_token_id = model.generation_config.eos_token_id[0] if assistant_model is not None: - assistant_model.generation_config.pad_token_id = 0 - assistant_model.generation_config.bos_token_id = 1 - assistant_model.generation_config.eos_token_id = 2 + if assistant_model.generation_config.pad_token_id is None: + if isinstance(assistant_model.generation_config.eos_token_id, int): + assistant_model.generation_config.pad_token_id = assistant_model.generation_config.eos_token_id + elif isinstance(assistant_model.generation_config.eos_token_id, list): + assistant_model.generation_config.pad_token_id = assistant_model.generation_config.eos_token_id[0] tokenizer.bos_token_id = model.generation_config.bos_token_id - tokenizer.eos_token_id = model.generation_config.eos_token_id + if isinstance(model.generation_config.eos_token_id, int): + tokenizer.eos_token_id = model.generation_config.eos_token_id + elif isinstance(model.generation_config.eos_token_id, list): + tokenizer.eos_token_id = model.generation_config.eos_token_id[0] tokenizer.pad_token_id = model.generation_config.pad_token_id tokenizer.pad_token = tokenizer.decode(tokenizer.pad_token_id) tokenizer.eos_token = tokenizer.decode(tokenizer.eos_token_id)