Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,10 @@ class Args:
num_training_steps: Optional[int] = None
"""RUNTIME VALUE: The number of training_steps to train"""
num_evals: int = 10
"""The number of evaluations to run throughout training"""
eval_freq: Optional[int] = None
"""RUNTIME VALUE: The frequency of evaluation steps"""
save_freq: int = -1
"""this sets how many in-loop evals we do during training. in-loop evals reuse the generation/reward verifier setup."""
local_eval_freq: Optional[int] = None
"""this controls the number of in-loop evals, which reuses the generation/reward verifier setup. don't set this directly, but set via num_evals."""
save_freq: int = 200
"""How many train steps to save the model"""
allow_world_padding: bool = False
"""Whether to allow world padding. This is useful for model sweeps, but wastes compute."""
Expand Down Expand Up @@ -1102,7 +1102,7 @@ def vllm_generate_thread(
num_training_steps: int,
eval_prompt_token_ids: Optional[List[int]],
evaluation_inference_results_Q: Queue,
eval_freq: int,
local_eval_freq: int,
resume_training_step: int = 1,
tool_use: bool = False,
):
Expand Down Expand Up @@ -1164,7 +1164,7 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar
inference_results_Q.put((response_ids, finish_reasons, masks, info))

# Evaluate the model
if eval_prompt_token_ids is not None and (training_step - 1) % eval_freq == 0:
if eval_prompt_token_ids is not None and (training_step - 1) % local_eval_freq == 0:
response_ids, finish_reasons, masks, info = generate_with_engines(
eval_prompt_token_ids, eval_generation_config
)
Expand Down Expand Up @@ -1450,7 +1450,9 @@ def setup_runtime_variables(args: Args) -> Args:
args.num_training_steps = args.total_episodes // (
args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout
)
args.eval_freq = max(1, args.num_training_steps // args.num_evals)
if args.local_eval_freq is not None:
raise ValueError("local_eval_freq should not be set manually; it will be computed automatically")
args.local_eval_freq = max(1, args.num_training_steps // args.num_evals)
args.try_launch_beaker_eval_jobs_on_weka = args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job()
if args.push_to_hub:
if args.hf_repo_id is None: # auto-generate one
Expand Down Expand Up @@ -1778,7 +1780,7 @@ def maybe_evaluate(
try:
# timeout 0.01 if this is the last training step or we're not evaluating
# otherwise, wait to get the last evaluation generations (long timeout just in case)
timeout = 0.01 if (training_step < args.num_training_steps or args.eval_freq < 0) else 100
timeout = 0.01 if (training_step < args.num_training_steps or args.local_eval_freq < 0) else 100
eval_responses, eval_finish_reasons, masks, eval_infos = evaluation_inference_results_Q.get(timeout=timeout)
logger.info("[Main Thread] 📊 Evaluation responses received")

Expand Down Expand Up @@ -2010,7 +2012,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa
args.num_training_steps,
eval_prompt_token_ids,
evaluation_inference_results_Q,
args.eval_freq,
args.local_eval_freq,
resume_training_step,
args.tool_use,
),
Expand Down
18 changes: 10 additions & 8 deletions open_instruct/grpo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,12 @@ class Args:
num_training_steps: Optional[int] = None
"""The number of training_steps to train"""
num_evals: int = 4
"""The number of evaluations to run throughout training"""
eval_freq: Optional[int] = None
"""The frequency of evaluation steps"""
"""this sets how many in-loop evals we do during training. in-loop evals reuse the generation/reward verifier setup."""
local_eval_freq: Optional[int] = None
"""this controls the number of in-loop evals, which reuses the generation/reward verifier setup. don't set this directly, but set via num_evals."""
local_dataloader_batch_size: Optional[int] = None
"""The batch size per GPU for the dataloader"""
save_freq: int = -1
save_freq: int = 200
"""How many train steps to save the model"""

# online settings
Expand Down Expand Up @@ -928,7 +928,7 @@ def vllm_generate(
num_training_steps: int,
sample_evaluation_prompt_token_ids: Optional[List[int]],
evaluation_Q: Queue,
eval_freq: int,
local_eval_freq: int,
resume_training_step: int,
):
def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingParams):
Expand Down Expand Up @@ -962,7 +962,7 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar
response_ids_Q.put(response_ids)

# Evaluate the model
if sample_evaluation_prompt_token_ids is not None and (training_step - 1) % eval_freq == 0:
if sample_evaluation_prompt_token_ids is not None and (training_step - 1) % local_eval_freq == 0:
response_ids = generate_with_engines(
sample_evaluation_prompt_token_ids, evaluation_generation_config
)
Expand All @@ -979,7 +979,7 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar
args.num_training_steps,
sample_evaluation_prompt_token_ids,
evaluation_Q,
args.eval_freq,
args.local_eval_freq,
resume_training_step,
),
)
Expand Down Expand Up @@ -1596,7 +1596,9 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig):
args.mini_batch_size = int(args.local_mini_batch_size * args.world_size)
args.num_mini_batches = exact_div((args.rollout_batch_size * args.number_samples_per_prompt), args.mini_batch_size)
args.num_training_steps = args.total_episodes // (args.rollout_batch_size * args.number_samples_per_prompt)
args.eval_freq = max(1, args.num_training_steps // args.num_evals)
if args.local_eval_freq is not None:
raise ValueError("local_eval_freq should not be set manually; it will be computed automatically")
args.local_eval_freq = max(1, args.num_training_steps // args.num_evals)
# PPO logic: do checks and set up dataloader batch size
if args.whiten_rewards:
assert args.local_mini_batch_size >= 8, (
Expand Down
20 changes: 11 additions & 9 deletions open_instruct/ppo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,10 @@ class Args:
num_training_steps: Optional[int] = None
"""RUNTIME VALUE: The number of training_steps to train"""
num_evals: int = 10
"""The number of evaluations to run throughout training"""
eval_freq: Optional[int] = None
"""RUNTIME VALUE: The frequency of evaluation steps"""
save_freq: int = -1
"""this sets how many in-loop evals we do during training. in-loop evals reuse the generation/reward verifier setup."""
local_eval_freq: Optional[int] = None
"""this controls the number of in-loop evals, which reuses the generation/reward verifier setup. don't set this directly, but set via num_evals."""
save_freq: int = 200
"""How many train steps to save the model"""

# Generation
Expand Down Expand Up @@ -1201,7 +1201,7 @@ def vllm_generate_thread(
num_training_steps: int,
eval_prompt_token_ids: Optional[List[int]],
evaluation_inference_results_Q: Queue,
eval_freq: int,
local_eval_freq: int,
resume_training_step: int = 1,
):
def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingParams):
Expand Down Expand Up @@ -1262,7 +1262,7 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar
inference_results_Q.put((response_ids, finish_reasons, masks, info))

# Evaluate the model
if eval_prompt_token_ids is not None and (training_step - 1) % eval_freq == 0:
if eval_prompt_token_ids is not None and (training_step - 1) % local_eval_freq == 0:
response_ids, finish_reasons, masks, info = generate_with_engines(
eval_prompt_token_ids, eval_generation_config
)
Expand Down Expand Up @@ -1486,7 +1486,9 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, reward_fn:
args.num_training_steps = args.total_episodes // (
args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout
)
args.eval_freq = max(1, args.num_training_steps // args.num_evals)
if args.local_eval_freq is not None:
raise ValueError("local_eval_freq should not be set manually; it will be computed automatically")
args.local_eval_freq = max(1, args.num_training_steps // args.num_evals)
args.try_launch_beaker_eval_jobs_on_weka = args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job()
if args.push_to_hub:
if args.hf_repo_id is None: # auto-generate one
Expand Down Expand Up @@ -1686,7 +1688,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, reward_fn:
args.num_training_steps,
eval_prompt_token_ids,
evaluation_inference_results_Q,
args.eval_freq,
args.local_eval_freq,
resume_training_step,
),
)
Expand Down Expand Up @@ -1830,7 +1832,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, reward_fn:
try:
# timeout 0.01 if this is the last training step or we're not evaluating
# otherwise, wait to get the last evaluation generations (long timeout just in case)
timeout = 0.01 if (training_step < args.num_training_steps or args.eval_freq < 0) else 100
timeout = 0.01 if (training_step < args.num_training_steps or args.local_eval_freq < 0) else 100
eval_responses, eval_finish_reasons, masks, eval_infos = evaluation_inference_results_Q.get(
timeout=timeout
)
Expand Down
18 changes: 10 additions & 8 deletions open_instruct/ppo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,12 @@ class Args:
num_training_steps: Optional[int] = None
"""The number of training_steps to train"""
num_evals: int = 4
"""The number of evaluations to run throughout training"""
eval_freq: Optional[int] = None
"""The frequency of evaluation steps"""
"""this sets how many in-loop evals we do during training. in-loop evals reuse the generation/reward verifier setup."""
local_eval_freq: Optional[int] = None
"""this controls the number of in-loop evals, which reuses the generation/reward verifier setup. don't set this directly, but set via num_evals."""
local_dataloader_batch_size: Optional[int] = None
"""The batch size per GPU for the dataloader"""
save_freq: int = -1
save_freq: int = 200
"""How many train steps to save the model"""

# online settings
Expand Down Expand Up @@ -971,7 +971,7 @@ def vllm_generate(
num_training_steps: int,
sample_evaluation_prompt_token_ids: Optional[List[int]],
evaluation_Q: Queue,
eval_freq: int,
local_eval_freq: int,
resume_training_step: int,
):
def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingParams):
Expand Down Expand Up @@ -1005,7 +1005,7 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar
response_ids_Q.put(response_ids)

