Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions resources_servers/arc_agi/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# ARC-AGI resources server

launch local vllm server
```bash
vllm serve Qwen/Qwen3-30B-A3B \
--dtype auto \
--tensor-parallel-size 8 \
--gpu-memory-utilization 0.9 \
--enable-auto-tool-choice --tool-call-parser hermes \
--host 0.0.0.0 \
--port 10240
```

Start ARC-AGI environment:
```bash
ng_run "+config_paths=[resources_servers/arc_agi/configs/arc_agi.yaml,responses_api_models/vllm_model/configs/vllm_model.yaml]"
```

or ARC-AGI-2 environment:
```bash
ng_run "+config_paths=[resources_servers/arc_agi/configs/arc_agi_2.yaml,responses_api_models/vllm_model/configs/vllm_model.yaml]"
```


collect rollouts:

ARC-AGI-1 example rollouts
```bash
ng_collect_rollouts +agent_name=arc_agi_simple_agent +input_jsonl_fpath=resources_servers/arc_agi/data/example_1.jsonl +output_jsonl_fpath=resources_servers/arc_agi/data/example_1_rollouts.jsonl +limit=5 +num_repeats=null +num_samples_in_parallel=null
```

ARC-AGI-2 example rollouts:
```bash
ng_collect_rollouts +agent_name=arc_agi_2_simple_agent +input_jsonl_fpath=resources_servers/arc_agi/data/example_2.jsonl +output_jsonl_fpath=resources_servers/arc_agi/data/example_2_rollouts.jsonl +limit=5 +num_repeats=null +num_samples_in_parallel=null
```

ARC-AGI-1 train set rollouts (400 problems):
```bash
ng_collect_rollouts +agent_name=arc_agi_simple_agent +input_jsonl_fpath=resources_servers/arc_agi/data/arc_agi_1_training.jsonl +output_jsonl_fpath=resources_servers/arc_agi/data/arc_agi_1_training_rollouts.jsonl +limit=null +num_repeats=null +num_samples_in_parallel=null
```

ARC-AGI-1 eval set rollouts (400 problems):
```bash
ng_collect_rollouts +agent_name=arc_agi_simple_agent +input_jsonl_fpath=resources_servers/arc_agi/data/arc_agi_1_evaluation.jsonl +output_jsonl_fpath=resources_servers/arc_agi/data/arc_agi_1_evaluation_rollouts.jsonl +limit=null +num_repeats=null +num_samples_in_parallel=null
```

ARC-AGI-2 train set rollouts (1000 problems):
```bash
ng_collect_rollouts +agent_name=arc_agi_2_simple_agent +input_jsonl_fpath=resources_servers/arc_agi/data/arc_agi_2_training.jsonl +output_jsonl_fpath=resources_servers/arc_agi/data/arc_agi_2_training_rollouts.jsonl +limit=null +num_repeats=null +num_samples_in_parallel=null
```

ARC-AGI-2 eval set rollouts (120 problems):
```bash
ng_collect_rollouts +agent_name=arc_agi_2_simple_agent +input_jsonl_fpath=resources_servers/arc_agi/data/arc_agi_2_evaluation.jsonl +output_jsonl_fpath=resources_servers/arc_agi/data/arc_agi_2_evaluation_rollouts.jsonl +limit=null +num_repeats=null +num_samples_in_parallel=null
```

run tests:
```bash
ng_test +entrypoint=resources_servers/arc_agi
```
115 changes: 115 additions & 0 deletions resources_servers/arc_agi/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import re
from typing import List, Optional

from fastapi import FastAPI

from nemo_gym.base_resources_server import (
BaseResourcesServerConfig,
BaseRunRequest,
BaseVerifyRequest,
BaseVerifyResponse,
SimpleResourcesServer,
)


class ARCAGIResourcesServerConfig(BaseResourcesServerConfig):
pass


class ARCAGIRunRequest(BaseRunRequest):
train: List[dict] = []
test_input: List[List[int]] = []
expected_output: List[List[int]] = []
task_id: Optional[str] = None


class ARCAGIVerifyRequest(ARCAGIRunRequest, BaseVerifyRequest):
pass


