diff --git a/conf/actor/web.yaml b/conf/actor/web.yaml
index d2617e95..6925873e 100644
--- a/conf/actor/web.yaml
+++ b/conf/actor/web.yaml
@@ -3,9 +3,12 @@ llm_max_rollouts: 128
rollout_workers: 1
rollout_policy: pipelinerl.domains.deep_research.tapeagents_rollouts.generate_rollout
-environment:
- _target_: tapeagents.mcp.MCPEnvironment
- config_path: conf/mcp/web.json
+environments:
+ - key: mcp
+ mode: embedded
+ _target_: tapeagents.mcp.MCPEnvironment
+ config_path: conf/mcp/web.json
+environment_key: mcp
llm:
_target_: tapeagents.llms.LiteLLM
@@ -105,4 +108,4 @@ only_tasks: #[] # list of (level, task_num)
- [1, 4]
- [1, 5]
- [1, 6]
-- [1, 7]
\ No newline at end of file
+- [1, 7]
diff --git a/conf/base.yaml b/conf/base.yaml
index e3122f5a..31c0b93b 100644
--- a/conf/base.yaml
+++ b/conf/base.yaml
@@ -2,6 +2,7 @@ defaults:
- finetune: base
- rewards: pure_success
- streams: files
+ - domain_mix: null
- _self_
seed: 42
@@ -18,6 +19,7 @@ actor:
result_queue_size: 64
throughput_window_size: 50
shared_memory_entry_size: 10000000
+ domain_mix: null
environment: null
preprocess:
input: actor
@@ -135,4 +137,3 @@ wandb:
wandb_dir: null
# Comma-separated list of keywords to tag the run.
tags: []
-
diff --git a/conf/coding.yaml b/conf/coding.yaml
new file mode 100644
index 00000000..7788bf7b
--- /dev/null
+++ b/conf/coding.yaml
@@ -0,0 +1,55 @@
+defaults:
+ - base
+ - _self_
+
+actor:
+ rollout_policy: pipelinerl.domains.coding.generate_coding_rollout
+ system_prompt: |-
+ You are an expert Python programmer. When providing code solutions, format your final code inside markdown code blocks using triple backticks with the python language identifier, like this:
+ ```python
+ # your code here
+ ```
+ Provide complete, working implementations that pass all test cases.
+ task_template: |-
+ {task}
+ task_prompt: ""
+ ensure_boxed_answers: false
+
+ coding_time_limit_s: 15.0
+ coding_per_test_timeout_s: 10.0
+ coding_memory_limit_bytes: 1073741824
+ coding_compile_timeout_s: 10.0
+ coding_sandbox_url: ${oc.env:CODING_SANDBOX_URL, "http://sandbox:8080/run_code"}
+
+dataset_loader: pipelinerl.domains.coding.dataset.load_problems
+dataset_loader_params:
+ dataset_id: ServiceNow-AI/mixed-training-text-datasets
+ dataset_config: 80k-if-math-coding-fncalling-stem
+ split_ratios:
+ train: 0.9
+ validation: 0.05
+ test: 0.05
+ allowed_call_types:
+ - assert
+ - std
+ max_examples_per_split: 2048
+ trust_remote_code: true
+ huggingface_token: ${oc.env:CODING_HF_TOKEN, null}
+
+train_dataset_names:
+ - coding@train
+
+test_dataset_names:
+ - coding@validation
+
+environments:
+ - key: coding
+ mode: remote
+ _target_: pipelinerl.domains.coding.CodingSandboxEnvironment
+ sandbox_url: ${actor.coding_sandbox_url}
+ compile_timeout_s: ${actor.coding_compile_timeout_s}
+ run_timeout_s: ${actor.coding_per_test_timeout_s}
+ request_timeout_s: ${actor.coding_time_limit_s}
+ memory_limit_bytes: ${actor.coding_memory_limit_bytes}
+
+environment_key: coding
diff --git a/conf/debug/multi_domain.yaml b/conf/debug/multi_domain.yaml
new file mode 100644
index 00000000..03b4c56e
--- /dev/null
+++ b/conf/debug/multi_domain.yaml
@@ -0,0 +1,30 @@
+defaults:
+ - base
+ - domain_rollouts: base
+ - override rewards: success_and_format
+ - _self_
+
+actor:
+ rollout_policy: pipelinerl.domains.dispatcher.generate_multidomain_rollout
+ llm_max_rollouts: 2
+ rollout_workers: 1
+ domain_rollouts:
+ math: ${domain_rollouts.math}
+ guessing: ${domain_rollouts.guessing}
+ coding: ${domain_rollouts.coding}
+
+dataset_loader: pipelinerl.domains.multidomain.load_problems
+train_dataset_names:
+ - math_debug
+ - guessing_debug
+ - coding_debug
+test_dataset_names:
+ - math_debug
+ - coding_debug
+
+environment: null
+environment_key: null
+
+world:
+ env_replicas_per_actor: 0
+ environment_mode: embedded
diff --git a/conf/domain_mix/README.md b/conf/domain_mix/README.md
new file mode 100644
index 00000000..2143b288
--- /dev/null
+++ b/conf/domain_mix/README.md
@@ -0,0 +1,12 @@
+# Domain mix presets
+
+Hydra group `domain_mix` stores reusable presets for `actor.domain_mix`.
+
+Usage examples:
+
+```
+python main.py --config-name multi_domain/base +domain_mix=math_coding_70_30
+python main.py --config-name multi_domain/base +domain_mix=balanced
+```
+
+Override or extend these presets by creating new files under `conf/domain_mix/`.
diff --git a/conf/domain_mix/balanced.yaml b/conf/domain_mix/balanced.yaml
new file mode 100644
index 00000000..7b2a4529
--- /dev/null
+++ b/conf/domain_mix/balanced.yaml
@@ -0,0 +1,9 @@
+# @package actor.domain_mix
+
+math: 1.0
+guessing: 1.0
+counting: 1.0
+chartqa: 1.0
+miniwob: 1.0
+coding: 1.0
+fn_calling: 1.0
diff --git a/conf/domain_mix/coding_heavy.yaml b/conf/domain_mix/coding_heavy.yaml
new file mode 100644
index 00000000..ff102abd
--- /dev/null
+++ b/conf/domain_mix/coding_heavy.yaml
@@ -0,0 +1,4 @@
+# @package actor.domain_mix
+
+math: 0.3
+coding: 0.7
diff --git a/conf/domain_mix/main_mix.yaml b/conf/domain_mix/main_mix.yaml
new file mode 100644
index 00000000..06038d0b
--- /dev/null
+++ b/conf/domain_mix/main_mix.yaml
@@ -0,0 +1,5 @@
+# @package actor.domain_mix
+
+math: 0.4
+coding: 0.3
+fn_calling: 0.3
\ No newline at end of file
diff --git a/conf/domain_mix/math_coding_70_30.yaml b/conf/domain_mix/math_coding_70_30.yaml
new file mode 100644
index 00000000..7ea64468
--- /dev/null
+++ b/conf/domain_mix/math_coding_70_30.yaml
@@ -0,0 +1,4 @@
+# @package actor.domain_mix
+
+math: 0.7
+coding: 0.3
diff --git a/conf/domain_rollouts/base.yaml b/conf/domain_rollouts/base.yaml
new file mode 100644
index 00000000..e744acba
--- /dev/null
+++ b/conf/domain_rollouts/base.yaml
@@ -0,0 +1,8 @@
+# Mapping between domain identifiers and rollout callables.
+math: pipelinerl.domains.math.generate_math_rollout
+guessing: pipelinerl.domains.guessing.generate_guessing_rollout
+counting: pipelinerl.domains.counting.generate_counting_rollout
+miniwob: pipelinerl.domains.miniwob.rollouts.generate_miniwob_rollout
+chartqa: pipelinerl.domains.chartqa.generate_chartqa_rollout
+coding: pipelinerl.domains.coding.generate_coding_rollout
+fn_calling: pipelinerl.domains.fn_calling.generate_fn_calling_rollout
diff --git a/conf/fn_calling.yaml b/conf/fn_calling.yaml
new file mode 100644
index 00000000..6eeb898c
--- /dev/null
+++ b/conf/fn_calling.yaml
@@ -0,0 +1,36 @@
+defaults:
+ - base
+ - _self_
+
+actor:
+ rollout_policy: pipelinerl.domains.fn_calling.generate_fn_calling_rollout
+ system_prompt: ""
+ task_template: "{task}"
+ task_prompt: ""
+ ensure_boxed_answers: false
+
+dataset_loader: pipelinerl.domains.fn_calling.dataset.load_problems
+dataset_loader_params:
+ dataset_id: ServiceNow-AI/mixed-training-text-datasets
+ dataset_config: 80k-if-math-coding-fncalling-stem
+ split_ratios:
+ train: 0.9
+ validation: 0.05
+ test: 0.05
+ allowed_call_types: []
+ max_examples_per_split: 2048
+ trust_remote_code: true
+ huggingface_token: ${oc.env:CODING_HF_TOKEN, null}
+
+train_dataset_names:
+ - fn_calling@train
+
+test_dataset_names:
+ - fn_calling@validation
+
+environments:
+ - key: fn_calling
+ mode: remote
+ _target_: pipelinerl.domains.fn_calling.AgenticToolsEnvironment
+
+environment_key: fn_calling
diff --git a/conf/math.yaml b/conf/math.yaml
index 069aa96b..d9a07c02 100644
--- a/conf/math.yaml
+++ b/conf/math.yaml
@@ -5,10 +5,13 @@ defaults:
actor:
rollout_policy: pipelinerl.domains.math.generate_math_rollout
system_prompt: Please reason step by step, and put your final answer within \boxed{}.
- task_template: |-
- {task}
-environment:
- _target_: pipelinerl.domains.math.MathEnvironment
+ task_template: "{task}"
+ task_prompt: ""
+environments:
+ - key: math
+ mode: remote
+ _target_: pipelinerl.domains.math.MathEnvironment
+environment_key: math
dataset_loader: pipelinerl.domains.math.load_datasets
train_dataset_names:
- open_reasoner_zero_57k
@@ -16,4 +19,4 @@ train_dataset_names:
test_dataset_names:
- aime_2024
- amc_2023
- - math_500
\ No newline at end of file
+ - math_500
diff --git a/conf/math_code.yaml b/conf/math_code.yaml
new file mode 100644
index 00000000..7a23a7e9
--- /dev/null
+++ b/conf/math_code.yaml
@@ -0,0 +1,69 @@
+defaults:
+ - base
+ - /domain_rollouts@domain_rollouts: base
+ - domain_mix: math_coding_70_30
+ - _self_
+
+actor:
+ rollout_policy: pipelinerl.domains.dispatcher.generate_multidomain_rollout
+ system_prompt: ""
+ task_template: |-
+ {task}
+ task_prompt: ""
+ ensure_boxed_answers: false
+ domain_rollouts:
+ math: ${domain_rollouts.math}
+ coding: ${domain_rollouts.coding}
+ coding_time_limit_s: 15.0
+ coding_per_test_timeout_s: 10.0
+ coding_memory_limit_bytes: 1073741824
+ coding_compile_timeout_s: 10.0
+ coding_sandbox_url: ${oc.env:CODING_SANDBOX_URL, "http://sandbox:8080/run_code"}
+
+dataset_loader: pipelinerl.domains.multidomain.loader.load_datasets
+dataset_loader_params:
+ per_domain_params:
+ coding:
+ dataset_id: ServiceNow-AI/mixed-training-text-datasets
+ dataset_config: 80k-if-math-coding-fncalling-stem
+ split_ratios:
+ train: 0.9
+ validation: 0.05
+ test: 0.05
+ allowed_call_types:
+ - assert
+ - std
+ max_examples_per_split: 2048
+ trust_remote_code: true
+ huggingface_token: ${oc.env:CODING_HF_TOKEN, null}
+
+train_dataset_names:
+ - math::open_reasoner_zero_57k
+ - math::open_reasoner_zero_extended_72k
+ - coding::coding@train
+
+test_dataset_names:
+ - math::aime_2024
+ - math::amc_2023
+ - math::math_500
+ - coding::coding@validation
+
+environments:
+ - key: math
+ mode: remote
+ replicas_per_actor: ${world.env_replicas_per_actor}
+ _target_: pipelinerl.domains.math.MathEnvironment
+ - key: coding
+ mode: remote
+ replicas_per_actor: ${world.env_replicas_per_actor}
+ _target_: pipelinerl.domains.coding.CodingSandboxEnvironment
+ sandbox_url: ${actor.coding_sandbox_url}
+ compile_timeout_s: ${actor.coding_compile_timeout_s}
+ run_timeout_s: ${actor.coding_per_test_timeout_s}
+ request_timeout_s: ${actor.coding_time_limit_s}
+ memory_limit_bytes: ${actor.coding_memory_limit_bytes}
+
+environment_key: null
+
+world:
+ env_replicas_per_actor: 1
diff --git a/conf/multi_domain/base.yaml b/conf/multi_domain/base.yaml
new file mode 100644
index 00000000..cda10faa
--- /dev/null
+++ b/conf/multi_domain/base.yaml
@@ -0,0 +1,92 @@
+# @package _global_
+defaults:
+ - /domain_rollouts@domain_rollouts: base
+ - domain_mix: null
+
+actor:
+ rollout_policy: pipelinerl.domains.dispatcher.generate_multidomain_rollout
+ system_prompt: ""
+ task_template: |-
+ {task}
+ task_prompt: ""
+ ensure_boxed_answers: false
+ domain_mix: null
+ # Domain-specific system prompts (used when actor.system_prompt is empty)
+ domain_system_prompts:
+ coding: |-
+ You are an expert Python programmer. When providing code solutions, format your final code inside markdown code blocks using triple backticks with the python language identifier, like this:
+ ```python
+ # your code here
+ ```
+ Provide complete, working implementations that pass all test cases.
+ fn_calling: ""
+ math: ""
+ domain_rollouts:
+ math: ${domain_rollouts.math}
+ guessing: ${domain_rollouts.guessing}
+ counting: ${domain_rollouts.counting}
+ chartqa: ${domain_rollouts.chartqa}
+ miniwob: ${domain_rollouts.miniwob}
+ coding: ${domain_rollouts.coding}
+ fn_calling: ${domain_rollouts.fn_calling}
+ coding_time_limit_s: 15.0
+ coding_per_test_timeout_s: 10.0
+ coding_memory_limit_bytes: 1073741824
+ coding_compile_timeout_s: 10.0
+ coding_sandbox_url: ${oc.env:CODING_SANDBOX_URL, "http://sandbox:8080/run_code"}
+
+dataset_loader: pipelinerl.domains.multidomain.loader.load_datasets
+dataset_loader_params:
+ per_domain_params:
+ coding:
+ dataset_id: ServiceNow-AI/mixed-training-text-datasets
+ dataset_config: 80k-if-math-coding-fncalling-stem
+ split_ratios:
+ train: 0.9
+ validation: 0.05
+ test: 0.05
+ allowed_call_types:
+ - assert
+ - std
+ max_examples_per_split: 2048
+ trust_remote_code: true
+ huggingface_token: ${oc.env:CODING_HF_TOKEN, null}
+ fn_calling:
+ dataset_id: ServiceNow-AI/mixed-training-text-datasets
+ dataset_config: 80k-if-math-coding-fncalling-stem
+ split_ratios:
+ train: 0.9
+ validation: 0.05
+ test: 0.05
+ allowed_call_types: []
+ max_examples_per_split: 2048
+ trust_remote_code: true
+ huggingface_token: ${oc.env:CODING_HF_TOKEN, null}
+
+train_dataset_names: []
+test_dataset_names: []
+
+environments:
+ - key: math
+ mode: remote
+ replicas_per_actor: ${world.env_replicas_per_actor}
+ _target_: pipelinerl.domains.math.MathEnvironment
+ - key: coding
+ mode: remote
+ replicas_per_actor: ${world.env_replicas_per_actor}
+ _target_: pipelinerl.domains.coding.CodingSandboxEnvironment
+ sandbox_url: ${actor.coding_sandbox_url}
+ compile_timeout_s: ${actor.coding_compile_timeout_s}
+ run_timeout_s: ${actor.coding_per_test_timeout_s}
+ request_timeout_s: ${actor.coding_time_limit_s}
+ memory_limit_bytes: ${actor.coding_memory_limit_bytes}
+ - key: fn_calling
+ mode: remote
+ replicas_per_actor: ${world.env_replicas_per_actor}
+ _target_: pipelinerl.domains.fn_calling.AgenticToolsEnvironment
+ max_workers: 4
+
+environment_key: null
+
+world:
+ env_replicas_per_actor: 1
diff --git a/conf/multi_domain/main_mix.yaml b/conf/multi_domain/main_mix.yaml
new file mode 100644
index 00000000..cae97fd6
--- /dev/null
+++ b/conf/multi_domain/main_mix.yaml
@@ -0,0 +1,10 @@
+defaults:
+ - base
+ - domain_mix: main_mix
+ - _self_
+
+actor:
+ domain_rollouts:
+ math: ${domain_rollouts.math}
+ coding: ${domain_rollouts.coding}
+ fn_calling: ${domain_rollouts.fn_calling}
diff --git a/conf/test.yaml b/conf/test.yaml
index b86dfa5d..de9399e5 100644
--- a/conf/test.yaml
+++ b/conf/test.yaml
@@ -3,7 +3,7 @@ defaults:
finetune:
seq_length: 4000
gradient_accumulation_passes: 6
- max_train_steps: 1
+ max_train_steps: 100
train_batch_size: 4
attempts: 4
llm:
diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py
index 1c238ff9..60e878c2 100644
--- a/pipelinerl/actor.py
+++ b/pipelinerl/actor.py
@@ -15,12 +15,13 @@
import aiohttp
import hydra
import uvloop
-from omegaconf import DictConfig
+from omegaconf import DictConfig, OmegaConf
from pydantic import BaseModel, Field
import wandb
-from pipelinerl.finetune.logging_ import flatten_dict_config, init_wandb
+from pipelinerl.domain_sampling import DomainWeightedSampler
from pipelinerl.finetune_loop import calculate_train_steps
+from pipelinerl.finetune.logging_ import flatten_dict_config, init_wandb
from pipelinerl.llm import TrainableLLM
from pipelinerl.rollouts import BaseMetrics, RolloutResult
from pipelinerl.shared_memory_array import SharedMemoryQueue
@@ -36,6 +37,7 @@
from .utils import (
always_or_never_success_stats,
calculate_stats,
+ resolve_environment_key,
setup_logging,
wait_for_environments,
wait_for_inference_servers,
@@ -173,7 +175,23 @@ async def rollout_and_maybe_produce_result(
llm = llms[llm_index]
model_version = trainer_state.propagated_weight_version
assert model_version is not None
+ domain_value: str | None = None
+ if isinstance(problem, dict):
+ raw_domain = problem.get("domain")
+ if raw_domain:
+ domain_value = str(raw_domain)
+ elif isinstance(problem, tuple) and len(problem) >= 2 and isinstance(problem[1], dict):
+ raw_domain = problem[1].get("domain")
+ if raw_domain:
+ domain_value = str(raw_domain)
+
+ if not domain_value:
+ resolved = resolve_environment_key(cfg)
+ domain_value = str(resolved) if resolved else None
+
rollout_result = await rollout_policy(cfg, llm, problem, session)
+ if domain_value and not rollout_result.domain:
+ rollout_result.domain = domain_value
rollout_result.model_version = model_version
# Make a group id that will be different from groups made by another rollout maker
full_group_id = f"{scheduler_name}_{group_id}"
@@ -348,6 +366,7 @@ def init_stats(self):
self.latency_list = []
self.model_versions_list = []
self.sliding_stats = defaultdict(list)
+ self.domain_counts = defaultdict(int)
def compute_domain_agnostic_metrics(self, result: RolloutResult) -> Dict[str, float]:
metrics = {}
@@ -367,6 +386,16 @@ def update_stats(self, rollout_results: List[RolloutResult]):
group_id = result.group_id
self.latency_list.append(result.latency)
self.model_versions_list.append(result.model_version)
+ domain_key: str | None = None
+ if getattr(result, "domain", None):
+ domain_key = str(result.domain)
+ elif isinstance(dataset_name, str):
+ domain_key = dataset_name.split("::", 1)[0]
+ elif dataset_name is not None:
+ domain_key = str(dataset_name)
+
+ if domain_key:
+ self.domain_counts[domain_key] += len(result.training_texts)
domain_agnostic_metrics = self.compute_domain_agnostic_metrics(result)
all_metrics = result.metrics.model_dump() | domain_agnostic_metrics
for k, v in all_metrics.items():
@@ -403,8 +432,15 @@ def run(self, dataset: list[tuple[str, dict]]):
# If training, we expect to sample infinitely
# for train sample, sample random batches infinitely
# for test samples, loop through the dataset once
+ domain_sampler = None
if self.is_training:
problem_iter = random_iter(dataset)
+ domain_mix_cfg = getattr(self.cfg.actor, "domain_mix", None)
+ if domain_mix_cfg:
+ mix_weights = OmegaConf.to_container(domain_mix_cfg, resolve=True)
+ if not isinstance(mix_weights, dict):
+ raise ValueError("actor.domain_mix must be a mapping from domain to weight")
+ domain_sampler = DomainWeightedSampler(dataset, mix_weights)
else:
problem_iter = sequential_iter(dataset)
assert self.trainer_state.propagated_weight_version is not None
@@ -466,7 +502,10 @@ def run(self, dataset: list[tuple[str, dict]]):
if not blocked_by_lag and not self.problem_queue.full():
try:
try:
- problem = next(problem_iter)
+ if domain_sampler is not None:
+ problem = domain_sampler.sample()
+ else:
+ problem = next(problem_iter)
self.problem_queue.put(problem, block=False)
submitted_groups += 1
except queue.Full:
@@ -491,6 +530,12 @@ def run(self, dataset: list[tuple[str, dict]]):
assert isinstance(rollout_results[0], RolloutResult)
group_samples = sum(len(r.training_texts) for r in rollout_results)
+ # Track completions per domain for adaptive sampling
+ if domain_sampler is not None:
+ for r in rollout_results:
+ if r.domain:
+ domain_sampler.record_completion(r.domain)
+
published_samples += group_samples
samples_in_queue = self.result_queue.qsize() * attempts
all_text_dumps = []
@@ -570,6 +615,25 @@ def publish_stats(self, stats_writer: StreamWriter, loop_stats: Dict):
)
stats |= loop_stats
+
+ total_domain_samples = sum(self.domain_counts.values())
+ if total_domain_samples:
+ for domain, count in sorted(self.domain_counts.items()):
+ stats[f"{split_name}domain_mix_count/{domain}"] = count
+ stats[f"{split_name}domain_mix_actual/{domain}"] = count / total_domain_samples
+
+ domain_mix_cfg = getattr(self.cfg.actor, "domain_mix", None)
+ if domain_mix_cfg:
+ mix_weights = OmegaConf.to_container(domain_mix_cfg, resolve=True)
+ if isinstance(mix_weights, dict):
+ target_total = sum(float(v) for v in mix_weights.values() if float(v) > 0)
+ if target_total > 0:
+ for domain, weight in mix_weights.items():
+ stats[f"{split_name}domain_mix_target/{domain}"] = float(weight) / target_total
+ else:
+ for domain in mix_weights:
+ stats[f"{split_name}domain_mix_target/{domain}"] = 0.0
+
for k, v in self.sliding_stats.items():
stats[k] = sum(v) / len(v) if v else 0
if self.cfg.wandb.use_wandb:
diff --git a/pipelinerl/async_llm.py b/pipelinerl/async_llm.py
index 4e78ebf9..33391ed3 100644
--- a/pipelinerl/async_llm.py
+++ b/pipelinerl/async_llm.py
@@ -15,6 +15,18 @@
logger = logging.getLogger(__name__)
+def _to_object(obj):
+ """Recursively convert OmegaConf objects to plain Python containers."""
+ if isinstance(obj, (DictConfig, ListConfig)):
+ return OmegaConf.to_container(obj, resolve=True)
+ elif isinstance(obj, dict):
+ return {k: _to_object(v) for k, v in obj.items()}
+ elif isinstance(obj, list):
+ return [_to_object(v) for v in obj]
+ else:
+ return obj
+
+
def extract_images_from_messages(messages: list[dict]) -> list[Image.Image]:
"""Extract PIL Images from multimodal messages."""
@@ -44,18 +56,6 @@ def extract_images_from_messages(messages: list[dict]) -> list[Image.Image]:
return images
-def _to_plain_obj(value):
- """convert OmegaConf containers into Python types"""
-
- if isinstance(value, (DictConfig, ListConfig)):
- return OmegaConf.to_container(value, resolve=True)
- if isinstance(value, dict):
- return {key: _to_plain_obj(val) for key, val in value.items()}
- if isinstance(value, (list, tuple)):
- return [_to_plain_obj(item) for item in value]
- return value
-
-
async def llm_async_generate(
llm: TrainableLLM, prompt: Prompt, session: aiohttp.ClientSession
) -> LLMCall:
@@ -88,7 +88,9 @@ async def llm_async_generate(
logger.debug(f"POST request to {llm.base_url}/v1/chat/completions")
- payload = _to_plain_obj({**data, **extra_parameters})
+ payload = {**data, **extra_parameters}
+ # TODO: upgrade omegaconf and use OmegaConf.to_object for recursive conversion
+ payload = _to_object(payload)
async with session.post(
url=f"{llm.base_url}/v1/chat/completions",
json=payload,
@@ -101,6 +103,7 @@ async def llm_async_generate(
response.raise_for_status()
data = await response.json()
+ finish_reason: str | None = None
try:
content = data["choices"][0]["message"]["content"]
if not content:
diff --git a/pipelinerl/domain_sampling.py b/pipelinerl/domain_sampling.py
new file mode 100644
index 00000000..497bb0f8
--- /dev/null
+++ b/pipelinerl/domain_sampling.py
@@ -0,0 +1,170 @@
+from __future__ import annotations
+
+import logging
+import random
+from collections import defaultdict
+from typing import Mapping
+
+logger = logging.getLogger(__name__)
+
+# Minimum completions before dynamic adjustment kicks in
+_MIN_COMPLETIONS_FOR_ADJUSTMENT = 50
+# Clamp adjustment factors to avoid extreme swings
+_MIN_ADJUSTMENT = 0.1
+_MAX_ADJUSTMENT = 10.0
+
+
+class DomainWeightedSampler:
+ """Randomly samples problems according to per-domain weights.
+
+ Supports dynamic weight adjustment based on completion tracking to maintain
+ target domain ratios in the output stream despite varying processing speeds.
+ """
+
+ def __init__(
+ self,
+ samples: list[dict],
+ weights: Mapping[str, float],
+ rng: random.Random | None = None,
+ adaptive: bool = True,
+ ):
+ if not weights:
+ raise ValueError("domain_mix cannot be empty when provided")
+ self.random = rng or random
+ self.adaptive = adaptive
+ samples_by_domain: dict[str, list[dict]] = defaultdict(list)
+ for sample in samples:
+ domain = sample.get("domain")
+ if not domain:
+ raise ValueError("Each sample must include a 'domain' field for domain_mix to work")
+ samples_by_domain[str(domain)].append(sample)
+
+ provided_domains = {str(domain) for domain in weights}
+ cleaned_weights: dict[str, float] = {}
+ for domain, value in weights.items():
+ val = float(value)
+ if val < 0:
+ raise ValueError(f"domain_mix weight for '{domain}' must be non-negative")
+ if val == 0:
+ continue
+ cleaned_weights[str(domain)] = val
+
+ if not cleaned_weights:
+ raise ValueError("domain_mix must include at least one positive weight")
+
+ # accept zero weights but require the domain to be declared.
+ missing = set(samples_by_domain) - provided_domains
+ if missing:
+ missing_list = ", ".join(sorted(missing))
+ raise ValueError(
+ "domain_mix is missing weights for dataset domains: " + missing_list
+ )
+
+ unused = provided_domains - set(samples_by_domain)
+ if unused:
+ unused_list = ", ".join(sorted(unused))
+ raise ValueError(
+ "domain_mix specifies domains not present in dataset: " + unused_list
+ )
+
+ self.samples_by_domain = samples_by_domain
+ self.domains: list[str] = []
+ self.base_weights: dict[str, float] = {}
+ self.thresholds: list[float] = []
+ total = 0.0
+ for domain, weight in cleaned_weights.items():
+ total += weight
+ self.domains.append(domain)
+ self.base_weights[domain] = weight
+ self.thresholds.append(total)
+ if total <= 0:
+ raise ValueError("Sum of domain_mix weights must be positive")
+ self.total_weight = total
+
+ # Target ratios (normalized weights)
+ self.target_ratios = {d: w / total for d, w in self.base_weights.items()}
+
+ # Completion tracking for adaptive sampling
+ self.completion_counts: dict[str, int] = {d: 0 for d in self.domains}
+ self.total_completions = 0
+ self._last_log_completions = 0
+
+ def record_completion(self, domain: str) -> None:
+ """Record that a sample from the given domain has completed processing.
+
+ This enables adaptive weight adjustment to maintain target domain ratios
+ in the output stream despite varying processing speeds per domain.
+ """
+ if domain in self.completion_counts:
+ self.completion_counts[domain] += 1
+ self.total_completions += 1
+
+ # Log periodically
+ if self.total_completions - self._last_log_completions >= 500:
+ self._log_domain_stats()
+ self._last_log_completions = self.total_completions
+
+ def _log_domain_stats(self) -> None:
+ """Log current domain distribution vs targets."""
+ if self.total_completions == 0:
+ return
+ parts = []
+ for domain in self.domains:
+ actual = self.completion_counts[domain] / self.total_completions
+ target = self.target_ratios[domain]
+ parts.append(f"{domain}={actual:.1%}(target={target:.1%})")
+ logger.info(f"Domain completion stats ({self.total_completions} total): {', '.join(parts)}")
+
+ def _pick_domain_static(self) -> str:
+ """Pick domain using static weights (original behavior)."""
+ r = self.random.random() * self.total_weight
+ for domain, threshold in zip(self.domains, self.thresholds):
+ if r < threshold:
+ return domain
+ return self.domains[-1]
+
+ def _pick_domain_adaptive(self) -> str:
+ """Pick domain using dynamically adjusted weights based on completion ratios."""
+ # Calculate current completion ratios
+ current_ratios = {
+ d: self.completion_counts[d] / self.total_completions
+ for d in self.domains
+ }
+
+ # Calculate adjusted weights: boost under-represented, reduce over-represented
+ adjusted_weights: dict[str, float] = {}
+ for domain in self.domains:
+ target = self.target_ratios[domain]
+ current = current_ratios[domain]
+
+ if current > 0:
+ # adjustment = target / current
+ # If current=46%, target=30% → adjustment=0.65 (sample less)
+ # If current=18%, target=30% → adjustment=1.67 (sample more)
+ adjustment = target / current
+ adjustment = max(_MIN_ADJUSTMENT, min(_MAX_ADJUSTMENT, adjustment))
+ else:
+ # No completions yet for this domain, boost sampling
+ adjustment = _MAX_ADJUSTMENT
+
+ adjusted_weights[domain] = self.base_weights[domain] * adjustment
+
+ # Sample based on adjusted weights
+ total = sum(adjusted_weights.values())
+ r = self.random.random() * total
+ cumsum = 0.0
+ for domain in self.domains:
+ cumsum += adjusted_weights[domain]
+ if r < cumsum:
+ return domain
+ return self.domains[-1]
+
+ def _pick_domain(self) -> str:
+ """Pick a domain for the next sample."""
+ if not self.adaptive or self.total_completions < _MIN_COMPLETIONS_FOR_ADJUSTMENT:
+ return self._pick_domain_static()
+ return self._pick_domain_adaptive()
+
+ def sample(self) -> dict:
+ domain = self._pick_domain()
+ return self.random.choice(self.samples_by_domain[domain])
diff --git a/pipelinerl/domains/__init__.py b/pipelinerl/domains/__init__.py
index e69de29b..f73dd529 100644
--- a/pipelinerl/domains/__init__.py
+++ b/pipelinerl/domains/__init__.py
@@ -0,0 +1,8 @@
+"""Domain utilities and rollouts."""
+
+from .dispatcher import generate_multidomain_rollout, register_domain_rollout
+
+__all__ = [
+ "generate_multidomain_rollout",
+ "register_domain_rollout",
+]
diff --git a/pipelinerl/domains/chartqa/load_datasets.py b/pipelinerl/domains/chartqa/load_datasets.py
index 49dc92f7..8265086a 100644
--- a/pipelinerl/domains/chartqa/load_datasets.py
+++ b/pipelinerl/domains/chartqa/load_datasets.py
@@ -5,6 +5,8 @@
logger = logging.getLogger(__name__)
+DOMAIN = "chartqa"
+
def process_chartqa(dataset, dataset_name: str):
"""Process ChartQA dataset into standardized format."""
@@ -31,6 +33,7 @@ def add_ids(dataset: list[dict]):
"""Add sequential IDs to dataset items."""
for i, entry in enumerate(dataset):
entry["id"] = i
+ entry.setdefault("domain", DOMAIN)
return dataset
diff --git a/pipelinerl/domains/coding/__init__.py b/pipelinerl/domains/coding/__init__.py
new file mode 100644
index 00000000..f04984b7
--- /dev/null
+++ b/pipelinerl/domains/coding/__init__.py
@@ -0,0 +1,17 @@
+"""Coding domain rollouts and dataset utilities."""
+
+from .rollouts import generate_coding_rollout
+from .verifier_api import (
+ CodingSandboxEnvironment,
+ evaluate_coding_prediction,
+ verify_coding_solution_rpc,
+)
+from .dataset import load_problems
+
+__all__ = [
+ "CodingSandboxEnvironment",
+ "evaluate_coding_prediction",
+ "generate_coding_rollout",
+ "load_problems",
+ "verify_coding_solution_rpc",
+]
diff --git a/pipelinerl/domains/coding/dataset.py b/pipelinerl/domains/coding/dataset.py
new file mode 100644
index 00000000..e8b1dbbc
--- /dev/null
+++ b/pipelinerl/domains/coding/dataset.py
@@ -0,0 +1,383 @@
+from __future__ import annotations
+
+import json
+import logging
+import math
+import random
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Dict, Iterable, List, Sequence
+
+import pyarrow as pa
+import pyarrow.parquet as pq
+from datasets import Dataset, DownloadMode, load_dataset
+from datasets.exceptions import DatasetGenerationError
+from huggingface_hub import snapshot_download
+from omegaconf import DictConfig, OmegaConf
+
+logger = logging.getLogger(__name__)
+
+DOMAIN_NAME = "coding"
+BASE_DATASET_NAME = "mixed-training-text-datasets"
+DATASET_ALIASES = frozenset({DOMAIN_NAME, "code"})
+SUPPORTED_DATASET_NAMES = frozenset({BASE_DATASET_NAME})
+DEFAULT_DATASET_ID = "ServiceNow-AI/mixed-training-text-datasets"
+DEFAULT_DATASET_CONFIG = "80k-if-math-coding-fncalling-stem"
+DEFAULT_SPLIT_ORDER = ("train", "validation", "test")
+DEFAULT_SPLIT_RATIOS = tuple((name, ratio) for name, ratio in zip(DEFAULT_SPLIT_ORDER, (0.9, 0.05, 0.05)))
+DEFAULT_CALL_TYPES = ("assert", "std")
+
+
+@dataclass(frozen=True)
+class DatasetSpec:
+ name: str
+ split: str = "train"
+
+
+@dataclass
+@dataclass(frozen=True)
+class ResolvedDatasetEntry:
+ spec: DatasetSpec
+ label: str
+
+
+@dataclass(frozen=True)
+class DatasetResolution:
+ requested: list[str]
+ entries: list[ResolvedDatasetEntry]
+
+
+@dataclass
+class DatasetOptions:
+ dataset_id: str = DEFAULT_DATASET_ID
+ dataset_config: str | None = DEFAULT_DATASET_CONFIG
+ split_ratios: Sequence[tuple[str, float]] = DEFAULT_SPLIT_RATIOS
+ trust_remote_code: bool = True
+ max_examples_per_split: int | None = None
+ allowed_call_types: Sequence[str] = DEFAULT_CALL_TYPES
+ huggingface_token: str | None = None
+ ability_filter: str = "code"
+
+
+CacheKey = tuple[str, str | None, bool, str | None]
+_DATASET_CACHE: dict[CacheKey, Dataset] = {}
+
+
+def _normalize_loader_options(loader_kwargs: Dict[str, Any]) -> DatasetOptions:
+ def _to_native(value: Any) -> Any:
+ if isinstance(value, DictConfig):
+ return OmegaConf.to_container(value, resolve=True)
+ return value
+
+ options = DatasetOptions()
+ if loader_kwargs:
+ raw_ratios = _to_native(loader_kwargs.get("split_ratios"))
+ if isinstance(raw_ratios, dict):
+ options.split_ratios = tuple(raw_ratios.items())
+ dataset_config = loader_kwargs.get("dataset_config")
+ if dataset_config:
+ options.dataset_config = str(dataset_config)
+ dataset_id = loader_kwargs.get("dataset_id")
+ if dataset_id:
+ options.dataset_id = str(dataset_id)
+ if "max_examples_per_split" in loader_kwargs:
+ value = loader_kwargs["max_examples_per_split"]
+ options.max_examples_per_split = int(value) if value is not None else None
+ if "trust_remote_code" in loader_kwargs:
+ options.trust_remote_code = bool(loader_kwargs["trust_remote_code"])
+ call_types = _to_native(loader_kwargs.get("allowed_call_types"))
+ if isinstance(call_types, Iterable) and not isinstance(call_types, (str, bytes)):
+ options.allowed_call_types = tuple(str(item) for item in call_types)
+ token = loader_kwargs.get("huggingface_token") or loader_kwargs.get("hf_token")
+ if token:
+ options.huggingface_token = str(token)
+ ability = loader_kwargs.get("ability_filter") or loader_kwargs.get("ability")
+ if ability:
+ options.ability_filter = str(ability)
+ return options
+
+
+def _ability_matches(value: Any, ability: str | None) -> bool:
+ """Return True when the sample's ability field contains the requested ability.
+
+ The upstream dataset sometimes stores ability as a single string ("code") and
+ sometimes as a list (e.g., ["agentic_fn_calling"]). Filtering must accept both
+ shapes or we drop all fn_calling samples.
+ """
+
+ if ability is None:
+ return True
+ if value is None:
+ return False
+ if isinstance(value, str):
+ return value == ability
+ if isinstance(value, (list, tuple, set)):
+ return ability in value
+ return False
+
+
+def parse_dataset_name(entry: str) -> DatasetSpec:
+ text = entry.strip()
+ if "@" not in text:
+ return DatasetSpec(name=text)
+ name, split = text.split("@", 1)
+ return DatasetSpec(name=name.strip(), split=split.strip() or "train")
+
+
+def _normalize_dataset_names(dataset_names: List[str] | str | None) -> list[str]:
+ if dataset_names is None:
+ return []
+ if isinstance(dataset_names, str):
+ return [dataset_names]
+ return [str(entry) for entry in dataset_names]
+
+
+def _resolve_dataset_requests(dataset_names: List[str] | str | None) -> DatasetResolution:
+ requested = _normalize_dataset_names(dataset_names)
+ entries: list[ResolvedDatasetEntry] = []
+ for entry in requested:
+ spec = parse_dataset_name(entry)
+ base_name = BASE_DATASET_NAME if spec.name in DATASET_ALIASES else spec.name
+ if base_name not in SUPPORTED_DATASET_NAMES:
+ raise ValueError(f"Unsupported coding dataset '{spec.name}'")
+ resolved_spec = DatasetSpec(name=base_name, split=spec.split)
+ entries.append(ResolvedDatasetEntry(spec=resolved_spec, label=f"{spec.name}@{spec.split}"))
+ return DatasetResolution(requested=requested, entries=entries)
+
+
+def _load_dataset(options: DatasetOptions) -> Dataset:
+ ability = options.ability_filter
+ cache_key: CacheKey = (
+ options.dataset_id,
+ options.dataset_config,
+ options.trust_remote_code,
+ ability,
+ )
+ if cache_key in _DATASET_CACHE:
+ return _DATASET_CACHE[cache_key]
+
+ def _materialize_dataset(**extra_kwargs: Any) -> Dataset:
+ return load_dataset(
+ options.dataset_id,
+ options.dataset_config,
+ split="train",
+ trust_remote_code=options.trust_remote_code,
+ token=options.huggingface_token,
+ **extra_kwargs,
+ )
+
+ try:
+ ds = _materialize_dataset()
+ except DatasetGenerationError as exc:
+ logger.warning(
+ "load_dataset failed for %s (%s): %s. Forcing re-download.",
+ options.dataset_id,
+ options.dataset_config,
+ exc,
+ )
+ try:
+ ds = _materialize_dataset(download_mode=DownloadMode.FORCE_REDOWNLOAD)
+ except DatasetGenerationError as redownload_exc:
+ logger.warning(
+ "Forced re-download also failed for %s (%s): %s. Falling back to streaming mode.",
+ options.dataset_id,
+ options.dataset_config,
+ redownload_exc,
+ )
+ stream = _materialize_dataset(streaming=True)
+ try:
+ ds = Dataset.from_list(list(stream))
+ except OSError as stream_exc:
+ logger.warning(
+ "Streaming fallback also failed for %s (%s): %s. Downloading snapshot locally.",
+ options.dataset_id,
+ options.dataset_config,
+ stream_exc,
+ )
+ ds = _load_snapshot(options)
+
+ if ability:
+ ds = ds.filter(lambda sample: _ability_matches(sample.get("ability"), ability))
+ _DATASET_CACHE[cache_key] = ds
+ logger.info(
+ "Loaded %s (%s) with %d coding samples",
+ options.dataset_id,
+ options.dataset_config,
+ len(ds),
+ )
+ return ds
+
+
+def _load_snapshot(options: DatasetOptions) -> Dataset:
+ if not options.dataset_config:
+ raise RuntimeError("Snapshot fallback requires a dataset_config but none was provided.")
+
+ snapshot_dir = snapshot_download(
+ repo_id=options.dataset_id,
+ repo_type="dataset",
+ token=options.huggingface_token,
+ allow_patterns=(f"{options.dataset_config}/*", "dataset_infos.json"),
+ )
+ config_dir = Path(snapshot_dir) / options.dataset_config
+ parquet_files = sorted(config_dir.glob("train-*.parquet"))
+ if not parquet_files:
+ raise RuntimeError(
+ f"Snapshot for {options.dataset_id} ({options.dataset_config}) contained no train parquet shards"
+ )
+
+ tables: list[pa.Table] = []
+ for shard in parquet_files:
+ try:
+ tables.append(pq.read_table(shard))
+ except Exception as exc: # pragma: no cover - defensive fallback
+ logger.warning("Skipping corrupted parquet shard %s: %s", shard.name, exc)
+ if not tables:
+ raise RuntimeError(
+ f"All locally downloaded shards failed to load for {options.dataset_id} ({options.dataset_config})."
+ )
+
+ table = pa.concat_tables(tables)
+ logger.info(
+ f"Loaded {table.num_rows} rows via snapshot fallback ({len(tables)} shards)",
+ )
+ return Dataset.from_table(table)
+
+
+def _normalized_split_sequence(ratios: Sequence[tuple[str, float]]) -> list[tuple[str, float]]:
+ if not ratios:
+ return [("train", 1.0)]
+ values = [(name, float(portion)) for name, portion in ratios if float(portion) > 0]
+ total = sum(portion for _, portion in values)
+ if math.isclose(total, 0.0):
+ return [("train", 1.0)]
+ return [(name, portion / total) for name, portion in values]
+
+
+def _slice_indices(count: int, split: str, ratios: Sequence[tuple[str, float]], seed: int | None) -> list[int]:
+ normalized = _normalized_split_sequence(ratios)
+ if split not in {name for name, _ in normalized}:
+ raise ValueError(f"Split '{split}' is not defined in split_ratios {normalized}")
+ indices = list(range(count))
+ rng = random.Random(seed or 0)
+ rng.shuffle(indices)
+ cumulative = 0.0
+ start = 0
+ selected: dict[str, list[int]] = {}
+ for idx, (name, portion) in enumerate(normalized):
+ cumulative += portion
+ end = count if idx == len(normalized) - 1 else min(count, int(round(cumulative * count)))
+ selected[name] = indices[start:end]
+ start = end
+ return selected.get(split, [])
+
+
+def _decode_extra_info(raw_extra: Any) -> dict[str, Any]:
+ if isinstance(raw_extra, dict):
+ return raw_extra
+ if isinstance(raw_extra, str) and raw_extra.strip():
+ try:
+ return json.loads(raw_extra)
+ except json.JSONDecodeError:
+ logger.debug("Failed to decode extra_info: %s", raw_extra[:128])
+ return {}
+
+
+def _build_record(sample: dict, dataset_label: str, allowed_call_types: Sequence[str]) -> dict | None:
+ reward_model = sample.get("reward_model") or {}
+ reward_raw = reward_model.get("ground_truth")
+ if reward_raw is None:
+ return None
+ try:
+ reward_context = json.loads(reward_raw)
+ except (TypeError, json.JSONDecodeError):
+ return None
+ if allowed_call_types and reward_context.get("call_type") not in set(allowed_call_types):
+ return None
+
+ prompt_messages = sample.get("prompt") or []
+ if not prompt_messages:
+ return None
+ task = prompt_messages[0].get("content")
+ if not task:
+ return None
+
+ extra_info = _decode_extra_info(sample.get("extra_info"))
+ record = {
+ "dataset": dataset_label,
+ "task": task,
+ "reward_context": reward_context,
+ "extra_info": extra_info,
+ }
+
+ # For fn_calling samples, keep the full prompt so multi-turn tool-calling
+ # instructions are available downstream.
+ if sample.get("ability") == "agentic_fn_calling":
+ if len(prompt_messages) == 1 and isinstance(prompt_messages[0], list):
+ record["prompt"] = prompt_messages[0]
+ else:
+ record["prompt"] = prompt_messages
+
+ return record
+
+
+def _load_split(
+ spec: DatasetSpec,
+ *,
+ options: DatasetOptions,
+ seed: int | None,
+ dataset_label: str | None = None,
+) -> list[dict]:
+ dataset = _load_dataset(options)
+ indices = _slice_indices(len(dataset), spec.split, options.split_ratios, seed)
+ if not indices:
+ logger.warning("Requested split '%s' produced zero samples", spec.split)
+ return []
+ subset = dataset.select(indices)
+ samples: list[dict] = []
+ label = dataset_label or f"{spec.name}@{spec.split}"
+ for sample in subset:
+ record = _build_record(sample, label, options.allowed_call_types)
+ if record is None:
+ continue
+ samples.append(record)
+ if options.max_examples_per_split and len(samples) >= options.max_examples_per_split:
+ break
+ logger.info(
+ "Loaded %d samples for %s",
+ len(samples),
+ label,
+ )
+ return samples
+
+
+def _attach_ids(samples: list[dict]) -> list[dict]:
+ for idx, sample in enumerate(samples):
+ sample["id"] = idx
+ return samples
+
+
+def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None, **loader_kwargs: Any) -> List[Dict]:
+ resolution = _resolve_dataset_requests(dataset_names)
+ if not resolution.entries:
+ if dataset_names:
+ logger.warning("No coding dataset entries were resolved for %s", dataset_names)
+ return []
+
+ options = _normalize_loader_options(loader_kwargs)
+ aggregated: list[dict] = []
+ for entry in resolution.entries:
+ aggregated.extend(
+ _load_split(entry.spec, options=options, seed=seed, dataset_label=entry.label),
+ )
+
+ if not aggregated:
+ logger.warning("No coding datasets were loaded for entries %s", resolution.requested)
+
+ return _attach_ids(aggregated)
+
+
+def load_problems(dataset_names: List[str] | str | None, **loader_kwargs: dict) -> List[Dict]:
+ """Hydra entrypoint that mirrors the math domain loader style."""
+
+ seed = loader_kwargs.pop("seed", None)
+ return load_datasets(dataset_names, seed=seed, **loader_kwargs)
\ No newline at end of file
diff --git a/pipelinerl/domains/coding/rollouts.py b/pipelinerl/domains/coding/rollouts.py
new file mode 100644
index 00000000..e30060ef
--- /dev/null
+++ b/pipelinerl/domains/coding/rollouts.py
@@ -0,0 +1,218 @@
+"""Rollout generation for the coding domain using the sandbox verifier."""
+
+from __future__ import annotations
+
+import json
+import random
+import time
+from typing import Any, Literal
+
+import aiohttp
+from omegaconf import DictConfig
+
+from pipelinerl.llm import Prompt, TrainableLLM
+from pipelinerl.async_llm import llm_async_generate, make_training_text
+from pipelinerl.rollouts import BaseMetrics, RolloutResult
+from pipelinerl.utils import get_environment_jobs, resolve_environment_key
+from pipelinerl.domains.math.rollouts import RewardTable, length_penalty
+
+from .verifier_api import verify_coding_solution_rpc
+
+
+class CodingMetrics(BaseMetrics):
+ compile_error: bool = False
+ runtime_error: bool = False
+ timeout_error: bool = False
+ passed: int = 0
+ total: int = 0
+
+
+def _format_task(problem: dict[str, Any]) -> str:
+ if "task" in problem and problem["task"]:
+ return str(problem["task"])
+ if "question" in problem and problem["question"]:
+ return str(problem["question"])
+ extra_info = problem.get("extra_info") or {}
+ if isinstance(extra_info, dict) and extra_info.get("question"):
+ return str(extra_info["question"])
+ return str(problem)
+
+
+def _determine_answer_status(verification: dict[str, Any]) -> Literal["correct", "wrong", "no_answer", "unparsable"]:
+ if verification.get("empty_response"):
+ return "no_answer"
+ if verification.get("compile_error") or verification.get("timeout_error"):
+ return "unparsable"
+ total = int(verification.get("total") or 0)
+ passed = int(verification.get("passed") or 0)
+ if total > 0 and passed == total:
+ return "correct"
+ return "wrong"
+
+
+def _compute_reward(
+ cfg: DictConfig,
+ rewards: RewardTable,
+ *,
+ answer_status: Literal["correct", "wrong", "no_answer", "unparsable"],
+ finished: bool,
+ output_tokens: int,
+ max_tokens: int | None,
+) -> tuple[float, float]:
+ match (answer_status, finished):
+ case ("wrong", False):
+ reward = rewards.wrong_answer_not_finished
+ case ("wrong", True):
+ reward = rewards.wrong_answer_finished
+ case ("no_answer", False):
+ reward = rewards.no_answer_not_finished
+ case ("no_answer", True):
+ reward = rewards.no_answer_finished
+ case ("unparsable", False):
+ reward = rewards.unparsable_not_finished
+ case ("unparsable", True):
+ reward = rewards.unparsable_finished
+ case ("correct", False):
+ reward = rewards.correct_answer_not_finished
+ case ("correct", True):
+ reward = rewards.correct_answer_finished
+ case _:
+ raise ValueError(f"Invalid answer_status/finished combination: {answer_status}/{finished}")
+
+ reward *= cfg.actor.discount_factor ** output_tokens
+ overlong_penalty = 0.0
+ if rewards.buffer_tokens and max_tokens is not None:
+ overlong_penalty = length_penalty(max_tokens, output_tokens, rewards.buffer_tokens)
+ reward += overlong_penalty
+ return reward, overlong_penalty
+
+
+async def _run_verification(
+ cfg: DictConfig,
+ *,
+ session: aiohttp.ClientSession,
+ prediction: str | None,
+ reward_context: dict[str, Any] | str | None,
+ extra_info: dict[str, Any] | str | None,
+) -> dict[str, Any]:
+ env_key = resolve_environment_key(cfg, default="coding")
+ env_jobs = get_environment_jobs(cfg, env_key)
+ if not env_jobs:
+ raise RuntimeError("No coding environment servers registered")
+ env_job = random.choice(env_jobs)
+ if env_job.hostname is None or env_job.port is None:
+ raise RuntimeError("Coding environment job is missing host/port information")
+ return await verify_coding_solution_rpc(
+ session,
+ host=env_job.hostname,
+ port=env_job.port,
+ prediction=prediction,
+ reward_context=reward_context,
+ extra_info=extra_info,
+ )
+
+
+def _coerce_dict(data: dict[str, Any] | str | None) -> dict[str, Any]:
+ if data is None:
+ return {}
+ if isinstance(data, dict):
+ return data
+ try:
+ return json.loads(data)
+ except Exception:
+ return {}
+
+
+def _get_system_prompt(cfg: DictConfig) -> str:
+ """Get the system prompt, preferring domain-specific prompt if global is empty."""
+ if cfg.actor.system_prompt:
+ return cfg.actor.system_prompt
+ # Fall back to domain-specific system prompt
+ domain_prompts = getattr(cfg.actor, "domain_system_prompts", None)
+ if domain_prompts:
+ return domain_prompts.get("coding", "") or ""
+ return ""
+
+
+async def generate_coding_rollout(
+ cfg: DictConfig,
+ llm: TrainableLLM,
+ problem: dict[str, Any],
+ session: aiohttp.ClientSession,
+) -> RolloutResult:
+ messages = []
+ system_prompt = _get_system_prompt(cfg)
+ if system_prompt:
+ messages.append({"role": "system", "content": system_prompt})
+ user_task = cfg.actor.task_template.format(task=_format_task(problem))
+ messages.append({"role": "user", "content": user_task})
+ prompt = Prompt(messages=messages)
+
+ start_time = time.time()
+ llm_call = await llm_async_generate(llm, prompt, session)
+ latency = time.time() - start_time
+ assert llm_call.output.content is not None
+ trace = make_training_text(llm, llm_call)
+
+ reward_context = _coerce_dict(problem.get("reward_context"))
+ extra_info = _coerce_dict(problem.get("extra_info"))
+ verification = await _run_verification(
+ cfg,
+ session=session,
+ prediction=llm_call.output.content,
+ reward_context=reward_context,
+ extra_info=extra_info,
+ )
+
+ rewards_cfg = RewardTable(**dict(cfg.rewards))
+ answer_status = _determine_answer_status(verification)
+ try:
+ max_tokens = llm.parameters["max_tokens"]
+ except (KeyError, TypeError):
+ max_tokens = None
+ reward, _ = _compute_reward(
+ cfg,
+ rewards_cfg,
+ answer_status=answer_status,
+ finished=trace.finished,
+ output_tokens=llm_call.output_length_tokens,
+ max_tokens=max_tokens,
+ )
+ trace.reward = reward
+
+ coding_metadata = {
+ "passed": verification.get("passed"),
+ "total": verification.get("total"),
+ "compile_error": verification.get("compile_error"),
+ "runtime_error": verification.get("runtime_error"),
+ "timeout_error": verification.get("timeout_error"),
+ "empty_response": verification.get("empty_response"),
+ "call_type": verification.get("call_type"),
+ "fn_name": verification.get("fn_name"),
+ "error": verification.get("error"),
+ "tests": (verification.get("tests") or [])[:5],
+ }
+ trace.metadata.setdefault("coding", {}).update(coding_metadata)
+
+ metrics = CodingMetrics(
+ reward=reward,
+ success=(verification.get("passed") == verification.get("total") and verification.get("total", 0) > 0),
+ no_error=not (
+ verification.get("compile_error")
+ or verification.get("runtime_error")
+ or verification.get("timeout_error")
+ ),
+ no_answer=bool(verification.get("empty_response")),
+ compile_error=bool(verification.get("compile_error")),
+ runtime_error=bool(verification.get("runtime_error")),
+ timeout_error=bool(verification.get("timeout_error")),
+ passed=int(verification.get("passed") or 0),
+ total=int(verification.get("total") or 0),
+ )
+
+ return RolloutResult(
+ training_texts=[trace],
+ metrics=metrics,
+ latency=latency,
+ dataset_name=problem.get("dataset"),
+ )
diff --git a/pipelinerl/domains/coding/test_sandbox.py b/pipelinerl/domains/coding/test_sandbox.py
new file mode 100644
index 00000000..2b7ef1d8
--- /dev/null
+++ b/pipelinerl/domains/coding/test_sandbox.py
@@ -0,0 +1,158 @@
+import requests
+import json
+
+fn_name = "add_one"
+generation = """
+def add_one(x):
+ return x + 1
+
+if __name__ == '__main__':
+ # Get input from stdin
+ import sys
+ input_value = int(sys.stdin.read().strip())
+ result = add_one(input_value)
+ print(result)
+"""
+
+fn_name = "add_one"
+generation = """
+def add_one(x):
+ return x + 1
+
+assert add_one(1) == 2
+assert add_one(5) == 8
+assert add_one(-3) == -2
+"""
+
+wrapper_code = f"""
+import traceback
+from string import *
+from re import *
+from datetime import *
+from collections import *
+from heapq import *
+from bisect import *
+from copy import *
+from math import *
+from random import *
+from statistics import *
+from itertools import *
+from functools import *
+from operator import *
+from io import *
+from sys import *
+from json import *
+from builtins import *
+from typing import *
+import string
+import re
+import datetime
+import collections
+import heapq
+import bisect
+import copy
+import math
+import random
+import statistics
+import itertools
+import functools
+import operator
+import io
+import sys
+import json
+
+# === User's Original Code START ===
+{generation}
+# === User's Original Code END ===
+
+_SANDBOX_FN_NAME = "{fn_name}"
+
+def _execute_user_function():
+ # --- Input Parsing ---
+ _raw_input_str = sys.stdin.read()
+ _args = []
+ if _raw_input_str.strip(): # If there's input
+ try:
+ _args = [json.loads(line) for line in _raw_input_str.split('\\n')]
+ except json.JSONDecodeError as _je:
+ sys.stderr.write(f"WrapperError: Invalid JSON input for '{{_SANDBOX_FN_NAME}}': {{_je}}\\nInput was: "
+ f"{{_raw_input_str[:200]}}\\n")
+ return None, True # result, error_occurred
+
+ # --- Function Location and Execution ---
+ try:
+ _target_callable = None
+ # Try global scope first
+ if _SANDBOX_FN_NAME in globals():
+ _target_callable = globals()[_SANDBOX_FN_NAME]
+ # Else, if 'Solution' class exists, try to get its method
+ elif 'Solution' in globals():
+ _Solution_class = globals()['Solution']
+ # Attempt to instantiate and get method.
+ # Errors (e.g., Solution not a class, instantiation fails, method missing)
+ # will be caught by the broad except block below.
+ _solution_instance = _Solution_class()
+ _target_callable = getattr(_solution_instance, _SANDBOX_FN_NAME)
+
+ if not _target_callable:
+ sys.stderr.write(f"WrapperError: Function or method '{{_SANDBOX_FN_NAME}}' not found.\\n")
+ return None, True # result, error_occurred
+
+ _fn_result = _target_callable(*_args)
+ return _fn_result, False # result, no_error
+ except Exception: # Catches errors from Solution instantiation, getattr, or function call
+ sys.stderr.write(f"Error during setup or execution of '{{_SANDBOX_FN_NAME}}':\\n{{traceback.format_exc()}}\\n")
+ return None, True # result, error_occurred
+
+if __name__ == '__main__':
+ _result, _error_occurred = _execute_user_function()
+
+ if not _error_occurred:
+ # Serialize result to stdout
+ if isinstance(_result, (dict, list, tuple)) or _result is None or isinstance(_result, bool):
+ print(json.dumps(_result))
+ elif isinstance(_result, (int, float, str)):
+ print(str(_result)) # Ensure string conversion for print
+ else:
+ # For other types, default to string representation.
+ print(str(_result))
+ # Optional: To explicitly exit with an error code if the sandbox relies on it
+ # else:
+ # sys.exit(1)
+"""
+current_generation_code = wrapper_code
+
+# current_generation_code = generation
+compile_timeout = 10
+run_timeout = 10
+request_timeout = 10
+code = current_generation_code
+stdin = "15"
+memory_limit_mb=1024
+sandbox_fusion_url="http://dns-4943305a-c17f-44b1-b767-9536529eb8bc-sandbox:8080/run_code"
+language="python"
+
+
+payload = json.dumps(
+ {
+ "compile_timeout": compile_timeout,
+ "run_timeout": run_timeout,
+ "code": code,
+ "stdin": stdin,
+ "memory_limit_MB": memory_limit_mb,
+ "language": language, # Use the passed language parameter
+ "files": {},
+ "fetch_files": [],
+ }
+)
+
+headers = {"Content-Type": "application/json", "Accept": "application/json"}
+
+response = requests.post(
+ sandbox_fusion_url,
+ headers=headers,
+ data=payload,
+ timeout=request_timeout, # Use the calculated timeout
+)
+
+print(response.json())
\ No newline at end of file
diff --git a/pipelinerl/domains/coding/verifier_api.py b/pipelinerl/domains/coding/verifier_api.py
new file mode 100644
index 00000000..490b0a2e
--- /dev/null
+++ b/pipelinerl/domains/coding/verifier_api.py
@@ -0,0 +1,452 @@
+"""Sandbox-backed verification utilities for the coding domain."""
+
+from __future__ import annotations
+
+import asyncio
+import json
+import logging
+import math
+import os
+import re
+import time
+from concurrent.futures import ThreadPoolExecutor
+from dataclasses import dataclass, field
+from typing import Any
+
+import requests
+import aiohttp
+import uvicorn
+from fastapi import FastAPI
+from fastapi.responses import JSONResponse
+from pydantic import BaseModel, Field
+
+logger = logging.getLogger(__name__)
+
+_CODE_FENCE_RE = re.compile(r"```(?:python|py)?\s*([\s\S]*?)```", re.IGNORECASE)
+_DEFAULT_SANDBOX_URL = os.environ.get("CODING_SANDBOX_URL", "http://sandbox:8080/run_code")
+
+
+class CodingVerificationRequest(BaseModel):
+ """Payload accepted by the coding verifier environment."""
+
+ prediction: str | None = None
+ reward_context: dict[str, Any] | str | None = Field(default_factory=dict)
+ extra_info: dict[str, Any] | str | None = None
+
+
+@dataclass
+class CodingTestResult:
+ index: int
+ kind: str
+ status: str
+ input: str | None = None
+ expected: str | None = None
+ assertion: str | None = None
+ stdout: str = ""
+ stderr: str = ""
+ elapsed: float | None = None
+
+ def to_dict(self) -> dict[str, Any]:
+ return {
+ "index": self.index,
+ "kind": self.kind,
+ "status": self.status,
+ "input": self.input,
+ "expected": self.expected,
+ "assertion": self.assertion,
+ "stdout": self.stdout,
+ "stderr": self.stderr,
+ "elapsed": self.elapsed,
+ }
+
+
+@dataclass
+class CodingVerificationSummary:
+ passed: int = 0
+ total: int = 0
+ compile_error: bool = False
+ runtime_error: bool = False
+ timeout_error: bool = False
+ empty_response: bool = False
+ error: str | None = None
+ call_type: str | None = None
+ fn_name: str | None = None
+ tests: list[CodingTestResult] = field(default_factory=list)
+
+ def to_payload(self) -> dict[str, Any]:
+ return {
+ "passed": self.passed,
+ "total": self.total,
+ "compile_error": self.compile_error,
+ "runtime_error": self.runtime_error,
+ "timeout_error": self.timeout_error,
+ "empty_response": self.empty_response,
+ "error": self.error,
+ "call_type": self.call_type,
+ "fn_name": self.fn_name,
+ "tests": [test.to_dict() for test in self.tests],
+ }
+
+
+def _extract_code(prediction: str | None) -> str:
+ if not prediction:
+ return ""
+ match = _CODE_FENCE_RE.findall(prediction)
+ if match:
+ return match[-1].strip()
+ return prediction.strip()
+
+
+def _ensure_dict(data: dict[str, Any] | str | None) -> dict[str, Any]:
+ if data is None:
+ return {}
+ if isinstance(data, dict):
+ return data
+ try:
+ return json.loads(data)
+ except Exception as exc: # pragma: no cover - defensive
+ logger.warning("Failed to parse reward context: %s", exc)
+ return {}
+
+
+def _normalize_output(text: str | None) -> str:
+ if text is None:
+ return ""
+ return text.strip()
+
+
+def _compose_script(user_code: str, extra_snippet: str | None) -> str:
+ if not extra_snippet:
+ return user_code
+ return f"{user_code.rstrip()}\n\n{extra_snippet.strip()}\n"
+
+
+def _convert_bytes_to_mb(value: int) -> int:
+ if value <= 0:
+ return 0
+ return max(16, int(math.ceil(value / (1024 * 1024))))
+
+
+def _has_compile_error(response: dict[str, Any]) -> bool:
+ compile_result = response.get("compile_result") or {}
+ if not compile_result:
+ return False
+ stderr = str(compile_result.get("stderr", ""))
+ return_code = compile_result.get("return_code")
+ if return_code is not None and return_code != 0:
+ return True
+ return bool(stderr.strip())
+
+
+def _is_timeout(response: dict[str, Any]) -> bool:
+ run_result = response.get("run_result") or {}
+ status = (run_result.get("status") or response.get("status") or "").lower()
+ return status == "timeout"
+
+
+def _has_runtime_error(response: dict[str, Any]) -> bool:
+ if _has_compile_error(response) or _is_timeout(response):
+ return False
+ run_result = response.get("run_result") or {}
+ return_code = run_result.get("return_code")
+ if return_code is None:
+ return False
+ if return_code != 0:
+ return True
+ stderr = str(run_result.get("stderr", ""))
+ return "Traceback" in stderr
+
+
+def _get_run_field(response: dict[str, Any], field: str) -> str:
+ run_result = response.get("run_result") or {}
+ value = run_result.get(field)
+ if value is None:
+ return ""
+ return str(value)
+
+
+def _get_status_text(response: dict[str, Any]) -> str:
+ run_result = response.get("run_result") or {}
+ status = run_result.get("status")
+ if status:
+ return str(status)
+ overall = response.get("status")
+ return str(overall or "")
+
+
+def _post_to_sandbox(
+ *,
+ code: str,
+ stdin: str,
+ sandbox_url: str,
+ compile_timeout_s: float,
+ run_timeout_s: float,
+ request_timeout_s: float,
+ memory_limit_mb: int,
+ language: str,
+) -> tuple[dict[str, Any], float]:
+ payload = {
+ "compile_timeout": compile_timeout_s,
+ "run_timeout": run_timeout_s,
+ "code": code,
+ "stdin": stdin,
+ "memory_limit_MB": memory_limit_mb,
+ "language": language,
+ "files": {},
+ "fetch_files": [],
+ }
+ headers = {"Content-Type": "application/json", "Accept": "application/json"}
+ start = time.perf_counter()
+ try:
+ response = requests.post(
+ sandbox_url,
+ headers=headers,
+ data=json.dumps(payload),
+ timeout=request_timeout_s,
+ )
+ response.raise_for_status()
+ data = response.json()
+ except requests.Timeout as exc:
+ data = {
+ "status": "Timeout",
+ "message": str(exc),
+ "compile_result": None,
+ "run_result": {"status": "Timeout", "stdout": "", "stderr": str(exc)},
+ }
+ except requests.RequestException as exc: # pragma: no cover - network failures
+ data = {
+ "status": "NetworkError",
+ "message": str(exc),
+ "compile_result": None,
+ "run_result": {"status": "Error", "stdout": "", "stderr": str(exc)},
+ }
+ elapsed = time.perf_counter() - start
+ return data, elapsed
+
+
+def evaluate_coding_prediction(
+ prediction: str | None,
+ reward_context: dict[str, Any] | str | None,
+ *,
+ extra_info: dict[str, Any] | str | None = None,
+ sandbox_url: str | None = None,
+ compile_timeout_s: float = 5.0,
+ run_timeout_s: float = 5.0,
+ request_timeout_s: float = 15.0,
+ memory_limit_mb: int = 512,
+ language: str = "python",
+) -> CodingVerificationSummary:
+ """Run generated code inside the sandbox and collect pass/fail statistics."""
+
+ sandbox_target = sandbox_url or _DEFAULT_SANDBOX_URL
+ context = _ensure_dict(reward_context)
+ summary = CodingVerificationSummary(
+ call_type=context.get("call_type"),
+ fn_name=context.get("fn_name"),
+ )
+
+ candidate_code = _extract_code(prediction)
+ if not candidate_code:
+ summary.empty_response = True
+ summary.error = "empty_prediction"
+ return summary
+
+ call_type = (context.get("call_type") or "assert").lower().strip()
+ tests: list[dict[str, Any]] = []
+ if call_type == "assert":
+ for idx, assertion in enumerate(context.get("assert_case", []) or []):
+ if not assertion:
+ continue
+ tests.append({"kind": "assert", "assertion": assertion, "index": idx})
+ elif call_type == "std":
+ inputs = context.get("inputs") or []
+ outputs = context.get("outputs") or []
+ total = min(len(inputs), len(outputs))
+ for idx in range(total):
+ tests.append(
+ {
+ "kind": "std",
+ "input": inputs[idx],
+ "expected": outputs[idx],
+ "index": idx,
+ }
+ )
+ else:
+ summary.error = f"unsupported_call_type:{call_type}"
+ return summary
+
+ if not tests:
+ summary.error = "no_tests"
+ return summary
+
+ for raw_test in tests:
+ idx = summary.total
+ summary.total += 1
+ if raw_test["kind"] == "assert":
+ script = _compose_script(candidate_code, raw_test["assertion"])
+ stdin = ""
+ else:
+ script = candidate_code
+ stdin = raw_test.get("input", "") or ""
+
+ response, elapsed = _post_to_sandbox(
+ code=script,
+ stdin=stdin,
+ sandbox_url=sandbox_target,
+ compile_timeout_s=compile_timeout_s,
+ run_timeout_s=run_timeout_s,
+ request_timeout_s=request_timeout_s,
+ memory_limit_mb=memory_limit_mb,
+ language=language,
+ )
+
+ if _has_compile_error(response):
+ status = "compile_error"
+ summary.compile_error = True
+ summary.error = "compile_error"
+ elif _is_timeout(response):
+ status = "timeout"
+ summary.timeout_error = True
+ summary.error = "timeout"
+ elif _has_runtime_error(response):
+ status = "runtime_error"
+ summary.runtime_error = True
+ summary.error = "runtime_error"
+ else:
+ run_status_text = _get_status_text(response).lower()
+ if raw_test["kind"] == "std":
+ produced = _normalize_output(_get_run_field(response, "stdout"))
+ expected = _normalize_output(raw_test.get("expected"))
+ status = "passed" if produced == expected else "failed"
+ else:
+ succeeded = run_status_text in ("finished", "success", "")
+ status = "passed" if succeeded else "failed"
+ if not succeeded:
+ summary.runtime_error = True
+ summary.error = summary.error or f"run_status:{run_status_text or 'unknown'}"
+
+ stdout = _get_run_field(response, "stdout")
+ stderr = _get_run_field(response, "stderr")
+ test_result = CodingTestResult(
+ index=idx,
+ kind=raw_test["kind"],
+ status=status,
+ input=raw_test.get("input"),
+ expected=raw_test.get("expected"),
+ assertion=raw_test.get("assertion"),
+ stdout=stdout,
+ stderr=stderr,
+ elapsed=elapsed,
+ )
+ summary.tests.append(test_result)
+ if status == "passed":
+ summary.passed += 1
+ if status == "compile_error":
+ break
+
+ return summary
+
+
+def _rpc_failure_summary(reason: str, *, status: int | None = None, body: str | None = None) -> dict[str, Any]:
+ details = reason
+ if status is not None:
+ details = f"{reason}:{status}"
+ if body:
+ details = f"{details}:{body[:256]}"
+ return {
+ "passed": 0,
+ "total": 0,
+ "compile_error": False,
+ "runtime_error": True,
+ "timeout_error": False,
+ "empty_response": False,
+ "error": f"verifier_rpc_error:{details}",
+ "call_type": None,
+ "fn_name": None,
+ "tests": [],
+ }
+
+
+async def verify_coding_solution_rpc(
+ session: aiohttp.ClientSession,
+ host: str,
+ port: int,
+ *,
+ prediction: str | None,
+ reward_context: dict[str, Any] | str | None,
+ extra_info: dict[str, Any] | str | None,
+) -> dict[str, Any]:
+ """Call a remote coding verifier via HTTP RPC."""
+
+ payload = {
+ "prediction": prediction,
+ "reward_context": reward_context,
+ "extra_info": extra_info,
+ }
+ url = f"http://{host}:{port}/verify_solution"
+ try:
+ async with session.post(url, json=payload) as response:
+ if response.status != 200:
+ text = await response.text()
+ logger.warning(
+ "Coding verifier RPC failed with %s: %s", response.status, text[:256]
+ )
+ return _rpc_failure_summary("http_status", status=response.status, body=text)
+ return await response.json()
+ except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
+ logger.warning("Coding verifier RPC request error: %s", exc)
+ return _rpc_failure_summary("client_error", body=str(exc))
+
+
+class CodingSandboxEnvironment:
+ """Environment server that proxies requests to the shared sandbox executor."""
+
+ def __init__(
+ self,
+ *,
+ sandbox_url: str | None = None,
+ compile_timeout_s: float = 5.0,
+ run_timeout_s: float = 5.0,
+ request_timeout_s: float = 15.0,
+ memory_limit_bytes: int = 512 * 1024 * 1024,
+ language: str = "python",
+ max_workers: int = 4,
+ ) -> None:
+ self.sandbox_url = sandbox_url or _DEFAULT_SANDBOX_URL
+ self.compile_timeout_s = compile_timeout_s
+ self.run_timeout_s = run_timeout_s
+ self.request_timeout_s = request_timeout_s
+ self.memory_limit_mb = _convert_bytes_to_mb(memory_limit_bytes)
+ self.language = language
+ self.max_workers = max_workers
+
+ def launch(self, port: int) -> None:
+ app = FastAPI()
+ executor = ThreadPoolExecutor(max_workers=self.max_workers)
+
+ @app.post("/verify_solution")
+ async def verify_endpoint(request: CodingVerificationRequest):
+ loop = asyncio.get_running_loop()
+
+ def _evaluate() -> dict[str, Any]:
+ summary = evaluate_coding_prediction(
+ prediction=request.prediction,
+ reward_context=request.reward_context,
+ extra_info=request.extra_info,
+ sandbox_url=self.sandbox_url,
+ compile_timeout_s=self.compile_timeout_s,
+ run_timeout_s=self.run_timeout_s,
+ request_timeout_s=self.request_timeout_s,
+ memory_limit_mb=self.memory_limit_mb,
+ language=self.language,
+ )
+ return summary.to_payload()
+
+ result = await loop.run_in_executor(executor, _evaluate)
+ return JSONResponse(content=result)
+
+ @app.get("/health")
+ async def health() -> dict[str, str]:
+ return {"status": "ok"}
+
+ uvicorn.run(app, host="0.0.0.0", port=port, timeout_keep_alive=60)
\ No newline at end of file
diff --git a/pipelinerl/domains/counting/counting.py b/pipelinerl/domains/counting/counting.py
index c3afb45f..b05d1746 100644
--- a/pipelinerl/domains/counting/counting.py
+++ b/pipelinerl/domains/counting/counting.py
@@ -11,6 +11,9 @@
from pipelinerl.rollouts import BaseMetrics, RolloutResult
+DOMAIN = "counting"
+
+
async def generate_counting_rollout(
cfg: DictConfig,
llm: TrainableLLM,
@@ -81,6 +84,7 @@ def load_problems(dataset_names: list[str]):
if not isinstance(problem, dict) or "letter" not in problem or "word" not in problem or "count" not in problem:
raise ValueError(f"Problem {problem} in dataset {name} is invalid.")
problem["dataset"] = name
+ problem.setdefault("domain", DOMAIN)
problems.append(problem)
return problems
diff --git a/pipelinerl/domains/dispatcher.py b/pipelinerl/domains/dispatcher.py
new file mode 100644
index 00000000..c226247c
--- /dev/null
+++ b/pipelinerl/domains/dispatcher.py
@@ -0,0 +1,103 @@
+import importlib
+import inspect
+from functools import lru_cache
+from typing import Awaitable, Callable, Iterable, Mapping
+
+import aiohttp
+from omegaconf import DictConfig
+
+from pipelinerl.llm import TrainableLLM
+from pipelinerl.rollouts import RolloutResult
+from pipelinerl.utils import resolve_environment_key
+
+RolloutCallable = Callable[[DictConfig, TrainableLLM, dict, aiohttp.ClientSession], Awaitable[RolloutResult] | RolloutResult]
+
+_RUNTIME_DOMAIN_ROLLOUTS: dict[str, str] = {}
+
+
+def _iter_domain_overrides(raw_overrides: Mapping | Iterable | None):
+ if raw_overrides is None:
+ return
+
+ if isinstance(raw_overrides, Mapping):
+ for domain, path in raw_overrides.items():
+ yield str(domain), str(path)
+ return
+
+ # Support Hydra list syntax: [{domain: math, rollout: path}, ...] or ["math:path"]
+ for item in raw_overrides:
+ if isinstance(item, Mapping):
+ if "domain" in item and "rollout" in item:
+ yield str(item["domain"]), str(item["rollout"])
+ else:
+ for domain, path in item.items():
+ yield str(domain), str(path)
+ else:
+ text = str(item)
+ if ":" in text:
+ domain, path = text.split(":", 1)
+ yield domain.strip(), path.strip()
+
+
+def _build_domain_mapping(cfg: DictConfig) -> dict[str, str]:
+ overrides = getattr(cfg.actor, "domain_rollouts", None)
+ mapping: dict[str, str] = {}
+ if overrides:
+ for domain, path in _iter_domain_overrides(overrides):
+ mapping[domain] = path
+
+ if _RUNTIME_DOMAIN_ROLLOUTS:
+ mapping.update(_RUNTIME_DOMAIN_ROLLOUTS)
+
+ if not mapping:
+ raise ValueError("`actor.domain_rollouts` produced an empty mapping and no runtime registrations were found")
+ return mapping
+
+
+@lru_cache(maxsize=None)
+def _import_callable(path: str) -> RolloutCallable:
+ module_name, func_name = path.rsplit(".", 1)
+ module = importlib.import_module(module_name)
+ return getattr(module, func_name)
+
+
+def _get_rollout_callable(cfg: DictConfig, domain: str) -> RolloutCallable:
+ mapping = _build_domain_mapping(cfg)
+ if domain not in mapping:
+ raise ValueError(f"No rollout policy registered for domain '{domain}'")
+ return _import_callable(mapping[domain])
+
+
+async def generate_multidomain_rollout(
+ cfg: DictConfig,
+ llm: TrainableLLM,
+ problem: dict,
+ session: aiohttp.ClientSession,
+) -> RolloutResult:
+ domain = problem.get("domain")
+ if not domain:
+ domain = resolve_environment_key(cfg)
+ if not domain:
+ raise ValueError("Problem is missing 'domain' and no default could be resolved from config")
+
+ rollout_fn = _get_rollout_callable(cfg, domain)
+ result = rollout_fn(cfg, llm, problem, session)
+ if inspect.isawaitable(result):
+ result = await result # type: ignore[assignment]
+
+ # Ensure domain is set on the result for completion tracking
+ if result.domain is None:
+ result.domain = domain
+
+ return result # type: ignore[return-value]
+
+
+def register_domain_rollout(domain: str, target: str) -> None:
+ domain_key = str(domain).strip()
+ target_path = str(target).strip()
+ if not domain_key:
+ raise ValueError("Domain key for registration cannot be empty")
+ if not target_path or "." not in target_path:
+ raise ValueError(f"Target '{target}' must be a fully-qualified callable path")
+ _RUNTIME_DOMAIN_ROLLOUTS[domain_key] = target_path
+ _import_callable.cache_clear()
diff --git a/pipelinerl/domains/fn_calling/__init__.py b/pipelinerl/domains/fn_calling/__init__.py
new file mode 100644
index 00000000..1a241972
--- /dev/null
+++ b/pipelinerl/domains/fn_calling/__init__.py
@@ -0,0 +1,18 @@
+"""Fn-calling domain toolkit."""
+
+from .dataset import load_datasets, load_problems
+from .rollouts import generate_fn_calling_rollout
+from .verifier_api import (
+ AgenticToolsEnvironment,
+ evaluate_fn_calling_answer,
+ verify_fn_calling_answer_rpc,
+)
+
+__all__ = [
+ "AgenticToolsEnvironment",
+ "evaluate_fn_calling_answer",
+ "generate_fn_calling_rollout",
+ "load_datasets",
+ "load_problems",
+ "verify_fn_calling_answer_rpc",
+]
diff --git a/pipelinerl/domains/fn_calling/dataset.py b/pipelinerl/domains/fn_calling/dataset.py
new file mode 100644
index 00000000..02cd31d0
--- /dev/null
+++ b/pipelinerl/domains/fn_calling/dataset.py
@@ -0,0 +1,88 @@
+"""Dataset loader for the fn_calling domain."""
+
+from __future__ import annotations
+
+import logging
+from dataclasses import dataclass
+from typing import Any, Dict, List
+
+from pipelinerl.domains.coding import dataset as mixed_dataset
+
+logger = logging.getLogger(__name__)
+
+DOMAIN_NAME = "fn_calling"
+ALIAS_NAMES = frozenset({DOMAIN_NAME, "fn-calling", "fncalling"})
+BASE_DATASET_NAME = "mixed-training-text-datasets"
+DEFAULT_DATASET_CONFIG = mixed_dataset.DEFAULT_DATASET_CONFIG
+ABILITY_FILTER = "agentic_fn_calling"
+
+
+@dataclass(frozen=True)
+class DatasetResolution:
+ requested: list[str]
+ resolved: list[str]
+ alias_map: dict[str, str]
+
+
+def _resolve_dataset_requests(dataset_names: List[str] | str | None) -> DatasetResolution:
+ if dataset_names is None:
+ requested = [DOMAIN_NAME]
+ elif isinstance(dataset_names, str):
+ requested = [dataset_names]
+ else:
+ requested = [str(entry) for entry in dataset_names]
+
+ resolved: list[str] = []
+ alias_map: dict[str, str] = {}
+ for entry in requested:
+ spec = mixed_dataset.parse_dataset_name(entry)
+ base_name = BASE_DATASET_NAME if spec.name in ALIAS_NAMES else spec.name
+ resolved_entry = f"{base_name}@{spec.split}"
+ resolved.append(resolved_entry)
+ alias_map.setdefault(resolved_entry, f"{spec.name}@{spec.split}")
+ return DatasetResolution(requested=requested, resolved=resolved, alias_map=alias_map)
+
+
+def load_datasets(
+ dataset_names: List[str] | str | None,
+ seed: int | None = None,
+ **loader_kwargs: Any,
+) -> List[Dict]:
+ resolution = _resolve_dataset_requests(dataset_names)
+ defaults = {
+ "dataset_id": mixed_dataset.DEFAULT_DATASET_ID,
+ "dataset_config": DEFAULT_DATASET_CONFIG,
+ # fn_calling loads all call_types (no filtering), unlike coding which filters to assert/std
+ "allowed_call_types": (),
+ "ability_filter": ABILITY_FILTER,
+ }
+ options = {**defaults, **loader_kwargs}
+
+ # Validate that allowed_call_types is empty for fn_calling domain
+ final_call_types = options.get("allowed_call_types")
+ if final_call_types:
+ logger.warning(
+ "fn_calling domain received non-empty allowed_call_types=%s. "
+ "This may filter out valid fn_calling samples. "
+ "Consider using allowed_call_types=[] for this domain.",
+ final_call_types,
+ )
+
+ samples = mixed_dataset.load_datasets(resolution.resolved, seed=seed, **options)
+
+ for sample in samples:
+ dataset_label = str(sample.get("dataset", ""))
+ alias_label = resolution.alias_map.get(dataset_label)
+ if alias_label:
+ sample["dataset"] = alias_label
+ elif dataset_label.startswith(BASE_DATASET_NAME):
+ sample["dataset"] = dataset_label.replace(BASE_DATASET_NAME, DOMAIN_NAME, 1)
+ sample["domain"] = DOMAIN_NAME
+ if not samples:
+ logger.warning("fn_calling loader returned zero samples for entries: %s", resolution.requested)
+ return samples
+
+
+def load_problems(dataset_names: List[str] | str | None, **loader_kwargs: Any) -> List[Dict]:
+ seed = loader_kwargs.pop("seed", None)
+ return load_datasets(dataset_names, seed=seed, **loader_kwargs)
\ No newline at end of file
diff --git a/pipelinerl/domains/fn_calling/rollouts.py b/pipelinerl/domains/fn_calling/rollouts.py
new file mode 100644
index 00000000..81f42c16
--- /dev/null
+++ b/pipelinerl/domains/fn_calling/rollouts.py
@@ -0,0 +1,199 @@
+"""Rollout generation for the fn_calling domain."""
+
+from __future__ import annotations
+
+import json
+import random
+import time
+from typing import Any
+
+import aiohttp
+from omegaconf import DictConfig
+
+from pipelinerl.llm import Prompt, TrainableLLM
+from pipelinerl.async_llm import llm_async_generate, make_training_text
+from pipelinerl.domains.math.rollouts import RewardTable, length_penalty
+from pipelinerl.domains.fn_calling.verifier_api import verify_fn_calling_answer_rpc
+from pipelinerl.rollouts import BaseMetrics, RolloutResult
+from pipelinerl.utils import get_environment_jobs, resolve_environment_key
+
+
+class FnCallingMetrics(BaseMetrics):
+ penalty: float
+
+
+def _format_task(problem: dict[str, Any]) -> str:
+ if problem.get("task"):
+ return str(problem["task"])
+ if problem.get("question"):
+ return str(problem["question"])
+ return str(problem)
+
+
+def _normalize_messages(raw_messages: Any) -> list[dict[str, str]]:
+ """Convert dataset prompts into chat messages compatible with vLLM."""
+
+ if not isinstance(raw_messages, list):
+ return []
+ if len(raw_messages) == 1 and isinstance(raw_messages[0], list):
+ raw_messages = raw_messages[0]
+
+ messages: list[dict[str, str]] = []
+ pending_assistant: list[str] = []
+
+ def flush_pending() -> None:
+ nonlocal pending_assistant
+ if not pending_assistant:
+ return
+ content = "\n\n".join(text for text in pending_assistant if text.strip())
+ if content:
+ messages.append({"role": "assistant", "content": content})
+ pending_assistant = []
+
+ for entry in raw_messages:
+ if not isinstance(entry, dict):
+ continue
+ role = entry.get("role")
+ content = entry.get("content")
+ if not isinstance(content, str):
+ continue
+
+ match role:
+ case "system":
+ flush_pending()
+ messages.append({"role": "system", "content": content})
+ case "user":
+ flush_pending()
+ messages.append({"role": "user", "content": content})
+ case "assistant" | "output_text":
+ pending_assistant.append(content)
+ case "thinking" | "plan" | "replan":
+ # Drop intermediate reasoning to keep prompts shorter.
+ continue
+ case "tool":
+ flush_pending()
+ name = entry.get("name")
+ tool_id = entry.get("tool_call_id") or entry.get("id")
+ if not isinstance(name, str) or not name:
+ continue
+ messages.append(
+ {
+ "role": "tool",
+ "name": name,
+ "content": content,
+ **({"tool_call_id": tool_id} if isinstance(tool_id, str) else {}),
+ }
+ )
+ case _:
+ # For function-call style entries nested under assistant.
+ if role == "function" and isinstance(entry.get("name"), str):
+ pending_assistant.append(content)
+ continue
+
+ flush_pending()
+ return messages
+
+
+def _coerce_dict(data: dict[str, Any] | str | None) -> dict[str, Any]:
+ if data is None:
+ return {}
+ if isinstance(data, dict):
+ return data
+ if isinstance(data, str) and data.strip():
+ try:
+ return json.loads(data)
+ except json.JSONDecodeError:
+ return {}
+ return {}
+
+
+async def generate_fn_calling_rollout(
+ cfg: DictConfig,
+ llm: TrainableLLM,
+ problem: dict[str, Any],
+ session: aiohttp.ClientSession,
+) -> RolloutResult:
+ provided_prompt = problem.get("prompt")
+ messages = _normalize_messages(provided_prompt)
+ if messages:
+ if cfg.actor.system_prompt and messages[0]["role"] != "system":
+ messages.insert(0, {"role": "system", "content": cfg.actor.system_prompt})
+ else:
+ messages = []
+ if cfg.actor.system_prompt:
+ messages.append({"role": "system", "content": cfg.actor.system_prompt})
+ messages.append({"role": "user", "content": cfg.actor.task_template.format(task=_format_task(problem))})
+ prompt = Prompt(messages=messages)
+
+ start_time = time.time()
+ llm_call = await llm_async_generate(llm, prompt, session)
+ latency = time.time() - start_time
+ assert llm_call.output.content is not None
+
+ env_key = resolve_environment_key(cfg, default="fn_calling")
+ env_jobs = get_environment_jobs(cfg, env_key)
+ if not env_jobs:
+ raise RuntimeError("No environment servers available for fn_calling domain")
+ env_job = random.choice(env_jobs)
+ if env_job.hostname is None or env_job.port is None:
+ raise RuntimeError("fn_calling environment job is missing host/port information")
+
+ reward_context = _coerce_dict(problem.get("reward_context"))
+ extra_info = _coerce_dict(problem.get("extra_info"))
+ answer_status = await verify_fn_calling_answer_rpc(
+ session=session,
+ host=env_job.hostname,
+ port=env_job.port,
+ generation=llm_call.output.content,
+ reward_context=reward_context,
+ extra_info=extra_info,
+ )
+
+ rewards = RewardTable(**dict(cfg.rewards))
+ trace = make_training_text(llm, llm_call)
+ match (answer_status, trace.finished):
+ case ("wrong", False):
+ reward = rewards.wrong_answer_not_finished
+ case ("wrong", True):
+ reward = rewards.wrong_answer_finished
+ case ("no_answer", False):
+ reward = rewards.no_answer_not_finished
+ case ("no_answer", True):
+ reward = rewards.no_answer_finished
+ case ("unparsable", False):
+ reward = rewards.unparsable_not_finished
+ case ("unparsable", True):
+ reward = rewards.unparsable_finished
+ case ("correct", False):
+ reward = rewards.correct_answer_not_finished
+ case ("correct", True):
+ reward = rewards.correct_answer_finished
+ case _:
+ raise ValueError(f"Unexpected fn_calling answer status '{answer_status}'")
+
+ reward *= cfg.actor.discount_factor ** llm_call.output_length_tokens
+ overlong_penalty = 0.0
+ try:
+ max_tokens = llm.parameters["max_tokens"]
+ except (KeyError, TypeError):
+ max_tokens = None
+ if rewards.buffer_tokens > 0 and max_tokens is not None:
+ overlong_penalty = length_penalty(max_tokens, llm_call.output_length_tokens, rewards.buffer_tokens)
+ reward += overlong_penalty
+ trace.reward = reward
+ trace.metadata.setdefault("fn_calling", {}).update({"answer_status": answer_status})
+
+ metrics = FnCallingMetrics(
+ reward=reward,
+ success=answer_status == "correct",
+ no_error=answer_status != "unparsable",
+ no_answer=answer_status == "no_answer",
+ penalty=overlong_penalty,
+ )
+
+ return RolloutResult(
+ training_texts=[trace],
+ metrics=metrics,
+ latency=latency,
+ dataset_name=problem.get("dataset"),
+ )
diff --git a/pipelinerl/domains/fn_calling/verifier_api.py b/pipelinerl/domains/fn_calling/verifier_api.py
new file mode 100644
index 00000000..35b5361b
--- /dev/null
+++ b/pipelinerl/domains/fn_calling/verifier_api.py
@@ -0,0 +1,252 @@
+"""Verifier utilities for the fn_calling domain."""
+
+from __future__ import annotations
+
+import asyncio
+import importlib
+import json
+import logging
+import os
+import re
+from collections import Counter
+from concurrent.futures import ProcessPoolExecutor
+from functools import lru_cache
+from typing import Any, Callable, Iterable, Literal
+
+import aiohttp
+import uvicorn
+from fastapi import FastAPI, HTTPException
+from fastapi.responses import JSONResponse
+from pydantic import BaseModel, Field
+
+LOGGER = logging.getLogger(__name__)
+
+AnswerStatus = Literal["correct", "wrong", "no_answer", "unparsable"]
+_VALID_STATUSES: set[str] = {"correct", "wrong", "no_answer", "unparsable"}
+_DEFAULT_REWARD_FN_ENV = "FN_CALLING_REWARD_FN"
+_TOOL_BLOCK = re.compile(r"(.*?)", re.DOTALL | re.IGNORECASE)
+
+
+def _json_loads(value: str) -> Any:
+ return json.loads(value)
+
+
+def _normalize_args(value: Any) -> Any:
+ if value is None or value == "":
+ return {}
+ if isinstance(value, (dict, list, bool, int, float)):
+ return value
+ if isinstance(value, str):
+ return _json_loads(value)
+ raise ValueError(f"Unsupported argument payload: {type(value)}")
+
+
+def _canonicalize(entries: Iterable[Any]) -> list[tuple[str, str]]:
+ canon: list[tuple[str, str]] = []
+ for entry in entries:
+ if not isinstance(entry, dict):
+ raise ValueError("Tool call entry must be a mapping")
+ fn_block = entry.get("function") if isinstance(entry.get("function"), dict) else None
+ name = entry.get("name") or (fn_block or {}).get("name")
+ if not isinstance(name, str) or not name.strip():
+ raise ValueError("Tool call missing name")
+ args = entry.get("arguments")
+ if args is None and fn_block:
+ args = fn_block.get("arguments")
+ canon_args = json.dumps(_normalize_args(args), sort_keys=True, ensure_ascii=False)
+ canon.append((name.strip(), canon_args))
+ return canon
+
+
+def _parse_prediction(generation: str | None) -> list[tuple[str, str]]:
+ if not generation:
+ return []
+ block = _TOOL_BLOCK.findall(generation)
+ if not block:
+ return []
+ payload = block[-1].strip()
+ if not payload:
+ return []
+ data = _json_loads(payload)
+ if not isinstance(data, list):
+ raise ValueError("tool_calls block must be a list")
+ return _canonicalize(data)
+
+
+def _parse_label(tool_calls: Iterable[Any] | None) -> list[tuple[str, str]]:
+ if not tool_calls:
+ return []
+ return _canonicalize(tool_calls)
+
+
+def _has_irrelevant_flag(reward_context: dict[str, Any], extra_info: dict[str, Any]) -> bool:
+ for source in (reward_context, extra_info):
+ flag = source.get("irrelevant_tool_call")
+ if isinstance(flag, bool):
+ return flag
+ if isinstance(flag, str) and flag.lower() in {"1", "true", "yes"}:
+ return True
+ return "irrelevant_tool_call" in json.dumps(reward_context, ensure_ascii=False)
+
+
+def evaluate_fn_calling_answer(
+ *,
+ generation: str,
+ reward_context: dict[str, Any] | None,
+ extra_info: dict[str, Any] | None = None,
+) -> AnswerStatus:
+ reward_context = reward_context or {}
+ extra_info = extra_info or {}
+
+ try:
+ expected = _parse_label(reward_context.get("tool_calls"))
+ predicted = _parse_prediction(generation)
+ except ValueError as exc:
+ LOGGER.debug("Failed to parse fn_calling sample: %s", exc)
+ return "unparsable"
+
+ if _has_irrelevant_flag(reward_context, extra_info):
+ return "wrong" if predicted else "correct"
+
+ expected_counter = Counter(expected)
+ predicted_counter = Counter(predicted)
+
+ if not predicted_counter and not expected_counter:
+ return "correct"
+ if not predicted_counter and expected_counter:
+ return "no_answer"
+ if predicted_counter and not expected_counter:
+ return "wrong"
+ return "correct" if predicted_counter == expected_counter else "wrong"
+
+
+def _ensure_mapping(data: dict[str, Any] | str | None) -> dict[str, Any]:
+ if data is None:
+ return {}
+ if isinstance(data, dict):
+ return data
+ if isinstance(data, str) and data.strip():
+ try:
+ return json.loads(data)
+ except json.JSONDecodeError:
+ LOGGER.debug("Failed to decode fn_calling payload: %s", data[:128])
+ return {}
+ return {}
+
+
+class FnCallingVerificationRequest(BaseModel):
+ """Payload accepted by the fn_calling verifier service."""
+
+ generation: str
+ reward_context: dict[str, Any] = Field(default_factory=dict)
+ extra_info: dict[str, Any] = Field(default_factory=dict)
+
+
+@lru_cache(maxsize=None)
+def _import_callable(path: str) -> Callable[..., Any]:
+ module_name, attr_name = path.rsplit(".", 1)
+ module = importlib.import_module(module_name)
+ fn = getattr(module, attr_name)
+ if not callable(fn): # pragma: no cover - defensive guard
+ raise TypeError(f"Object at '{path}' is not callable")
+ return fn
+
+
+def _invoke_reward_fn(
+ reward_fn_path: str,
+ generation: str,
+ reward_context: dict[str, Any],
+ extra_info: dict[str, Any],
+) -> AnswerStatus:
+ fn = _import_callable(reward_fn_path)
+ try:
+ result = fn(generation=generation, reward_context=reward_context, extra_info=extra_info)
+ except TypeError:
+ result = fn(generation, reward_context)
+ status = str(result).strip().lower()
+ if status not in _VALID_STATUSES:
+ raise ValueError(f"Reward function returned invalid status '{result}'")
+ return status # type: ignore[return-value]
+
+
+def _execute_reward_job(
+ reward_fn_path: str | None,
+ generation: str,
+ reward_context: dict[str, Any],
+ extra_info: dict[str, Any],
+) -> AnswerStatus:
+ if reward_fn_path:
+ return _invoke_reward_fn(reward_fn_path, generation, reward_context, extra_info)
+ return evaluate_fn_calling_answer(
+ generation=generation,
+ reward_context=reward_context,
+ extra_info=extra_info,
+ )
+
+
+async def verify_fn_calling_answer_rpc(
+ *,
+ session: aiohttp.ClientSession,
+ host: str,
+ port: int,
+ generation: str,
+ reward_context: dict[str, Any] | str | None,
+ extra_info: dict[str, Any] | str | None = None,
+) -> AnswerStatus:
+ payload = {
+ "generation": generation,
+ "reward_context": _ensure_mapping(reward_context),
+ "extra_info": _ensure_mapping(extra_info),
+ }
+ async with session.post(f"http://{host}:{port}/verify_answer", json=payload) as response:
+ body = await response.text()
+ if response.status != 200:
+ LOGGER.error("fn_calling verifier returned %s: %s", response.status, body[:512])
+ raise ValueError("fn_calling verifier request failed")
+ data = json.loads(body)
+ status = str(data.get("answer_status", "")).strip().lower()
+ if status not in _VALID_STATUSES:
+ raise ValueError(f"fn_calling verifier produced invalid status '{status}'")
+ return status # type: ignore[return-value]
+
+
+class AgenticToolsEnvironment:
+ """FastAPI wrapper that exposes a deterministic fn_calling verifier."""
+
+ def __init__(
+ self,
+ *,
+ reward_fn_path: str | None = None,
+ max_workers: int = 4,
+ keepalive_timeout_s: int = 60,
+ ) -> None:
+ self._reward_fn_path = reward_fn_path or os.environ.get(_DEFAULT_REWARD_FN_ENV)
+ self._max_workers = max_workers
+ self._keepalive_timeout_s = keepalive_timeout_s
+
+ def launch(self, port: int) -> None:
+ app = FastAPI()
+
+ with ProcessPoolExecutor(max_workers=self._max_workers) as process_pool:
+ @app.post("/verify_answer")
+ async def verify(request: FnCallingVerificationRequest):
+ loop = asyncio.get_running_loop()
+ try:
+ answer_status = await loop.run_in_executor(
+ process_pool,
+ _execute_reward_job,
+ self._reward_fn_path,
+ request.generation,
+ dict(request.reward_context),
+ dict(request.extra_info),
+ )
+ except Exception as exc: # pragma: no cover - server-side diagnostics
+ LOGGER.exception("fn_calling reward function failed")
+ raise HTTPException(status_code=500, detail=str(exc))
+ return JSONResponse(content={"answer_status": answer_status})
+
+ @app.get("/health")
+ async def health():
+ return {"status": "ok"}
+
+ uvicorn.run(app, host="0.0.0.0", port=port, timeout_keep_alive=self._keepalive_timeout_s)
\ No newline at end of file
diff --git a/pipelinerl/domains/guessing/guessing.py b/pipelinerl/domains/guessing/guessing.py
index a1ede13c..40e64c7e 100644
--- a/pipelinerl/domains/guessing/guessing.py
+++ b/pipelinerl/domains/guessing/guessing.py
@@ -9,6 +9,9 @@
from pipelinerl.rollouts import BaseMetrics, RolloutResult
+DOMAIN = "guessing"
+
+
async def generate_guessing_rollout(
cfg: DictConfig,
llm: TrainableLLM,
@@ -91,10 +94,10 @@ def load_problems(dataset_names: list[str]):
for name in dataset_names:
if name == "train":
problems.extend([
- {"answer": (2 * i * c) % n + 1, "dataset": "train"} for i in range(512)
+ {"answer": (2 * i * c) % n + 1, "dataset": "train", "domain": DOMAIN} for i in range(512)
])
elif name == "test":
problems.extend([
- {"answer": ((2 * i + 1) * c) % n + 1, "dataset": "test"} for i in range(512)
+ {"answer": ((2 * i + 1) * c) % n + 1, "dataset": "test", "domain": DOMAIN} for i in range(512)
])
return problems
diff --git a/pipelinerl/domains/math/load_datasets.py b/pipelinerl/domains/math/load_datasets.py
index 4b44dfb6..fbfaea97 100644
--- a/pipelinerl/domains/math/load_datasets.py
+++ b/pipelinerl/domains/math/load_datasets.py
@@ -2,7 +2,8 @@
import logging
import random
import re
-from typing import Dict, List, Tuple
+from pathlib import Path
+from typing import Dict, Iterable, List, Sequence, Tuple
import datasets
import hydra
@@ -152,6 +153,26 @@ def load_math(split):
return datasets.Dataset.from_list(data)
+def _load_aime_2025_opencompass_dataset(upsample_factor: int = 0) -> list[dict]:
+ configs = ["AIME2025-I", "AIME2025-II"]
+ dataset_name = "aime_2025" + ("" if upsample_factor > 0 else "_original")
+
+ samples: list[dict] = []
+ for config_name in configs:
+ ds = load_dataset("opencompass/AIME2025", config_name, split="test")
+ samples.extend([s for s in process_math(ds, dataset_name) if s is not None])
+
+ original_size = len(samples)
+ if upsample_factor > 0:
+ samples *= upsample_factor
+
+ logger.info(
+ f"Loading aime 2025 (OpenCompass) dataset: {len(samples)} samples"
+ + (f" (upsampled from {original_size})" if upsample_factor > 0 else "")
+ )
+ return add_ids(samples)
+
+
def _load_aime_dataset(year: int, upsample_factor: int = 0) -> list[dict]:
aime_dataset = load_dataset("AI-MO/aimo-validation-aime", split="train", trust_remote_code=True)
aime_dataset = aime_dataset.filter(lambda x: str(year) in x["url"])
@@ -194,18 +215,93 @@ def add_ids(dataset: list[dict]):
return dataset
+def _resolve_custom_path(relative_paths: str | Sequence[str]) -> Path:
+ """
+ Resolve a path for locally generated datasets.
+
+ Hydra jobs may change the working directory, so we check both the current
+ directory and the repository root.
+ """
+ if isinstance(relative_paths, str):
+ relative_paths = [relative_paths]
+
+ resolved = Path(__file__).resolve()
+ base_candidates = [Path.cwd()]
+ if len(resolved.parents) >= 5:
+ base_candidates.append(resolved.parents[4])
+
+ candidates: List[Path] = []
+ for rel in relative_paths:
+ rel_path = Path(rel)
+ candidates.append(rel_path)
+ for base in base_candidates:
+ if base == Path.cwd():
+ continue
+ candidates.append(base / rel_path)
+
+ for candidate in candidates:
+ if candidate.exists():
+ return candidate
+ raise FileNotFoundError(
+ f"Custom dataset not found. Tried: {[str(path) for path in candidates]}"
+ )
+
+
+def _load_custom_dataset(dataset_name: str) -> list[dict]:
+ """
+ Load a locally generated dataset by name.
+
+ The loader searches under `datasets/custom/` and `datasets/custom_runs/` for either
+ `` or `.jsonl`.
+ """
+ candidate_names: List[str] = []
+ if dataset_name.endswith(".jsonl"):
+ candidate_names.append(dataset_name)
+ else:
+ candidate_names.extend([dataset_name, f"{dataset_name}.jsonl"])
+
+ search_paths: List[str] = []
+ for name in candidate_names:
+ search_paths.extend(
+ [
+ f"datasets/custom/{name}",
+ f"datasets/custom_runs/{name}",
+ name,
+ ]
+ )
+
+ dataset_path = _resolve_custom_path(search_paths)
+ with dataset_path.open("r", encoding="utf-8") as handle:
+ samples = [json.loads(line) for line in handle if line.strip()]
+
+ dataset_label = dataset_name[:-6] if dataset_name.endswith(".jsonl") else dataset_name
+
+ for idx, sample in enumerate(samples):
+ sample.setdefault("source_dataset", sample.get("dataset", dataset_label))
+ sample.setdefault("source_id", sample.get("id"))
+ sample["dataset"] = dataset_label
+ sample["id"] = idx
+
+ logger.info(f"Loading custom dataset {dataset_name}: {len(samples)} samples from {dataset_path}")
+ return samples
+
+
def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None) -> List[Tuple[str, Dict]]:
if dataset_names is None:
return []
if isinstance(dataset_names, str):
dataset_names = [dataset_names]
+ # Preserve order while de-duplicating
+ dataset_names = list(dict.fromkeys(dataset_names))
datasets = []
+ remaining = set(dataset_names)
if "eurus_train" in dataset_names:
dataset = load_dataset("PRIME-RL/Eurus-2-RL-Data", split="train", trust_remote_code=True)
samples = [s for s in process_eurus(dataset) if s is not None]
logger.info(f"Loading eurus train dataset: {len(samples)} samples")
datasets += add_ids(samples)
+ remaining.discard("eurus_train")
# great for debugging since its much smaller than eurus train
if "eurus_validation" in dataset_names:
@@ -213,6 +309,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None
samples = [s for s in process_eurus(dataset) if s is not None]
logger.info(f"Loading eurus validation dataset: {len(samples)} samples")
datasets += add_ids(samples)
+ remaining.discard("eurus_validation")
if "math_train" in dataset_names:
# math_dataset = load_math("train")
@@ -220,6 +317,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None
samples = [s for s in process_math(dataset, "math_train") if s is not None]
logger.info(f"Loading math train dataset: {len(samples)} samples")
datasets += add_ids(samples)
+ remaining.discard("math_train")
if "math_simplerl_train" in dataset_names:
# SimpleRL MATH dataset
@@ -234,6 +332,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None
samples = [s for s in process_math(dataset, "math_simplerl_train") if s is not None]
logger.info(f"Loading math simplerl train dataset: {len(samples)} samples")
datasets += add_ids(samples)
+ remaining.discard("math_simplerl_train")
if "simplerl_math_subset_1000" in dataset_names:
# SimpleRL MATH dataset subset
@@ -252,12 +351,14 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None
samples = samples[:1000]
logger.info(f"Loading math simplerl subset test dataset: {len(samples)} samples")
datasets += add_ids(samples)
+ remaining.discard("simplerl_math_subset_1000")
if "deepscaler_preview" in dataset_names:
dataset = load_dataset("agentica-org/DeepScaleR-Preview-Dataset", split="train", trust_remote_code=True)
samples = [s for s in process_math(dataset, "deepscaler") if s is not None]
logger.info(f"Loading deepscaler preview train dataset: {len(samples)} samples")
datasets += add_ids(samples)
+ remaining.discard("deepscaler_preview")
if "math_test" in dataset_names:
# math_dataset = load_math("test")
@@ -265,36 +366,42 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None
samples = [s for s in process_math(dataset, "math_test") if s is not None]
logger.info(f"Loading math test dataset: {len(samples)} samples")
datasets += add_ids(samples)
+ remaining.discard("math_test")
if "omni_math_500" in dataset_names:
dataset = load_dataset("reliable-agents/Omni-MATH-500", split="test", trust_remote_code=True)
samples = [s for s in process_math(dataset, "omni_math_500") if s is not None]
logger.info(f"Loading omni math 500 dataset: {len(samples)} samples")
datasets += add_ids(samples)
+ remaining.discard("omni_math_500")
if "math_500" in dataset_names:
dataset = load_dataset("HuggingFaceH4/MATH-500", split="test", trust_remote_code=True)
samples = [s for s in process_math(dataset, "math_500") if s is not None]
logger.info(f"Loading math 500 dataset: {len(samples)} samples")
datasets += add_ids(samples)
+ remaining.discard("math_500")
if "open_r1_math_220k" in dataset_names:
dataset = load_dataset("open-r1/OpenR1-Math-220k", split="default", trust_remote_code=True)
samples = [s for s in process_math(dataset, "open_r1_math_220k") if s is not None]
logger.info(f"Loading open r1 math 220k dataset: {len(samples)} samples")
datasets += add_ids(samples)
+ remaining.discard("open_r1_math_220k")
if "gpqa_main" in dataset_names:
dataset = load_dataset("hendrydong/gpqa_main", split="test", trust_remote_code=True)
samples = [s for s in process_gpqa(dataset, "gpqa_main") if s is not None]
logger.info(f"Loading gpqa main dataset: {len(samples)} samples")
datasets += add_ids(samples)
+ remaining.discard("gpqa_main")
if "gpqa_diamond" in dataset_names:
dataset = load_dataset("hendrydong/gpqa_diamond", split="test", trust_remote_code=True)
samples = [s for s in process_gpqa(dataset, "gpqa_diamond") if s is not None]
logger.info(f"Loading gpqa diamond dataset: {len(samples)} samples")
datasets += add_ids(samples)
+ remaining.discard("gpqa_diamond")
if "gpqa_diamond" in dataset_names:
pass
@@ -304,49 +411,70 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None
samples = [s for s in process_gsm8k(dataset, "gsm8k_train") if s is not None]
logger.info(f"Loading gsm8k train dataset: {len(samples)} samples")
datasets += add_ids(samples)
+ remaining.discard("gsm8k_train")
if "gsm8k_test" in dataset_names:
dataset = load_dataset("openai/gsm8k", "main", split="test", trust_remote_code=True)
samples = [s for s in process_gsm8k(dataset, "gsm8k_test") if s is not None]
logger.info(f"Loading gsm8k test dataset: {len(samples)} samples")
datasets += add_ids(samples)
+ remaining.discard("gsm8k_test")
if "limo" in dataset_names:
dataset = load_dataset("GAIR/LIMO", split="train", trust_remote_code=True)
samples = [s for s in process_limo(dataset) if s is not None]
logger.info(f"Loading limo dataset: {len(samples)} samples")
datasets += add_ids(samples)
+ remaining.discard("limo")
if "aime_2022" in dataset_names:
datasets += _load_aime_dataset(2022, upsample_factor=16)
+ remaining.discard("aime_2022")
if "aime_2022_original" in dataset_names:
datasets += _load_aime_dataset(2022)
+ remaining.discard("aime_2022_original")
if "aime_2023" in dataset_names:
datasets += _load_aime_dataset(2023, upsample_factor=16)
+ remaining.discard("aime_2023")
if "aime_2023_original" in dataset_names:
datasets += _load_aime_dataset(2023)
+ remaining.discard("aime_2023_original")
if "aime_2024" in dataset_names:
datasets += _load_aime_dataset(2024, upsample_factor=16)
+ remaining.discard("aime_2024")
if "aime_2024_original" in dataset_names:
datasets += _load_aime_dataset(2024)
+ remaining.discard("aime_2024_original")
+
+ if "aime_2025" in dataset_names:
+ datasets += _load_aime_2025_opencompass_dataset(upsample_factor=16)
+ remaining.discard("aime_2025")
+
+ if "aime_2025_original" in dataset_names:
+ datasets += _load_aime_2025_opencompass_dataset()
+ remaining.discard("aime_2025_original")
if "amc_2022" in dataset_names:
# TODO: AMC 2022 is 43 problems, is that to be expected?
datasets += _load_amc_dataset(2022, upsample_factor=16)
+ remaining.discard("amc_2022")
if "amc_2022_original" in dataset_names:
datasets += _load_amc_dataset(2022)
+ remaining.discard("amc_2022_original")
if "amc_2023" in dataset_names:
datasets += _load_amc_dataset(2023, upsample_factor=16)
+ remaining.discard("amc_2023")
if "amc_2023_original" in dataset_names:
datasets += _load_amc_dataset(2023)
+ remaining.discard("amc_2023_original")
if "sometimes_success_data" in dataset_names:
PATH = "data/sometimes_success_data/data.jsonl"
@@ -354,6 +482,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None
samples = [json.loads(line) for line in f]
logger.info(f"Loading easy data dataset: {len(samples)} samples")
datasets += add_ids(samples)
+ remaining.discard("sometimes_success_data")
if "open_reasoner_zero_57k" in dataset_names:
dataset = load_dataset(
@@ -365,6 +494,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None
samples = [s for s in process_open_reasoner(dataset, "open_reasoner_zero_57k") if s is not None]
logger.info(f"Loading Open Reasoner Zero dataset: {len(samples)} samples")
datasets += add_ids(samples)
+ remaining.discard("open_reasoner_zero_57k")
if "open_reasoner_zero_extended_72k" in dataset_names:
dataset = load_dataset(
@@ -376,6 +506,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None
samples = [s for s in process_open_reasoner(dataset, "open_reasoner_zero_extended_72k") if s is not None]
logger.info(f"Loading Open Reasoner Zero extended dataset: {len(samples)} samples")
datasets += add_ids(samples)
+ remaining.discard("open_reasoner_zero_extended_72k")
if "open_reasoner_zero_hard_13k" in dataset_names:
dataset = load_dataset(
@@ -387,6 +518,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None
samples = [s for s in process_open_reasoner(dataset, "open_reasoner_zero_hard_13k") if s is not None]
logger.info(f"Loading Open Reasoner Zero hard dataset: {len(samples)} samples")
datasets += add_ids(samples)
+ remaining.discard("open_reasoner_zero_hard_13k")
for dataset_name in dataset_names:
test_matched = re.match(r"multiplication_(\d+)_by_(\d+)_(\d+)_test", dataset_name)
@@ -407,6 +539,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None
]
logger.info(f"Loading multiplication {num_digits_1}_by_{num_digits_2} dataset: {len(samples)} samples")
datasets += add_ids(samples)
+ remaining.discard(dataset_name)
elif train_matched:
upto_prefix = train_matched.group(1) or ""
num_digits_1 = int(train_matched.group(2))
@@ -428,6 +561,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None
f"Loading multiplication {upto_prefix}_{num_digits_1}_by_{num_digits_2} dataset: {len(samples)} samples"
)
datasets += add_ids(samples)
+ remaining.discard(dataset_name)
if "countdown" in dataset_names:
dataset = load_dataset(
@@ -436,6 +570,19 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None
samples = [s for s in process_countdown(dataset) if s is not None]
logger.info(f"Loading countdown dataset: {len(samples)} samples")
datasets += samples
+ remaining.discard("countdown")
+
+ # resolve any remaining names as local custom datasets.
+ unresolved: List[str] = []
+ for dataset_name in list(remaining):
+ try:
+ datasets += _load_custom_dataset(dataset_name)
+ remaining.discard(dataset_name)
+ except FileNotFoundError:
+ unresolved.append(dataset_name)
+
+ if unresolved:
+ raise ValueError(f"Unknown dataset(s): {unresolved}")
if len(datasets) == 0:
raise ValueError("No datasets loaded")
diff --git a/pipelinerl/domains/math/rollouts.py b/pipelinerl/domains/math/rollouts.py
index 62758c7e..4e27f23c 100644
--- a/pipelinerl/domains/math/rollouts.py
+++ b/pipelinerl/domains/math/rollouts.py
@@ -4,6 +4,8 @@
import aiohttp
from omegaconf import DictConfig
from pydantic import BaseModel
+from pipelinerl.rollouts import RolloutResult, BaseMetrics
+from pipelinerl.utils import get_environment_jobs, resolve_environment_key
from pipelinerl.async_llm import llm_async_generate, make_training_text
from pipelinerl.llm import Prompt, TrainableLLM
@@ -55,9 +57,10 @@ async def generate_math_rollout(
rewards = RewardTable(**dict(cfg.rewards))
discount_factor = cfg.actor.discount_factor
- # math_verify is a fast environment, no support for environment replicas for now
- env_jobs = [Job(**job) for job in cfg.jobs if job["kind"] == "environment"]
- # choose the job randomly
+ env_key = resolve_environment_key(cfg, default="math")
+ env_jobs = get_environment_jobs(cfg, env_key)
+ if not env_jobs:
+ raise RuntimeError("No environment servers available for math domain")
env_job = random.choice(env_jobs)
assert env_job.port is not None
answer_status = await verify_answer_rpc(
diff --git a/pipelinerl/domains/multidomain/__init__.py b/pipelinerl/domains/multidomain/__init__.py
new file mode 100644
index 00000000..10392710
--- /dev/null
+++ b/pipelinerl/domains/multidomain/__init__.py
@@ -0,0 +1,54 @@
+"""Debug utilities for multi-domain experimentation."""
+
+from collections.abc import Iterable
+from typing import Any
+
+DOMAIN = "multi"
+
+
+_MATH_SAMPLE = {
+ "id": 0,
+ "dataset": "math_debug",
+ "task": "Compute 2 + 2.",
+ "answer": "\\boxed{4}",
+ "domain": "math",
+}
+
+_GUESSING_SAMPLE = {
+ "id": 0,
+ "dataset": "guessing_debug",
+ "task": "Hidden number between 1 and 3. Start guessing.",
+ "answer": 2,
+ "domain": "guessing",
+}
+
+_CODING_SAMPLE = {
+ "id": 0,
+ "dataset": "coding_debug",
+ "question": "Implement class Solution with method addTwoNumbers(a: int, b: int) that returns a + b.",
+ "starter_code": "class Solution:\n def addTwoNumbers(self, a: int, b: int) -> int:\n pass",
+ "tests": [
+ {"id": 0, "type": "functional", "input": "1\n2", "output": "3"},
+ {"id": 1, "type": "functional", "input": "5\n7", "output": "12"},
+ ],
+ "entry_point": "addTwoNumbers",
+ "domain": "coding",
+}
+
+
+def load_problems(dataset_names: Iterable[str] | None = None, **_: Any) -> list[dict]:
+ """Return tiny synthetic problems for smoke-testing multi-domain dispatch."""
+ if dataset_names is None:
+ dataset_names = ["math_debug", "guessing_debug"]
+
+ problems: list[dict] = []
+ for name in dataset_names:
+ if name == "math_debug":
+ problems.append(dict(_MATH_SAMPLE))
+ elif name == "guessing_debug":
+ problems.append(dict(_GUESSING_SAMPLE))
+ elif name == "coding_debug":
+ problems.append(dict(_CODING_SAMPLE))
+ else:
+ raise ValueError(f"Unknown debug dataset '{name}'")
+ return problems
diff --git a/pipelinerl/domains/multidomain/loader.py b/pipelinerl/domains/multidomain/loader.py
new file mode 100644
index 00000000..84c60ecc
--- /dev/null
+++ b/pipelinerl/domains/multidomain/loader.py
@@ -0,0 +1,100 @@
+from collections import defaultdict
+from typing import Dict, Iterable, List, Sequence
+
+from pipelinerl.domains.math.load_datasets import load_datasets as load_math_datasets
+from pipelinerl.domains.guessing.guessing import load_problems as load_guessing_problems
+from pipelinerl.domains.counting.counting import load_problems as load_counting_problems
+from pipelinerl.domains.chartqa.load_datasets import load_problems as load_chartqa_problems
+from pipelinerl.domains.coding.dataset import load_problems as load_coding_problems
+from pipelinerl.domains.fn_calling.dataset import load_problems as load_fn_calling_problems
+from pipelinerl.domains.miniwob.load_tasks import load_tasks as load_miniwob_tasks
+
+
+def _load_math(dataset_names: Sequence[str], *, seed=None, **_: dict) -> List[Dict]:
+ return load_math_datasets(list(dataset_names), seed=seed)
+
+
+def _load_guessing(dataset_names: Sequence[str], **_: dict) -> List[Dict]:
+ return load_guessing_problems(list(dataset_names))
+
+
+def _load_coding(dataset_names: Sequence[str], **loader_kwargs: dict) -> List[Dict]:
+ return load_coding_problems(list(dataset_names), **loader_kwargs)
+
+
+def _load_fn_calling(dataset_names: Sequence[str], **loader_kwargs: dict) -> List[Dict]:
+ return load_fn_calling_problems(list(dataset_names), **loader_kwargs)
+
+
+def _load_counting(dataset_names: Sequence[str], **_: dict) -> List[Dict]:
+ return load_counting_problems(list(dataset_names))
+
+
+def _load_chartqa(dataset_names: Sequence[str], **_: dict) -> List[Dict]:
+ return load_chartqa_problems(list(dataset_names))
+
+
+def _load_miniwob(dataset_names: Sequence[str], **loader_kwargs: dict) -> List[Dict]:
+ return load_miniwob_tasks(list(dataset_names), **loader_kwargs)
+
+
+DOMAIN_LOADERS = {
+ "math": _load_math,
+ "guessing": _load_guessing,
+ "coding": _load_coding,
+ "counting": _load_counting,
+ "chartqa": _load_chartqa,
+ "miniwob": _load_miniwob,
+ "fn_calling": _load_fn_calling,
+}
+
+
+def _parse_entry(entry: str) -> tuple[str, str]:
+ if "::" not in entry:
+ raise ValueError(
+ f"Dataset entry '{entry}' is missing a domain prefix. "
+ "Expected format '::'."
+ )
+ domain, dataset = entry.split("::", 1)
+ return domain.strip(), dataset.strip()
+
+
+def load_datasets(
+ dataset_names: Iterable[str] | str | None,
+ *,
+ seed: int | None = None,
+ per_domain_params: dict[str, dict] | None = None,
+ **kwargs,
+) -> List[Dict]:
+ if dataset_names is None:
+ return []
+ if isinstance(dataset_names, str):
+ dataset_names = [dataset_names]
+
+ grouped: dict[str, list[str]] = defaultdict(list)
+ for entry in dataset_names:
+ domain, name = _parse_entry(str(entry))
+ grouped[domain].append(name)
+
+ counters: dict[tuple[str, str], int] = defaultdict(int)
+ problems: List[Dict] = []
+ per_domain_params = dict(per_domain_params or {})
+ for domain, names in grouped.items():
+ loader = DOMAIN_LOADERS.get(domain)
+ if loader is None:
+ raise ValueError(f"No loader registered for domain '{domain}'")
+ domain_kwargs = dict(kwargs)
+ if domain in per_domain_params:
+ domain_kwargs.update(dict(per_domain_params[domain] or {}))
+ loaded = loader(names, seed=seed, **domain_kwargs)
+ for sample in loaded:
+ dataset_name = str(sample.get("dataset", names[0] if names else domain))
+ sample.setdefault("domain", domain)
+ if "id" not in sample:
+ key = (sample["domain"], dataset_name)
+ sample["id"] = counters[key]
+ counters[key] += 1
+ if "dataset" not in sample:
+ sample["dataset"] = dataset_name
+ problems.append(sample)
+ return problems
diff --git a/pipelinerl/entrypoints/run_environment.py b/pipelinerl/entrypoints/run_environment.py
index d5e05d2a..1951d4a7 100644
--- a/pipelinerl/entrypoints/run_environment.py
+++ b/pipelinerl/entrypoints/run_environment.py
@@ -1,14 +1,21 @@
import hydra
from omegaconf import DictConfig
-from pipelinerl.utils import better_crashing
+from pipelinerl.utils import better_crashing, select_environment_config
@hydra.main(config_path="../../conf", config_name="base", version_base="1.3.2")
def hydra_entrypoint(cfg: DictConfig):
with better_crashing("environment"):
- environment = hydra.utils.instantiate(cfg.environment)
this_job, = [job for job in cfg.jobs if job["idx"] == cfg.me.job_idx]
+ environment_cfg = select_environment_config(
+ cfg,
+ key=this_job.get("environment_key"),
+ index=this_job.get("environment_index"),
+ )
+ if environment_cfg is None:
+ raise ValueError("No environment configuration found for job")
+ environment = hydra.utils.instantiate(environment_cfg)
port = this_job["port"]
environment.launch(port=port)
diff --git a/pipelinerl/rollouts.py b/pipelinerl/rollouts.py
index dcb27f2d..1200ba23 100644
--- a/pipelinerl/rollouts.py
+++ b/pipelinerl/rollouts.py
@@ -64,3 +64,4 @@ class RolloutResult(BaseModel):
model_version: int | None = None
dataset_name: str | None = None
group_id: str | None = None
+ domain: str | None = None
diff --git a/pipelinerl/utils.py b/pipelinerl/utils.py
index fbcd9926..96c0ceab 100644
--- a/pipelinerl/utils.py
+++ b/pipelinerl/utils.py
@@ -7,11 +7,12 @@
from pathlib import Path
import traceback
from typing import Dict, Mapping, List, Any
+
import numpy as np
-from omegaconf import DictConfig
import psutil
import requests
from importlib.metadata import distributions
+from omegaconf import DictConfig, ListConfig, OmegaConf
from transformers import PreTrainedTokenizer
from pipelinerl.world import Job
@@ -22,6 +23,177 @@
logger = logging.getLogger(__name__)
+_ENV_METADATA_KEYS = {"key", "mode", "replicas_per_actor"}
+
+
+def _strip_environment_metadata(env_cfg: DictConfig | dict | None):
+ if env_cfg is None:
+ return None
+ if isinstance(env_cfg, DictConfig):
+ data = OmegaConf.to_container(env_cfg, resolve=True)
+ elif isinstance(env_cfg, dict):
+ data = dict(env_cfg)
+ else:
+ return env_cfg
+ for meta_key in _ENV_METADATA_KEYS:
+ data.pop(meta_key, None)
+ return OmegaConf.create(data)
+
+
+def _env_cfg_type(env_cfg, field: str):
+ if isinstance(env_cfg, DictConfig):
+ return env_cfg.get(field, None)
+ if isinstance(env_cfg, dict):
+ return env_cfg.get(field)
+ return None
+
+
+def select_environment_config(cfg: DictConfig, *, key: str | None = None, index: int | None = None):
+ env_cfgs = getattr(cfg, "environments", None)
+ if env_cfgs:
+ if isinstance(env_cfgs, (ListConfig, list)):
+ if key is not None:
+ for env_cfg in env_cfgs:
+ env_key = _env_cfg_type(env_cfg, "key") or _env_cfg_type(env_cfg, "name")
+ if env_key is not None and str(env_key) == str(key):
+ return _strip_environment_metadata(env_cfg)
+ if index is not None and 0 <= index < len(env_cfgs):
+ return _strip_environment_metadata(env_cfgs[index])
+ elif isinstance(env_cfgs, (DictConfig, dict)):
+ if key is not None and key in env_cfgs:
+ return _strip_environment_metadata(env_cfgs[key])
+ if index is not None:
+ for idx, env_key in enumerate(env_cfgs):
+ if idx == index:
+ return _strip_environment_metadata(env_cfgs[env_key])
+
+ return getattr(cfg, "environment", None)
+
+
+def _domain_mix_weights(cfg: DictConfig) -> dict[str, float]:
+ domain_mix_cfg = getattr(getattr(cfg, "actor", None), "domain_mix", None)
+ if not domain_mix_cfg:
+ return {}
+ try:
+ mix_weights = OmegaConf.to_container(domain_mix_cfg, resolve=True)
+ except Exception:
+ return {}
+ if not isinstance(mix_weights, dict):
+ return {}
+ weights: dict[str, float] = {}
+ for key, value in mix_weights.items():
+ try:
+ weight = float(value)
+ except (TypeError, ValueError):
+ continue
+ if weight > 0:
+ weights[str(key)] = weight
+ return weights
+
+
+def _apply_domain_mix_replicas(cfg: DictConfig, specs: list[dict[str, Any]], default_replicas: Any) -> None:
+ weights = _domain_mix_weights(cfg)
+ if not weights:
+ return
+ try:
+ default_value = float(default_replicas)
+ except (TypeError, ValueError):
+ return
+ if default_value <= 0:
+ return
+ weighted_specs = [spec for spec in specs if spec.get("mode") == "remote" and spec.get("key") in weights]
+ if not weighted_specs:
+ return
+ total_weight = sum(weights[spec["key"]] for spec in weighted_specs)
+ if total_weight <= 0:
+ return
+ average_weight = total_weight / len(weighted_specs)
+ if average_weight <= 0:
+ return
+ for spec in weighted_specs:
+ key = spec["key"]
+ scaled = default_value * (weights[key] / average_weight)
+ replicas = max(1, int(round(scaled)))
+ current = spec.get("replicas_per_actor")
+ if current is None or current == default_replicas:
+ spec["replicas_per_actor"] = replicas
+
+
+def collect_environment_specs(cfg: DictConfig) -> list[dict[str, Any]]:
+ specs: list[dict[str, Any]] = []
+ env_cfgs = getattr(cfg, "environments", None)
+ default_mode = str(getattr(cfg.world, "environment_mode", "remote"))
+ default_replicas = getattr(cfg.world, "env_replicas_per_actor", 1)
+
+ if isinstance(env_cfgs, (ListConfig, list)):
+ iterable = list(env_cfgs)
+ for idx, env_cfg in enumerate(iterable):
+ if env_cfg is None:
+ continue
+ key = _env_cfg_type(env_cfg, "key") or _env_cfg_type(env_cfg, "name")
+ mode = _env_cfg_type(env_cfg, "mode")
+ replicas = _env_cfg_type(env_cfg, "replicas_per_actor")
+ specs.append(
+ {
+ "key": str(key) if key is not None else f"environment_{idx}",
+ "mode": str(mode) if mode is not None else default_mode,
+ "replicas_per_actor": replicas,
+ "index": idx,
+ }
+ )
+ elif isinstance(env_cfgs, (DictConfig, dict)):
+ items = env_cfgs.items()
+ for idx, (key, env_cfg) in enumerate(items):
+ if env_cfg is None:
+ continue
+ mode = _env_cfg_type(env_cfg, "mode")
+ replicas = _env_cfg_type(env_cfg, "replicas_per_actor")
+ specs.append(
+ {
+ "key": str(key),
+ "mode": str(mode) if mode is not None else default_mode,
+ "replicas_per_actor": replicas,
+ "index": idx,
+ }
+ )
+ else:
+ single_env = getattr(cfg, "environment", None)
+ if single_env:
+ key = _env_cfg_type(single_env, "key") or _env_cfg_type(single_env, "name")
+ specs.append(
+ {
+ "key": str(key) if key is not None else "default",
+ "mode": default_mode,
+ "replicas_per_actor": default_replicas,
+ "index": 0,
+ }
+ )
+
+ active_domains = _domain_mix_weights(cfg)
+ if active_domains:
+ specs = [spec for spec in specs if spec.get("key") in active_domains]
+
+ _apply_domain_mix_replicas(cfg, specs, default_replicas)
+ return specs
+
+
+def resolve_environment_key(cfg: DictConfig, default: str | None = None) -> str | None:
+ explicit = cfg.get("environment_key", None) if hasattr(cfg, "get") else None
+ if explicit:
+ return str(explicit)
+ specs = collect_environment_specs(cfg)
+ if len(specs) == 1:
+ return specs[0]["key"]
+ return default
+
+
+def get_environment_jobs(cfg: DictConfig, key: str | None = None) -> list[Job]:
+ jobs_cfg = getattr(cfg, "jobs", [])
+ env_jobs = [Job(**job) for job in jobs_cfg if job["kind"] == "environment"]
+ if key is None:
+ return env_jobs
+ filtered = [job for job in env_jobs if getattr(job, "environment_key", None) == key]
+ return filtered or env_jobs
def init_wandb(
cfg: DictConfig,
@@ -237,6 +409,9 @@ def calculate_stats(stats: List | Dict[Any, Any]) -> Dict[str, float]:
if not isinstance(stats, list):
raise TypeError(f"Expected stats to be a list, got {type(stats)}")
+ if len(stats) == 0:
+ return {}
+
aggregated_stats = {
"max": float(max(stats)),
"min": float(min(stats)),
@@ -291,19 +466,27 @@ def wait_for_inference_servers(urls: list[str]):
def wait_for_environments(cfg: DictConfig):
- """
- Wait for the verifier to be ready.
- """
- env_jobs = [Job(**job) for job in cfg.jobs if job.kind == "environment"]
+ """Wait for remote environment servers to report healthy."""
+ specs = collect_environment_specs(cfg)
+ if not any(spec.get("mode") == "remote" for spec in specs):
+ return
+
+ env_jobs = get_environment_jobs(cfg)
+ if not env_jobs:
+ return
for job in env_jobs:
while True:
url = f"http://{job.hostname}:{job.port}/health"
- # use requests
try:
response = requests.get(url)
if response.status_code == 200:
+ logger.info(
+ "Environment %s ready at %s",
+ job.environment_key if job.environment_key is not None else job.replica_idx,
+ url,
+ )
break
- except:
+ except requests.exceptions.RequestException:
logger.info(f"Waiting for environment at {url} to be ready...")
time.sleep(5.0)
diff --git a/pipelinerl/world.py b/pipelinerl/world.py
index f41714e4..aa5ea420 100644
--- a/pipelinerl/world.py
+++ b/pipelinerl/world.py
@@ -26,6 +26,10 @@ class Job(BaseModel):
gpus: list[int] = []
# The URL of the job
url: str = ""
+ # Domain identifier for environment jobs
+ environment_key: str | None = None
+ # Idx of the environments in the list of environments
+ environment_index: int | None = None
class WorldMap:
@@ -71,8 +75,11 @@ def __init__(self, cfg: DictConfig, verbose: bool = False):
if place_inference_jobs:
self._place_inference_jobs(cfg)
self._place_pipeline_stages(cfg)
- if cfg.environment:
- self._place_environments(cfg)
+ # Lazy import to avoid circular dependency with utils.py
+ from pipelinerl.utils import collect_environment_specs
+ self.environment_specs = collect_environment_specs(cfg)
+ if any(spec["mode"] == "remote" for spec in self.environment_specs):
+ self._place_environments(cfg, self.environment_specs)
# Place the finetune workers on the remaining gpus, take all remaining GPUs
current_finetune_rank = 0
@@ -108,7 +115,7 @@ def __init__(self, cfg: DictConfig, verbose: bool = False):
for job in jobs:
self._log_info(f" {job.kind} {job.replica_idx} on gpus {job.gpus}, local idx {job.local_idx}")
- def add_job(self, node_rank: int, kind: str, replica_idx: int, local_idx: int = 0, port: int | None = None, gpus: list[int] | None = None, cpu_heavy: bool = False, url: str = "") -> Job:
+ def add_job(self, node_rank: int, kind: str, replica_idx: int, local_idx: int = 0, port: int | None = None, gpus: list[int] | None = None, cpu_heavy: bool = False, url: str = "", environment_key: str | None = None, environment_index: int | None = None) -> Job:
"""Add a job to the world map."""
if gpus is None:
gpus = []
@@ -121,7 +128,9 @@ def add_job(self, node_rank: int, kind: str, replica_idx: int, local_idx: int =
hostname=self.address_map[node_rank],
port=port,
gpus=gpus,
- url=url
+ url=url,
+ environment_key=environment_key,
+ environment_index=environment_index,
)
self.job_map[node_rank].append(job)
self.total_jobs += 1
@@ -187,18 +196,35 @@ def _place_pipeline_stages(self, cfg):
self.add_job(kind="actor", replica_idx=worker_idx, node_rank=node, gpus=[], cpu_heavy=True)
self.add_job(kind="preprocessor", replica_idx=worker_idx, node_rank=node, gpus=[], cpu_heavy=True)
- def _place_environments(self, cfg):
- for worker_idx in range(cfg.world.env_replicas):
- node = self.get_least_busy_node()
- envs_at_node = len([job for job in self.job_map[node] if job.kind == "environment"])
- self.add_job(
- kind="environment",
- replica_idx=worker_idx,
- node_rank=node,
- port=cfg.world.environment_start_port + envs_at_node,
- gpus=[],
- cpu_heavy=True,
- )
+ def _place_environments(self, cfg: DictConfig, environment_specs: list[dict]):
+ # Scale environment servers to be the same as llm servers
+ base_start_port = cfg.world.environment_start_port
+ llms_per_actor = getattr(self, "llms_per_actor", 1) or 1
+ global_replica_idx = 0
+ for spec_idx, spec in enumerate(environment_specs):
+ if spec["mode"] != "remote":
+ continue
+ replicas_per_actor = spec.get("replicas_per_actor")
+ if replicas_per_actor is None:
+ replicas_per_actor = getattr(cfg.world, "env_replicas_per_actor", None)
+ if replicas_per_actor is not None:
+ total_env_replicas = cfg.world.replicas * llms_per_actor * replicas_per_actor
+ else:
+ total_env_replicas = getattr(cfg.world, "env_replicas", cfg.world.replicas * llms_per_actor)
+ for replica_offset in range(total_env_replicas):
+ node = self.get_least_busy_node()
+ envs_at_node = len([job for job in self.job_map[node] if job.kind == "environment"])
+ self.add_job(
+ kind="environment",
+ replica_idx=global_replica_idx,
+ node_rank=node,
+ port=base_start_port + envs_at_node,
+ gpus=[],
+ cpu_heavy=True,
+ environment_key=spec["key"],
+ environment_index=spec.get("index", spec_idx),
+ )
+ global_replica_idx += 1
def _place_inference_jobs(self, cfg):
for _ in range(cfg.world.replicas):