Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions resources_servers/swerl_gen/README.md

Large diffs are not rendered by default.

178 changes: 178 additions & 0 deletions resources_servers/swerl_gen/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import json
import time
from asyncio import Semaphore
from typing import Any, Optional

from fastapi import FastAPI

from nemo_gym.base_resources_server import (
BaseResourcesServerConfig,
BaseRunRequest,
BaseVerifyRequest,
BaseVerifyResponse,
SimpleResourcesServer,
)
from resources_servers.swerl_gen.eval.process_patch import (
extract_pred_patch,
extract_pred_patch_relaxed_formatting,
extract_repro_test,
)
from resources_servers.swerl_gen.eval.singularity_utils import (
compute_score,
)


class SWEGenResourcesServerConfig(BaseResourcesServerConfig):
num_processes: int = 1
sandbox_timeout: int = 600
debug: bool = False
relaxed_formatting: bool = False


class SWEGenRunRequest(BaseRunRequest):
instance: dict[
str, Any
] ## dictionary keys: instance_id, repo, setup_script, test_script, regression_script, PASS_TO_PASS, FAIL_TO_PASS, patch
dataset_name: Optional[str] = None
dataset_split: Optional[str] = None
metadata: dict[str, Any] = {} ## keys: relevant_file_contents, remove_repo_name, image
partial_similarity: Optional[bool] = None
mode: str = "eval" ## eval or repro-gen


class SWEGenVerifyRequest(SWEGenRunRequest, BaseVerifyRequest):
pass


class SWEGenVerifyResponse(BaseVerifyResponse):
verification_result: Optional[dict[str, Any]] = None
verification_time: Optional[float] = None
model_patch: Optional[str] = None
repro_test_info_base64: Optional[str] = None
model_output: Optional[str] = None


def _extract_last_assistant_text(body: BaseVerifyRequest) -> str:
"""Extract the last assistant message's text from the NeMo Gym response."""
texts: list[str] = []
for o in body.response.output:
if getattr(o, "type", None) == "message" and getattr(o, "role", None) == "assistant":
content = getattr(o, "content", None)
if isinstance(content, list):
for c in content:
t = getattr(c, "text", None)
if isinstance(t, str):
texts.append(t)
elif isinstance(content, str):
texts.append(content)
return "\n".join(texts).strip()


class SWEGenResourcesServer(SimpleResourcesServer):
config: SWEGenResourcesServerConfig

def setup_webserver(self) -> FastAPI:
app = super().setup_webserver()
return app

def model_post_init(self, context):
self._semaphore: Semaphore = Semaphore(value=self.config.num_processes)

async def verify(self, body: SWEGenVerifyRequest) -> SWEGenVerifyResponse:
# Extract full model output text (including <think> and <solution> blocks).
predict_str = _extract_last_assistant_text(body)
if not predict_str or not predict_str.strip():
return SWEGenVerifyResponse(
**body.model_dump(),
reward=0.0,
)

# Extract the predicted patch or reproduction test info from the model output.
if body.mode == "repro-gen":
try:
extracted_data = extract_repro_test(predict_str, body.instance["instance_id"])
except Exception:
extracted_data = None
if extracted_data is None:
return SWEGenVerifyResponse(
**body.model_dump(),
reward=0.0,
model_output=predict_str,
)
patch_str = body.instance["patch"]
repro_test_info_base64 = extracted_data["repro_test_info_base64"]
elif body.mode == "eval":
try:
if self.config.relaxed_formatting:
extracted_data = extract_pred_patch_relaxed_formatting(
json.loads(body.metadata["relevant_file_contents"]),
predict_str,
body.metadata["remove_repo_name"],
)
else:
extracted_data = extract_pred_patch(
json.loads(body.metadata["relevant_file_contents"]),
predict_str,
body.metadata["remove_repo_name"],
)
except Exception:
extracted_data = None
if extracted_data is None:
return SWEGenVerifyResponse(
**body.model_dump(),
reward=0.0,
model_output=predict_str,
)
patch_str = extracted_data["model_patch"]
repro_test_info_base64 = None
else:
raise ValueError(f"Invalid mode: {body.mode}")