class ARCAGIVerifyResponse(BaseVerifyResponse):
expected_output: List[List[int]]
predicted_output: Optional[List[List[int]]] = None
extraction_successful: bool = False


def _extract_assistant_text(body: BaseVerifyRequest) -> str:
texts = []
for output in body.response.output:
if getattr(output, "type", None) == "message" and getattr(output, "role", None) == "assistant":
content = getattr(output, "content", None)
if isinstance(content, list):
for part in content:
text = getattr(part, "text", None)
if isinstance(text, str):
texts.append(text)
elif isinstance(content, str):
texts.append(content)
return "\n".join(texts).strip()


def _parse_grid(text: str) -> Optional[List[List[int]]]:
"""expects format: \\boxed{[[1,2,3],[4,5,6]]}"""
boxed_pattern = r"\\boxed\{(\[\s*\[[\d\s,\[\]]+\]\s*\])\}"
boxed_matches = re.findall(boxed_pattern, text, re.DOTALL)

if not boxed_matches:
boxed_matches = re.findall(r"\[\s*\[[\d\s,\[\]]+\]\s*\]", text, re.DOTALL)

for match in boxed_matches:
try:
cleaned = re.sub(r"\s+", "", match)
grid = json.loads(cleaned)

if (
isinstance(grid, list)
and all(isinstance(row, list) and all(isinstance(cell, int) for cell in row) for row in grid)
and len(grid) > 0
and len(grid[0]) > 0
):
return grid
except (json.JSONDecodeError, IndexError, TypeError):
continue

return None


class ARCAGIResourcesServer(SimpleResourcesServer):
config: ARCAGIResourcesServerConfig

def setup_webserver(self) -> FastAPI:
app = super().setup_webserver()
return app

async def verify(self, body: ARCAGIVerifyRequest) -> ARCAGIVerifyResponse:
assistant_text = _extract_assistant_text(body)
predicted_grid = _parse_grid(assistant_text)

extraction_successful = predicted_grid is not None
reward = 1.0 if extraction_successful and predicted_grid == body.expected_output else 0.0

return ARCAGIVerifyResponse(
**body.model_dump(),
reward=reward,
predicted_output=predicted_grid,
extraction_successful=extraction_successful,
)


if __name__ == "__main__":
ARCAGIResourcesServer.run_webserver()
37 changes: 37 additions & 0 deletions resources_servers/arc_agi/configs/arc_agi.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
arc_agi_resources_server:
resources_servers:
arc_agi:
entrypoint: app.py
domain: knowledge
verified: false
arc_agi_simple_agent:
responses_api_agents:
simple_agent:
entrypoint: app.py
resources_server:
type: resources_servers
name: arc_agi_resources_server
model_server:
type: responses_api_models
name: policy_model
datasets:
- name: example
type: example
jsonl_fpath: resources_servers/arc_agi/data/example.jsonl
- name: training_1
type: validation
jsonl_fpath: resources_servers/arc_agi/data/arc_agi_1_training.jsonl
gitlab_identifier:
dataset_name: arc_agi
version: 0.0.1
artifact_fpath: arc_agi_1_training.jsonl
license: Apache 2.0
- name: evaluation_1
type: validation
jsonl_fpath: resources_servers/arc_agi/data/arc_agi_1_evaluation.jsonl
gitlab_identifier:
dataset_name: arc_agi
version: 0.0.1
artifact_fpath: arc_agi_1_evaluation.jsonl
license: Apache 2.0

28 changes: 28 additions & 0 deletions resources_servers/arc_agi/configs/arc_agi_2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
arc_agi_2:
resources_servers:
arc_agi:
entrypoint: app.py
domain: knowledge
verified: false
arc_agi_2_simple_agent:
responses_api_agents:
simple_agent:
entrypoint: app.py
host: 127.0.0.1
port: 15215
resources_server:
type: resources_servers
name: arc_agi_2
model_server:
type: responses_api_models
name: policy_model
datasets:
- name: example_2
type: example
jsonl_fpath: resources_servers/arc_agi/data/example_2.jsonl
- name: training_2
type: validation
jsonl_fpath: resources_servers/arc_agi/data/arc_agi_2_training.jsonl
- name: evaluation_2
type: validation
jsonl_fpath: resources_servers/arc_agi/data/arc_agi_2_evaluation.jsonl
128 changes: 128 additions & 0 deletions resources_servers/arc_agi/create_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
from pathlib import Path