# Evaluate the model
if sample_evaluation_prompt_token_ids is not None and (training_step - 1) % eval_freq == 0:
if sample_evaluation_prompt_token_ids is not None and (training_step - 1) % local_eval_freq == 0:
response_ids = generate_with_engines(
sample_evaluation_prompt_token_ids, evaluation_generation_config
)
Expand All @@ -1022,7 +1022,7 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar
args.num_training_steps,
sample_evaluation_prompt_token_ids,
evaluation_Q,
args.eval_freq,
args.local_eval_freq,
resume_training_step,
),
)
Expand Down Expand Up @@ -1671,7 +1671,9 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig):
args.mini_batch_size = int(args.local_mini_batch_size * args.world_size)
args.num_mini_batches = exact_div((args.rollout_batch_size * args.number_samples_per_prompt), args.mini_batch_size)
args.num_training_steps = args.total_episodes // (args.rollout_batch_size * args.number_samples_per_prompt)
args.eval_freq = max(1, args.num_training_steps // args.num_evals)
if args.local_eval_freq is not None:
raise ValueError("local_eval_freq should not be set manually; it will be computed automatically")
args.local_eval_freq = max(1, args.num_training_steps // args.num_evals)
# PPO logic: do checks and set up dataloader batch size
if args.whiten_rewards:
assert args.local_mini_batch_size >= 8, (
Expand Down