Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
433 changes: 412 additions & 21 deletions open_instruct/dataset_transformation.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion open_instruct/dpo_tune_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,7 @@ def load_model():
accelerator.wait_for_everyone()

if args.output_dir is not None:
save_with_accelerate(accelerator, model, tokenizer, args.output_dir, args.use_lora)
save_with_accelerate(accelerator, model, tokenizer, args.output_dir, args.use_lora, tc.chat_template_name)

# remove all checkpoints to save space
if accelerator.is_local_main_process:
Expand Down
2 changes: 1 addition & 1 deletion open_instruct/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
accelerator.wait_for_everyone()

if args.output_dir is not None:
save_with_accelerate(accelerator, model, tokenizer, args.output_dir, args.use_lora)
save_with_accelerate(accelerator, model, tokenizer, args.output_dir, args.use_lora, tc.chat_template_name)

# remove all checkpoints to save space
if args.clean_checkpoints_at_end and accelerator.is_local_main_process:
Expand Down
34 changes: 29 additions & 5 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
apply_verifiable_reward,
disable_dropout_in_model,
entropy_from_logits,
get_olmo3_generation_config,
log_softmax_and_gather,
print_rich_single_line_metrics,
print_rich_table,
Expand Down Expand Up @@ -983,8 +984,12 @@ def save_checkpoint_state(self, checkpoint_state_dir: str, client_state: Dict[st
checkpoint_state_dir, args.gs_checkpoint_state_dir
)

def save_model(self, output_dir: str) -> None:
def save_model(self, output_dir: str, chat_template_name: str, tokenizer: PreTrainedTokenizer) -> None:
model_to_save = self.model
if "olmo" in chat_template_name:
# New chat template has no bos token, and two eos tokens: <|im_end|> and <|endoftext|>
model_to_save.generation_config = get_olmo3_generation_config(tokenizer)

if self.rank == 0:
os.makedirs(output_dir, exist_ok=True)

Expand Down Expand Up @@ -1774,6 +1779,7 @@ def one_training_step(
train_dataset,
writer,
wandb_url,
chat_template_name,
):
"""Train the model for one step."""
update_ref_policy_future = []
Expand Down Expand Up @@ -1820,7 +1826,12 @@ def one_training_step(
checkpoint_dir = f"{args.output_dir}_checkpoints"
step_dir = os.path.join(checkpoint_dir, f"step_{training_step}")
logger.info(f"Saving model at step {training_step} to {step_dir}")
ray.get([policy_group.models[i].save_model.remote(step_dir) for i in range(args.world_size)])
ray.get(
[
policy_group.models[i].save_model.remote(step_dir, chat_template_name, tokenizer)
for i in range(args.world_size)
]
)
if args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job():
leaderboard_name = f"{args.hf_repo_revision}_step_{training_step}"
for i in range(args.world_size):
Expand Down Expand Up @@ -1917,11 +1928,23 @@ def maybe_evaluate(
logger.warning("[Main Thread] 🙈 Evaluation responses not received")


def save_final_model(args: Args, policy_group: ModelGroup, training_step: int, wandb_url: str):
def save_final_model(
args: Args,
policy_group: ModelGroup,
tokenizer: PreTrainedTokenizer,
training_step: int,
wandb_url: str,
chat_template_name: str,
):
"""Save the final model and launch evaluation jobs if configured."""
logger.info(f"Saving final model at step {training_step} to {args.output_dir}")
with Timer("[Main Thread] 🗡️ Saving model"):
ray.get([policy_group.models[i].save_model.remote(args.output_dir) for i in range(args.world_size)])
ray.get(
[
policy_group.models[i].save_model.remote(args.output_dir, chat_template_name, tokenizer)
for i in range(args.world_size)
]
)
if args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job():
leaderboard_name = args.hf_repo_revision
for i in range(args.world_size):
Expand Down Expand Up @@ -2189,6 +2212,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa
train_dataset,
writer,
wandb_url,
tc.chat_template_name,
)

maybe_evaluate(
Expand All @@ -2204,7 +2228,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa
writer,
)

save_final_model(args, policy_group, training_step, wandb_url)
save_final_model(args, policy_group, tokenizer, training_step, wandb_url, tc.chat_template_name)

except Exception as e:
logger.error(f"Training error occurred: {str(e)}\n{traceback.format_exc()}")
Expand Down
14 changes: 11 additions & 3 deletions open_instruct/grpo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
disable_dropout_in_model,
exact_div,
first_true_indices,
get_olmo3_generation_config,
get_reward,
log_softmax_and_gather,
print_rich_single_line_metrics,
Expand Down Expand Up @@ -791,6 +792,7 @@ def train(
train_dataset: Dataset,
eval_dataset: Dataset,
tokenizer: PreTrainedTokenizer,
tc: TokenizerConfig,
vllm_engines: List[ray.actor.ActorHandle],
metrics_queue: RayQueue,
data_collator: Callable,
Expand Down Expand Up @@ -1378,7 +1380,7 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar
checkpoint_dir = f"{args.output_dir}_checkpoints"
step_dir = os.path.join(checkpoint_dir, f"step_{training_step}")
print(f"Saving model at step {training_step} to {step_dir}")
self.save_model(self.model, step_dir)
self.save_model(self.model, tc.chat_template_name, tokenizer, step_dir)
if args.try_launch_beaker_eval_jobs_on_weka:
leaderboard_name = f"{args.hf_repo_revision}_step_{training_step}"
if self.rank == 0 and is_beaker_job():
Expand All @@ -1404,7 +1406,7 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar
print(f"Eval future {eval_futures[0]} is done")
eval_futures.popleft()
print(f"Saving final model at step {training_step} to {args.output_dir}")
self.save_model(self.model, args.output_dir)
self.save_model(self.model, tc.chat_template_name, tokenizer, args.output_dir)
if args.try_launch_beaker_eval_jobs_on_weka:
leaderboard_name = args.hf_repo_revision
if self.rank == 0 and is_beaker_job():
Expand Down Expand Up @@ -1438,14 +1440,20 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar
shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True)
print("finished training")

def save_model(self, model_to_save: PreTrainedModel, output_dir: str) -> None:
def save_model(
self, model_to_save: PreTrainedModel, chat_template_name: str, tokenizer: PreTrainedTokenizer, output_dir: str
) -> None:
if self.rank == 0:
os.makedirs(output_dir, exist_ok=True)

# save model weights for ZeRO2/3
if hasattr(model_to_save, "module"):
model_to_save = model_to_save.module

if "olmo" in chat_template_name:
# New chat template has no bos token, and two eos tokens: <|im_end|> and <|endoftext|>
model_to_save.generation_config = get_olmo3_generation_config(tokenizer)

# gather parameters
output_state_dict = {}
for k, v in model_to_save.named_parameters():
Expand Down
20 changes: 17 additions & 3 deletions open_instruct/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,21 +403,35 @@ def batch_generation(
return torch.cat(query_responses, 0), torch.cat(logitss, 0)


def get_olmo3_generation_config(tokenizer):
return transformers.GenerationConfig(
temperature=None,
top_p=None,
eos_token_id=[tokenizer.convert_tokens_to_ids("<|im_end|>"), tokenizer.convert_tokens_to_ids("<|endoftext|>")],
)


def save_with_accelerate(
accelerator: Accelerator,
model: torch.nn.Module,
tokenizer: PreTrainedTokenizer,
output_dir: str,
use_lora: bool = False,
model_attribute_to_save: Optional[str] = None,
chat_template_name: str = "tulu",
) -> None:
"""`model_attribute_to_save` is for used to save PPO's policy instead of the full model"""
# set the generation config to an empty setting to be safe.
# we usually do greedy decoding for generation, so this should be okay.
# otherwise, we get an error thrown at save time.
model.generation_config = transformers.GenerationConfig(
temperature=None, top_p=None, eos_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id
)
if "olmo" in chat_template_name:
# New chat template has no bos token, and two eos tokens: <|im_end|> and <|endoftext|>
logger.log(f"Detected olmo chat template: {chat_template_name}, updating model generation config.")
model.generation_config = get_olmo3_generation_config(tokenizer)
else:
model.generation_config = transformers.GenerationConfig(
temperature=None, top_p=None, eos_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id
)

unwrapped_model: PreTrainedModel = accelerator.unwrap_model(model)
if model_attribute_to_save is not None:
Expand Down
21 changes: 18 additions & 3 deletions open_instruct/ppo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
apply_verifiable_reward,
disable_dropout_in_model,
entropy_from_logits,
get_olmo3_generation_config,
log_softmax_and_gather,
print_rich_single_line_metrics,
print_rich_table,
Expand Down Expand Up @@ -1074,7 +1075,7 @@ def train(
self.offload_to_cpu(self.model)
return metrics_list

def save_model(self, output_dir: str) -> None:
def save_model(self, output_dir: str, chat_template_name: str, tokenizer: PreTrainedTokenizer) -> None:
model_to_save = self.model
if self.rank == 0:
os.makedirs(output_dir, exist_ok=True)
Expand All @@ -1083,6 +1084,10 @@ def save_model(self, output_dir: str) -> None:
if hasattr(model_to_save, "module"):
model_to_save = model_to_save.module

if "olmo" in chat_template_name:
# New chat template has no bos token, and two eos tokens: <|im_end|> and <|endoftext|>
model_to_save.generation_config = get_olmo3_generation_config(tokenizer)

# gather parameters
output_state_dict = {}
for k, v in model_to_save.named_parameters():
Expand Down Expand Up @@ -1819,7 +1824,12 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, reward_fn:
checkpoint_dir = f"{args.output_dir}_checkpoints"
step_dir = os.path.join(checkpoint_dir, f"step_{training_step}")
print(f"Saving model at step {training_step} to {step_dir}")
ray.get([policy_group.models[i].save_model.remote(step_dir) for i in range(args.world_size)])
ray.get(
[
policy_group.models[i].save_model.remote(step_dir, tc.chat_template_name, tokenizer)
for i in range(args.world_size)
]
)
if args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job():
leaderboard_name = f"{args.hf_repo_revision}_step_{training_step}"
for i in range(args.world_size):
Expand Down Expand Up @@ -1889,7 +1899,12 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, reward_fn:

print(f"Saving final model at step {training_step} to {args.output_dir}")
with Timer("[Main Thread] 🗡️ Saving model"):
ray.get([policy_group.models[i].save_model.remote(args.output_dir) for i in range(args.world_size)])
ray.get(
[
policy_group.models[i].save_model.remote(args.output_dir, tc.chat_template_name, tokenizer)
for i in range(args.world_size)
]
)
if args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job():
leaderboard_name = args.hf_repo_revision
for i in range(args.world_size):
Expand Down
9 changes: 8 additions & 1 deletion open_instruct/ppo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
disable_dropout_in_model,
exact_div,
first_true_indices,
get_olmo3_generation_config,
get_reward,
log_softmax_and_gather,
print_rich_single_line_metrics,
Expand Down Expand Up @@ -1513,14 +1514,20 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar
shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True)
print("finished training")

def save_model(self, model_to_save: PreTrainedModel, output_dir: str) -> None:
def save_model(
self, model_to_save: PreTrainedModel, chat_template_name: str, tokenizer: PreTrainedTokenizer, output_dir: str
) -> None:
if self.rank == 0:
os.makedirs(output_dir, exist_ok=True)

# save model weights for ZeRO2/3
if hasattr(model_to_save, "module"):
model_to_save = model_to_save.module

if "olmo" in chat_template_name:
# New chat template has no bos token, and two eos tokens: <|im_end|> and <|endoftext|>
model_to_save.generation_config = get_olmo3_generation_config(tokenizer)

# gather parameters
output_state_dict = {}
for k, v in model_to_save.named_parameters():
Expand Down
2 changes: 1 addition & 1 deletion open_instruct/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig):

# save model
os.makedirs(os.path.dirname(args.output_dir), exist_ok=True)
save_with_accelerate(accelerator, model, tokenizer, args.output_dir)
save_with_accelerate(accelerator, model, tokenizer, args.output_dir, tc.chat_template_name)
if args.push_to_hub:
push_folder_to_hub(accelerator, args.output_dir, args.hf_repo_id, args.hf_repo_revision)

Expand Down
Loading