diff --git a/conf/actor/web.yaml b/conf/actor/web.yaml index d2617e95..6925873e 100644 --- a/conf/actor/web.yaml +++ b/conf/actor/web.yaml @@ -3,9 +3,12 @@ llm_max_rollouts: 128 rollout_workers: 1 rollout_policy: pipelinerl.domains.deep_research.tapeagents_rollouts.generate_rollout -environment: - _target_: tapeagents.mcp.MCPEnvironment - config_path: conf/mcp/web.json +environments: + - key: mcp + mode: embedded + _target_: tapeagents.mcp.MCPEnvironment + config_path: conf/mcp/web.json +environment_key: mcp llm: _target_: tapeagents.llms.LiteLLM @@ -105,4 +108,4 @@ only_tasks: #[] # list of (level, task_num) - [1, 4] - [1, 5] - [1, 6] -- [1, 7] \ No newline at end of file +- [1, 7] diff --git a/conf/base.yaml b/conf/base.yaml index e3122f5a..31c0b93b 100644 --- a/conf/base.yaml +++ b/conf/base.yaml @@ -2,6 +2,7 @@ defaults: - finetune: base - rewards: pure_success - streams: files + - domain_mix: null - _self_ seed: 42 @@ -18,6 +19,7 @@ actor: result_queue_size: 64 throughput_window_size: 50 shared_memory_entry_size: 10000000 + domain_mix: null environment: null preprocess: input: actor @@ -135,4 +137,3 @@ wandb: wandb_dir: null # Comma-separated list of keywords to tag the run. tags: [] - diff --git a/conf/coding.yaml b/conf/coding.yaml new file mode 100644 index 00000000..7788bf7b --- /dev/null +++ b/conf/coding.yaml @@ -0,0 +1,55 @@ +defaults: + - base + - _self_ + +actor: + rollout_policy: pipelinerl.domains.coding.generate_coding_rollout + system_prompt: |- + You are an expert Python programmer. When providing code solutions, format your final code inside markdown code blocks using triple backticks with the python language identifier, like this: + ```python + # your code here + ``` + Provide complete, working implementations that pass all test cases. + task_template: |- + {task} + task_prompt: "" + ensure_boxed_answers: false + + coding_time_limit_s: 15.0 + coding_per_test_timeout_s: 10.0 + coding_memory_limit_bytes: 1073741824 + coding_compile_timeout_s: 10.0 + coding_sandbox_url: ${oc.env:CODING_SANDBOX_URL, "http://sandbox:8080/run_code"} + +dataset_loader: pipelinerl.domains.coding.dataset.load_problems +dataset_loader_params: + dataset_id: ServiceNow-AI/mixed-training-text-datasets + dataset_config: 80k-if-math-coding-fncalling-stem + split_ratios: + train: 0.9 + validation: 0.05 + test: 0.05 + allowed_call_types: + - assert + - std + max_examples_per_split: 2048 + trust_remote_code: true + huggingface_token: ${oc.env:CODING_HF_TOKEN, null} + +train_dataset_names: + - coding@train + +test_dataset_names: + - coding@validation + +environments: + - key: coding + mode: remote + _target_: pipelinerl.domains.coding.CodingSandboxEnvironment + sandbox_url: ${actor.coding_sandbox_url} + compile_timeout_s: ${actor.coding_compile_timeout_s} + run_timeout_s: ${actor.coding_per_test_timeout_s} + request_timeout_s: ${actor.coding_time_limit_s} + memory_limit_bytes: ${actor.coding_memory_limit_bytes} + +environment_key: coding diff --git a/conf/debug/multi_domain.yaml b/conf/debug/multi_domain.yaml new file mode 100644 index 00000000..03b4c56e --- /dev/null +++ b/conf/debug/multi_domain.yaml @@ -0,0 +1,30 @@ +defaults: + - base + - domain_rollouts: base + - override rewards: success_and_format + - _self_ + +actor: + rollout_policy: pipelinerl.domains.dispatcher.generate_multidomain_rollout + llm_max_rollouts: 2 + rollout_workers: 1 + domain_rollouts: + math: ${domain_rollouts.math} + guessing: ${domain_rollouts.guessing} + coding: ${domain_rollouts.coding} + +dataset_loader: pipelinerl.domains.multidomain.load_problems +train_dataset_names: + - math_debug + - guessing_debug + - coding_debug +test_dataset_names: + - math_debug + - coding_debug + +environment: null +environment_key: null + +world: + env_replicas_per_actor: 0 + environment_mode: embedded diff --git a/conf/domain_mix/README.md b/conf/domain_mix/README.md new file mode 100644 index 00000000..2143b288 --- /dev/null +++ b/conf/domain_mix/README.md @@ -0,0 +1,12 @@ +# Domain mix presets + +Hydra group `domain_mix` stores reusable presets for `actor.domain_mix`. + +Usage examples: + +``` +python main.py --config-name multi_domain/base +domain_mix=math_coding_70_30 +python main.py --config-name multi_domain/base +domain_mix=balanced +``` + +Override or extend these presets by creating new files under `conf/domain_mix/`. diff --git a/conf/domain_mix/balanced.yaml b/conf/domain_mix/balanced.yaml new file mode 100644 index 00000000..7b2a4529 --- /dev/null +++ b/conf/domain_mix/balanced.yaml @@ -0,0 +1,9 @@ +# @package actor.domain_mix + +math: 1.0 +guessing: 1.0 +counting: 1.0 +chartqa: 1.0 +miniwob: 1.0 +coding: 1.0 +fn_calling: 1.0 diff --git a/conf/domain_mix/coding_heavy.yaml b/conf/domain_mix/coding_heavy.yaml new file mode 100644 index 00000000..ff102abd --- /dev/null +++ b/conf/domain_mix/coding_heavy.yaml @@ -0,0 +1,4 @@ +# @package actor.domain_mix + +math: 0.3 +coding: 0.7 diff --git a/conf/domain_mix/main_mix.yaml b/conf/domain_mix/main_mix.yaml new file mode 100644 index 00000000..06038d0b --- /dev/null +++ b/conf/domain_mix/main_mix.yaml @@ -0,0 +1,5 @@ +# @package actor.domain_mix + +math: 0.4 +coding: 0.3 +fn_calling: 0.3 \ No newline at end of file diff --git a/conf/domain_mix/math_coding_70_30.yaml b/conf/domain_mix/math_coding_70_30.yaml new file mode 100644 index 00000000..7ea64468 --- /dev/null +++ b/conf/domain_mix/math_coding_70_30.yaml @@ -0,0 +1,4 @@ +# @package actor.domain_mix + +math: 0.7 +coding: 0.3 diff --git a/conf/domain_rollouts/base.yaml b/conf/domain_rollouts/base.yaml new file mode 100644 index 00000000..e744acba --- /dev/null +++ b/conf/domain_rollouts/base.yaml @@ -0,0 +1,8 @@ +# Mapping between domain identifiers and rollout callables. +math: pipelinerl.domains.math.generate_math_rollout +guessing: pipelinerl.domains.guessing.generate_guessing_rollout +counting: pipelinerl.domains.counting.generate_counting_rollout +miniwob: pipelinerl.domains.miniwob.rollouts.generate_miniwob_rollout +chartqa: pipelinerl.domains.chartqa.generate_chartqa_rollout +coding: pipelinerl.domains.coding.generate_coding_rollout +fn_calling: pipelinerl.domains.fn_calling.generate_fn_calling_rollout diff --git a/conf/fn_calling.yaml b/conf/fn_calling.yaml new file mode 100644 index 00000000..6eeb898c --- /dev/null +++ b/conf/fn_calling.yaml @@ -0,0 +1,36 @@ +defaults: + - base + - _self_ + +actor: + rollout_policy: pipelinerl.domains.fn_calling.generate_fn_calling_rollout + system_prompt: "" + task_template: "{task}" + task_prompt: "" + ensure_boxed_answers: false + +dataset_loader: pipelinerl.domains.fn_calling.dataset.load_problems +dataset_loader_params: + dataset_id: ServiceNow-AI/mixed-training-text-datasets + dataset_config: 80k-if-math-coding-fncalling-stem + split_ratios: + train: 0.9 + validation: 0.05 + test: 0.05 + allowed_call_types: [] + max_examples_per_split: 2048 + trust_remote_code: true + huggingface_token: ${oc.env:CODING_HF_TOKEN, null} + +train_dataset_names: + - fn_calling@train + +test_dataset_names: + - fn_calling@validation + +environments: + - key: fn_calling + mode: remote + _target_: pipelinerl.domains.fn_calling.AgenticToolsEnvironment + +environment_key: fn_calling diff --git a/conf/math.yaml b/conf/math.yaml index 069aa96b..d9a07c02 100644 --- a/conf/math.yaml +++ b/conf/math.yaml @@ -5,10 +5,13 @@ defaults: actor: rollout_policy: pipelinerl.domains.math.generate_math_rollout system_prompt: Please reason step by step, and put your final answer within \boxed{}. - task_template: |- - {task} -environment: - _target_: pipelinerl.domains.math.MathEnvironment + task_template: "{task}" + task_prompt: "" +environments: + - key: math + mode: remote + _target_: pipelinerl.domains.math.MathEnvironment +environment_key: math dataset_loader: pipelinerl.domains.math.load_datasets train_dataset_names: - open_reasoner_zero_57k @@ -16,4 +19,4 @@ train_dataset_names: test_dataset_names: - aime_2024 - amc_2023 - - math_500 \ No newline at end of file + - math_500 diff --git a/conf/math_code.yaml b/conf/math_code.yaml new file mode 100644 index 00000000..7a23a7e9 --- /dev/null +++ b/conf/math_code.yaml @@ -0,0 +1,69 @@ +defaults: + - base + - /domain_rollouts@domain_rollouts: base + - domain_mix: math_coding_70_30 + - _self_ + +actor: + rollout_policy: pipelinerl.domains.dispatcher.generate_multidomain_rollout + system_prompt: "" + task_template: |- + {task} + task_prompt: "" + ensure_boxed_answers: false + domain_rollouts: + math: ${domain_rollouts.math} + coding: ${domain_rollouts.coding} + coding_time_limit_s: 15.0 + coding_per_test_timeout_s: 10.0 + coding_memory_limit_bytes: 1073741824 + coding_compile_timeout_s: 10.0 + coding_sandbox_url: ${oc.env:CODING_SANDBOX_URL, "http://sandbox:8080/run_code"} + +dataset_loader: pipelinerl.domains.multidomain.loader.load_datasets +dataset_loader_params: + per_domain_params: + coding: + dataset_id: ServiceNow-AI/mixed-training-text-datasets + dataset_config: 80k-if-math-coding-fncalling-stem + split_ratios: + train: 0.9 + validation: 0.05 + test: 0.05 + allowed_call_types: + - assert + - std + max_examples_per_split: 2048 + trust_remote_code: true + huggingface_token: ${oc.env:CODING_HF_TOKEN, null} + +train_dataset_names: + - math::open_reasoner_zero_57k + - math::open_reasoner_zero_extended_72k + - coding::coding@train + +test_dataset_names: + - math::aime_2024 + - math::amc_2023 + - math::math_500 + - coding::coding@validation + +environments: + - key: math + mode: remote + replicas_per_actor: ${world.env_replicas_per_actor} + _target_: pipelinerl.domains.math.MathEnvironment + - key: coding + mode: remote + replicas_per_actor: ${world.env_replicas_per_actor} + _target_: pipelinerl.domains.coding.CodingSandboxEnvironment + sandbox_url: ${actor.coding_sandbox_url} + compile_timeout_s: ${actor.coding_compile_timeout_s} + run_timeout_s: ${actor.coding_per_test_timeout_s} + request_timeout_s: ${actor.coding_time_limit_s} + memory_limit_bytes: ${actor.coding_memory_limit_bytes} + +environment_key: null + +world: + env_replicas_per_actor: 1 diff --git a/conf/multi_domain/base.yaml b/conf/multi_domain/base.yaml new file mode 100644 index 00000000..cda10faa --- /dev/null +++ b/conf/multi_domain/base.yaml @@ -0,0 +1,92 @@ +# @package _global_ +defaults: + - /domain_rollouts@domain_rollouts: base + - domain_mix: null + +actor: + rollout_policy: pipelinerl.domains.dispatcher.generate_multidomain_rollout + system_prompt: "" + task_template: |- + {task} + task_prompt: "" + ensure_boxed_answers: false + domain_mix: null + # Domain-specific system prompts (used when actor.system_prompt is empty) + domain_system_prompts: + coding: |- + You are an expert Python programmer. When providing code solutions, format your final code inside markdown code blocks using triple backticks with the python language identifier, like this: + ```python + # your code here + ``` + Provide complete, working implementations that pass all test cases. + fn_calling: "" + math: "" + domain_rollouts: + math: ${domain_rollouts.math} + guessing: ${domain_rollouts.guessing} + counting: ${domain_rollouts.counting} + chartqa: ${domain_rollouts.chartqa} + miniwob: ${domain_rollouts.miniwob} + coding: ${domain_rollouts.coding} + fn_calling: ${domain_rollouts.fn_calling} + coding_time_limit_s: 15.0 + coding_per_test_timeout_s: 10.0 + coding_memory_limit_bytes: 1073741824 + coding_compile_timeout_s: 10.0 + coding_sandbox_url: ${oc.env:CODING_SANDBOX_URL, "http://sandbox:8080/run_code"} + +dataset_loader: pipelinerl.domains.multidomain.loader.load_datasets +dataset_loader_params: + per_domain_params: + coding: + dataset_id: ServiceNow-AI/mixed-training-text-datasets + dataset_config: 80k-if-math-coding-fncalling-stem + split_ratios: + train: 0.9 + validation: 0.05 + test: 0.05 + allowed_call_types: + - assert + - std + max_examples_per_split: 2048 + trust_remote_code: true + huggingface_token: ${oc.env:CODING_HF_TOKEN, null} + fn_calling: + dataset_id: ServiceNow-AI/mixed-training-text-datasets + dataset_config: 80k-if-math-coding-fncalling-stem + split_ratios: + train: 0.9 + validation: 0.05 + test: 0.05 + allowed_call_types: [] + max_examples_per_split: 2048 + trust_remote_code: true + huggingface_token: ${oc.env:CODING_HF_TOKEN, null} + +train_dataset_names: [] +test_dataset_names: [] + +environments: + - key: math + mode: remote + replicas_per_actor: ${world.env_replicas_per_actor} + _target_: pipelinerl.domains.math.MathEnvironment + - key: coding + mode: remote + replicas_per_actor: ${world.env_replicas_per_actor} + _target_: pipelinerl.domains.coding.CodingSandboxEnvironment + sandbox_url: ${actor.coding_sandbox_url} + compile_timeout_s: ${actor.coding_compile_timeout_s} + run_timeout_s: ${actor.coding_per_test_timeout_s} + request_timeout_s: ${actor.coding_time_limit_s} + memory_limit_bytes: ${actor.coding_memory_limit_bytes} + - key: fn_calling + mode: remote + replicas_per_actor: ${world.env_replicas_per_actor} + _target_: pipelinerl.domains.fn_calling.AgenticToolsEnvironment + max_workers: 4 + +environment_key: null + +world: + env_replicas_per_actor: 1 diff --git a/conf/multi_domain/main_mix.yaml b/conf/multi_domain/main_mix.yaml new file mode 100644 index 00000000..cae97fd6 --- /dev/null +++ b/conf/multi_domain/main_mix.yaml @@ -0,0 +1,10 @@ +defaults: + - base + - domain_mix: main_mix + - _self_ + +actor: + domain_rollouts: + math: ${domain_rollouts.math} + coding: ${domain_rollouts.coding} + fn_calling: ${domain_rollouts.fn_calling} diff --git a/conf/test.yaml b/conf/test.yaml index b86dfa5d..de9399e5 100644 --- a/conf/test.yaml +++ b/conf/test.yaml @@ -3,7 +3,7 @@ defaults: finetune: seq_length: 4000 gradient_accumulation_passes: 6 - max_train_steps: 1 + max_train_steps: 100 train_batch_size: 4 attempts: 4 llm: diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 1c238ff9..60e878c2 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -15,12 +15,13 @@ import aiohttp import hydra import uvloop -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from pydantic import BaseModel, Field import wandb -from pipelinerl.finetune.logging_ import flatten_dict_config, init_wandb +from pipelinerl.domain_sampling import DomainWeightedSampler from pipelinerl.finetune_loop import calculate_train_steps +from pipelinerl.finetune.logging_ import flatten_dict_config, init_wandb from pipelinerl.llm import TrainableLLM from pipelinerl.rollouts import BaseMetrics, RolloutResult from pipelinerl.shared_memory_array import SharedMemoryQueue @@ -36,6 +37,7 @@ from .utils import ( always_or_never_success_stats, calculate_stats, + resolve_environment_key, setup_logging, wait_for_environments, wait_for_inference_servers, @@ -173,7 +175,23 @@ async def rollout_and_maybe_produce_result( llm = llms[llm_index] model_version = trainer_state.propagated_weight_version assert model_version is not None + domain_value: str | None = None + if isinstance(problem, dict): + raw_domain = problem.get("domain") + if raw_domain: + domain_value = str(raw_domain) + elif isinstance(problem, tuple) and len(problem) >= 2 and isinstance(problem[1], dict): + raw_domain = problem[1].get("domain") + if raw_domain: + domain_value = str(raw_domain) + + if not domain_value: + resolved = resolve_environment_key(cfg) + domain_value = str(resolved) if resolved else None + rollout_result = await rollout_policy(cfg, llm, problem, session) + if domain_value and not rollout_result.domain: + rollout_result.domain = domain_value rollout_result.model_version = model_version # Make a group id that will be different from groups made by another rollout maker full_group_id = f"{scheduler_name}_{group_id}" @@ -348,6 +366,7 @@ def init_stats(self): self.latency_list = [] self.model_versions_list = [] self.sliding_stats = defaultdict(list) + self.domain_counts = defaultdict(int) def compute_domain_agnostic_metrics(self, result: RolloutResult) -> Dict[str, float]: metrics = {} @@ -367,6 +386,16 @@ def update_stats(self, rollout_results: List[RolloutResult]): group_id = result.group_id self.latency_list.append(result.latency) self.model_versions_list.append(result.model_version) + domain_key: str | None = None + if getattr(result, "domain", None): + domain_key = str(result.domain) + elif isinstance(dataset_name, str): + domain_key = dataset_name.split("::", 1)[0] + elif dataset_name is not None: + domain_key = str(dataset_name) + + if domain_key: + self.domain_counts[domain_key] += len(result.training_texts) domain_agnostic_metrics = self.compute_domain_agnostic_metrics(result) all_metrics = result.metrics.model_dump() | domain_agnostic_metrics for k, v in all_metrics.items(): @@ -403,8 +432,15 @@ def run(self, dataset: list[tuple[str, dict]]): # If training, we expect to sample infinitely # for train sample, sample random batches infinitely # for test samples, loop through the dataset once + domain_sampler = None if self.is_training: problem_iter = random_iter(dataset) + domain_mix_cfg = getattr(self.cfg.actor, "domain_mix", None) + if domain_mix_cfg: + mix_weights = OmegaConf.to_container(domain_mix_cfg, resolve=True) + if not isinstance(mix_weights, dict): + raise ValueError("actor.domain_mix must be a mapping from domain to weight") + domain_sampler = DomainWeightedSampler(dataset, mix_weights) else: problem_iter = sequential_iter(dataset) assert self.trainer_state.propagated_weight_version is not None @@ -466,7 +502,10 @@ def run(self, dataset: list[tuple[str, dict]]): if not blocked_by_lag and not self.problem_queue.full(): try: try: - problem = next(problem_iter) + if domain_sampler is not None: + problem = domain_sampler.sample() + else: + problem = next(problem_iter) self.problem_queue.put(problem, block=False) submitted_groups += 1 except queue.Full: @@ -491,6 +530,12 @@ def run(self, dataset: list[tuple[str, dict]]): assert isinstance(rollout_results[0], RolloutResult) group_samples = sum(len(r.training_texts) for r in rollout_results) + # Track completions per domain for adaptive sampling + if domain_sampler is not None: + for r in rollout_results: + if r.domain: + domain_sampler.record_completion(r.domain) + published_samples += group_samples samples_in_queue = self.result_queue.qsize() * attempts all_text_dumps = [] @@ -570,6 +615,25 @@ def publish_stats(self, stats_writer: StreamWriter, loop_stats: Dict): ) stats |= loop_stats + + total_domain_samples = sum(self.domain_counts.values()) + if total_domain_samples: + for domain, count in sorted(self.domain_counts.items()): + stats[f"{split_name}domain_mix_count/{domain}"] = count + stats[f"{split_name}domain_mix_actual/{domain}"] = count / total_domain_samples + + domain_mix_cfg = getattr(self.cfg.actor, "domain_mix", None) + if domain_mix_cfg: + mix_weights = OmegaConf.to_container(domain_mix_cfg, resolve=True) + if isinstance(mix_weights, dict): + target_total = sum(float(v) for v in mix_weights.values() if float(v) > 0) + if target_total > 0: + for domain, weight in mix_weights.items(): + stats[f"{split_name}domain_mix_target/{domain}"] = float(weight) / target_total + else: + for domain in mix_weights: + stats[f"{split_name}domain_mix_target/{domain}"] = 0.0 + for k, v in self.sliding_stats.items(): stats[k] = sum(v) / len(v) if v else 0 if self.cfg.wandb.use_wandb: diff --git a/pipelinerl/async_llm.py b/pipelinerl/async_llm.py index 4e78ebf9..33391ed3 100644 --- a/pipelinerl/async_llm.py +++ b/pipelinerl/async_llm.py @@ -15,6 +15,18 @@ logger = logging.getLogger(__name__) +def _to_object(obj): + """Recursively convert OmegaConf objects to plain Python containers.""" + if isinstance(obj, (DictConfig, ListConfig)): + return OmegaConf.to_container(obj, resolve=True) + elif isinstance(obj, dict): + return {k: _to_object(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [_to_object(v) for v in obj] + else: + return obj + + def extract_images_from_messages(messages: list[dict]) -> list[Image.Image]: """Extract PIL Images from multimodal messages.""" @@ -44,18 +56,6 @@ def extract_images_from_messages(messages: list[dict]) -> list[Image.Image]: return images -def _to_plain_obj(value): - """convert OmegaConf containers into Python types""" - - if isinstance(value, (DictConfig, ListConfig)): - return OmegaConf.to_container(value, resolve=True) - if isinstance(value, dict): - return {key: _to_plain_obj(val) for key, val in value.items()} - if isinstance(value, (list, tuple)): - return [_to_plain_obj(item) for item in value] - return value - - async def llm_async_generate( llm: TrainableLLM, prompt: Prompt, session: aiohttp.ClientSession ) -> LLMCall: @@ -88,7 +88,9 @@ async def llm_async_generate( logger.debug(f"POST request to {llm.base_url}/v1/chat/completions") - payload = _to_plain_obj({**data, **extra_parameters}) + payload = {**data, **extra_parameters} + # TODO: upgrade omegaconf and use OmegaConf.to_object for recursive conversion + payload = _to_object(payload) async with session.post( url=f"{llm.base_url}/v1/chat/completions", json=payload, @@ -101,6 +103,7 @@ async def llm_async_generate( response.raise_for_status() data = await response.json() + finish_reason: str | None = None try: content = data["choices"][0]["message"]["content"] if not content: diff --git a/pipelinerl/domain_sampling.py b/pipelinerl/domain_sampling.py new file mode 100644 index 00000000..497bb0f8 --- /dev/null +++ b/pipelinerl/domain_sampling.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +import logging +import random +from collections import defaultdict +from typing import Mapping + +logger = logging.getLogger(__name__) + +# Minimum completions before dynamic adjustment kicks in +_MIN_COMPLETIONS_FOR_ADJUSTMENT = 50 +# Clamp adjustment factors to avoid extreme swings +_MIN_ADJUSTMENT = 0.1 +_MAX_ADJUSTMENT = 10.0 + + +class DomainWeightedSampler: + """Randomly samples problems according to per-domain weights. + + Supports dynamic weight adjustment based on completion tracking to maintain + target domain ratios in the output stream despite varying processing speeds. + """ + + def __init__( + self, + samples: list[dict], + weights: Mapping[str, float], + rng: random.Random | None = None, + adaptive: bool = True, + ): + if not weights: + raise ValueError("domain_mix cannot be empty when provided") + self.random = rng or random + self.adaptive = adaptive + samples_by_domain: dict[str, list[dict]] = defaultdict(list) + for sample in samples: + domain = sample.get("domain") + if not domain: + raise ValueError("Each sample must include a 'domain' field for domain_mix to work") + samples_by_domain[str(domain)].append(sample) + + provided_domains = {str(domain) for domain in weights} + cleaned_weights: dict[str, float] = {} + for domain, value in weights.items(): + val = float(value) + if val < 0: + raise ValueError(f"domain_mix weight for '{domain}' must be non-negative") + if val == 0: + continue + cleaned_weights[str(domain)] = val + + if not cleaned_weights: + raise ValueError("domain_mix must include at least one positive weight") + + # accept zero weights but require the domain to be declared. + missing = set(samples_by_domain) - provided_domains + if missing: + missing_list = ", ".join(sorted(missing)) + raise ValueError( + "domain_mix is missing weights for dataset domains: " + missing_list + ) + + unused = provided_domains - set(samples_by_domain) + if unused: + unused_list = ", ".join(sorted(unused)) + raise ValueError( + "domain_mix specifies domains not present in dataset: " + unused_list + ) + + self.samples_by_domain = samples_by_domain + self.domains: list[str] = [] + self.base_weights: dict[str, float] = {} + self.thresholds: list[float] = [] + total = 0.0 + for domain, weight in cleaned_weights.items(): + total += weight + self.domains.append(domain) + self.base_weights[domain] = weight + self.thresholds.append(total) + if total <= 0: + raise ValueError("Sum of domain_mix weights must be positive") + self.total_weight = total + + # Target ratios (normalized weights) + self.target_ratios = {d: w / total for d, w in self.base_weights.items()} + + # Completion tracking for adaptive sampling + self.completion_counts: dict[str, int] = {d: 0 for d in self.domains} + self.total_completions = 0 + self._last_log_completions = 0 + + def record_completion(self, domain: str) -> None: + """Record that a sample from the given domain has completed processing. + + This enables adaptive weight adjustment to maintain target domain ratios + in the output stream despite varying processing speeds per domain. + """ + if domain in self.completion_counts: + self.completion_counts[domain] += 1 + self.total_completions += 1 + + # Log periodically + if self.total_completions - self._last_log_completions >= 500: + self._log_domain_stats() + self._last_log_completions = self.total_completions + + def _log_domain_stats(self) -> None: + """Log current domain distribution vs targets.""" + if self.total_completions == 0: + return + parts = [] + for domain in self.domains: + actual = self.completion_counts[domain] / self.total_completions + target = self.target_ratios[domain] + parts.append(f"{domain}={actual:.1%}(target={target:.1%})") + logger.info(f"Domain completion stats ({self.total_completions} total): {', '.join(parts)}") + + def _pick_domain_static(self) -> str: + """Pick domain using static weights (original behavior).""" + r = self.random.random() * self.total_weight + for domain, threshold in zip(self.domains, self.thresholds): + if r < threshold: + return domain + return self.domains[-1] + + def _pick_domain_adaptive(self) -> str: + """Pick domain using dynamically adjusted weights based on completion ratios.""" + # Calculate current completion ratios + current_ratios = { + d: self.completion_counts[d] / self.total_completions + for d in self.domains + } + + # Calculate adjusted weights: boost under-represented, reduce over-represented + adjusted_weights: dict[str, float] = {} + for domain in self.domains: + target = self.target_ratios[domain] + current = current_ratios[domain] + + if current > 0: + # adjustment = target / current + # If current=46%, target=30% → adjustment=0.65 (sample less) + # If current=18%, target=30% → adjustment=1.67 (sample more) + adjustment = target / current + adjustment = max(_MIN_ADJUSTMENT, min(_MAX_ADJUSTMENT, adjustment)) + else: + # No completions yet for this domain, boost sampling + adjustment = _MAX_ADJUSTMENT + + adjusted_weights[domain] = self.base_weights[domain] * adjustment + + # Sample based on adjusted weights + total = sum(adjusted_weights.values()) + r = self.random.random() * total + cumsum = 0.0 + for domain in self.domains: + cumsum += adjusted_weights[domain] + if r < cumsum: + return domain + return self.domains[-1] + + def _pick_domain(self) -> str: + """Pick a domain for the next sample.""" + if not self.adaptive or self.total_completions < _MIN_COMPLETIONS_FOR_ADJUSTMENT: + return self._pick_domain_static() + return self._pick_domain_adaptive() + + def sample(self) -> dict: + domain = self._pick_domain() + return self.random.choice(self.samples_by_domain[domain]) diff --git a/pipelinerl/domains/__init__.py b/pipelinerl/domains/__init__.py index e69de29b..f73dd529 100644 --- a/pipelinerl/domains/__init__.py +++ b/pipelinerl/domains/__init__.py @@ -0,0 +1,8 @@ +"""Domain utilities and rollouts.""" + +from .dispatcher import generate_multidomain_rollout, register_domain_rollout + +__all__ = [ + "generate_multidomain_rollout", + "register_domain_rollout", +] diff --git a/pipelinerl/domains/chartqa/load_datasets.py b/pipelinerl/domains/chartqa/load_datasets.py index 49dc92f7..8265086a 100644 --- a/pipelinerl/domains/chartqa/load_datasets.py +++ b/pipelinerl/domains/chartqa/load_datasets.py @@ -5,6 +5,8 @@ logger = logging.getLogger(__name__) +DOMAIN = "chartqa" + def process_chartqa(dataset, dataset_name: str): """Process ChartQA dataset into standardized format.""" @@ -31,6 +33,7 @@ def add_ids(dataset: list[dict]): """Add sequential IDs to dataset items.""" for i, entry in enumerate(dataset): entry["id"] = i + entry.setdefault("domain", DOMAIN) return dataset diff --git a/pipelinerl/domains/coding/__init__.py b/pipelinerl/domains/coding/__init__.py new file mode 100644 index 00000000..f04984b7 --- /dev/null +++ b/pipelinerl/domains/coding/__init__.py @@ -0,0 +1,17 @@ +"""Coding domain rollouts and dataset utilities.""" + +from .rollouts import generate_coding_rollout +from .verifier_api import ( + CodingSandboxEnvironment, + evaluate_coding_prediction, + verify_coding_solution_rpc, +) +from .dataset import load_problems + +__all__ = [ + "CodingSandboxEnvironment", + "evaluate_coding_prediction", + "generate_coding_rollout", + "load_problems", + "verify_coding_solution_rpc", +] diff --git a/pipelinerl/domains/coding/dataset.py b/pipelinerl/domains/coding/dataset.py new file mode 100644 index 00000000..e8b1dbbc --- /dev/null +++ b/pipelinerl/domains/coding/dataset.py @@ -0,0 +1,383 @@ +from __future__ import annotations + +import json +import logging +import math +import random +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Sequence + +import pyarrow as pa +import pyarrow.parquet as pq +from datasets import Dataset, DownloadMode, load_dataset +from datasets.exceptions import DatasetGenerationError +from huggingface_hub import snapshot_download +from omegaconf import DictConfig, OmegaConf + +logger = logging.getLogger(__name__) + +DOMAIN_NAME = "coding" +BASE_DATASET_NAME = "mixed-training-text-datasets" +DATASET_ALIASES = frozenset({DOMAIN_NAME, "code"}) +SUPPORTED_DATASET_NAMES = frozenset({BASE_DATASET_NAME}) +DEFAULT_DATASET_ID = "ServiceNow-AI/mixed-training-text-datasets" +DEFAULT_DATASET_CONFIG = "80k-if-math-coding-fncalling-stem" +DEFAULT_SPLIT_ORDER = ("train", "validation", "test") +DEFAULT_SPLIT_RATIOS = tuple((name, ratio) for name, ratio in zip(DEFAULT_SPLIT_ORDER, (0.9, 0.05, 0.05))) +DEFAULT_CALL_TYPES = ("assert", "std") + + +@dataclass(frozen=True) +class DatasetSpec: + name: str + split: str = "train" + + +@dataclass +@dataclass(frozen=True) +class ResolvedDatasetEntry: + spec: DatasetSpec + label: str + + +@dataclass(frozen=True) +class DatasetResolution: + requested: list[str] + entries: list[ResolvedDatasetEntry] + + +@dataclass +class DatasetOptions: + dataset_id: str = DEFAULT_DATASET_ID + dataset_config: str | None = DEFAULT_DATASET_CONFIG + split_ratios: Sequence[tuple[str, float]] = DEFAULT_SPLIT_RATIOS + trust_remote_code: bool = True + max_examples_per_split: int | None = None + allowed_call_types: Sequence[str] = DEFAULT_CALL_TYPES + huggingface_token: str | None = None + ability_filter: str = "code" + + +CacheKey = tuple[str, str | None, bool, str | None] +_DATASET_CACHE: dict[CacheKey, Dataset] = {} + + +def _normalize_loader_options(loader_kwargs: Dict[str, Any]) -> DatasetOptions: + def _to_native(value: Any) -> Any: + if isinstance(value, DictConfig): + return OmegaConf.to_container(value, resolve=True) + return value + + options = DatasetOptions() + if loader_kwargs: + raw_ratios = _to_native(loader_kwargs.get("split_ratios")) + if isinstance(raw_ratios, dict): + options.split_ratios = tuple(raw_ratios.items()) + dataset_config = loader_kwargs.get("dataset_config") + if dataset_config: + options.dataset_config = str(dataset_config) + dataset_id = loader_kwargs.get("dataset_id") + if dataset_id: + options.dataset_id = str(dataset_id) + if "max_examples_per_split" in loader_kwargs: + value = loader_kwargs["max_examples_per_split"] + options.max_examples_per_split = int(value) if value is not None else None + if "trust_remote_code" in loader_kwargs: + options.trust_remote_code = bool(loader_kwargs["trust_remote_code"]) + call_types = _to_native(loader_kwargs.get("allowed_call_types")) + if isinstance(call_types, Iterable) and not isinstance(call_types, (str, bytes)): + options.allowed_call_types = tuple(str(item) for item in call_types) + token = loader_kwargs.get("huggingface_token") or loader_kwargs.get("hf_token") + if token: + options.huggingface_token = str(token) + ability = loader_kwargs.get("ability_filter") or loader_kwargs.get("ability") + if ability: + options.ability_filter = str(ability) + return options + + +def _ability_matches(value: Any, ability: str | None) -> bool: + """Return True when the sample's ability field contains the requested ability. + + The upstream dataset sometimes stores ability as a single string ("code") and + sometimes as a list (e.g., ["agentic_fn_calling"]). Filtering must accept both + shapes or we drop all fn_calling samples. + """ + + if ability is None: + return True + if value is None: + return False + if isinstance(value, str): + return value == ability + if isinstance(value, (list, tuple, set)): + return ability in value + return False + + +def parse_dataset_name(entry: str) -> DatasetSpec: + text = entry.strip() + if "@" not in text: + return DatasetSpec(name=text) + name, split = text.split("@", 1) + return DatasetSpec(name=name.strip(), split=split.strip() or "train") + + +def _normalize_dataset_names(dataset_names: List[str] | str | None) -> list[str]: + if dataset_names is None: + return [] + if isinstance(dataset_names, str): + return [dataset_names] + return [str(entry) for entry in dataset_names] + + +def _resolve_dataset_requests(dataset_names: List[str] | str | None) -> DatasetResolution: + requested = _normalize_dataset_names(dataset_names) + entries: list[ResolvedDatasetEntry] = [] + for entry in requested: + spec = parse_dataset_name(entry) + base_name = BASE_DATASET_NAME if spec.name in DATASET_ALIASES else spec.name + if base_name not in SUPPORTED_DATASET_NAMES: + raise ValueError(f"Unsupported coding dataset '{spec.name}'") + resolved_spec = DatasetSpec(name=base_name, split=spec.split) + entries.append(ResolvedDatasetEntry(spec=resolved_spec, label=f"{spec.name}@{spec.split}")) + return DatasetResolution(requested=requested, entries=entries) + + +def _load_dataset(options: DatasetOptions) -> Dataset: + ability = options.ability_filter + cache_key: CacheKey = ( + options.dataset_id, + options.dataset_config, + options.trust_remote_code, + ability, + ) + if cache_key in _DATASET_CACHE: + return _DATASET_CACHE[cache_key] + + def _materialize_dataset(**extra_kwargs: Any) -> Dataset: + return load_dataset( + options.dataset_id, + options.dataset_config, + split="train", + trust_remote_code=options.trust_remote_code, + token=options.huggingface_token, + **extra_kwargs, + ) + + try: + ds = _materialize_dataset() + except DatasetGenerationError as exc: + logger.warning( + "load_dataset failed for %s (%s): %s. Forcing re-download.", + options.dataset_id, + options.dataset_config, + exc, + ) + try: + ds = _materialize_dataset(download_mode=DownloadMode.FORCE_REDOWNLOAD) + except DatasetGenerationError as redownload_exc: + logger.warning( + "Forced re-download also failed for %s (%s): %s. Falling back to streaming mode.", + options.dataset_id, + options.dataset_config, + redownload_exc, + ) + stream = _materialize_dataset(streaming=True) + try: + ds = Dataset.from_list(list(stream)) + except OSError as stream_exc: + logger.warning( + "Streaming fallback also failed for %s (%s): %s. Downloading snapshot locally.", + options.dataset_id, + options.dataset_config, + stream_exc, + ) + ds = _load_snapshot(options) + + if ability: + ds = ds.filter(lambda sample: _ability_matches(sample.get("ability"), ability)) + _DATASET_CACHE[cache_key] = ds + logger.info( + "Loaded %s (%s) with %d coding samples", + options.dataset_id, + options.dataset_config, + len(ds), + ) + return ds + + +def _load_snapshot(options: DatasetOptions) -> Dataset: + if not options.dataset_config: + raise RuntimeError("Snapshot fallback requires a dataset_config but none was provided.") + + snapshot_dir = snapshot_download( + repo_id=options.dataset_id, + repo_type="dataset", + token=options.huggingface_token, + allow_patterns=(f"{options.dataset_config}/*", "dataset_infos.json"), + ) + config_dir = Path(snapshot_dir) / options.dataset_config + parquet_files = sorted(config_dir.glob("train-*.parquet")) + if not parquet_files: + raise RuntimeError( + f"Snapshot for {options.dataset_id} ({options.dataset_config}) contained no train parquet shards" + ) + + tables: list[pa.Table] = [] + for shard in parquet_files: + try: + tables.append(pq.read_table(shard)) + except Exception as exc: # pragma: no cover - defensive fallback + logger.warning("Skipping corrupted parquet shard %s: %s", shard.name, exc) + if not tables: + raise RuntimeError( + f"All locally downloaded shards failed to load for {options.dataset_id} ({options.dataset_config})." + ) + + table = pa.concat_tables(tables) + logger.info( + f"Loaded {table.num_rows} rows via snapshot fallback ({len(tables)} shards)", + ) + return Dataset.from_table(table) + + +def _normalized_split_sequence(ratios: Sequence[tuple[str, float]]) -> list[tuple[str, float]]: + if not ratios: + return [("train", 1.0)] + values = [(name, float(portion)) for name, portion in ratios if float(portion) > 0] + total = sum(portion for _, portion in values) + if math.isclose(total, 0.0): + return [("train", 1.0)] + return [(name, portion / total) for name, portion in values] + + +def _slice_indices(count: int, split: str, ratios: Sequence[tuple[str, float]], seed: int | None) -> list[int]: + normalized = _normalized_split_sequence(ratios) + if split not in {name for name, _ in normalized}: + raise ValueError(f"Split '{split}' is not defined in split_ratios {normalized}") + indices = list(range(count)) + rng = random.Random(seed or 0) + rng.shuffle(indices) + cumulative = 0.0 + start = 0 + selected: dict[str, list[int]] = {} + for idx, (name, portion) in enumerate(normalized): + cumulative += portion + end = count if idx == len(normalized) - 1 else min(count, int(round(cumulative * count))) + selected[name] = indices[start:end] + start = end + return selected.get(split, []) + + +def _decode_extra_info(raw_extra: Any) -> dict[str, Any]: + if isinstance(raw_extra, dict): + return raw_extra + if isinstance(raw_extra, str) and raw_extra.strip(): + try: + return json.loads(raw_extra) + except json.JSONDecodeError: + logger.debug("Failed to decode extra_info: %s", raw_extra[:128]) + return {} + + +def _build_record(sample: dict, dataset_label: str, allowed_call_types: Sequence[str]) -> dict | None: + reward_model = sample.get("reward_model") or {} + reward_raw = reward_model.get("ground_truth") + if reward_raw is None: + return None + try: + reward_context = json.loads(reward_raw) + except (TypeError, json.JSONDecodeError): + return None + if allowed_call_types and reward_context.get("call_type") not in set(allowed_call_types): + return None + + prompt_messages = sample.get("prompt") or [] + if not prompt_messages: + return None + task = prompt_messages[0].get("content") + if not task: + return None + + extra_info = _decode_extra_info(sample.get("extra_info")) + record = { + "dataset": dataset_label, + "task": task, + "reward_context": reward_context, + "extra_info": extra_info, + } + + # For fn_calling samples, keep the full prompt so multi-turn tool-calling + # instructions are available downstream. + if sample.get("ability") == "agentic_fn_calling": + if len(prompt_messages) == 1 and isinstance(prompt_messages[0], list): + record["prompt"] = prompt_messages[0] + else: + record["prompt"] = prompt_messages + + return record + + +def _load_split( + spec: DatasetSpec, + *, + options: DatasetOptions, + seed: int | None, + dataset_label: str | None = None, +) -> list[dict]: + dataset = _load_dataset(options) + indices = _slice_indices(len(dataset), spec.split, options.split_ratios, seed) + if not indices: + logger.warning("Requested split '%s' produced zero samples", spec.split) + return [] + subset = dataset.select(indices) + samples: list[dict] = [] + label = dataset_label or f"{spec.name}@{spec.split}" + for sample in subset: + record = _build_record(sample, label, options.allowed_call_types) + if record is None: + continue + samples.append(record) + if options.max_examples_per_split and len(samples) >= options.max_examples_per_split: + break + logger.info( + "Loaded %d samples for %s", + len(samples), + label, + ) + return samples + + +def _attach_ids(samples: list[dict]) -> list[dict]: + for idx, sample in enumerate(samples): + sample["id"] = idx + return samples + + +def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None, **loader_kwargs: Any) -> List[Dict]: + resolution = _resolve_dataset_requests(dataset_names) + if not resolution.entries: + if dataset_names: + logger.warning("No coding dataset entries were resolved for %s", dataset_names) + return [] + + options = _normalize_loader_options(loader_kwargs) + aggregated: list[dict] = [] + for entry in resolution.entries: + aggregated.extend( + _load_split(entry.spec, options=options, seed=seed, dataset_label=entry.label), + ) + + if not aggregated: + logger.warning("No coding datasets were loaded for entries %s", resolution.requested) + + return _attach_ids(aggregated) + + +def load_problems(dataset_names: List[str] | str | None, **loader_kwargs: dict) -> List[Dict]: + """Hydra entrypoint that mirrors the math domain loader style.""" + + seed = loader_kwargs.pop("seed", None) + return load_datasets(dataset_names, seed=seed, **loader_kwargs) \ No newline at end of file diff --git a/pipelinerl/domains/coding/rollouts.py b/pipelinerl/domains/coding/rollouts.py new file mode 100644 index 00000000..e30060ef --- /dev/null +++ b/pipelinerl/domains/coding/rollouts.py @@ -0,0 +1,218 @@ +"""Rollout generation for the coding domain using the sandbox verifier.""" + +from __future__ import annotations + +import json +import random +import time +from typing import Any, Literal + +import aiohttp +from omegaconf import DictConfig + +from pipelinerl.llm import Prompt, TrainableLLM +from pipelinerl.async_llm import llm_async_generate, make_training_text +from pipelinerl.rollouts import BaseMetrics, RolloutResult +from pipelinerl.utils import get_environment_jobs, resolve_environment_key +from pipelinerl.domains.math.rollouts import RewardTable, length_penalty + +from .verifier_api import verify_coding_solution_rpc + + +class CodingMetrics(BaseMetrics): + compile_error: bool = False + runtime_error: bool = False + timeout_error: bool = False + passed: int = 0 + total: int = 0 + + +def _format_task(problem: dict[str, Any]) -> str: + if "task" in problem and problem["task"]: + return str(problem["task"]) + if "question" in problem and problem["question"]: + return str(problem["question"]) + extra_info = problem.get("extra_info") or {} + if isinstance(extra_info, dict) and extra_info.get("question"): + return str(extra_info["question"]) + return str(problem) + + +def _determine_answer_status(verification: dict[str, Any]) -> Literal["correct", "wrong", "no_answer", "unparsable"]: + if verification.get("empty_response"): + return "no_answer" + if verification.get("compile_error") or verification.get("timeout_error"): + return "unparsable" + total = int(verification.get("total") or 0) + passed = int(verification.get("passed") or 0) + if total > 0 and passed == total: + return "correct" + return "wrong" + + +def _compute_reward( + cfg: DictConfig, + rewards: RewardTable, + *, + answer_status: Literal["correct", "wrong", "no_answer", "unparsable"], + finished: bool, + output_tokens: int, + max_tokens: int | None, +) -> tuple[float, float]: + match (answer_status, finished): + case ("wrong", False): + reward = rewards.wrong_answer_not_finished + case ("wrong", True): + reward = rewards.wrong_answer_finished + case ("no_answer", False): + reward = rewards.no_answer_not_finished + case ("no_answer", True): + reward = rewards.no_answer_finished + case ("unparsable", False): + reward = rewards.unparsable_not_finished + case ("unparsable", True): + reward = rewards.unparsable_finished + case ("correct", False): + reward = rewards.correct_answer_not_finished + case ("correct", True): + reward = rewards.correct_answer_finished + case _: + raise ValueError(f"Invalid answer_status/finished combination: {answer_status}/{finished}") + + reward *= cfg.actor.discount_factor ** output_tokens + overlong_penalty = 0.0 + if rewards.buffer_tokens and max_tokens is not None: + overlong_penalty = length_penalty(max_tokens, output_tokens, rewards.buffer_tokens) + reward += overlong_penalty + return reward, overlong_penalty + + +async def _run_verification( + cfg: DictConfig, + *, + session: aiohttp.ClientSession, + prediction: str | None, + reward_context: dict[str, Any] | str | None, + extra_info: dict[str, Any] | str | None, +) -> dict[str, Any]: + env_key = resolve_environment_key(cfg, default="coding") + env_jobs = get_environment_jobs(cfg, env_key) + if not env_jobs: + raise RuntimeError("No coding environment servers registered") + env_job = random.choice(env_jobs) + if env_job.hostname is None or env_job.port is None: + raise RuntimeError("Coding environment job is missing host/port information") + return await verify_coding_solution_rpc( + session, + host=env_job.hostname, + port=env_job.port, + prediction=prediction, + reward_context=reward_context, + extra_info=extra_info, + ) + + +def _coerce_dict(data: dict[str, Any] | str | None) -> dict[str, Any]: + if data is None: + return {} + if isinstance(data, dict): + return data + try: + return json.loads(data) + except Exception: + return {} + + +def _get_system_prompt(cfg: DictConfig) -> str: + """Get the system prompt, preferring domain-specific prompt if global is empty.""" + if cfg.actor.system_prompt: + return cfg.actor.system_prompt + # Fall back to domain-specific system prompt + domain_prompts = getattr(cfg.actor, "domain_system_prompts", None) + if domain_prompts: + return domain_prompts.get("coding", "") or "" + return "" + + +async def generate_coding_rollout( + cfg: DictConfig, + llm: TrainableLLM, + problem: dict[str, Any], + session: aiohttp.ClientSession, +) -> RolloutResult: + messages = [] + system_prompt = _get_system_prompt(cfg) + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + user_task = cfg.actor.task_template.format(task=_format_task(problem)) + messages.append({"role": "user", "content": user_task}) + prompt = Prompt(messages=messages) + + start_time = time.time() + llm_call = await llm_async_generate(llm, prompt, session) + latency = time.time() - start_time + assert llm_call.output.content is not None + trace = make_training_text(llm, llm_call) + + reward_context = _coerce_dict(problem.get("reward_context")) + extra_info = _coerce_dict(problem.get("extra_info")) + verification = await _run_verification( + cfg, + session=session, + prediction=llm_call.output.content, + reward_context=reward_context, + extra_info=extra_info, + ) + + rewards_cfg = RewardTable(**dict(cfg.rewards)) + answer_status = _determine_answer_status(verification) + try: + max_tokens = llm.parameters["max_tokens"] + except (KeyError, TypeError): + max_tokens = None + reward, _ = _compute_reward( + cfg, + rewards_cfg, + answer_status=answer_status, + finished=trace.finished, + output_tokens=llm_call.output_length_tokens, + max_tokens=max_tokens, + ) + trace.reward = reward + + coding_metadata = { + "passed": verification.get("passed"), + "total": verification.get("total"), + "compile_error": verification.get("compile_error"), + "runtime_error": verification.get("runtime_error"), + "timeout_error": verification.get("timeout_error"), + "empty_response": verification.get("empty_response"), + "call_type": verification.get("call_type"), + "fn_name": verification.get("fn_name"), + "error": verification.get("error"), + "tests": (verification.get("tests") or [])[:5], + } + trace.metadata.setdefault("coding", {}).update(coding_metadata) + + metrics = CodingMetrics( + reward=reward, + success=(verification.get("passed") == verification.get("total") and verification.get("total", 0) > 0), + no_error=not ( + verification.get("compile_error") + or verification.get("runtime_error") + or verification.get("timeout_error") + ), + no_answer=bool(verification.get("empty_response")), + compile_error=bool(verification.get("compile_error")), + runtime_error=bool(verification.get("runtime_error")), + timeout_error=bool(verification.get("timeout_error")), + passed=int(verification.get("passed") or 0), + total=int(verification.get("total") or 0), + ) + + return RolloutResult( + training_texts=[trace], + metrics=metrics, + latency=latency, + dataset_name=problem.get("dataset"), + ) diff --git a/pipelinerl/domains/coding/test_sandbox.py b/pipelinerl/domains/coding/test_sandbox.py new file mode 100644 index 00000000..2b7ef1d8 --- /dev/null +++ b/pipelinerl/domains/coding/test_sandbox.py @@ -0,0 +1,158 @@ +import requests +import json + +fn_name = "add_one" +generation = """ +def add_one(x): + return x + 1 + +if __name__ == '__main__': + # Get input from stdin + import sys + input_value = int(sys.stdin.read().strip()) + result = add_one(input_value) + print(result) +""" + +fn_name = "add_one" +generation = """ +def add_one(x): + return x + 1 + +assert add_one(1) == 2 +assert add_one(5) == 8 +assert add_one(-3) == -2 +""" + +wrapper_code = f""" +import traceback +from string import * +from re import * +from datetime import * +from collections import * +from heapq import * +from bisect import * +from copy import * +from math import * +from random import * +from statistics import * +from itertools import * +from functools import * +from operator import * +from io import * +from sys import * +from json import * +from builtins import * +from typing import * +import string +import re +import datetime +import collections +import heapq +import bisect +import copy +import math +import random +import statistics +import itertools +import functools +import operator +import io +import sys +import json + +# === User's Original Code START === +{generation} +# === User's Original Code END === + +_SANDBOX_FN_NAME = "{fn_name}" + +def _execute_user_function(): + # --- Input Parsing --- + _raw_input_str = sys.stdin.read() + _args = [] + if _raw_input_str.strip(): # If there's input + try: + _args = [json.loads(line) for line in _raw_input_str.split('\\n')] + except json.JSONDecodeError as _je: + sys.stderr.write(f"WrapperError: Invalid JSON input for '{{_SANDBOX_FN_NAME}}': {{_je}}\\nInput was: " + f"{{_raw_input_str[:200]}}\\n") + return None, True # result, error_occurred + + # --- Function Location and Execution --- + try: + _target_callable = None + # Try global scope first + if _SANDBOX_FN_NAME in globals(): + _target_callable = globals()[_SANDBOX_FN_NAME] + # Else, if 'Solution' class exists, try to get its method + elif 'Solution' in globals(): + _Solution_class = globals()['Solution'] + # Attempt to instantiate and get method. + # Errors (e.g., Solution not a class, instantiation fails, method missing) + # will be caught by the broad except block below. + _solution_instance = _Solution_class() + _target_callable = getattr(_solution_instance, _SANDBOX_FN_NAME) + + if not _target_callable: + sys.stderr.write(f"WrapperError: Function or method '{{_SANDBOX_FN_NAME}}' not found.\\n") + return None, True # result, error_occurred + + _fn_result = _target_callable(*_args) + return _fn_result, False # result, no_error + except Exception: # Catches errors from Solution instantiation, getattr, or function call + sys.stderr.write(f"Error during setup or execution of '{{_SANDBOX_FN_NAME}}':\\n{{traceback.format_exc()}}\\n") + return None, True # result, error_occurred + +if __name__ == '__main__': + _result, _error_occurred = _execute_user_function() + + if not _error_occurred: + # Serialize result to stdout + if isinstance(_result, (dict, list, tuple)) or _result is None or isinstance(_result, bool): + print(json.dumps(_result)) + elif isinstance(_result, (int, float, str)): + print(str(_result)) # Ensure string conversion for print + else: + # For other types, default to string representation. + print(str(_result)) + # Optional: To explicitly exit with an error code if the sandbox relies on it + # else: + # sys.exit(1) +""" +current_generation_code = wrapper_code + +# current_generation_code = generation +compile_timeout = 10 +run_timeout = 10 +request_timeout = 10 +code = current_generation_code +stdin = "15" +memory_limit_mb=1024 +sandbox_fusion_url="http://dns-4943305a-c17f-44b1-b767-9536529eb8bc-sandbox:8080/run_code" +language="python" + + +payload = json.dumps( + { + "compile_timeout": compile_timeout, + "run_timeout": run_timeout, + "code": code, + "stdin": stdin, + "memory_limit_MB": memory_limit_mb, + "language": language, # Use the passed language parameter + "files": {}, + "fetch_files": [], + } +) + +headers = {"Content-Type": "application/json", "Accept": "application/json"} + +response = requests.post( + sandbox_fusion_url, + headers=headers, + data=payload, + timeout=request_timeout, # Use the calculated timeout +) + +print(response.json()) \ No newline at end of file diff --git a/pipelinerl/domains/coding/verifier_api.py b/pipelinerl/domains/coding/verifier_api.py new file mode 100644 index 00000000..490b0a2e --- /dev/null +++ b/pipelinerl/domains/coding/verifier_api.py @@ -0,0 +1,452 @@ +"""Sandbox-backed verification utilities for the coding domain.""" + +from __future__ import annotations + +import asyncio +import json +import logging +import math +import os +import re +import time +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from typing import Any + +import requests +import aiohttp +import uvicorn +from fastapi import FastAPI +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + +_CODE_FENCE_RE = re.compile(r"```(?:python|py)?\s*([\s\S]*?)```", re.IGNORECASE) +_DEFAULT_SANDBOX_URL = os.environ.get("CODING_SANDBOX_URL", "http://sandbox:8080/run_code") + + +class CodingVerificationRequest(BaseModel): + """Payload accepted by the coding verifier environment.""" + + prediction: str | None = None + reward_context: dict[str, Any] | str | None = Field(default_factory=dict) + extra_info: dict[str, Any] | str | None = None + + +@dataclass +class CodingTestResult: + index: int + kind: str + status: str + input: str | None = None + expected: str | None = None + assertion: str | None = None + stdout: str = "" + stderr: str = "" + elapsed: float | None = None + + def to_dict(self) -> dict[str, Any]: + return { + "index": self.index, + "kind": self.kind, + "status": self.status, + "input": self.input, + "expected": self.expected, + "assertion": self.assertion, + "stdout": self.stdout, + "stderr": self.stderr, + "elapsed": self.elapsed, + } + + +@dataclass +class CodingVerificationSummary: + passed: int = 0 + total: int = 0 + compile_error: bool = False + runtime_error: bool = False + timeout_error: bool = False + empty_response: bool = False + error: str | None = None + call_type: str | None = None + fn_name: str | None = None + tests: list[CodingTestResult] = field(default_factory=list) + + def to_payload(self) -> dict[str, Any]: + return { + "passed": self.passed, + "total": self.total, + "compile_error": self.compile_error, + "runtime_error": self.runtime_error, + "timeout_error": self.timeout_error, + "empty_response": self.empty_response, + "error": self.error, + "call_type": self.call_type, + "fn_name": self.fn_name, + "tests": [test.to_dict() for test in self.tests], + } + + +def _extract_code(prediction: str | None) -> str: + if not prediction: + return "" + match = _CODE_FENCE_RE.findall(prediction) + if match: + return match[-1].strip() + return prediction.strip() + + +def _ensure_dict(data: dict[str, Any] | str | None) -> dict[str, Any]: + if data is None: + return {} + if isinstance(data, dict): + return data + try: + return json.loads(data) + except Exception as exc: # pragma: no cover - defensive + logger.warning("Failed to parse reward context: %s", exc) + return {} + + +def _normalize_output(text: str | None) -> str: + if text is None: + return "" + return text.strip() + + +def _compose_script(user_code: str, extra_snippet: str | None) -> str: + if not extra_snippet: + return user_code + return f"{user_code.rstrip()}\n\n{extra_snippet.strip()}\n" + + +def _convert_bytes_to_mb(value: int) -> int: + if value <= 0: + return 0 + return max(16, int(math.ceil(value / (1024 * 1024)))) + + +def _has_compile_error(response: dict[str, Any]) -> bool: + compile_result = response.get("compile_result") or {} + if not compile_result: + return False + stderr = str(compile_result.get("stderr", "")) + return_code = compile_result.get("return_code") + if return_code is not None and return_code != 0: + return True + return bool(stderr.strip()) + + +def _is_timeout(response: dict[str, Any]) -> bool: + run_result = response.get("run_result") or {} + status = (run_result.get("status") or response.get("status") or "").lower() + return status == "timeout" + + +def _has_runtime_error(response: dict[str, Any]) -> bool: + if _has_compile_error(response) or _is_timeout(response): + return False + run_result = response.get("run_result") or {} + return_code = run_result.get("return_code") + if return_code is None: + return False + if return_code != 0: + return True + stderr = str(run_result.get("stderr", "")) + return "Traceback" in stderr + + +def _get_run_field(response: dict[str, Any], field: str) -> str: + run_result = response.get("run_result") or {} + value = run_result.get(field) + if value is None: + return "" + return str(value) + + +def _get_status_text(response: dict[str, Any]) -> str: + run_result = response.get("run_result") or {} + status = run_result.get("status") + if status: + return str(status) + overall = response.get("status") + return str(overall or "") + + +def _post_to_sandbox( + *, + code: str, + stdin: str, + sandbox_url: str, + compile_timeout_s: float, + run_timeout_s: float, + request_timeout_s: float, + memory_limit_mb: int, + language: str, +) -> tuple[dict[str, Any], float]: + payload = { + "compile_timeout": compile_timeout_s, + "run_timeout": run_timeout_s, + "code": code, + "stdin": stdin, + "memory_limit_MB": memory_limit_mb, + "language": language, + "files": {}, + "fetch_files": [], + } + headers = {"Content-Type": "application/json", "Accept": "application/json"} + start = time.perf_counter() + try: + response = requests.post( + sandbox_url, + headers=headers, + data=json.dumps(payload), + timeout=request_timeout_s, + ) + response.raise_for_status() + data = response.json() + except requests.Timeout as exc: + data = { + "status": "Timeout", + "message": str(exc), + "compile_result": None, + "run_result": {"status": "Timeout", "stdout": "", "stderr": str(exc)}, + } + except requests.RequestException as exc: # pragma: no cover - network failures + data = { + "status": "NetworkError", + "message": str(exc), + "compile_result": None, + "run_result": {"status": "Error", "stdout": "", "stderr": str(exc)}, + } + elapsed = time.perf_counter() - start + return data, elapsed + + +def evaluate_coding_prediction( + prediction: str | None, + reward_context: dict[str, Any] | str | None, + *, + extra_info: dict[str, Any] | str | None = None, + sandbox_url: str | None = None, + compile_timeout_s: float = 5.0, + run_timeout_s: float = 5.0, + request_timeout_s: float = 15.0, + memory_limit_mb: int = 512, + language: str = "python", +) -> CodingVerificationSummary: + """Run generated code inside the sandbox and collect pass/fail statistics.""" + + sandbox_target = sandbox_url or _DEFAULT_SANDBOX_URL + context = _ensure_dict(reward_context) + summary = CodingVerificationSummary( + call_type=context.get("call_type"), + fn_name=context.get("fn_name"), + ) + + candidate_code = _extract_code(prediction) + if not candidate_code: + summary.empty_response = True + summary.error = "empty_prediction" + return summary + + call_type = (context.get("call_type") or "assert").lower().strip() + tests: list[dict[str, Any]] = [] + if call_type == "assert": + for idx, assertion in enumerate(context.get("assert_case", []) or []): + if not assertion: + continue + tests.append({"kind": "assert", "assertion": assertion, "index": idx}) + elif call_type == "std": + inputs = context.get("inputs") or [] + outputs = context.get("outputs") or [] + total = min(len(inputs), len(outputs)) + for idx in range(total): + tests.append( + { + "kind": "std", + "input": inputs[idx], + "expected": outputs[idx], + "index": idx, + } + ) + else: + summary.error = f"unsupported_call_type:{call_type}" + return summary + + if not tests: + summary.error = "no_tests" + return summary + + for raw_test in tests: + idx = summary.total + summary.total += 1 + if raw_test["kind"] == "assert": + script = _compose_script(candidate_code, raw_test["assertion"]) + stdin = "" + else: + script = candidate_code + stdin = raw_test.get("input", "") or "" + + response, elapsed = _post_to_sandbox( + code=script, + stdin=stdin, + sandbox_url=sandbox_target, + compile_timeout_s=compile_timeout_s, + run_timeout_s=run_timeout_s, + request_timeout_s=request_timeout_s, + memory_limit_mb=memory_limit_mb, + language=language, + ) + + if _has_compile_error(response): + status = "compile_error" + summary.compile_error = True + summary.error = "compile_error" + elif _is_timeout(response): + status = "timeout" + summary.timeout_error = True + summary.error = "timeout" + elif _has_runtime_error(response): + status = "runtime_error" + summary.runtime_error = True + summary.error = "runtime_error" + else: + run_status_text = _get_status_text(response).lower() + if raw_test["kind"] == "std": + produced = _normalize_output(_get_run_field(response, "stdout")) + expected = _normalize_output(raw_test.get("expected")) + status = "passed" if produced == expected else "failed" + else: + succeeded = run_status_text in ("finished", "success", "") + status = "passed" if succeeded else "failed" + if not succeeded: + summary.runtime_error = True + summary.error = summary.error or f"run_status:{run_status_text or 'unknown'}" + + stdout = _get_run_field(response, "stdout") + stderr = _get_run_field(response, "stderr") + test_result = CodingTestResult( + index=idx, + kind=raw_test["kind"], + status=status, + input=raw_test.get("input"), + expected=raw_test.get("expected"), + assertion=raw_test.get("assertion"), + stdout=stdout, + stderr=stderr, + elapsed=elapsed, + ) + summary.tests.append(test_result) + if status == "passed": + summary.passed += 1 + if status == "compile_error": + break + + return summary + + +def _rpc_failure_summary(reason: str, *, status: int | None = None, body: str | None = None) -> dict[str, Any]: + details = reason + if status is not None: + details = f"{reason}:{status}" + if body: + details = f"{details}:{body[:256]}" + return { + "passed": 0, + "total": 0, + "compile_error": False, + "runtime_error": True, + "timeout_error": False, + "empty_response": False, + "error": f"verifier_rpc_error:{details}", + "call_type": None, + "fn_name": None, + "tests": [], + } + + +async def verify_coding_solution_rpc( + session: aiohttp.ClientSession, + host: str, + port: int, + *, + prediction: str | None, + reward_context: dict[str, Any] | str | None, + extra_info: dict[str, Any] | str | None, +) -> dict[str, Any]: + """Call a remote coding verifier via HTTP RPC.""" + + payload = { + "prediction": prediction, + "reward_context": reward_context, + "extra_info": extra_info, + } + url = f"http://{host}:{port}/verify_solution" + try: + async with session.post(url, json=payload) as response: + if response.status != 200: + text = await response.text() + logger.warning( + "Coding verifier RPC failed with %s: %s", response.status, text[:256] + ) + return _rpc_failure_summary("http_status", status=response.status, body=text) + return await response.json() + except (aiohttp.ClientError, asyncio.TimeoutError) as exc: + logger.warning("Coding verifier RPC request error: %s", exc) + return _rpc_failure_summary("client_error", body=str(exc)) + + +class CodingSandboxEnvironment: + """Environment server that proxies requests to the shared sandbox executor.""" + + def __init__( + self, + *, + sandbox_url: str | None = None, + compile_timeout_s: float = 5.0, + run_timeout_s: float = 5.0, + request_timeout_s: float = 15.0, + memory_limit_bytes: int = 512 * 1024 * 1024, + language: str = "python", + max_workers: int = 4, + ) -> None: + self.sandbox_url = sandbox_url or _DEFAULT_SANDBOX_URL + self.compile_timeout_s = compile_timeout_s + self.run_timeout_s = run_timeout_s + self.request_timeout_s = request_timeout_s + self.memory_limit_mb = _convert_bytes_to_mb(memory_limit_bytes) + self.language = language + self.max_workers = max_workers + + def launch(self, port: int) -> None: + app = FastAPI() + executor = ThreadPoolExecutor(max_workers=self.max_workers) + + @app.post("/verify_solution") + async def verify_endpoint(request: CodingVerificationRequest): + loop = asyncio.get_running_loop() + + def _evaluate() -> dict[str, Any]: + summary = evaluate_coding_prediction( + prediction=request.prediction, + reward_context=request.reward_context, + extra_info=request.extra_info, + sandbox_url=self.sandbox_url, + compile_timeout_s=self.compile_timeout_s, + run_timeout_s=self.run_timeout_s, + request_timeout_s=self.request_timeout_s, + memory_limit_mb=self.memory_limit_mb, + language=self.language, + ) + return summary.to_payload() + + result = await loop.run_in_executor(executor, _evaluate) + return JSONResponse(content=result) + + @app.get("/health") + async def health() -> dict[str, str]: + return {"status": "ok"} + + uvicorn.run(app, host="0.0.0.0", port=port, timeout_keep_alive=60) \ No newline at end of file diff --git a/pipelinerl/domains/counting/counting.py b/pipelinerl/domains/counting/counting.py index c3afb45f..b05d1746 100644 --- a/pipelinerl/domains/counting/counting.py +++ b/pipelinerl/domains/counting/counting.py @@ -11,6 +11,9 @@ from pipelinerl.rollouts import BaseMetrics, RolloutResult +DOMAIN = "counting" + + async def generate_counting_rollout( cfg: DictConfig, llm: TrainableLLM, @@ -81,6 +84,7 @@ def load_problems(dataset_names: list[str]): if not isinstance(problem, dict) or "letter" not in problem or "word" not in problem or "count" not in problem: raise ValueError(f"Problem {problem} in dataset {name} is invalid.") problem["dataset"] = name + problem.setdefault("domain", DOMAIN) problems.append(problem) return problems diff --git a/pipelinerl/domains/dispatcher.py b/pipelinerl/domains/dispatcher.py new file mode 100644 index 00000000..c226247c --- /dev/null +++ b/pipelinerl/domains/dispatcher.py @@ -0,0 +1,103 @@ +import importlib +import inspect +from functools import lru_cache +from typing import Awaitable, Callable, Iterable, Mapping + +import aiohttp +from omegaconf import DictConfig + +from pipelinerl.llm import TrainableLLM +from pipelinerl.rollouts import RolloutResult +from pipelinerl.utils import resolve_environment_key + +RolloutCallable = Callable[[DictConfig, TrainableLLM, dict, aiohttp.ClientSession], Awaitable[RolloutResult] | RolloutResult] + +_RUNTIME_DOMAIN_ROLLOUTS: dict[str, str] = {} + + +def _iter_domain_overrides(raw_overrides: Mapping | Iterable | None): + if raw_overrides is None: + return + + if isinstance(raw_overrides, Mapping): + for domain, path in raw_overrides.items(): + yield str(domain), str(path) + return + + # Support Hydra list syntax: [{domain: math, rollout: path}, ...] or ["math:path"] + for item in raw_overrides: + if isinstance(item, Mapping): + if "domain" in item and "rollout" in item: + yield str(item["domain"]), str(item["rollout"]) + else: + for domain, path in item.items(): + yield str(domain), str(path) + else: + text = str(item) + if ":" in text: + domain, path = text.split(":", 1) + yield domain.strip(), path.strip() + + +def _build_domain_mapping(cfg: DictConfig) -> dict[str, str]: + overrides = getattr(cfg.actor, "domain_rollouts", None) + mapping: dict[str, str] = {} + if overrides: + for domain, path in _iter_domain_overrides(overrides): + mapping[domain] = path + + if _RUNTIME_DOMAIN_ROLLOUTS: + mapping.update(_RUNTIME_DOMAIN_ROLLOUTS) + + if not mapping: + raise ValueError("`actor.domain_rollouts` produced an empty mapping and no runtime registrations were found") + return mapping + + +@lru_cache(maxsize=None) +def _import_callable(path: str) -> RolloutCallable: + module_name, func_name = path.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, func_name) + + +def _get_rollout_callable(cfg: DictConfig, domain: str) -> RolloutCallable: + mapping = _build_domain_mapping(cfg) + if domain not in mapping: + raise ValueError(f"No rollout policy registered for domain '{domain}'") + return _import_callable(mapping[domain]) + + +async def generate_multidomain_rollout( + cfg: DictConfig, + llm: TrainableLLM, + problem: dict, + session: aiohttp.ClientSession, +) -> RolloutResult: + domain = problem.get("domain") + if not domain: + domain = resolve_environment_key(cfg) + if not domain: + raise ValueError("Problem is missing 'domain' and no default could be resolved from config") + + rollout_fn = _get_rollout_callable(cfg, domain) + result = rollout_fn(cfg, llm, problem, session) + if inspect.isawaitable(result): + result = await result # type: ignore[assignment] + + # Ensure domain is set on the result for completion tracking + if result.domain is None: + result.domain = domain + + return result # type: ignore[return-value] + + +def register_domain_rollout(domain: str, target: str) -> None: + domain_key = str(domain).strip() + target_path = str(target).strip() + if not domain_key: + raise ValueError("Domain key for registration cannot be empty") + if not target_path or "." not in target_path: + raise ValueError(f"Target '{target}' must be a fully-qualified callable path") + _RUNTIME_DOMAIN_ROLLOUTS[domain_key] = target_path + _import_callable.cache_clear() diff --git a/pipelinerl/domains/fn_calling/__init__.py b/pipelinerl/domains/fn_calling/__init__.py new file mode 100644 index 00000000..1a241972 --- /dev/null +++ b/pipelinerl/domains/fn_calling/__init__.py @@ -0,0 +1,18 @@ +"""Fn-calling domain toolkit.""" + +from .dataset import load_datasets, load_problems +from .rollouts import generate_fn_calling_rollout +from .verifier_api import ( + AgenticToolsEnvironment, + evaluate_fn_calling_answer, + verify_fn_calling_answer_rpc, +) + +__all__ = [ + "AgenticToolsEnvironment", + "evaluate_fn_calling_answer", + "generate_fn_calling_rollout", + "load_datasets", + "load_problems", + "verify_fn_calling_answer_rpc", +] diff --git a/pipelinerl/domains/fn_calling/dataset.py b/pipelinerl/domains/fn_calling/dataset.py new file mode 100644 index 00000000..02cd31d0 --- /dev/null +++ b/pipelinerl/domains/fn_calling/dataset.py @@ -0,0 +1,88 @@ +"""Dataset loader for the fn_calling domain.""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any, Dict, List + +from pipelinerl.domains.coding import dataset as mixed_dataset + +logger = logging.getLogger(__name__) + +DOMAIN_NAME = "fn_calling" +ALIAS_NAMES = frozenset({DOMAIN_NAME, "fn-calling", "fncalling"}) +BASE_DATASET_NAME = "mixed-training-text-datasets" +DEFAULT_DATASET_CONFIG = mixed_dataset.DEFAULT_DATASET_CONFIG +ABILITY_FILTER = "agentic_fn_calling" + + +@dataclass(frozen=True) +class DatasetResolution: + requested: list[str] + resolved: list[str] + alias_map: dict[str, str] + + +def _resolve_dataset_requests(dataset_names: List[str] | str | None) -> DatasetResolution: + if dataset_names is None: + requested = [DOMAIN_NAME] + elif isinstance(dataset_names, str): + requested = [dataset_names] + else: + requested = [str(entry) for entry in dataset_names] + + resolved: list[str] = [] + alias_map: dict[str, str] = {} + for entry in requested: + spec = mixed_dataset.parse_dataset_name(entry) + base_name = BASE_DATASET_NAME if spec.name in ALIAS_NAMES else spec.name + resolved_entry = f"{base_name}@{spec.split}" + resolved.append(resolved_entry) + alias_map.setdefault(resolved_entry, f"{spec.name}@{spec.split}") + return DatasetResolution(requested=requested, resolved=resolved, alias_map=alias_map) + + +def load_datasets( + dataset_names: List[str] | str | None, + seed: int | None = None, + **loader_kwargs: Any, +) -> List[Dict]: + resolution = _resolve_dataset_requests(dataset_names) + defaults = { + "dataset_id": mixed_dataset.DEFAULT_DATASET_ID, + "dataset_config": DEFAULT_DATASET_CONFIG, + # fn_calling loads all call_types (no filtering), unlike coding which filters to assert/std + "allowed_call_types": (), + "ability_filter": ABILITY_FILTER, + } + options = {**defaults, **loader_kwargs} + + # Validate that allowed_call_types is empty for fn_calling domain + final_call_types = options.get("allowed_call_types") + if final_call_types: + logger.warning( + "fn_calling domain received non-empty allowed_call_types=%s. " + "This may filter out valid fn_calling samples. " + "Consider using allowed_call_types=[] for this domain.", + final_call_types, + ) + + samples = mixed_dataset.load_datasets(resolution.resolved, seed=seed, **options) + + for sample in samples: + dataset_label = str(sample.get("dataset", "")) + alias_label = resolution.alias_map.get(dataset_label) + if alias_label: + sample["dataset"] = alias_label + elif dataset_label.startswith(BASE_DATASET_NAME): + sample["dataset"] = dataset_label.replace(BASE_DATASET_NAME, DOMAIN_NAME, 1) + sample["domain"] = DOMAIN_NAME + if not samples: + logger.warning("fn_calling loader returned zero samples for entries: %s", resolution.requested) + return samples + + +def load_problems(dataset_names: List[str] | str | None, **loader_kwargs: Any) -> List[Dict]: + seed = loader_kwargs.pop("seed", None) + return load_datasets(dataset_names, seed=seed, **loader_kwargs) \ No newline at end of file diff --git a/pipelinerl/domains/fn_calling/rollouts.py b/pipelinerl/domains/fn_calling/rollouts.py new file mode 100644 index 00000000..81f42c16 --- /dev/null +++ b/pipelinerl/domains/fn_calling/rollouts.py @@ -0,0 +1,199 @@ +"""Rollout generation for the fn_calling domain.""" + +from __future__ import annotations + +import json +import random +import time +from typing import Any + +import aiohttp +from omegaconf import DictConfig + +from pipelinerl.llm import Prompt, TrainableLLM +from pipelinerl.async_llm import llm_async_generate, make_training_text +from pipelinerl.domains.math.rollouts import RewardTable, length_penalty +from pipelinerl.domains.fn_calling.verifier_api import verify_fn_calling_answer_rpc +from pipelinerl.rollouts import BaseMetrics, RolloutResult +from pipelinerl.utils import get_environment_jobs, resolve_environment_key + + +class FnCallingMetrics(BaseMetrics): + penalty: float + + +def _format_task(problem: dict[str, Any]) -> str: + if problem.get("task"): + return str(problem["task"]) + if problem.get("question"): + return str(problem["question"]) + return str(problem) + + +def _normalize_messages(raw_messages: Any) -> list[dict[str, str]]: + """Convert dataset prompts into chat messages compatible with vLLM.""" + + if not isinstance(raw_messages, list): + return [] + if len(raw_messages) == 1 and isinstance(raw_messages[0], list): + raw_messages = raw_messages[0] + + messages: list[dict[str, str]] = [] + pending_assistant: list[str] = [] + + def flush_pending() -> None: + nonlocal pending_assistant + if not pending_assistant: + return + content = "\n\n".join(text for text in pending_assistant if text.strip()) + if content: + messages.append({"role": "assistant", "content": content}) + pending_assistant = [] + + for entry in raw_messages: + if not isinstance(entry, dict): + continue + role = entry.get("role") + content = entry.get("content") + if not isinstance(content, str): + continue + + match role: + case "system": + flush_pending() + messages.append({"role": "system", "content": content}) + case "user": + flush_pending() + messages.append({"role": "user", "content": content}) + case "assistant" | "output_text": + pending_assistant.append(content) + case "thinking" | "plan" | "replan": + # Drop intermediate reasoning to keep prompts shorter. + continue + case "tool": + flush_pending() + name = entry.get("name") + tool_id = entry.get("tool_call_id") or entry.get("id") + if not isinstance(name, str) or not name: + continue + messages.append( + { + "role": "tool", + "name": name, + "content": content, + **({"tool_call_id": tool_id} if isinstance(tool_id, str) else {}), + } + ) + case _: + # For function-call style entries nested under assistant. + if role == "function" and isinstance(entry.get("name"), str): + pending_assistant.append(content) + continue + + flush_pending() + return messages + + +def _coerce_dict(data: dict[str, Any] | str | None) -> dict[str, Any]: + if data is None: + return {} + if isinstance(data, dict): + return data + if isinstance(data, str) and data.strip(): + try: + return json.loads(data) + except json.JSONDecodeError: + return {} + return {} + + +async def generate_fn_calling_rollout( + cfg: DictConfig, + llm: TrainableLLM, + problem: dict[str, Any], + session: aiohttp.ClientSession, +) -> RolloutResult: + provided_prompt = problem.get("prompt") + messages = _normalize_messages(provided_prompt) + if messages: + if cfg.actor.system_prompt and messages[0]["role"] != "system": + messages.insert(0, {"role": "system", "content": cfg.actor.system_prompt}) + else: + messages = [] + if cfg.actor.system_prompt: + messages.append({"role": "system", "content": cfg.actor.system_prompt}) + messages.append({"role": "user", "content": cfg.actor.task_template.format(task=_format_task(problem))}) + prompt = Prompt(messages=messages) + + start_time = time.time() + llm_call = await llm_async_generate(llm, prompt, session) + latency = time.time() - start_time + assert llm_call.output.content is not None + + env_key = resolve_environment_key(cfg, default="fn_calling") + env_jobs = get_environment_jobs(cfg, env_key) + if not env_jobs: + raise RuntimeError("No environment servers available for fn_calling domain") + env_job = random.choice(env_jobs) + if env_job.hostname is None or env_job.port is None: + raise RuntimeError("fn_calling environment job is missing host/port information") + + reward_context = _coerce_dict(problem.get("reward_context")) + extra_info = _coerce_dict(problem.get("extra_info")) + answer_status = await verify_fn_calling_answer_rpc( + session=session, + host=env_job.hostname, + port=env_job.port, + generation=llm_call.output.content, + reward_context=reward_context, + extra_info=extra_info, + ) + + rewards = RewardTable(**dict(cfg.rewards)) + trace = make_training_text(llm, llm_call) + match (answer_status, trace.finished): + case ("wrong", False): + reward = rewards.wrong_answer_not_finished + case ("wrong", True): + reward = rewards.wrong_answer_finished + case ("no_answer", False): + reward = rewards.no_answer_not_finished + case ("no_answer", True): + reward = rewards.no_answer_finished + case ("unparsable", False): + reward = rewards.unparsable_not_finished + case ("unparsable", True): + reward = rewards.unparsable_finished + case ("correct", False): + reward = rewards.correct_answer_not_finished + case ("correct", True): + reward = rewards.correct_answer_finished + case _: + raise ValueError(f"Unexpected fn_calling answer status '{answer_status}'") + + reward *= cfg.actor.discount_factor ** llm_call.output_length_tokens + overlong_penalty = 0.0 + try: + max_tokens = llm.parameters["max_tokens"] + except (KeyError, TypeError): + max_tokens = None + if rewards.buffer_tokens > 0 and max_tokens is not None: + overlong_penalty = length_penalty(max_tokens, llm_call.output_length_tokens, rewards.buffer_tokens) + reward += overlong_penalty + trace.reward = reward + trace.metadata.setdefault("fn_calling", {}).update({"answer_status": answer_status}) + + metrics = FnCallingMetrics( + reward=reward, + success=answer_status == "correct", + no_error=answer_status != "unparsable", + no_answer=answer_status == "no_answer", + penalty=overlong_penalty, + ) + + return RolloutResult( + training_texts=[trace], + metrics=metrics, + latency=latency, + dataset_name=problem.get("dataset"), + ) diff --git a/pipelinerl/domains/fn_calling/verifier_api.py b/pipelinerl/domains/fn_calling/verifier_api.py new file mode 100644 index 00000000..35b5361b --- /dev/null +++ b/pipelinerl/domains/fn_calling/verifier_api.py @@ -0,0 +1,252 @@ +"""Verifier utilities for the fn_calling domain.""" + +from __future__ import annotations + +import asyncio +import importlib +import json +import logging +import os +import re +from collections import Counter +from concurrent.futures import ProcessPoolExecutor +from functools import lru_cache +from typing import Any, Callable, Iterable, Literal + +import aiohttp +import uvicorn +from fastapi import FastAPI, HTTPException +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field + +LOGGER = logging.getLogger(__name__) + +AnswerStatus = Literal["correct", "wrong", "no_answer", "unparsable"] +_VALID_STATUSES: set[str] = {"correct", "wrong", "no_answer", "unparsable"} +_DEFAULT_REWARD_FN_ENV = "FN_CALLING_REWARD_FN" +_TOOL_BLOCK = re.compile(r"(.*?)", re.DOTALL | re.IGNORECASE) + + +def _json_loads(value: str) -> Any: + return json.loads(value) + + +def _normalize_args(value: Any) -> Any: + if value is None or value == "": + return {} + if isinstance(value, (dict, list, bool, int, float)): + return value + if isinstance(value, str): + return _json_loads(value) + raise ValueError(f"Unsupported argument payload: {type(value)}") + + +def _canonicalize(entries: Iterable[Any]) -> list[tuple[str, str]]: + canon: list[tuple[str, str]] = [] + for entry in entries: + if not isinstance(entry, dict): + raise ValueError("Tool call entry must be a mapping") + fn_block = entry.get("function") if isinstance(entry.get("function"), dict) else None + name = entry.get("name") or (fn_block or {}).get("name") + if not isinstance(name, str) or not name.strip(): + raise ValueError("Tool call missing name") + args = entry.get("arguments") + if args is None and fn_block: + args = fn_block.get("arguments") + canon_args = json.dumps(_normalize_args(args), sort_keys=True, ensure_ascii=False) + canon.append((name.strip(), canon_args)) + return canon + + +def _parse_prediction(generation: str | None) -> list[tuple[str, str]]: + if not generation: + return [] + block = _TOOL_BLOCK.findall(generation) + if not block: + return [] + payload = block[-1].strip() + if not payload: + return [] + data = _json_loads(payload) + if not isinstance(data, list): + raise ValueError("tool_calls block must be a list") + return _canonicalize(data) + + +def _parse_label(tool_calls: Iterable[Any] | None) -> list[tuple[str, str]]: + if not tool_calls: + return [] + return _canonicalize(tool_calls) + + +def _has_irrelevant_flag(reward_context: dict[str, Any], extra_info: dict[str, Any]) -> bool: + for source in (reward_context, extra_info): + flag = source.get("irrelevant_tool_call") + if isinstance(flag, bool): + return flag + if isinstance(flag, str) and flag.lower() in {"1", "true", "yes"}: + return True + return "irrelevant_tool_call" in json.dumps(reward_context, ensure_ascii=False) + + +def evaluate_fn_calling_answer( + *, + generation: str, + reward_context: dict[str, Any] | None, + extra_info: dict[str, Any] | None = None, +) -> AnswerStatus: + reward_context = reward_context or {} + extra_info = extra_info or {} + + try: + expected = _parse_label(reward_context.get("tool_calls")) + predicted = _parse_prediction(generation) + except ValueError as exc: + LOGGER.debug("Failed to parse fn_calling sample: %s", exc) + return "unparsable" + + if _has_irrelevant_flag(reward_context, extra_info): + return "wrong" if predicted else "correct" + + expected_counter = Counter(expected) + predicted_counter = Counter(predicted) + + if not predicted_counter and not expected_counter: + return "correct" + if not predicted_counter and expected_counter: + return "no_answer" + if predicted_counter and not expected_counter: + return "wrong" + return "correct" if predicted_counter == expected_counter else "wrong" + + +def _ensure_mapping(data: dict[str, Any] | str | None) -> dict[str, Any]: + if data is None: + return {} + if isinstance(data, dict): + return data + if isinstance(data, str) and data.strip(): + try: + return json.loads(data) + except json.JSONDecodeError: + LOGGER.debug("Failed to decode fn_calling payload: %s", data[:128]) + return {} + return {} + + +class FnCallingVerificationRequest(BaseModel): + """Payload accepted by the fn_calling verifier service.""" + + generation: str + reward_context: dict[str, Any] = Field(default_factory=dict) + extra_info: dict[str, Any] = Field(default_factory=dict) + + +@lru_cache(maxsize=None) +def _import_callable(path: str) -> Callable[..., Any]: + module_name, attr_name = path.rsplit(".", 1) + module = importlib.import_module(module_name) + fn = getattr(module, attr_name) + if not callable(fn): # pragma: no cover - defensive guard + raise TypeError(f"Object at '{path}' is not callable") + return fn + + +def _invoke_reward_fn( + reward_fn_path: str, + generation: str, + reward_context: dict[str, Any], + extra_info: dict[str, Any], +) -> AnswerStatus: + fn = _import_callable(reward_fn_path) + try: + result = fn(generation=generation, reward_context=reward_context, extra_info=extra_info) + except TypeError: + result = fn(generation, reward_context) + status = str(result).strip().lower() + if status not in _VALID_STATUSES: + raise ValueError(f"Reward function returned invalid status '{result}'") + return status # type: ignore[return-value] + + +def _execute_reward_job( + reward_fn_path: str | None, + generation: str, + reward_context: dict[str, Any], + extra_info: dict[str, Any], +) -> AnswerStatus: + if reward_fn_path: + return _invoke_reward_fn(reward_fn_path, generation, reward_context, extra_info) + return evaluate_fn_calling_answer( + generation=generation, + reward_context=reward_context, + extra_info=extra_info, + ) + + +async def verify_fn_calling_answer_rpc( + *, + session: aiohttp.ClientSession, + host: str, + port: int, + generation: str, + reward_context: dict[str, Any] | str | None, + extra_info: dict[str, Any] | str | None = None, +) -> AnswerStatus: + payload = { + "generation": generation, + "reward_context": _ensure_mapping(reward_context), + "extra_info": _ensure_mapping(extra_info), + } + async with session.post(f"http://{host}:{port}/verify_answer", json=payload) as response: + body = await response.text() + if response.status != 200: + LOGGER.error("fn_calling verifier returned %s: %s", response.status, body[:512]) + raise ValueError("fn_calling verifier request failed") + data = json.loads(body) + status = str(data.get("answer_status", "")).strip().lower() + if status not in _VALID_STATUSES: + raise ValueError(f"fn_calling verifier produced invalid status '{status}'") + return status # type: ignore[return-value] + + +class AgenticToolsEnvironment: + """FastAPI wrapper that exposes a deterministic fn_calling verifier.""" + + def __init__( + self, + *, + reward_fn_path: str | None = None, + max_workers: int = 4, + keepalive_timeout_s: int = 60, + ) -> None: + self._reward_fn_path = reward_fn_path or os.environ.get(_DEFAULT_REWARD_FN_ENV) + self._max_workers = max_workers + self._keepalive_timeout_s = keepalive_timeout_s + + def launch(self, port: int) -> None: + app = FastAPI() + + with ProcessPoolExecutor(max_workers=self._max_workers) as process_pool: + @app.post("/verify_answer") + async def verify(request: FnCallingVerificationRequest): + loop = asyncio.get_running_loop() + try: + answer_status = await loop.run_in_executor( + process_pool, + _execute_reward_job, + self._reward_fn_path, + request.generation, + dict(request.reward_context), + dict(request.extra_info), + ) + except Exception as exc: # pragma: no cover - server-side diagnostics + LOGGER.exception("fn_calling reward function failed") + raise HTTPException(status_code=500, detail=str(exc)) + return JSONResponse(content={"answer_status": answer_status}) + + @app.get("/health") + async def health(): + return {"status": "ok"} + + uvicorn.run(app, host="0.0.0.0", port=port, timeout_keep_alive=self._keepalive_timeout_s) \ No newline at end of file diff --git a/pipelinerl/domains/guessing/guessing.py b/pipelinerl/domains/guessing/guessing.py index a1ede13c..40e64c7e 100644 --- a/pipelinerl/domains/guessing/guessing.py +++ b/pipelinerl/domains/guessing/guessing.py @@ -9,6 +9,9 @@ from pipelinerl.rollouts import BaseMetrics, RolloutResult +DOMAIN = "guessing" + + async def generate_guessing_rollout( cfg: DictConfig, llm: TrainableLLM, @@ -91,10 +94,10 @@ def load_problems(dataset_names: list[str]): for name in dataset_names: if name == "train": problems.extend([ - {"answer": (2 * i * c) % n + 1, "dataset": "train"} for i in range(512) + {"answer": (2 * i * c) % n + 1, "dataset": "train", "domain": DOMAIN} for i in range(512) ]) elif name == "test": problems.extend([ - {"answer": ((2 * i + 1) * c) % n + 1, "dataset": "test"} for i in range(512) + {"answer": ((2 * i + 1) * c) % n + 1, "dataset": "test", "domain": DOMAIN} for i in range(512) ]) return problems diff --git a/pipelinerl/domains/math/load_datasets.py b/pipelinerl/domains/math/load_datasets.py index 4b44dfb6..fbfaea97 100644 --- a/pipelinerl/domains/math/load_datasets.py +++ b/pipelinerl/domains/math/load_datasets.py @@ -2,7 +2,8 @@ import logging import random import re -from typing import Dict, List, Tuple +from pathlib import Path +from typing import Dict, Iterable, List, Sequence, Tuple import datasets import hydra @@ -152,6 +153,26 @@ def load_math(split): return datasets.Dataset.from_list(data) +def _load_aime_2025_opencompass_dataset(upsample_factor: int = 0) -> list[dict]: + configs = ["AIME2025-I", "AIME2025-II"] + dataset_name = "aime_2025" + ("" if upsample_factor > 0 else "_original") + + samples: list[dict] = [] + for config_name in configs: + ds = load_dataset("opencompass/AIME2025", config_name, split="test") + samples.extend([s for s in process_math(ds, dataset_name) if s is not None]) + + original_size = len(samples) + if upsample_factor > 0: + samples *= upsample_factor + + logger.info( + f"Loading aime 2025 (OpenCompass) dataset: {len(samples)} samples" + + (f" (upsampled from {original_size})" if upsample_factor > 0 else "") + ) + return add_ids(samples) + + def _load_aime_dataset(year: int, upsample_factor: int = 0) -> list[dict]: aime_dataset = load_dataset("AI-MO/aimo-validation-aime", split="train", trust_remote_code=True) aime_dataset = aime_dataset.filter(lambda x: str(year) in x["url"]) @@ -194,18 +215,93 @@ def add_ids(dataset: list[dict]): return dataset +def _resolve_custom_path(relative_paths: str | Sequence[str]) -> Path: + """ + Resolve a path for locally generated datasets. + + Hydra jobs may change the working directory, so we check both the current + directory and the repository root. + """ + if isinstance(relative_paths, str): + relative_paths = [relative_paths] + + resolved = Path(__file__).resolve() + base_candidates = [Path.cwd()] + if len(resolved.parents) >= 5: + base_candidates.append(resolved.parents[4]) + + candidates: List[Path] = [] + for rel in relative_paths: + rel_path = Path(rel) + candidates.append(rel_path) + for base in base_candidates: + if base == Path.cwd(): + continue + candidates.append(base / rel_path) + + for candidate in candidates: + if candidate.exists(): + return candidate + raise FileNotFoundError( + f"Custom dataset not found. Tried: {[str(path) for path in candidates]}" + ) + + +def _load_custom_dataset(dataset_name: str) -> list[dict]: + """ + Load a locally generated dataset by name. + + The loader searches under `datasets/custom/` and `datasets/custom_runs/` for either + `` or `.jsonl`. + """ + candidate_names: List[str] = [] + if dataset_name.endswith(".jsonl"): + candidate_names.append(dataset_name) + else: + candidate_names.extend([dataset_name, f"{dataset_name}.jsonl"]) + + search_paths: List[str] = [] + for name in candidate_names: + search_paths.extend( + [ + f"datasets/custom/{name}", + f"datasets/custom_runs/{name}", + name, + ] + ) + + dataset_path = _resolve_custom_path(search_paths) + with dataset_path.open("r", encoding="utf-8") as handle: + samples = [json.loads(line) for line in handle if line.strip()] + + dataset_label = dataset_name[:-6] if dataset_name.endswith(".jsonl") else dataset_name + + for idx, sample in enumerate(samples): + sample.setdefault("source_dataset", sample.get("dataset", dataset_label)) + sample.setdefault("source_id", sample.get("id")) + sample["dataset"] = dataset_label + sample["id"] = idx + + logger.info(f"Loading custom dataset {dataset_name}: {len(samples)} samples from {dataset_path}") + return samples + + def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None) -> List[Tuple[str, Dict]]: if dataset_names is None: return [] if isinstance(dataset_names, str): dataset_names = [dataset_names] + # Preserve order while de-duplicating + dataset_names = list(dict.fromkeys(dataset_names)) datasets = [] + remaining = set(dataset_names) if "eurus_train" in dataset_names: dataset = load_dataset("PRIME-RL/Eurus-2-RL-Data", split="train", trust_remote_code=True) samples = [s for s in process_eurus(dataset) if s is not None] logger.info(f"Loading eurus train dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("eurus_train") # great for debugging since its much smaller than eurus train if "eurus_validation" in dataset_names: @@ -213,6 +309,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_eurus(dataset) if s is not None] logger.info(f"Loading eurus validation dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("eurus_validation") if "math_train" in dataset_names: # math_dataset = load_math("train") @@ -220,6 +317,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_math(dataset, "math_train") if s is not None] logger.info(f"Loading math train dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("math_train") if "math_simplerl_train" in dataset_names: # SimpleRL MATH dataset @@ -234,6 +332,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_math(dataset, "math_simplerl_train") if s is not None] logger.info(f"Loading math simplerl train dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("math_simplerl_train") if "simplerl_math_subset_1000" in dataset_names: # SimpleRL MATH dataset subset @@ -252,12 +351,14 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = samples[:1000] logger.info(f"Loading math simplerl subset test dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("simplerl_math_subset_1000") if "deepscaler_preview" in dataset_names: dataset = load_dataset("agentica-org/DeepScaleR-Preview-Dataset", split="train", trust_remote_code=True) samples = [s for s in process_math(dataset, "deepscaler") if s is not None] logger.info(f"Loading deepscaler preview train dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("deepscaler_preview") if "math_test" in dataset_names: # math_dataset = load_math("test") @@ -265,36 +366,42 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_math(dataset, "math_test") if s is not None] logger.info(f"Loading math test dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("math_test") if "omni_math_500" in dataset_names: dataset = load_dataset("reliable-agents/Omni-MATH-500", split="test", trust_remote_code=True) samples = [s for s in process_math(dataset, "omni_math_500") if s is not None] logger.info(f"Loading omni math 500 dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("omni_math_500") if "math_500" in dataset_names: dataset = load_dataset("HuggingFaceH4/MATH-500", split="test", trust_remote_code=True) samples = [s for s in process_math(dataset, "math_500") if s is not None] logger.info(f"Loading math 500 dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("math_500") if "open_r1_math_220k" in dataset_names: dataset = load_dataset("open-r1/OpenR1-Math-220k", split="default", trust_remote_code=True) samples = [s for s in process_math(dataset, "open_r1_math_220k") if s is not None] logger.info(f"Loading open r1 math 220k dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("open_r1_math_220k") if "gpqa_main" in dataset_names: dataset = load_dataset("hendrydong/gpqa_main", split="test", trust_remote_code=True) samples = [s for s in process_gpqa(dataset, "gpqa_main") if s is not None] logger.info(f"Loading gpqa main dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("gpqa_main") if "gpqa_diamond" in dataset_names: dataset = load_dataset("hendrydong/gpqa_diamond", split="test", trust_remote_code=True) samples = [s for s in process_gpqa(dataset, "gpqa_diamond") if s is not None] logger.info(f"Loading gpqa diamond dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("gpqa_diamond") if "gpqa_diamond" in dataset_names: pass @@ -304,49 +411,70 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_gsm8k(dataset, "gsm8k_train") if s is not None] logger.info(f"Loading gsm8k train dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("gsm8k_train") if "gsm8k_test" in dataset_names: dataset = load_dataset("openai/gsm8k", "main", split="test", trust_remote_code=True) samples = [s for s in process_gsm8k(dataset, "gsm8k_test") if s is not None] logger.info(f"Loading gsm8k test dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("gsm8k_test") if "limo" in dataset_names: dataset = load_dataset("GAIR/LIMO", split="train", trust_remote_code=True) samples = [s for s in process_limo(dataset) if s is not None] logger.info(f"Loading limo dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("limo") if "aime_2022" in dataset_names: datasets += _load_aime_dataset(2022, upsample_factor=16) + remaining.discard("aime_2022") if "aime_2022_original" in dataset_names: datasets += _load_aime_dataset(2022) + remaining.discard("aime_2022_original") if "aime_2023" in dataset_names: datasets += _load_aime_dataset(2023, upsample_factor=16) + remaining.discard("aime_2023") if "aime_2023_original" in dataset_names: datasets += _load_aime_dataset(2023) + remaining.discard("aime_2023_original") if "aime_2024" in dataset_names: datasets += _load_aime_dataset(2024, upsample_factor=16) + remaining.discard("aime_2024") if "aime_2024_original" in dataset_names: datasets += _load_aime_dataset(2024) + remaining.discard("aime_2024_original") + + if "aime_2025" in dataset_names: + datasets += _load_aime_2025_opencompass_dataset(upsample_factor=16) + remaining.discard("aime_2025") + + if "aime_2025_original" in dataset_names: + datasets += _load_aime_2025_opencompass_dataset() + remaining.discard("aime_2025_original") if "amc_2022" in dataset_names: # TODO: AMC 2022 is 43 problems, is that to be expected? datasets += _load_amc_dataset(2022, upsample_factor=16) + remaining.discard("amc_2022") if "amc_2022_original" in dataset_names: datasets += _load_amc_dataset(2022) + remaining.discard("amc_2022_original") if "amc_2023" in dataset_names: datasets += _load_amc_dataset(2023, upsample_factor=16) + remaining.discard("amc_2023") if "amc_2023_original" in dataset_names: datasets += _load_amc_dataset(2023) + remaining.discard("amc_2023_original") if "sometimes_success_data" in dataset_names: PATH = "data/sometimes_success_data/data.jsonl" @@ -354,6 +482,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [json.loads(line) for line in f] logger.info(f"Loading easy data dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("sometimes_success_data") if "open_reasoner_zero_57k" in dataset_names: dataset = load_dataset( @@ -365,6 +494,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_open_reasoner(dataset, "open_reasoner_zero_57k") if s is not None] logger.info(f"Loading Open Reasoner Zero dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("open_reasoner_zero_57k") if "open_reasoner_zero_extended_72k" in dataset_names: dataset = load_dataset( @@ -376,6 +506,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_open_reasoner(dataset, "open_reasoner_zero_extended_72k") if s is not None] logger.info(f"Loading Open Reasoner Zero extended dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("open_reasoner_zero_extended_72k") if "open_reasoner_zero_hard_13k" in dataset_names: dataset = load_dataset( @@ -387,6 +518,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_open_reasoner(dataset, "open_reasoner_zero_hard_13k") if s is not None] logger.info(f"Loading Open Reasoner Zero hard dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("open_reasoner_zero_hard_13k") for dataset_name in dataset_names: test_matched = re.match(r"multiplication_(\d+)_by_(\d+)_(\d+)_test", dataset_name) @@ -407,6 +539,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None ] logger.info(f"Loading multiplication {num_digits_1}_by_{num_digits_2} dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard(dataset_name) elif train_matched: upto_prefix = train_matched.group(1) or "" num_digits_1 = int(train_matched.group(2)) @@ -428,6 +561,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None f"Loading multiplication {upto_prefix}_{num_digits_1}_by_{num_digits_2} dataset: {len(samples)} samples" ) datasets += add_ids(samples) + remaining.discard(dataset_name) if "countdown" in dataset_names: dataset = load_dataset( @@ -436,6 +570,19 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_countdown(dataset) if s is not None] logger.info(f"Loading countdown dataset: {len(samples)} samples") datasets += samples + remaining.discard("countdown") + + # resolve any remaining names as local custom datasets. + unresolved: List[str] = [] + for dataset_name in list(remaining): + try: + datasets += _load_custom_dataset(dataset_name) + remaining.discard(dataset_name) + except FileNotFoundError: + unresolved.append(dataset_name) + + if unresolved: + raise ValueError(f"Unknown dataset(s): {unresolved}") if len(datasets) == 0: raise ValueError("No datasets loaded") diff --git a/pipelinerl/domains/math/rollouts.py b/pipelinerl/domains/math/rollouts.py index 62758c7e..4e27f23c 100644 --- a/pipelinerl/domains/math/rollouts.py +++ b/pipelinerl/domains/math/rollouts.py @@ -4,6 +4,8 @@ import aiohttp from omegaconf import DictConfig from pydantic import BaseModel +from pipelinerl.rollouts import RolloutResult, BaseMetrics +from pipelinerl.utils import get_environment_jobs, resolve_environment_key from pipelinerl.async_llm import llm_async_generate, make_training_text from pipelinerl.llm import Prompt, TrainableLLM @@ -55,9 +57,10 @@ async def generate_math_rollout( rewards = RewardTable(**dict(cfg.rewards)) discount_factor = cfg.actor.discount_factor - # math_verify is a fast environment, no support for environment replicas for now - env_jobs = [Job(**job) for job in cfg.jobs if job["kind"] == "environment"] - # choose the job randomly + env_key = resolve_environment_key(cfg, default="math") + env_jobs = get_environment_jobs(cfg, env_key) + if not env_jobs: + raise RuntimeError("No environment servers available for math domain") env_job = random.choice(env_jobs) assert env_job.port is not None answer_status = await verify_answer_rpc( diff --git a/pipelinerl/domains/multidomain/__init__.py b/pipelinerl/domains/multidomain/__init__.py new file mode 100644 index 00000000..10392710 --- /dev/null +++ b/pipelinerl/domains/multidomain/__init__.py @@ -0,0 +1,54 @@ +"""Debug utilities for multi-domain experimentation.""" + +from collections.abc import Iterable +from typing import Any + +DOMAIN = "multi" + + +_MATH_SAMPLE = { + "id": 0, + "dataset": "math_debug", + "task": "Compute 2 + 2.", + "answer": "\\boxed{4}", + "domain": "math", +} + +_GUESSING_SAMPLE = { + "id": 0, + "dataset": "guessing_debug", + "task": "Hidden number between 1 and 3. Start guessing.", + "answer": 2, + "domain": "guessing", +} + +_CODING_SAMPLE = { + "id": 0, + "dataset": "coding_debug", + "question": "Implement class Solution with method addTwoNumbers(a: int, b: int) that returns a + b.", + "starter_code": "class Solution:\n def addTwoNumbers(self, a: int, b: int) -> int:\n pass", + "tests": [ + {"id": 0, "type": "functional", "input": "1\n2", "output": "3"}, + {"id": 1, "type": "functional", "input": "5\n7", "output": "12"}, + ], + "entry_point": "addTwoNumbers", + "domain": "coding", +} + + +def load_problems(dataset_names: Iterable[str] | None = None, **_: Any) -> list[dict]: + """Return tiny synthetic problems for smoke-testing multi-domain dispatch.""" + if dataset_names is None: + dataset_names = ["math_debug", "guessing_debug"] + + problems: list[dict] = [] + for name in dataset_names: + if name == "math_debug": + problems.append(dict(_MATH_SAMPLE)) + elif name == "guessing_debug": + problems.append(dict(_GUESSING_SAMPLE)) + elif name == "coding_debug": + problems.append(dict(_CODING_SAMPLE)) + else: + raise ValueError(f"Unknown debug dataset '{name}'") + return problems diff --git a/pipelinerl/domains/multidomain/loader.py b/pipelinerl/domains/multidomain/loader.py new file mode 100644 index 00000000..84c60ecc --- /dev/null +++ b/pipelinerl/domains/multidomain/loader.py @@ -0,0 +1,100 @@ +from collections import defaultdict +from typing import Dict, Iterable, List, Sequence + +from pipelinerl.domains.math.load_datasets import load_datasets as load_math_datasets +from pipelinerl.domains.guessing.guessing import load_problems as load_guessing_problems +from pipelinerl.domains.counting.counting import load_problems as load_counting_problems +from pipelinerl.domains.chartqa.load_datasets import load_problems as load_chartqa_problems +from pipelinerl.domains.coding.dataset import load_problems as load_coding_problems +from pipelinerl.domains.fn_calling.dataset import load_problems as load_fn_calling_problems +from pipelinerl.domains.miniwob.load_tasks import load_tasks as load_miniwob_tasks + + +def _load_math(dataset_names: Sequence[str], *, seed=None, **_: dict) -> List[Dict]: + return load_math_datasets(list(dataset_names), seed=seed) + + +def _load_guessing(dataset_names: Sequence[str], **_: dict) -> List[Dict]: + return load_guessing_problems(list(dataset_names)) + + +def _load_coding(dataset_names: Sequence[str], **loader_kwargs: dict) -> List[Dict]: + return load_coding_problems(list(dataset_names), **loader_kwargs) + + +def _load_fn_calling(dataset_names: Sequence[str], **loader_kwargs: dict) -> List[Dict]: + return load_fn_calling_problems(list(dataset_names), **loader_kwargs) + + +def _load_counting(dataset_names: Sequence[str], **_: dict) -> List[Dict]: + return load_counting_problems(list(dataset_names)) + + +def _load_chartqa(dataset_names: Sequence[str], **_: dict) -> List[Dict]: + return load_chartqa_problems(list(dataset_names)) + + +def _load_miniwob(dataset_names: Sequence[str], **loader_kwargs: dict) -> List[Dict]: + return load_miniwob_tasks(list(dataset_names), **loader_kwargs) + + +DOMAIN_LOADERS = { + "math": _load_math, + "guessing": _load_guessing, + "coding": _load_coding, + "counting": _load_counting, + "chartqa": _load_chartqa, + "miniwob": _load_miniwob, + "fn_calling": _load_fn_calling, +} + + +def _parse_entry(entry: str) -> tuple[str, str]: + if "::" not in entry: + raise ValueError( + f"Dataset entry '{entry}' is missing a domain prefix. " + "Expected format '::'." + ) + domain, dataset = entry.split("::", 1) + return domain.strip(), dataset.strip() + + +def load_datasets( + dataset_names: Iterable[str] | str | None, + *, + seed: int | None = None, + per_domain_params: dict[str, dict] | None = None, + **kwargs, +) -> List[Dict]: + if dataset_names is None: + return [] + if isinstance(dataset_names, str): + dataset_names = [dataset_names] + + grouped: dict[str, list[str]] = defaultdict(list) + for entry in dataset_names: + domain, name = _parse_entry(str(entry)) + grouped[domain].append(name) + + counters: dict[tuple[str, str], int] = defaultdict(int) + problems: List[Dict] = [] + per_domain_params = dict(per_domain_params or {}) + for domain, names in grouped.items(): + loader = DOMAIN_LOADERS.get(domain) + if loader is None: + raise ValueError(f"No loader registered for domain '{domain}'") + domain_kwargs = dict(kwargs) + if domain in per_domain_params: + domain_kwargs.update(dict(per_domain_params[domain] or {})) + loaded = loader(names, seed=seed, **domain_kwargs) + for sample in loaded: + dataset_name = str(sample.get("dataset", names[0] if names else domain)) + sample.setdefault("domain", domain) + if "id" not in sample: + key = (sample["domain"], dataset_name) + sample["id"] = counters[key] + counters[key] += 1 + if "dataset" not in sample: + sample["dataset"] = dataset_name + problems.append(sample) + return problems diff --git a/pipelinerl/entrypoints/run_environment.py b/pipelinerl/entrypoints/run_environment.py index d5e05d2a..1951d4a7 100644 --- a/pipelinerl/entrypoints/run_environment.py +++ b/pipelinerl/entrypoints/run_environment.py @@ -1,14 +1,21 @@ import hydra from omegaconf import DictConfig -from pipelinerl.utils import better_crashing +from pipelinerl.utils import better_crashing, select_environment_config @hydra.main(config_path="../../conf", config_name="base", version_base="1.3.2") def hydra_entrypoint(cfg: DictConfig): with better_crashing("environment"): - environment = hydra.utils.instantiate(cfg.environment) this_job, = [job for job in cfg.jobs if job["idx"] == cfg.me.job_idx] + environment_cfg = select_environment_config( + cfg, + key=this_job.get("environment_key"), + index=this_job.get("environment_index"), + ) + if environment_cfg is None: + raise ValueError("No environment configuration found for job") + environment = hydra.utils.instantiate(environment_cfg) port = this_job["port"] environment.launch(port=port) diff --git a/pipelinerl/rollouts.py b/pipelinerl/rollouts.py index dcb27f2d..1200ba23 100644 --- a/pipelinerl/rollouts.py +++ b/pipelinerl/rollouts.py @@ -64,3 +64,4 @@ class RolloutResult(BaseModel): model_version: int | None = None dataset_name: str | None = None group_id: str | None = None + domain: str | None = None diff --git a/pipelinerl/utils.py b/pipelinerl/utils.py index fbcd9926..96c0ceab 100644 --- a/pipelinerl/utils.py +++ b/pipelinerl/utils.py @@ -7,11 +7,12 @@ from pathlib import Path import traceback from typing import Dict, Mapping, List, Any + import numpy as np -from omegaconf import DictConfig import psutil import requests from importlib.metadata import distributions +from omegaconf import DictConfig, ListConfig, OmegaConf from transformers import PreTrainedTokenizer from pipelinerl.world import Job @@ -22,6 +23,177 @@ logger = logging.getLogger(__name__) +_ENV_METADATA_KEYS = {"key", "mode", "replicas_per_actor"} + + +def _strip_environment_metadata(env_cfg: DictConfig | dict | None): + if env_cfg is None: + return None + if isinstance(env_cfg, DictConfig): + data = OmegaConf.to_container(env_cfg, resolve=True) + elif isinstance(env_cfg, dict): + data = dict(env_cfg) + else: + return env_cfg + for meta_key in _ENV_METADATA_KEYS: + data.pop(meta_key, None) + return OmegaConf.create(data) + + +def _env_cfg_type(env_cfg, field: str): + if isinstance(env_cfg, DictConfig): + return env_cfg.get(field, None) + if isinstance(env_cfg, dict): + return env_cfg.get(field) + return None + + +def select_environment_config(cfg: DictConfig, *, key: str | None = None, index: int | None = None): + env_cfgs = getattr(cfg, "environments", None) + if env_cfgs: + if isinstance(env_cfgs, (ListConfig, list)): + if key is not None: + for env_cfg in env_cfgs: + env_key = _env_cfg_type(env_cfg, "key") or _env_cfg_type(env_cfg, "name") + if env_key is not None and str(env_key) == str(key): + return _strip_environment_metadata(env_cfg) + if index is not None and 0 <= index < len(env_cfgs): + return _strip_environment_metadata(env_cfgs[index]) + elif isinstance(env_cfgs, (DictConfig, dict)): + if key is not None and key in env_cfgs: + return _strip_environment_metadata(env_cfgs[key]) + if index is not None: + for idx, env_key in enumerate(env_cfgs): + if idx == index: + return _strip_environment_metadata(env_cfgs[env_key]) + + return getattr(cfg, "environment", None) + + +def _domain_mix_weights(cfg: DictConfig) -> dict[str, float]: + domain_mix_cfg = getattr(getattr(cfg, "actor", None), "domain_mix", None) + if not domain_mix_cfg: + return {} + try: + mix_weights = OmegaConf.to_container(domain_mix_cfg, resolve=True) + except Exception: + return {} + if not isinstance(mix_weights, dict): + return {} + weights: dict[str, float] = {} + for key, value in mix_weights.items(): + try: + weight = float(value) + except (TypeError, ValueError): + continue + if weight > 0: + weights[str(key)] = weight + return weights + + +def _apply_domain_mix_replicas(cfg: DictConfig, specs: list[dict[str, Any]], default_replicas: Any) -> None: + weights = _domain_mix_weights(cfg) + if not weights: + return + try: + default_value = float(default_replicas) + except (TypeError, ValueError): + return + if default_value <= 0: + return + weighted_specs = [spec for spec in specs if spec.get("mode") == "remote" and spec.get("key") in weights] + if not weighted_specs: + return + total_weight = sum(weights[spec["key"]] for spec in weighted_specs) + if total_weight <= 0: + return + average_weight = total_weight / len(weighted_specs) + if average_weight <= 0: + return + for spec in weighted_specs: + key = spec["key"] + scaled = default_value * (weights[key] / average_weight) + replicas = max(1, int(round(scaled))) + current = spec.get("replicas_per_actor") + if current is None or current == default_replicas: + spec["replicas_per_actor"] = replicas + + +def collect_environment_specs(cfg: DictConfig) -> list[dict[str, Any]]: + specs: list[dict[str, Any]] = [] + env_cfgs = getattr(cfg, "environments", None) + default_mode = str(getattr(cfg.world, "environment_mode", "remote")) + default_replicas = getattr(cfg.world, "env_replicas_per_actor", 1) + + if isinstance(env_cfgs, (ListConfig, list)): + iterable = list(env_cfgs) + for idx, env_cfg in enumerate(iterable): + if env_cfg is None: + continue + key = _env_cfg_type(env_cfg, "key") or _env_cfg_type(env_cfg, "name") + mode = _env_cfg_type(env_cfg, "mode") + replicas = _env_cfg_type(env_cfg, "replicas_per_actor") + specs.append( + { + "key": str(key) if key is not None else f"environment_{idx}", + "mode": str(mode) if mode is not None else default_mode, + "replicas_per_actor": replicas, + "index": idx, + } + ) + elif isinstance(env_cfgs, (DictConfig, dict)): + items = env_cfgs.items() + for idx, (key, env_cfg) in enumerate(items): + if env_cfg is None: + continue + mode = _env_cfg_type(env_cfg, "mode") + replicas = _env_cfg_type(env_cfg, "replicas_per_actor") + specs.append( + { + "key": str(key), + "mode": str(mode) if mode is not None else default_mode, + "replicas_per_actor": replicas, + "index": idx, + } + ) + else: + single_env = getattr(cfg, "environment", None) + if single_env: + key = _env_cfg_type(single_env, "key") or _env_cfg_type(single_env, "name") + specs.append( + { + "key": str(key) if key is not None else "default", + "mode": default_mode, + "replicas_per_actor": default_replicas, + "index": 0, + } + ) + + active_domains = _domain_mix_weights(cfg) + if active_domains: + specs = [spec for spec in specs if spec.get("key") in active_domains] + + _apply_domain_mix_replicas(cfg, specs, default_replicas) + return specs + + +def resolve_environment_key(cfg: DictConfig, default: str | None = None) -> str | None: + explicit = cfg.get("environment_key", None) if hasattr(cfg, "get") else None + if explicit: + return str(explicit) + specs = collect_environment_specs(cfg) + if len(specs) == 1: + return specs[0]["key"] + return default + + +def get_environment_jobs(cfg: DictConfig, key: str | None = None) -> list[Job]: + jobs_cfg = getattr(cfg, "jobs", []) + env_jobs = [Job(**job) for job in jobs_cfg if job["kind"] == "environment"] + if key is None: + return env_jobs + filtered = [job for job in env_jobs if getattr(job, "environment_key", None) == key] + return filtered or env_jobs def init_wandb( cfg: DictConfig, @@ -237,6 +409,9 @@ def calculate_stats(stats: List | Dict[Any, Any]) -> Dict[str, float]: if not isinstance(stats, list): raise TypeError(f"Expected stats to be a list, got {type(stats)}") + if len(stats) == 0: + return {} + aggregated_stats = { "max": float(max(stats)), "min": float(min(stats)), @@ -291,19 +466,27 @@ def wait_for_inference_servers(urls: list[str]): def wait_for_environments(cfg: DictConfig): - """ - Wait for the verifier to be ready. - """ - env_jobs = [Job(**job) for job in cfg.jobs if job.kind == "environment"] + """Wait for remote environment servers to report healthy.""" + specs = collect_environment_specs(cfg) + if not any(spec.get("mode") == "remote" for spec in specs): + return + + env_jobs = get_environment_jobs(cfg) + if not env_jobs: + return for job in env_jobs: while True: url = f"http://{job.hostname}:{job.port}/health" - # use requests try: response = requests.get(url) if response.status_code == 200: + logger.info( + "Environment %s ready at %s", + job.environment_key if job.environment_key is not None else job.replica_idx, + url, + ) break - except: + except requests.exceptions.RequestException: logger.info(f"Waiting for environment at {url} to be ready...") time.sleep(5.0) diff --git a/pipelinerl/world.py b/pipelinerl/world.py index f41714e4..aa5ea420 100644 --- a/pipelinerl/world.py +++ b/pipelinerl/world.py @@ -26,6 +26,10 @@ class Job(BaseModel): gpus: list[int] = [] # The URL of the job url: str = "" + # Domain identifier for environment jobs + environment_key: str | None = None + # Idx of the environments in the list of environments + environment_index: int | None = None class WorldMap: @@ -71,8 +75,11 @@ def __init__(self, cfg: DictConfig, verbose: bool = False): if place_inference_jobs: self._place_inference_jobs(cfg) self._place_pipeline_stages(cfg) - if cfg.environment: - self._place_environments(cfg) + # Lazy import to avoid circular dependency with utils.py + from pipelinerl.utils import collect_environment_specs + self.environment_specs = collect_environment_specs(cfg) + if any(spec["mode"] == "remote" for spec in self.environment_specs): + self._place_environments(cfg, self.environment_specs) # Place the finetune workers on the remaining gpus, take all remaining GPUs current_finetune_rank = 0 @@ -108,7 +115,7 @@ def __init__(self, cfg: DictConfig, verbose: bool = False): for job in jobs: self._log_info(f" {job.kind} {job.replica_idx} on gpus {job.gpus}, local idx {job.local_idx}") - def add_job(self, node_rank: int, kind: str, replica_idx: int, local_idx: int = 0, port: int | None = None, gpus: list[int] | None = None, cpu_heavy: bool = False, url: str = "") -> Job: + def add_job(self, node_rank: int, kind: str, replica_idx: int, local_idx: int = 0, port: int | None = None, gpus: list[int] | None = None, cpu_heavy: bool = False, url: str = "", environment_key: str | None = None, environment_index: int | None = None) -> Job: """Add a job to the world map.""" if gpus is None: gpus = [] @@ -121,7 +128,9 @@ def add_job(self, node_rank: int, kind: str, replica_idx: int, local_idx: int = hostname=self.address_map[node_rank], port=port, gpus=gpus, - url=url + url=url, + environment_key=environment_key, + environment_index=environment_index, ) self.job_map[node_rank].append(job) self.total_jobs += 1 @@ -187,18 +196,35 @@ def _place_pipeline_stages(self, cfg): self.add_job(kind="actor", replica_idx=worker_idx, node_rank=node, gpus=[], cpu_heavy=True) self.add_job(kind="preprocessor", replica_idx=worker_idx, node_rank=node, gpus=[], cpu_heavy=True) - def _place_environments(self, cfg): - for worker_idx in range(cfg.world.env_replicas): - node = self.get_least_busy_node() - envs_at_node = len([job for job in self.job_map[node] if job.kind == "environment"]) - self.add_job( - kind="environment", - replica_idx=worker_idx, - node_rank=node, - port=cfg.world.environment_start_port + envs_at_node, - gpus=[], - cpu_heavy=True, - ) + def _place_environments(self, cfg: DictConfig, environment_specs: list[dict]): + # Scale environment servers to be the same as llm servers + base_start_port = cfg.world.environment_start_port + llms_per_actor = getattr(self, "llms_per_actor", 1) or 1 + global_replica_idx = 0 + for spec_idx, spec in enumerate(environment_specs): + if spec["mode"] != "remote": + continue + replicas_per_actor = spec.get("replicas_per_actor") + if replicas_per_actor is None: + replicas_per_actor = getattr(cfg.world, "env_replicas_per_actor", None) + if replicas_per_actor is not None: + total_env_replicas = cfg.world.replicas * llms_per_actor * replicas_per_actor + else: + total_env_replicas = getattr(cfg.world, "env_replicas", cfg.world.replicas * llms_per_actor) + for replica_offset in range(total_env_replicas): + node = self.get_least_busy_node() + envs_at_node = len([job for job in self.job_map[node] if job.kind == "environment"]) + self.add_job( + kind="environment", + replica_idx=global_replica_idx, + node_rank=node, + port=base_start_port + envs_at_node, + gpus=[], + cpu_heavy=True, + environment_key=spec["key"], + environment_index=spec.get("index", spec_idx), + ) + global_replica_idx += 1 def _place_inference_jobs(self, cfg): for _ in range(cfg.world.replicas):