diff --git a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.yaml b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v2.yaml similarity index 100% rename from examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.yaml rename to examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v2.yaml diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml similarity index 100% rename from examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.yaml rename to examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.yaml diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.v2.yaml similarity index 100% rename from examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.yaml rename to examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.v2.yaml diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v2.yaml similarity index 100% rename from examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.yaml rename to examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v2.yaml diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.v2.yaml similarity index 100% rename from examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.yaml rename to examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.v2.yaml diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v2.yaml similarity index 100% rename from examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.yaml rename to examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v2.yaml diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v2.yaml similarity index 100% rename from examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml rename to examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v2.yaml diff --git a/nemo_rl/environments/math_environment.py b/nemo_rl/environments/math_environment.py index 8da0528652..fd968298b0 100644 --- a/nemo_rl/environments/math_environment.py +++ b/nemo_rl/environments/math_environment.py @@ -15,7 +15,8 @@ import ray import torch -from math_verify import parse, verify +from math_verify.metric import math_metric +from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.virtual_cluster import PY_EXECUTABLES @@ -53,9 +54,23 @@ def verify( results = [] for response, ground_truth in zip(pred_responses, ground_truths): try: - gold = parse(ground_truth) - pred = parse(response[-100:]) # avoid looking at the whole string - results.append(float(verify(gold, pred))) + # Use Latex and plain math extraction from predictions + # https://github.com/huggingface/Math-Verify?tab=readme-ov-file#extraction-targets + verify_func = math_metric( + gold_extraction_target=(LatexExtractionConfig(),), + pred_extraction_target=( + ExprExtractionConfig(), + LatexExtractionConfig(), + ), + ) + + ground_truth_parsable = "\\boxed{" + ground_truth + "}" + try: + ret_score, _ = verify_func([ground_truth_parsable], [response]) + except Exception: + ret_score = 0.0 + + results.append(float(ret_score)) except Exception: results.append(0) return results diff --git a/tests/test_suites/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.sh b/tests/test_suites/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v2.sh similarity index 100% rename from tests/test_suites/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.sh rename to tests/test_suites/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v2.sh diff --git a/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.sh b/tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.sh similarity index 100% rename from tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.sh rename to tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.sh diff --git a/tests/test_suites/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.sh b/tests/test_suites/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.v2.sh similarity index 100% rename from tests/test_suites/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.sh rename to tests/test_suites/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.v2.sh diff --git a/tests/test_suites/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.sh b/tests/test_suites/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v2.sh similarity index 100% rename from tests/test_suites/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.sh rename to tests/test_suites/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v2.sh diff --git a/tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.sh b/tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.v2.sh similarity index 100% rename from tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.sh rename to tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.v2.sh diff --git a/tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.sh b/tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v2.sh similarity index 100% rename from tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.sh rename to tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v2.sh diff --git a/tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.sh b/tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v2.sh similarity index 100% rename from tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.sh rename to tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v2.sh diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index 4c609d5bff..b80a7ad545 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -3,15 +3,15 @@ ######## # Short 1N/1B runs (go past 200 steps - usually divergence happens by now) -- going to 4 nodes doesn't help that much -tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.sh -tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.sh +tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v2.sh +tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.sh # FSDP1 vs Dtensor (Qwen/Qwen2.5-7B-Instruct) -tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.sh -tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.sh +tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.v2.sh +tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v2.sh # Functional 32b run -tests/test_suites/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.sh +tests/test_suites/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v2.sh ####### # SFT # diff --git a/tests/test_suites/release.txt b/tests/test_suites/release.txt index 69735cb0cb..42e9c49d00 100644 --- a/tests/test_suites/release.txt +++ b/tests/test_suites/release.txt @@ -3,10 +3,10 @@ ######## # Long 8b run -tests/test_suites/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.sh +tests/test_suites/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v2.sh # Long 32b run -tests/test_suites/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.sh +tests/test_suites/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.v2.sh ####### # SFT #