def format_grid(grid):
return "\n".join([" ".join(map(str, row)) for row in grid])


def create_arc_prompt(task_data, task_id, version=1):
prompt = f"You are solving ARC-AGI{'-' + str(version) if version != 1 else ''} task {task_id}.\n\n"
prompt += "Here are the training examples that demonstrate the pattern:\n\n"

for i, example in enumerate(task_data["train"]):
prompt += f"Example {i + 1}:\n"
prompt += "Input:\n"
prompt += format_grid(example["input"])
prompt += "\n\nOutput:\n"
prompt += format_grid(example["output"])
prompt += "\n\n"

test_input = task_data["test"][0]["input"]
prompt += "Now solve this test case following the same pattern:\n"
prompt += "Test Input:\n"
prompt += format_grid(test_input)
prompt += (
"\n\nProvide your solution as a 2D array inside \\boxed{} in this exact format: \\boxed{[[row1],[row2],...]}"
)
prompt += "\nFor example: \\boxed{[[1,2,3],[4,5,6],[7,8,9]]}"

return prompt


def create_dataset(version=1):
data_base = f"../../ARC-AGI{'-' + str(version) if version != 1 else ''}"
training_dir = Path(f"{data_base}/data/training")
evaluation_dir = Path(f"{data_base}/data/evaluation")

Path("data").mkdir(exist_ok=True)

training_dataset = []
print(f"Processing {len(list(training_dir.glob('*.json')))} training tasks...") # 400 tasks

for task_file in sorted(training_dir.glob("*.json")):
task_id = task_file.stem

with open(task_file) as f:
task_data = json.load(f)

prompt = create_arc_prompt(task_data, task_id, version)
expected_output = task_data["test"][0]["output"]
test_input = task_data["test"][0]["input"]

entry = {
"responses_create_params": {"input": [{"role": "user", "content": prompt}]},
"train": task_data["train"],
"test_input": test_input,
"expected_output": expected_output,
"task_id": task_id,
}

training_dataset.append(entry)

training_output_file = Path(f"data/arc_agi_{version}_training.jsonl")
with open(training_output_file, "w") as f:
for entry in training_dataset:
f.write(json.dumps(entry) + "\n")

print(f"Created training dataset with {len(training_dataset)} tasks at {training_output_file}")

evaluation_dataset = []
print(f"Processing {len(list(evaluation_dir.glob('*.json')))} evaluation tasks...") # 400 tasks

for task_file in sorted(evaluation_dir.glob("*.json")):
task_id = task_file.stem

with open(task_file) as f:
task_data = json.load(f)

prompt = create_arc_prompt(task_data, task_id, version)
expected_output = task_data["test"][0]["output"]
test_input = task_data["test"][0]["input"]

entry = {
"responses_create_params": {"input": [{"role": "user", "content": prompt}]},
"train": task_data["train"],
"test_input": test_input,
"expected_output": expected_output,
"task_id": task_id,
}

evaluation_dataset.append(entry)

evaluation_output_file = Path(f"data/arc_agi_{version}_evaluation.jsonl")
with open(evaluation_output_file, "w") as f:
for entry in evaluation_dataset:
f.write(json.dumps(entry) + "\n")

print(f"Created evaluation dataset with {len(evaluation_dataset)} tasks at {evaluation_output_file}")

example_output_file = Path(f"data/example_{version}.jsonl")
with open(example_output_file, "w") as f:
for entry in evaluation_dataset[:5]:
f.write(json.dumps(entry) + "\n")

print(f"Created example dataset with 5 tasks at {example_output_file}")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Create ARC-AGI dataset")
parser.add_argument("--version", type=int, default=1, choices=[1, 2], help="ARC-AGI version (1 or 2)")
args = parser.parse_args()

create_dataset(version=args.version)
Loading