diff --git a/open_instruct/ground_truth_utils.py b/open_instruct/ground_truth_utils.py index 8b24efbbc8..90dfe3efd2 100644 --- a/open_instruct/ground_truth_utils.py +++ b/open_instruct/ground_truth_utils.py @@ -77,6 +77,7 @@ class LMJudgeVerifierConfig(VerifierConfig): class CodeVerifierConfig(VerifierConfig): code_api_url: str code_max_execution_time: float + code_pass_rate_reward_threshold: float @dataclass @@ -762,6 +763,7 @@ class CodeVerifier(VerifierFunction): def __init__(self, verifier_config: CodeVerifierConfig) -> None: super().__init__("code", verifier_config=verifier_config, weight=1.0) + self.pass_rate_reward_threshold = verifier_config.code_pass_rate_reward_threshold def extract_python_code(self, model_output: str) -> str: """Extract the last code block between ``` markers from the model output.""" @@ -830,7 +832,8 @@ def make_request(): result = await asyncio.to_thread(make_request) passes = result["results"] pass_rate = sum(passes) / len(passes) if passes else 0.0 - return VerificationResult(score=pass_rate) + score = 0.0 if pass_rate < self.pass_rate_reward_threshold else pass_rate + return VerificationResult(score=score) except Exception as e: logger.warning(f"Error verifying code sample: {e}") return VerificationResult(score=0.0) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 3c89bfab10..fe02a5bb29 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -300,6 +300,8 @@ class Args: """the api url to use for the code verifier""" code_max_execution_time: float = 1.0 """the max execution time to use for the code verifier""" + code_pass_rate_reward_threshold: float = 0.0 + """the pass rate reward threshold for the code verifier. If pass rate is less than this threshold, reward is 0.0, otherwise reward is pass rate""" # -- non stop penalty non_stop_penalty: bool = False diff --git a/open_instruct/grpo_vllm_thread_ray_gtrl.py b/open_instruct/grpo_vllm_thread_ray_gtrl.py index bde722ea96..be8121baca 100644 --- a/open_instruct/grpo_vllm_thread_ray_gtrl.py +++ b/open_instruct/grpo_vllm_thread_ray_gtrl.py @@ -350,6 +350,8 @@ class Args: """the api url to use for the code verifier""" code_max_execution_time: float = 1.0 """the max execution time to use for the code verifier""" + code_pass_rate_reward_threshold: float = 0.0 + """the pass rate reward threshold for the code verifier. If pass rate is less than this threshold, reward is 0.0, otherwise reward is pass rate""" def __post_init__(self): assert self.number_samples_per_prompt > 0, "Number of samples per prompt must be greater than 0!" diff --git a/open_instruct/ppo_fast.py b/open_instruct/ppo_fast.py index ef7443cd45..3ecb262e44 100644 --- a/open_instruct/ppo_fast.py +++ b/open_instruct/ppo_fast.py @@ -291,6 +291,8 @@ class Args: """the api url to use for the code verifier""" code_max_execution_time: float = 1.0 """the max execution time to use for the code verifier""" + code_pass_rate_reward_threshold: float = 0.0 + """the pass rate reward threshold for the code verifier. If pass rate is less than this threshold, reward is 0.0, otherwise reward is pass rate""" # -- non stop penalty non_stop_penalty: bool = False diff --git a/open_instruct/ppo_vllm_thread_ray_gtrl.py b/open_instruct/ppo_vllm_thread_ray_gtrl.py index 283ab84d18..5e2dc91ac3 100644 --- a/open_instruct/ppo_vllm_thread_ray_gtrl.py +++ b/open_instruct/ppo_vllm_thread_ray_gtrl.py @@ -357,6 +357,8 @@ class Args: """the api url to use for the code verifier""" code_max_execution_time: float = 1.0 """the max execution time to use for the code verifier""" + code_pass_rate_reward_threshold: float = 0.0 + """the pass rate reward threshold for the code verifier. If pass rate is less than this threshold, reward is 0.0, otherwise reward is pass rate""" def process_dataset_mixer(value) -> Tuple[Optional[dict], Optional[str]]: