diff --git a/examples/configs/recipes/llm/grpo-acereason-math-7b-16K.yaml b/examples/configs/recipes/llm/grpo-acereason-math-7b-16K.yaml new file mode 100644 index 0000000000..b603ceab32 --- /dev/null +++ b/examples/configs/recipes/llm/grpo-acereason-math-7b-16K.yaml @@ -0,0 +1,21 @@ +defaults: + - ../../grpo_math_1B.yaml + - grpo-acereason-math-7b-8K.yaml +policy: + max_total_sequence_length: 16384 + dtensor_cfg: + activation_checkpointing: true + context_parallel_size: 2 + dynamic_batching: + logprob_mb_tokens: 32768 + train_mb_tokens: 16384 + sequence_packing: + enabled: false + logprob_mb_tokens: 32768 + train_mb_tokens: 16384 + generation: + max_new_tokens: 16384 + vllm_cfg: + max_model_len: 16384 +data: + max_input_seq_length: 16384 diff --git a/examples/configs/recipes/llm/grpo-acereason-math-7b-24K.yaml b/examples/configs/recipes/llm/grpo-acereason-math-7b-24K.yaml new file mode 100644 index 0000000000..7162571eb1 --- /dev/null +++ b/examples/configs/recipes/llm/grpo-acereason-math-7b-24K.yaml @@ -0,0 +1,23 @@ +defaults: + - ../../grpo_math_1B.yaml + - grpo-acereason-math-7b-16K.yaml +policy: + max_total_sequence_length: 24576 + dtensor_cfg: + activation_checkpointing: true + context_parallel_size: 8 + dynamic_batching: + logprob_mb_tokens: 49152 + train_mb_tokens: 24576 + sequence_packing: + enabled: false + logprob_mb_tokens: 49152 + train_mb_tokens: 24576 + optimizer: + kwargs: + lr: 5.0e-07 + generation: + max_new_tokens: 24576 +data: + max_input_seq_length: 24576 + diff --git a/examples/configs/recipes/llm/grpo-acereason-math-7b-32K.yaml b/examples/configs/recipes/llm/grpo-acereason-math-7b-32K.yaml new file mode 100644 index 0000000000..e0b739cbaf --- /dev/null +++ b/examples/configs/recipes/llm/grpo-acereason-math-7b-32K.yaml @@ -0,0 +1,26 @@ +defaults: + - ../../grpo_math_1B.yaml + - grpo-acereason-math-7b-16K.yaml +policy: + max_total_sequence_length: 32768 + logprob_batch_size: 2 + dtensor_cfg: + activation_checkpointing: true + context_parallel_size: 8 + dynamic_batching: + logprob_mb_tokens: 65536 + train_mb_tokens: 32768 + sequence_packing: + enabled: false + logprob_mb_tokens: 65536 + train_mb_tokens: 32768 + optimizer: + kwargs: + lr: 5.0e-07 + generation: + max_new_tokens: 32768 + vllm_cfg: + max_model_len: 32768 +data: + max_input_seq_length: 32768 + diff --git a/examples/configs/recipes/llm/grpo-acereason-math-7b-8K.yaml b/examples/configs/recipes/llm/grpo-acereason-math-7b-8K.yaml new file mode 100644 index 0000000000..625649b6df --- /dev/null +++ b/examples/configs/recipes/llm/grpo-acereason-math-7b-8K.yaml @@ -0,0 +1,85 @@ +defaults: ../../grpo_math_1B.yaml +grpo: + max_num_epochs: 30 + num_prompts_per_step: 128 + use_leave_one_out_baseline: false + val_period: 0 +loss_fn: + ratio_clip_c: 3 + reference_policy_kl_penalty: 0.0 +checkpointing: + keep_top_k: 10 + model_save_format: null +policy: + activation_checkpointing_enabled: false + dtensor_cfg: + activation_checkpointing: true + context_parallel_size: 2 + dynamic_batching: + logprob_mb_tokens: 16384 + train_mb_tokens: 8192 + fsdp_offload_enabled: false + generation: + colocated: + resources: + gpus_per_node: 8 + max_new_tokens: 8192 + model_name: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + pad_token_id: 151643 + stop_token_ids: + - 151643 + vllm_cfg: + enable_expert_parallel: false + enforce_eager: true + load_format: dummy + max_model_len: 8192 + precision: bfloat16 + skip_tokenizer_init: true + tensor_parallel_size: 4 + logprob_batch_size: 2 + lr: 1.0e-06 + make_sequence_length_divisible_by: 4 + max_total_sequence_length: 8192 + min_lr: 1.0e-06 + model_name: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + optimizer: + kwargs: + lr: 1.0e-06 + pipeline_model_parallel_size: 1 + refit_buffer_size_gb: 4 + scheduler: + - kwargs: + end_factor: 1.0 + start_factor: 1.0 + total_iters: 1 + name: torch.optim.lr_scheduler.LinearLR + - kwargs: + T_max: 1000000 + eta_min: 1.0e-06 + name: torch.optim.lr_scheduler.CosineAnnealingLR + - milestones: + - 0 + sequence_packing: + enabled: false + logprob_mb_tokens: 16384 + train_mb_tokens: 8192 + tensor_model_parallel_size: 1 + tokenizer: + name: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B + train_global_batch_size: 2048 + train_micro_batch_size: 1 + weight_decay: 0.01 +data: + dataset_name: nvidia/AceReason-Math + max_input_seq_length: 8192 + prompt_file: examples/prompts/acemath_qwen_cot.txt + shuffle: false + num_workers: 16 +env: + math: + env_cls: nemo_skills.training.nemo_rl.environments.math_environment.MathEnvironment + num_workers: 16 +logger: + monitor_gpus: false +cluster: + gpus_per_node: 8 diff --git a/examples/prompts/acemath_qwen_cot.txt b/examples/prompts/acemath_qwen_cot.txt new file mode 100644 index 0000000000..8992e3eb31 --- /dev/null +++ b/examples/prompts/acemath_qwen_cot.txt @@ -0,0 +1,3 @@ +Solve the following math problem. Make sure to put the answer (and only answer) inside \boxed{{}}. + +{} \ No newline at end of file diff --git a/nemo_rl/data/datasets/response_datasets/__init__.py b/nemo_rl/data/datasets/response_datasets/__init__.py index 8e75a99a0c..2895022343 100644 --- a/nemo_rl/data/datasets/response_datasets/__init__.py +++ b/nemo_rl/data/datasets/response_datasets/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Any +from nemo_rl.data.datasets.response_datasets.acereason_math import AceReasonMathDataset from nemo_rl.data.datasets.response_datasets.clevr import CLEVRCoGenTDataset from nemo_rl.data.datasets.response_datasets.dapo_math import DAPOMath17KDataset from nemo_rl.data.datasets.response_datasets.deepscaler import DeepScalerDataset @@ -79,6 +80,9 @@ def load_response_dataset(data_config, seed: int = 42): "Loading BytedTsinghua-SIA/DAPO-Math-17k for training and AIME 2024 for validation" ) base_dataset: Any = DAPOMath17KDataset(seed=seed) + elif dataset_name == "nvidia/AceReason-Math": + print("Loading nvidia/AceReason-Math for training and validation") + base_dataset: Any = AceReasonMathDataset(seed=seed) # for vlm rl training elif dataset_name == "clevr-cogent": base_dataset: Any = CLEVRCoGenTDataset( @@ -124,6 +128,7 @@ def load_response_dataset(data_config, seed: int = 42): __all__ = [ + "AceReasonMathDataset", "CLEVRCoGenTDataset", "DeepScalerDataset", "DAPOMath17KDataset", diff --git a/nemo_rl/data/datasets/response_datasets/acereason_math.py b/nemo_rl/data/datasets/response_datasets/acereason_math.py new file mode 100644 index 0000000000..46efaeec36 --- /dev/null +++ b/nemo_rl/data/datasets/response_datasets/acereason_math.py @@ -0,0 +1,105 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any + +from datasets import Dataset, load_dataset + +from nemo_rl.data.interfaces import TaskDataSpec + + +def format_acereason_math( + data: dict[str, str | float | int], +) -> dict[str, list[Any] | str]: + """Format AceReason-Math data to the expected message format.""" + return { + "messages": [ + { + "role": "user", + "content": data["problem"], + }, + { + "role": "assistant", + "content": data["answer"], + }, + ], + # For v0.1 release, nemo rl datasets require a task_name key such that user can map a task processor per unique task. + "task_name": "math", + } + + +def extract_dataset(split_name: str, data_split: Any) -> Any: + """Extract dataset split and add task_name field for GRPO compatibility.""" + if data_split is None: + return None + + # Add task_name field to each sample for GRPO compatibility + def add_task_name(example: dict) -> dict: + example["task_name"] = "math" + return example + + return data_split.map(add_task_name) + + +def prepare_acereason_math_dataset(seed: int = 42) -> dict[str, Dataset | None]: + """Load and prepare the AceReason-Math dataset for GRPO training.""" + # Load the AceReason-Math dataset for training + train_ds = load_dataset("nvidia/AceReason-Math", split="train") + + # Load AIME 2024 dataset for validation (following pattern of other math datasets) + val_ds = load_dataset("HuggingFaceH4/aime_2024", split="train") + + # Shuffle the training dataset with the specified seed + train_ds = train_ds.shuffle(seed=seed) + + # Format the examples, removing original columns + train_formatted = train_ds.map( + format_acereason_math, remove_columns=train_ds.column_names + ) + val_formatted = val_ds.map( + format_acereason_math, remove_columns=val_ds.column_names + ) + + formatted_ds_dict = { + "train": extract_dataset("train", train_formatted), + "validation": extract_dataset("validation", val_formatted), + } + + return prepare_math_dataset(formatted_ds_dict) + + +def prepare_math_dataset(formatted_ds_dict: dict[str, Any]) -> dict[str, Any]: + """Prepare math dataset with proper formatting for GRPO.""" + prepared_ds = {} + for split, dataset in formatted_ds_dict.items(): + if dataset is not None: + prepared_ds[split] = dataset + else: + prepared_ds[split] = None + return prepared_ds + + +class AceReasonMathDataset: + def __init__(self, seed: int = 42) -> None: + """Initialize the AceReason-Math dataset with train/validation split. + + Args: + seed: Random seed for reproducible splitting + """ + self.formatted_ds = prepare_acereason_math_dataset(seed=seed) + + self.task_spec = TaskDataSpec( + task_name="AceReason-Math", + ) diff --git a/tests/test_suites/llm/grpo-acereason-math-7b-16K.sh b/tests/test_suites/llm/grpo-acereason-math-7b-16K.sh new file mode 100644 index 0000000000..515447f57a --- /dev/null +++ b/tests/test_suites/llm/grpo-acereason-math-7b-16K.sh @@ -0,0 +1,68 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=100 +MAX_STEPS=1000 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=240 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo_math.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'mean(data["train/token_mult_prob_error"]) < 1.05' \ + "data['train/token_mult_prob_error']['$MAX_STEPS'] < 1.05" +fi + +# Convert 16k checkpoint +uv run examples/converters/convert_dcp_to_hf.py \ + --config=$CKPT_DIR/step_${MAX_STEPS}/config.yaml \ + --dcp-ckpt-path=$CKPT_DIR/step_${MAX_STEPS}/policy/weights \ + --hf-ckpt-path=$CKPT_DIR/grpo-acereason-math-7b-16k-${MAX_STEPS}-hf + +# Run eval on AceReason-Math dataset +uv run examples/run_eval.py \ + generation.model_name=$CKPT_DIR/grpo-acereason-math-7b-16k-${MAX_STEPS}-hf \ + data.prompt_file=examples/prompts/acemath_qwen_cot.txt \ + generation.vllm_cfg.max_model_len=16384 \ + generation.vllm_cfg.enforce_eager=True \ + generation.temperature=1.0 \ + eval.num_tests_per_prompt=16 \ + 2>&1 | tee ${RUN_LOG}.acereason-eval + +cat ${RUN_LOG}.acereason-eval | grep "score=" | sed 's/.*score=\([^ ]*\).*/{"score": \1}/' > ${RUN_LOG}-16k-metric.json + +# Set baseline score for AceReason-Math evaluation (adjust based on expected performance) +uv run tests/check_metrics.py ${RUN_LOG}-16k-metric.json \ + 'data["score"] >= 0.30' # Baseline score to be adjusted based on actual performance + +# Performance tracking comments +# ======================================================== +# deepseek-ai/DeepSeek-R1-Distill-Qwen-7B baseline performance +# ======================================================== +# This section will be updated with baseline performance metrics +# after initial runs to establish proper thresholds diff --git a/tests/test_suites/llm/grpo-acereason-math-7b-24K.sh b/tests/test_suites/llm/grpo-acereason-math-7b-24K.sh new file mode 100644 index 0000000000..145d796574 --- /dev/null +++ b/tests/test_suites/llm/grpo-acereason-math-7b-24K.sh @@ -0,0 +1,76 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=10 +MAX_STEPS=100 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=240 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Use checkpoint created from the 16K checkpoint in grpo-acereason-math-7b-16K.sh +if [[ -z "$NRL_ACEREASON_16K_CKPT" ]]; then + echo "Need to set NRL_ACEREASON_16K_CKPT to the path to the trained 16K checkpoint" + exit 1 +fi + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo_math.py \ + --config $CONFIG_PATH \ + policy.model_name=$NRL_ACEREASON_16K_CKPT \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'mean(data["train/token_mult_prob_error"]) < 1.05' \ + "data['train/token_mult_prob_error']['$MAX_STEPS'] < 1.05" +fi + +# Convert 24k checkpoint +uv run examples/converters/convert_dcp_to_hf.py \ + --config=$CKPT_DIR/step_${MAX_STEPS}/config.yaml \ + --dcp-ckpt-path=$CKPT_DIR/step_${MAX_STEPS}/policy/weights \ + --hf-ckpt-path=$CKPT_DIR/grpo-acereason-math-7b-24k-${MAX_STEPS}-hf + +# Run eval on AceReason-Math dataset +uv run examples/run_eval.py \ + generation.model_name=$CKPT_DIR/grpo-acereason-math-7b-24k-${MAX_STEPS}-hf \ + data.prompt_file=examples/prompts/acemath_qwen_cot.txt \ + generation.vllm_cfg.max_model_len=24576 \ + generation.vllm_cfg.enforce_eager=True \ + generation.temperature=1.0 \ + eval.num_tests_per_prompt=16 \ + 2>&1 | tee ${RUN_LOG}.acereason-24k + +cat ${RUN_LOG}.acereason-24k | grep "score=" | sed 's/.*score=\([^ ]*\).*/{"score": \1}/' > ${RUN_LOG}-24k-metric.json + +# Set baseline score for AceReason-Math evaluation +uv run tests/check_metrics.py ${RUN_LOG}-24k-metric.json \ + 'data["score"] >= 0.30' # Baseline score to be adjusted based on actual performance + +# Performance tracking comments +# ======================================================== +# deepseek-ai/DeepSeek-R1-Distill-Qwen-7B baseline performance +# ======================================================== +# This section will be updated with baseline performance metrics +# after initial runs to establish proper thresholds + diff --git a/tests/test_suites/llm/grpo-acereason-math-7b-32K.sh b/tests/test_suites/llm/grpo-acereason-math-7b-32K.sh new file mode 100644 index 0000000000..4f4ef995f6 --- /dev/null +++ b/tests/test_suites/llm/grpo-acereason-math-7b-32K.sh @@ -0,0 +1,76 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=10 +MAX_STEPS=100 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=240 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Use checkpoint created from the 24K checkpoint in grpo-acereason-math-7b-24K.sh +if [[ -z "$NRL_ACEREASON_24K_CKPT" ]]; then + echo "Need to set NRL_ACEREASON_24K_CKPT to the path to the trained 24K checkpoint" + exit 1 +fi + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo_math.py \ + --config $CONFIG_PATH \ + policy.model_name=$NRL_ACEREASON_24K_CKPT \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'mean(data["train/token_mult_prob_error"]) < 1.05' \ + "data['train/token_mult_prob_error']['$MAX_STEPS'] < 1.05" +fi + +# Convert 32k checkpoint +uv run examples/converters/convert_dcp_to_hf.py \ + --config=$CKPT_DIR/step_${MAX_STEPS}/config.yaml \ + --dcp-ckpt-path=$CKPT_DIR/step_${MAX_STEPS}/policy/weights \ + --hf-ckpt-path=$CKPT_DIR/grpo-acereason-math-7b-32k-${MAX_STEPS}-hf + +# Run eval on AceReason-Math dataset +uv run examples/run_eval.py \ + generation.model_name=$CKPT_DIR/grpo-acereason-math-7b-32k-${MAX_STEPS}-hf \ + data.prompt_file=examples/prompts/acemath_qwen_cot.txt \ + generation.vllm_cfg.max_model_len=32768 \ + generation.vllm_cfg.enforce_eager=True \ + generation.temperature=1.0 \ + eval.num_tests_per_prompt=16 \ + 2>&1 | tee ${RUN_LOG}.acereason-32k + +cat ${RUN_LOG}.acereason-32k | grep "score=" | sed 's/.*score=\([^ ]*\).*/{"score": \1}/' > ${RUN_LOG}-32k-metric.json + +# Set baseline score for AceReason-Math evaluation +uv run tests/check_metrics.py ${RUN_LOG}-32k-metric.json \ + 'data["score"] >= 0.30' # Baseline score to be adjusted based on actual performance + +# Performance tracking comments +# ======================================================== +# deepseek-ai/DeepSeek-R1-Distill-Qwen-7B baseline performance +# ======================================================== +# This section will be updated with baseline performance metrics +# after initial runs to establish proper thresholds + diff --git a/tests/test_suites/llm/grpo-acereason-math-7b-8K.sh b/tests/test_suites/llm/grpo-acereason-math-7b-8K.sh new file mode 100644 index 0000000000..743f1f5af6 --- /dev/null +++ b/tests/test_suites/llm/grpo-acereason-math-7b-8K.sh @@ -0,0 +1,69 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=100 +MAX_STEPS=1000 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=240 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo_math.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'mean(data["train/token_mult_prob_error"]) < 1.05' \ + "data['train/token_mult_prob_error']['$MAX_STEPS'] < 1.05" +fi + +# Convert 8k checkpoint +uv run examples/converters/convert_dcp_to_hf.py \ + --config=$CKPT_DIR/step_${MAX_STEPS}/config.yaml \ + --dcp-ckpt-path=$CKPT_DIR/step_${MAX_STEPS}/policy/weights \ + --hf-ckpt-path=$CKPT_DIR/grpo-acereason-math-7b-8k-${MAX_STEPS}-hf + +# Run eval on AceReason-Math dataset +uv run examples/run_eval.py \ + generation.model_name=$CKPT_DIR/grpo-acereason-math-7b-8k-${MAX_STEPS}-hf \ + data.prompt_file=examples/prompts/acemath_qwen_cot.txt \ + generation.vllm_cfg.max_model_len=8192 \ + generation.vllm_cfg.enforce_eager=True \ + generation.temperature=1.0 \ + eval.num_tests_per_prompt=16 \ + 2>&1 | tee ${RUN_LOG}.acereason-8k + +cat ${RUN_LOG}.acereason-8k | grep "score=" | sed 's/.*score=\([^ ]*\).*/{"score": \1}/' > ${RUN_LOG}-8k-metric.json + +# Set baseline score for AceReason-Math evaluation +uv run tests/check_metrics.py ${RUN_LOG}-8k-metric.json \ + 'data["score"] >= 0.30' # Baseline score to be adjusted based on actual performance + +# Performance tracking comments +# ======================================================== +# deepseek-ai/DeepSeek-R1-Distill-Qwen-7B baseline performance +# ======================================================== +# This section will be updated with baseline performance metrics +# after initial runs to establish proper thresholds +