diff --git a/examples/language-modeling/run_lora_clm.py b/examples/language-modeling/run_lora_clm.py index 1c6e29da25..4782ed58ae 100644 --- a/examples/language-modeling/run_lora_clm.py +++ b/examples/language-modeling/run_lora_clm.py @@ -700,6 +700,11 @@ def main(): raise ValueError("Must provide model_name_or_path to load a pretrained CausalLM model.") if model.config.model_type == "llama": + if model.generation_config.pad_token_id is None: + if isinstance(model.generation_config.eos_token_id, int): + model.generation_config.pad_token_id = model.generation_config.eos_token_id + elif isinstance(model.generation_config.eos_token_id, list): + model.generation_config.pad_token_id = model.generation_config.eos_token_id[0] if model_args.attn_softmax_bf16: model.generation_config.attn_softmax_bf16 = True if model_args.use_flash_attention: @@ -717,7 +722,10 @@ def main(): if hasattr(model.generation_config, "pad_token_id") and model.generation_config.pad_token_id is not None: tokenizer.pad_token_id = model.generation_config.pad_token_id if hasattr(model.generation_config, "eos_token_id") and model.generation_config.eos_token_id is not None: - tokenizer.eos_token_id = model.generation_config.eos_token_id + if isinstance(model.generation_config.eos_token_id, int): + tokenizer.eos_token_id = model.generation_config.eos_token_id + elif isinstance(model.generation_config.eos_token_id, list): + tokenizer.eos_token_id = model.generation_config.eos_token_id[0] if hasattr(model.generation_config, "bos_token_id") and model.generation_config.bos_token_id is not None: tokenizer.bos_token_id = model.generation_config.bos_token_id