Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion examples/language-modeling/run_lora_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,11 @@ def main():
raise ValueError("Must provide model_name_or_path to load a pretrained CausalLM model.")

if model.config.model_type == "llama":
if model.generation_config.pad_token_id is None:
if isinstance(model.generation_config.eos_token_id, int):
model.generation_config.pad_token_id = model.generation_config.eos_token_id
elif isinstance(model.generation_config.eos_token_id, list):
model.generation_config.pad_token_id = model.generation_config.eos_token_id[0]
if model_args.attn_softmax_bf16:
model.generation_config.attn_softmax_bf16 = True
if model_args.use_flash_attention:
Expand All @@ -717,7 +722,10 @@ def main():
if hasattr(model.generation_config, "pad_token_id") and model.generation_config.pad_token_id is not None:
tokenizer.pad_token_id = model.generation_config.pad_token_id
if hasattr(model.generation_config, "eos_token_id") and model.generation_config.eos_token_id is not None:
tokenizer.eos_token_id = model.generation_config.eos_token_id
if isinstance(model.generation_config.eos_token_id, int):
tokenizer.eos_token_id = model.generation_config.eos_token_id
elif isinstance(model.generation_config.eos_token_id, list):
tokenizer.eos_token_id = model.generation_config.eos_token_id[0]
if hasattr(model.generation_config, "bos_token_id") and model.generation_config.bos_token_id is not None:
tokenizer.bos_token_id = model.generation_config.bos_token_id

Expand Down