extra_info = {
"instance_info": body.instance,
"image": body.metadata["image"],
}
extra_info_base64 = base64.b64encode(json.dumps(extra_info).encode()).decode()

async with self._semaphore:
start_time = time.time()
task_args = (
extra_info_base64,
patch_str,
repro_test_info_base64,
body.mode,
self.config.sandbox_timeout,
self.config.debug,
)
future = compute_score.remote(*task_args)
reward, verification_result = await future
verification_time = time.time() - start_time

return SWEGenVerifyResponse(
**body.model_dump(),
reward=float(reward),
verification_result=verification_result,
verification_time=verification_time,
model_patch=patch_str,
repro_test_info_base64=repro_test_info_base64,
model_output=predict_str,
)


if __name__ == "__main__":
SWEGenResourcesServer.run_webserver()
46 changes: 46 additions & 0 deletions resources_servers/swerl_gen/configs/swerl_gen.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
swerl_gen_resources_server:
resources_servers:
swerl_gen:
entrypoint: app.py
domain: coding
verified: false
description: Running sandboxed evaluation for SWE-style tasks (either patch generation or reproduction test generation)
value: Improve SWE capabilities useful for benchmarks like SWE-bench
env: singularity
num_processes: 2048
sandbox_timeout: 900
debug: false
relaxed_formatting: false
swerl_gen_simple_agent:
responses_api_agents:
simple_agent:
entrypoint: app.py
resources_server:
type: resources_servers
name: swerl_gen_resources_server
model_server:
type: responses_api_models
name: policy_model
datasets:
- name: train
type: train
jsonl_fpath: resources_servers/swerl_gen/data/train_swebenchverified_n32768.jsonl
num_repeats: 1
gitlab_identifier:
dataset_name: swerl_gen
version: 0.0.1
artifact_fpath: train_swebenchverified_n32768.jsonl
license: Apache 2.0
- name: validation
type: validation
jsonl_fpath: resources_servers/swerl_gen/data/validation_gym690_curriculum2.jsonl
num_repeats: 1
gitlab_identifier:
dataset_name: swerl_gen
version: 0.0.1
artifact_fpath: validation_gym690_curriculum2.jsonl
license: Apache 2.0
- name: example
type: example
jsonl_fpath: resources_servers/swerl_gen/data/example.jsonl
num_repeats: 1
5 changes: 5 additions & 0 deletions resources_servers/swerl_gen/data/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
*train.jsonl
*validation.jsonl
*train_prepare.jsonl
*validation_prepare.jsonl
*example_prepare.jsonl
5 changes: 5 additions & 0 deletions resources_servers/swerl_gen/data/example.jsonl

Large diffs are not rendered by default.

47 changes: 47 additions & 0 deletions resources_servers/swerl_gen/data/example_metrics.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
{
"Number of examples": 5,
"Number of tools": {
"Total # non-null values": 0,
"Average": 0.0,
"Min": 0.0,
"Max": 0.0,
"Median": 0.0,
"Standard deviation": 0.0
},
"Json-dumped number of words (proxy for token count)": {
"Total # non-null values": 5,
"Average": 752.0,
"Min": 532.0,
"Max": 935.0,
"Median": 764.0,
"Standard deviation": 149.52
},
"Number of turns": {
"Total # non-null values": 5,
"Average": 1.0,
"Min": 1.0,
"Max": 1.0,
"Median": 1.0,
"Standard deviation": 0.0
},
"Temperature": {
"Total # non-null values": 0,
"Average": 0.0,
"Min": 0.0,
"Max": 0.0,
"Median": 0.0,
"Standard deviation": 0.0
},
"dataset_name": {
"unique_count": 1,
"total_count": 5
},
"dataset_split": {
"unique_count": 1,
"total_count": 5
},
"mode": {
"unique_count": 2,
"total_count": 5
}
}
5 changes: 5 additions & 0 deletions resources_servers/swerl_gen/data/example_rollouts.jsonl

Large diffs are not rendered by default.

Loading