diff --git a/tests/test_suites/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.sh b/tests/test_suites/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.sh index aea8c91747..94781e4931 100755 --- a/tests/test_suites/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.sh +++ b/tests/test_suites/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.sh @@ -4,8 +4,8 @@ source $SCRIPT_DIR/common.env # ===== BEGIN CONFIG ===== NUM_NODES=1 -STEPS_PER_RUN=500 -MAX_STEPS=500 +STEPS_PER_RUN=400 +MAX_STEPS=400 NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up NUM_MINUTES=120 # ===== END CONFIG ===== @@ -35,5 +35,5 @@ uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then uv run tests/check_metrics.py $JSON_METRICS \ 'mean(data["train/token_mult_prob_error"]) < 1.1' \ - 'data["train/token_mult_prob_error"]["500"] < 1.1' + "data[\"train/token_mult_prob_error\"][\"${MAX_STEPS}\"] < 1.1" fi