Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion open_instruct/ground_truth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class LMJudgeVerifierConfig(VerifierConfig):
class CodeVerifierConfig(VerifierConfig):
code_api_url: str
code_max_execution_time: float
code_pass_rate_reward_threshold: float


@dataclass
Expand Down Expand Up @@ -762,6 +763,7 @@ class CodeVerifier(VerifierFunction):

def __init__(self, verifier_config: CodeVerifierConfig) -> None:
super().__init__("code", verifier_config=verifier_config, weight=1.0)
self.pass_rate_reward_threshold = verifier_config.code_pass_rate_reward_threshold

def extract_python_code(self, model_output: str) -> str:
"""Extract the last code block between ``` markers from the model output."""
Expand Down Expand Up @@ -830,7 +832,8 @@ def make_request():
result = await asyncio.to_thread(make_request)
passes = result["results"]
pass_rate = sum(passes) / len(passes) if passes else 0.0
return VerificationResult(score=pass_rate)
score = 0.0 if pass_rate < self.pass_rate_reward_threshold else pass_rate
return VerificationResult(score=score)
except Exception as e:
logger.warning(f"Error verifying code sample: {e}")
return VerificationResult(score=0.0)
Expand Down
2 changes: 2 additions & 0 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ class Args:
"""the api url to use for the code verifier"""
code_max_execution_time: float = 1.0
"""the max execution time to use for the code verifier"""
code_pass_rate_reward_threshold: float = 0.0
"""the pass rate reward threshold for the code verifier. If pass rate is less than this threshold, reward is 0.0, otherwise reward is pass rate"""

# -- non stop penalty
non_stop_penalty: bool = False
Expand Down
2 changes: 2 additions & 0 deletions open_instruct/grpo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,8 @@ class Args:
"""the api url to use for the code verifier"""
code_max_execution_time: float = 1.0
"""the max execution time to use for the code verifier"""
code_pass_rate_reward_threshold: float = 0.0
"""the pass rate reward threshold for the code verifier. If pass rate is less than this threshold, reward is 0.0, otherwise reward is pass rate"""

def __post_init__(self):
assert self.number_samples_per_prompt > 0, "Number of samples per prompt must be greater than 0!"
Expand Down
2 changes: 2 additions & 0 deletions open_instruct/ppo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ class Args:
"""the api url to use for the code verifier"""
code_max_execution_time: float = 1.0
"""the max execution time to use for the code verifier"""
code_pass_rate_reward_threshold: float = 0.0
"""the pass rate reward threshold for the code verifier. If pass rate is less than this threshold, reward is 0.0, otherwise reward is pass rate"""

# -- non stop penalty
non_stop_penalty: bool = False
Expand Down
2 changes: 2 additions & 0 deletions open_instruct/ppo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ class Args:
"""the api url to use for the code verifier"""
code_max_execution_time: float = 1.0
"""the max execution time to use for the code verifier"""
code_pass_rate_reward_threshold: float = 0.0
"""the pass rate reward threshold for the code verifier. If pass rate is less than this threshold, reward is 0.0, otherwise reward is pass rate"""


def process_dataset_mixer(value) -> Tuple[Optional[dict], Optional[str]]:
Expand Down