diff --git a/.github/workflows/test-arealite.yml b/.github/workflows/test-arealite.yml new file mode 100644 index 000000000..41ae60b8f --- /dev/null +++ b/.github/workflows/test-arealite.yml @@ -0,0 +1,50 @@ +name: Test AReaLite + +on: + push: + paths: + - .github/workflows/test-arealite.yml + - arealite/** + - ci/** + workflow_dispatch: + +jobs: + test-arealite: + runs-on: ubuntu-latest + concurrency: + group: test-arealite + steps: + - uses: actions/checkout@v4 + + - uses: appleboy/ssh-action@v1 + env: + GIT_REPO_URL: https://github.bibk.top/${{ github.repository }} + GIT_COMMIT_SHA: ${{ github.sha }} + with: + host: ${{ secrets.CI_NODE_ADDR }} + username: ${{ secrets.CI_NODE_USER }} + key: ${{ secrets.REMOTE_SSH_KEY }} + envs: GIT_REPO_URL,GIT_COMMIT_SHA + script_path: ci/clone_repo.sh + + - uses: appleboy/ssh-action@v1 + env: + GIT_COMMIT_SHA: ${{ github.sha }} + with: + host: ${{ secrets.CI_NODE_ADDR }} + username: ${{ secrets.CI_NODE_USER }} + key: ${{ secrets.REMOTE_SSH_KEY }} + command_timeout: 2h + envs: GIT_COMMIT_SHA + script_path: ci/build_env_image.sh + + - uses: appleboy/ssh-action@v1 + env: + GIT_COMMIT_SHA: ${{ github.sha }} + with: + host: ${{ secrets.CI_NODE_ADDR }} + username: ${{ secrets.CI_NODE_USER }} + key: ${{ secrets.REMOTE_SSH_KEY }} + command_timeout: 1h + envs: GIT_COMMIT_SHA + script_path: ci/test_arealite.sh diff --git a/arealite/README.md b/arealite/README.md new file mode 100644 index 000000000..caa7640ec --- /dev/null +++ b/arealite/README.md @@ -0,0 +1,213 @@ +# AReaLite Design Doc + +## Motivation + +AReaL is too heavy for AI researchers to use, understand, and develop with for several reasons. The most important issue is that its code architecture is *system-centric* rather than *AI-centric* — the RL algorithm workflow consists of multiple *workers* that run consecutive *model function calls*, neither of which are well-known concepts for AI researchers. As a result, users must first understand these concepts before they can develop workflows and algorithms for their own use cases. + +Additionally, due to historical reasons, AReaL's code is not clean. There are large pieces of code inherited from previous projects that are not useful but significantly increase the burden on users and developers. Sometimes debugging is difficult even for core developers like myself. + +Since the tools for building RL workflows are becoming increasingly mature, implementing a framework that achieves comparable efficiency requires much fewer lines of code. Now is the proper time to revisit the API design and distill the giant codebase into a neat and clean one. The distilled codebase does not need to be ultra-efficient. Instead, we want to deliver 90% functionality of the original AReaL while minimizing the lines of code and the burden on potential users. Our aim is to build an RL training framework that is fast to use, fast to read, and fast to execute. Here comes the lite version of AReaL — AReaLite. + +AReaLite is the first step in AReaL's refactoring process. It is not only a standalone training library with shallow interfaces, but will also provide the core API definitions to be used by AReaL in the future. AReaL will essentially transform its current worker-based architecture into an AI-centric architecture like AReaLite. AReaL will **extend** AReaLite's APIs and implementations to support more backends for efficient large-scale training. + +## Expectations of AReaLite + +### Highlights + ++ Fully asynchronous training with decoupled inference and training. ++ Elastic inference device scaling — users can shut down or launch more inference processes independently during training. ++ Full SFT/RL algorithmic functionality matching AReaL. ++ Arbitrary agentic rollout collector customization in a single file. ++ Easy navigation to implementation references via Ctrl+click in VSCode. ++ Support for distributed launching with Ray/SLURM/torchrun. + +### AReaLite's Scope + ++ Not bound to Ray. ++ Only supports SGLang and PyTorch FSDP2 with SPMD launching. ++ No customized data structures like `SequenceSample`. All data are PyTorch tensors. ++ Uses HuggingFace (models, datasets) and PyTorch (FSDP, data structures) as much as possible. + +## Architecture + +### Core Components + +``` +arealite/ +├── api/ # Abstract interfaces and data structures +├── impl/ # Concrete implementations +├── cli/ # Command-line interfaces +├── config/ # Configuration templates +└── tests/ # Standalone test scripts +``` + +#### 1. API Layer (`api/`) + +The API layer defines abstract interfaces and data structures that provide a clean contract between different components: + +- **`engine_api.py`**: Defines `SPMDWrapper` for SPMD-based training backends (FSDP) and `EngineFactory` +- **`trainer_api.py`**: Defines `Trainer` base class for different training algorithms and `TrainerFactory` +- **`rollout_api.py`**: Defines `RolloutCollector`, `Agent`, `Environment` for RL data collection and `RolloutCollectorFactory` +- **`cli_args.py`**: Defines configuration dataclasses for all components + +#### 2. Implementation Layer (`impl/`) + +The implementation layer contains concrete implementations of the API interfaces: + +- **`fsdp_wrapper.py`**: FSDP-based training engine using PyTorch FSDP2 +- **`trainer/grpo.py`**: GRPO trainer implementation for reinforcement learning +- **`rollout_controller.py`**: Coordinates rollout data collection across workers +- **`rlvr/`**: RLVR collector implementations +- **`agentic/`**: Agentic collector implementations (math, code tasks) + +#### 3. CLI Layer (`cli/`) + +The CLI layer provides user-facing commands: + +- **`main.py`**: Main entry point for launching complete training pipelines +- **`launch_server.py`**: Utility for launching standalone LLM servers + +### Data Flow Architecture + +AReaLite uses an **async producer-consumer pattern**: + +``` +┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐ +│ LLM Servers │◄──►│ Rollout Workers │───►│ Data Buffer │ +│ (SGLang) │ │ (Async Batch) │ │ │ +└─────────────────┘ └──────────────────┘ └─────────────────┘ + ▲ │ + │ ▼ +┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐ +│ Checkpoints │◄───│ FSDP Trainer │◄───│ Training Loop │ +│ │ │ (Sync Batch) │ │ │ +└─────────────────┘ └──────────────────┘ └─────────────────┘ +``` + +### Key Design Principles + +#### 1. **AI-Centric API Design** +Unlike the original AReaL's system-centric approach with workers and model functions, AReaLite uses familiar ML concepts: +- `Agent` and `Environment` (from RL literature) +- `RolloutCollector` (combines multiple agents and the environment to generate rollout data) +- `Trainer` (from HuggingFace/PyTorch, fetches rollout data and updates model parameters) + +#### 2. **Factory Pattern for Extensibility** +Each major component uses a factory pattern for easy customization: +- `EngineFactory` creates training backends +- `TrainerFactory` creates training algorithms +- `RolloutCollectorFactory` creates rollout collectors + +#### 3. **Configuration-Driven Architecture** +All components are configured through dataclasses defined in `cli_args.py`, enabling: +- Type-safe configuration +- Easy CLI argument generation +- Clear documentation of available options + + +## Implementation Details + +### Training Pipeline + +1. **Initialization**: Factory classes create configured instances of engines, trainers, and rollout collectors +2. **Rollout Phase**: `RolloutController` coordinates async data collection across multiple `RolloutWorker` instances +3. **Training Phase**: `Trainer` performs synchronous gradient updates using collected data +4. **Weight Updates**: Updated model weights are pushed to LLM servers via `update_weights_to()` + +### Rollout System + +The rollout system supports arbitrary agentic rollout paradigms, implemented as `RolloutCollector` instances. `RolloutCollector` exposes a `run_episode` method for users to implement the logic of collecting a complete agentic trajectory. Users can implement gymnasium-compatible `Agent` and `Environment` interfaces first and combine them as a collector as in normal RL literature (in `arealite/impl/agentic/`), or users can implement the collector directly if the agent-environment interfaces are not compatible with the desired use cases (in `arealite/impl/rlvr/`). + +## Expected Usage + +### Basic RL Training +```bash +python3 arealite.cli.main \ + experiment_name=my-exp trial_name=my-trial \ + trainer.type=grpo \ + trainer.grpo.actor.path=Qwen/Qwen2-0.5B +``` + +### Rollout-Only Evaluation +```bash +python3 arealite.cli.main \ + trainer.type=null \ + valid_dataset.path=huggingface/dataset +``` + +### Distributed Training +```bash +python3 arealite.cli.main \ + mode=ray \ + allocation_mode=sglang.d16p1m1+d32p2m1 \ + trainer.type=grpo +``` + +## Customization Guide + +### Adding New Trainers + +1. **Implement trainer class** in `impl/trainer/`: +```python +from arealite.api.trainer_api import Trainer + +class MyTrainer(Trainer): + def train(self, resume_from_checkpoint=None): + # Implementation here + pass +``` + +2. **Add configuration** in `cli_args.py`: +```python +@dataclass +class MyTrainerConfig: + learning_rate: float = 1e-4 +``` + +3. **Register in factory** in `trainer_api.py`: +```python +def make_trainer(self, config: TrainerConfig) -> Trainer: + if config.type == "my_trainer": + return MyTrainer(...) +``` + +### Adding New Rollout Collectors + +1. **Implement collector** in `impl/`: +```python +from arealite.api.rollout_api import RolloutCollector + +class MyCollector(RolloutCollector): + async def arun_episode(self, gconfig, env_option=None, seed=None): + # Implementation here + pass +``` + +2. **Register in factory** in `rollout_api.py`: +```python +def make_collector(self, config: RolloutCollectorConfig): + if config.type == "my_collector": + return MyCollector(...) +``` + +## Roadmap + +- [ ] Finalize API design. (In-progress) +- [x] Implement standalone SGLang server (`impl/sglang_server.py`). +- [x] Implement SGLang client generation (`impl/sglang_client.py`). +- [x] Rollout pipeline (`tests/test_rollout.py`). +- [x] SGLang rollout interruption. +- [x] Asynchronous RL system-wide utilities (e.g., `RolloutController`). +- [ ] Various launching scripts: ray, torchrun, slurm. +- [ ] FSDP2 engine with transformers models. (In-progress) +- [ ] SFT trainer. (In-progress) +- [ ] SGLang update weights. (In-progress) +- [x] GRPO trainer. +- [ ] Add benchmarking against original AReaL +- [ ] CI and unittests. +- [ ] Other RL algorithms (DPO, REINFORCE, etc.) +- [ ] Support for multi-modal models +- [ ] User guide for transitioning from v0.3.0. +- [ ] Advanced agentic collectors (tool use, planning) +- [ ] Examples of training GSM8K, TLDR, and a search agent. +- [ ] Allow external persistent SGLang servers for debugging purposes. diff --git a/arealite/api/cli_args.py b/arealite/api/cli_args.py new file mode 100644 index 000000000..e2fd8dad1 --- /dev/null +++ b/arealite/api/cli_args.py @@ -0,0 +1,721 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 +import argparse +import os +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import uvloop + +uvloop.install() +from hydra import compose as hydra_compose +from hydra import initialize as hydra_init +from omegaconf import MISSING, OmegaConf + +from realhf.api.cli_args import ( + ClusterSpecConfig, + ExperimentSaveEvalControl, + MicroBatchSpec, + OptimizerConfig, + TensorBoardConfig, + WandBConfig, +) + + +@dataclass(unsafe_hash=True) +class ParallelismConfig: + """Configuration for 3D parallelism (tensor, pipeline, and data parallelism). + + Note: + Sequence parallelism is only used in combination with tensor-model parallelism. + """ + + tensor_parallel_size: int = field( + default=1, metadata={"help": "Size of tensor-model parallelism"} + ) + pipeline_parallel_size: int = field( + default=1, metadata={"help": "Number of pipeline parallel stages"} + ) + data_parallel_size: int = field( + default=1, metadata={"help": "Data parallelism size for ZeRO optimization"} + ) + use_sequence_parallel: bool = field( + default=False, + metadata={ + "help": "Enable sequence parallelism. Only used with tensor-model parallelism in Megatron", + }, + ) + + def __str__(self): + """Returns compact string representation: 'Parallel(mp=X,pp=Y,dp=Z)'.""" + return ( + f"Parallel(mp={self.tensor_parallel_size}," + f"pp={self.pipeline_parallel_size}," + f"dp={self.data_parallel_size})" + ) + + @staticmethod + def parallelism_eq(this, other): + """Compare parallelism configurations (excluding sequence parallelism). + + Note: + Implemented as static method to avoid OmegaConf compatibility issues. + """ + return ( + (this.tensor_parallel_size == other.tensor_parallel_size) + and (this.pipeline_parallel_size == other.pipeline_parallel_size) + and (this.data_parallel_size == other.data_parallel_size) + ) + + +@dataclass +class GenerationHyperparameters: + """Controls text generation behavior for RL training.""" + + n_samples: int = field( + default=1, metadata={"help": "Number of sequences to generate per prompt."} + ) + max_new_tokens: int = field( + default=16384, metadata={"help": "Maximum number of tokens to generate."} + ) + min_new_tokens: int = field( + default=0, metadata={"help": "Minimum number of tokens to generate."} + ) + greedy: bool = field( + default=False, + metadata={"help": "Whether to use greedy decoding (max probability)."}, + ) + top_p: float = field( + default=1.0, + metadata={"help": "Nucleus sampling probability threshold (0.0, 1.0]."}, + ) + top_k: int = field( + default=int(1e8), + metadata={"help": "Number of highest probability tokens to consider."}, + ) + temperature: float = field( + default=1.0, + metadata={"help": "Sampling temperature. Higher values increase diversity."}, + ) + stop_token_ids: List[int] = field( + default_factory=list, + metadata={"help": "Stop generation when encoutering these token ids."}, + ) + + def new(self, **kwargs): + args = asdict(self) + args.update(kwargs) + return GenerationHyperparameters(**args) + + +## Inference config for clients and servers. ## + + +@dataclass +class SGLangConfig: + """Configuration for SGLang runtime. Refer to: + https://github.com/sgl-project/sglang for detailed documentation. + """ + + disable_cuda_graph: bool = False + disable_radix_cache: bool = False + disable_cuda_graph_padding: bool = False + enable_nccl_nvls: bool = False + disable_outlines_disk_cache: bool = False + disable_custom_all_reduce: bool = False + disable_overlap_schedule: bool = False + enable_mixed_chunk: bool = False + enable_dp_attention: bool = False + enable_ep_moe: bool = False + enable_torch_compile: bool = False + torch_compile_max_bs: int = 32 + cuda_graph_max_bs: Optional[int] = None + cuda_graph_bs: Optional[List[int]] = None + torchao_config: str = "" + enable_nan_detection: bool = False + enable_p2p_check: bool = False + triton_attention_reduce_in_fp32: bool = False + triton_attention_num_kv_splits: int = 8 + num_continuous_decode_steps: int = 1 + enable_memory_saver: bool = False + allow_auto_truncate: bool = False + # NOTE: to avoid the illegal memory access error + attention_backend: Optional[str] = "flashinfer" + sampling_backend: Optional[str] = None + context_length: Optional[int] = 32768 + mem_fraction_static: Optional[float] = 0.9 + max_running_requests: Optional[int] = None + # NOTE: chunked_prefill_size is by default 8192 on GPUs with 80GB mem in SGLang, + # but we disable it to avoid precision issues + chunked_prefill_size: Optional[int] = -1 + max_prefill_tokens: int = 32768 + schedule_policy: str = "lpm" + schedule_conservativeness: float = 1.0 + cpu_offload_gb: int = 0 + + dtype: str = "float16" + kv_cache_dtype: str = "auto" + + # logging + log_level: str = "warning" + log_level_http: Optional[str] = "warning" + log_requests: bool = False + log_requests_level: int = 0 + show_time_cost: bool = False + enable_metrics: bool = True # Exports Prometheus-like metrics + # The interval (in decoding iterations) to log throughput + # and update prometheus metrics + decode_log_interval: int = 1 + + # Use staticmethod to make OmegaConf happy. + @staticmethod + def build_cmd( + sglang_config: "SGLangConfig", + model_path, + tp_size, + base_gpu_id, + dist_init_addr: Optional[str] = None, + served_model_name: Optional[str] = None, + skip_tokenizer_init: bool = True, + ): + from realhf.base import network, pkg_version, seeding + from realhf.experiments.common.utils import asdict as conf_as_dict + + args: Dict = conf_as_dict(sglang_config) + args["random_seed"] = seeding.get_seed() + + if served_model_name is None: + served_model_name = model_path + host_ip = network.gethostip() + host = "localhost" if not sglang_config.enable_metrics else host_ip + args = dict( + host=host, + model_path=model_path, + # Model and tokenizer + tokenizer_path=model_path, + tokenizer_mode="auto", + load_format="auto", + trust_remote_code=True, + device="cuda", + served_model_name=served_model_name, + is_embedding=False, + skip_tokenizer_init=skip_tokenizer_init, + # Other runtime options + tp_size=tp_size, + # Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process + base_gpu_id=base_gpu_id, + nnodes=1, + node_rank=0, + dist_init_addr=dist_init_addr, + **args, + ) + + if pkg_version.is_version_less("sglang", "0.4.4"): + args.pop("log_requests_level") + if pkg_version.is_version_less("sglang", "0.4.3"): + args.pop("enable_nccl_nvls") + args.pop("triton_attention_num_kv_splits") + args.pop("cuda_graph_bs") + args.pop("enable_memory_saver") + args.pop("allow_auto_truncate") + args.pop("file_storage_path") + + flags = [] + for k, v in args.items(): + if v is None or v is False or v == "": + continue + if v is True: + flags.append(f"--{k.replace('_','-')} ") + continue + if isinstance(v, list): + values = " ".join(map(str, v)) + flags.append(f"--{k.replace('_','-')} {values}") + continue + flags.append(f"--{k.replace('_','-')} {v}") + flags = " ".join(flags) + return f"python3 -m sglang.launch_server {flags}" + + +@dataclass +class LLMServiceConfig: + served_model_name: Optional[str] = field( + default=None, metadata={"help": "Name of the served model"} + ) + health_check_interval: int = field( + default=5, metadata={"help": "Health check interval in seconds"} + ) + startup_timeout: int = field( + default=90, metadata={"help": "Startup timeout in seconds"} + ) + max_unhealth_count: int = field( + default=3, metadata={"help": "Max unhealthy count before restart"} + ) + graceful_shutdown_on_unhealthy: bool = field( + default=True, metadata={"help": "Enable graceful shutdown when unhealthy"} + ) + + +@dataclass +class LLMClientConfig: + schedule_policy: str = field( + default="round_robin", + metadata={"help": "Request scheduling policy", "choices": ["round_robin"]}, + ) + request_timeout: int = field( + default=3600, metadata={"help": "Request timeout in seconds"} + ) + request_retries: int = field( + default=3, metadata={"help": "Number of retries for each request"} + ) + + +## Dataset configs. ## + + +@dataclass +class GSM8KPreprocessor: + reward_mode: str = "strict" + + +@dataclass +class DatasetPreprocessor: + type: Optional[str] = field( + default=None, + metadata={ + "help": "Number of retries for each request", + }, + ) + gsm8k: Optional[GSM8KPreprocessor] = None + + +@dataclass +class DatasetConfig: + path: str = field( + default="", metadata={"help": "Path or HuggingFace identifier to the dataset"} + ) + name: Optional[str] = field( + default=None, metadata={"help": "Dataset name (for HuggingFace datasets)"} + ) + split: Optional[str] = field( + default=None, metadata={"help": "Dataset split to use (e.g., 'train', 'test')"} + ) + data_files: Optional[str] = field( + default=None, metadata={"help": "Specific data files to load"} + ) + batch_size: int = field( + default=1, metadata={"help": "Batch size of the dataloader"} + ) + shuffle: bool = field( + default=True, metadata={"help": "Whether to shuffle the dataset"} + ) + pin_memory: bool = field( + default=False, + metadata={ + "help": "Pin memory for faster data loading (set True for GPU training)" + }, + ) + num_workers: int = field( + default=0, metadata={"help": "Number of worker processes for data loading"} + ) + preprocessor: Optional[DatasetPreprocessor] = field( + default=None, + metadata={"help": "Dataset preprocessor config. None means no preprocessing."}, + ) + + +## Training backend configs. ## + + +@dataclass +class FSDPWrapPolicy: + transformer_layer_cls_to_wrap: Optional[List[str]] = field( + default=None, + metadata={"help": "A list of transformer layer names for FSDP to wrap."}, + ) + + +@dataclass +class FSDPConfig: + wrap_policy: Optional[FSDPWrapPolicy] = field( + default=None, + metadata={"help": "FSDP wrap policy, specifying model layers to wrap."}, + ) + offload_params: bool = field( + default=False, + metadata={"help": "Whether to offload FSDP parameters to CPU."}, + ) + + +@dataclass +class EngineBackendConfig: + type: str = field( + default="hf", + metadata={"help": "Training backend", "choices": ["fsdp", "hf"]}, + ) + fsdp: Optional[FSDPConfig] = field( + default=None, metadata={"help": "FSDP configuration (if using FSDP backend)"} + ) + + +@dataclass +class EngineConfig: + # Model Architecture Configuration + path: str = field(default="", metadata={"help": "Path to HuggingFace checkpoint"}) + init_from_scratch: bool = field( + default=False, metadata={"help": "Initialize model weights randomly"} + ) + init_critic_from_actor: bool = field( + default=False, + metadata={"help": "Initialize critic/reward model from LM checkpoint"}, + ) + + # Training Backend Configuration + gradient_checkpointing: bool = field( + default=True, metadata={"help": "Enable gradient checkpointing"} + ) + bf16: bool = field(default=False, metadata={"help": "Use bf16 precision"}) + optimizer: Optional[OptimizerConfig] = field( + default=None, metadata={"help": "Optimizer configuration"} + ) + backend: EngineBackendConfig = field( + default_factory=EngineBackendConfig, + metadata={"help": "Training backend configuration"}, + ) + + +## Agent configurations. ## + + +@dataclass +class MathCodeSingleStepConfig: + solution_path: str = field(default="", metadata={"help": "Path to solutions"}) + + +@dataclass +class RLVRConfig: + reward_type: str = field( + default="areal-math", + metadata={ + "help": "The type of the reward function", + "choices": ["areal-math", "areal-code", "gsm8k"], + }, + ) + solution_path: str = field( + default="", metadata={"help": "Path to solutions. Required by areal-math/code."} + ) + + +@dataclass +class RolloutCollectorConfig: + type: str = field( + default="rlvr", + metadata={ + "help": "Rollout collector type", + "choices": ["rlvr", "math_code_single_step"], + }, + ) + rlvr: Optional[RLVRConfig] = field( + default=None, + metadata={"help": "The configuration for the RLVR collector"}, + ) + math_code_single_step: Optional[MathCodeSingleStepConfig] = field( + default=None, + metadata={"help": "The configuration for the single-step math/code collector"}, + ) + + +## Rollout configurations. ## + + +@dataclass +class RolloutConfig: + server_backend: str = field( + default="sglang", + metadata={"help": "Backend for serving", "choices": ["sglang", "vllm"]}, + ) + model_path: str = field(default="", metadata={"help": "Path to the rollout model"}) + collector: RolloutCollectorConfig = field( + default_factory=RolloutCollectorConfig, + metadata={"help": "Rollout collector configuration."}, + ) + num_workers: int = field( + default=1, metadata={"help": "Number of rollout worker processes"} + ) + max_concurrent_rollouts: Optional[int] = field( + default=None, + metadata={ + "help": "Maximum number of concurrent rollouts. Defaults to train batch size." + }, + ) + max_head_offpolicyness: int = field( + default=0, + metadata={"help": "Maximum off-policyness tolerance for the first token"}, + ) + filter_reward_lb: float = field( + default=-float("inf"), metadata={"help": "Lower bound for reward filtering"} + ) + filter_reward_ub: float = field( + default=float("inf"), metadata={"help": "Upper bound for reward filtering"} + ) + llm_client: LLMClientConfig = field( + default_factory=LLMClientConfig, + metadata={"help": "LLM client configuration for rollouts"}, + ) + gconfig: GenerationHyperparameters = field( + default_factory=GenerationHyperparameters, + metadata={"help": "Generation hyperparameters for rollouts"}, + ) + llm_service: LLMServiceConfig = field( + default_factory=LLMServiceConfig, metadata={"help": "LLM server configuration"} + ) + sglang: Optional[SGLangConfig] = field( + default_factory=SGLangConfig, + metadata={"help": "SGLang configuration (if using SGLang backend)"}, + ) + + +## Trainer configurations. ## + + +@dataclass +class SFTTrainerConfig: + model: EngineConfig = field( + default_factory=EngineConfig, + metadata={"help": "Model configuration for SFT training"}, + ) + mb_spec: MicroBatchSpec = field( + default_factory=MicroBatchSpec, + metadata={"help": "Micro-batch specification for SFT training"}, + ) + + +@dataclass +class GRPOTrainerConfig: + async_training: bool = field( + default=True, metadata={"help": "Enable asynchronous training mode"} + ) + actor: EngineConfig = field( + default_factory=EngineConfig, + metadata={"help": "Actor model configuration"}, + ) + ref: Optional[EngineConfig] = field( + default=None, metadata={"help": "Reference model configuration"} + ) + mb_spec: MicroBatchSpec = field( + default_factory=MicroBatchSpec, + metadata={"help": "Micro-batch specification"}, + ) + + # Core PPO/GRPO Parameters + group_adv_norm: bool = field( + default=False, + metadata={ + "help": "Normalize advantages within each prompt group rather than globally" + }, + ) + ppo_n_minibatches: int = field( + default=4, metadata={"help": "Number of minibatches for each PPO update"} + ) + eps_clip: float = field( + default=0.2, metadata={"help": "Clipping factor for policy ratio"} + ) + c_clip: Optional[float] = field( + default=None, + metadata={ + "help": "Dual clipping factor for policy ratio, must > 1.0. None disables dual clipping." + }, + ) + actor_sample_reuse: int = field( + default=1, metadata={"help": "The data reuse (aka PPO epoch) for actor."} + ) + + # Reward + group_reward_norm: bool = field( + default=False, + metadata={ + "help": "Normalize final reward of each sequence (GRPO-style) to reduce length bias" + }, + ) + reward_scaling: float = field( + default=1.0, metadata={"help": "Reward scaling factor"} + ) + reward_bias: float = field(default=0.0, metadata={"help": "Reward bias"}) + max_reward_clip: float = field( + default=20.0, metadata={"help": "Maximum absolute value for reward clipping"} + ) + mask_no_eos_with_zero: bool = field( + default=False, + metadata={ + "help": "Mask truncated generations (no EOS token) and exclude from training" + }, + ) + + # Advantage Estimation + discount: float = field( + default=1.0, metadata={"help": "Discount factor for future rewards"} + ) + gae_lambda: float = field( + default=1.0, metadata={"help": "Lambda parameter for GAE"} + ) + adv_norm: bool = field( + default=True, metadata={"help": "Enable advantage normalization"} + ) + + # KL Control + kl_ctl: float = field(default=0.1, metadata={"help": "KL divergence coefficient"}) + + # Asynchronous PPO + recompute_logprob: bool = field( + default=False, + metadata={"help": "Recompute logp and replace the logp returned by inference."}, + ) + use_decoupled_loss: bool = field( + default=False, + metadata={"help": "Use the decoupled loss. recompute_logprob must be True."}, + ) + behav_imp_weight_cap: Optional[float] = field( + default=None, + metadata={ + "help": "We filter out the tokens where behav_imp_weight exceeds behav_imp_weight_cap when computing the loss, must be > 1.0, use_decoupled_loss must be true" + }, + ) + + +@dataclass +class TrainerConfig: + type: str = field( + default="grpo", + metadata={"help": "Trainer type", "choices": ["grpo", "sft", "null"]}, + ) + grpo: Optional[GRPOTrainerConfig] = field( + default=None, metadata={"help": "GRPO trainer configuration (if using GRPO)"} + ) + sft: Optional[SFTTrainerConfig] = field( + default=None, metadata={"help": "SFT trainer configuration (if using SFT)"} + ) + + +## Entrypoint. ## + + +@dataclass +class TrainingArgs: + experiment_name: str = field( + default=MISSING, + metadata={"help": "Name of the experiment (no '_' or '/'). Required."}, + ) + trial_name: str = field( + default=MISSING, + metadata={"help": "Name of the trial (no '-' or '/'). Required."}, + ) + mode: str = field( + default="slurm", + metadata={ + "help": "Experiment launching mode.", + "choices": ["slurm", "local", "ray"], + }, + ) + wandb: WandBConfig = field( + default_factory=WandBConfig, + metadata={"help": "Weights & Biases configuration."}, + ) + tensorboard: TensorBoardConfig = field( + default_factory=TensorBoardConfig, + metadata={"help": "TensorBoard configuration. Only 'path' field required."}, + ) + allocation_mode: str = field( + default="sglang.d1p1t1+d1p1t1", + metadata={ + "help": "GPU parallel strategy allocation mode. " + "Options: manual/heuristic or pattern-based." + }, + ) + ray_temp_path: str = field( + default="/tmp/ray", metadata={"help": "Absolute path for Ray's log."} + ) + n_nodes: int = field( + default=1, metadata={"help": "Number of nodes for experiment."} + ) + n_gpus_per_node: int = field( + default=8, metadata={"help": "Number of GPUs per node for this experiment."} + ) + nodelist: Optional[str] = field( + default=None, + metadata={ + "help": "SLURM nodelist for manual allocation. " + "Format: 'slurmd-01:0,1,2,3' or 'slurmd-[01-02,03,07],COM08'." + }, + ) + exclude: Optional[str] = field( + default=None, + metadata={ + "help": "SLURM nodelist to exclude from allocation. " + "Format: 'slurmd-01:0,1,2,3' or 'slurmd-[01-02,03,07],COM08'." + }, + ) + seed: int = field(default=1, metadata={"help": "Random seed for reproducibility."}) + exp_ctrl: ExperimentSaveEvalControl = field( + default_factory=ExperimentSaveEvalControl, + metadata={"help": "Experiment save/evaluation control configuration."}, + ) + shutdown_server_on_exit: bool = field( + default=False, + metadata={"help": "Whether to shut down the LLM generation server on exit."}, + ) + cluster: ClusterSpecConfig = field( + default_factory=ClusterSpecConfig, + metadata={"help": "Cluster specification. Mainly used by slurm."}, + ) + train_dataset: DatasetConfig = field( + default_factory=DatasetConfig, metadata={"help": "Train dataset configuration"} + ) + valid_dataset: Optional[DatasetConfig] = field( + default=None, metadata={"help": "Validation dataset configuration"} + ) + rollout: Optional[RolloutConfig] = field( + default_factory=RolloutConfig, + metadata={"help": "Rollout controller configuration for RL training"}, + ) + trainer: Optional[TrainerConfig] = field( + default=None, metadata={"help": "Trainer configuration"} + ) + cpu_per_inf_proc: int = 16 + mem_per_inf_proc: int = 100000 + cpu_per_train_proc: int = 16 + mem_per_train_proc: int = 100000 + + +def prepare_training_args(argv: List[str]) -> Tuple[TrainingArgs, str]: + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", help="The path of the main configuration file", required=True + ) + args, overrides = parser.parse_known_args(argv) + + # Initialize hydra config + config_file = Path(args.config).absolute() + assert config_file.exists() + # hydra only recognize relative paths + relpath = Path( + os.path.relpath( + str(config_file), (Path(__file__).parent.parent / "cli").absolute() + ) + ) + hydra_init(config_path=str(relpath.parent), job_name="app", version_base=None) + cfg = hydra_compose( + config_name=str(relpath.name).rstrip(".yaml"), + overrides=overrides, + ) + + # Merge with the default configuration + default_cfg = OmegaConf.structured(TrainingArgs) + cfg = OmegaConf.merge(default_cfg, cfg) + cfg: TrainingArgs = OmegaConf.to_object(cfg) + + # Setup environment + from realhf.base import constants, name_resolve + + constants.set_experiment_trial_names(cfg.experiment_name, cfg.trial_name) + name_resolve.reconfigure(cfg.cluster.name_resolve) + return cfg, str(config_file) diff --git a/arealite/api/dataset_api.py b/arealite/api/dataset_api.py new file mode 100644 index 000000000..9dad8adf8 --- /dev/null +++ b/arealite/api/dataset_api.py @@ -0,0 +1,61 @@ +from dataclasses import dataclass + +from datasets import Dataset, load_dataset +from datasets.distributed import split_dataset_by_node + +from arealite.api.cli_args import DatasetConfig +from arealite.utils import TrainingArgs + + +def create_distributed_dataset(cfg: DatasetConfig, rank, world_size): + dataset = load_dataset( + cfg.path, + name=cfg.name, + split=cfg.split, + data_files=cfg.data_files, + ) + dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size) + return dataset + + +@dataclass +class DatasetFactory: + args: TrainingArgs + + def make_dataset( + self, config: DatasetConfig, rank: int, world_size: int + ) -> Dataset: + dataset = create_distributed_dataset(config, rank, world_size) + if config.preprocessor.type == "gsm8k_rl": + from arealite.impl.dataset.gsm8k import process_gsm8k_rl_dataset + + tokenizer_path = self.args.rollout.model_path + assert self.args.rollout.model_path is not None + from realhf.api.core.data_api import load_hf_tokenizer + + tokenizer = load_hf_tokenizer(tokenizer_path) + return process_gsm8k_rl_dataset( + dataset, + tokenizer=tokenizer, + reward_mode=config.preprocessor.gsm8k.reward_mode, + ) + if config.preprocessor.type == "gsm8k_sft": + from arealite.impl.dataset.gsm8k import process_gsm8k_sft_dataset + + tokenizer_path = self.args.trainer.sft.model.path + from realhf.api.core.data_api import load_hf_tokenizer + + tokenizer = load_hf_tokenizer(tokenizer_path) + return process_gsm8k_sft_dataset(dataset, tokenizer=tokenizer) + if config.preprocessor.type == "areal": + tokenizer_path = self.args.rollout.model_path + assert self.args.rollout.model_path is not None + from realhf.api.core.data_api import load_hf_tokenizer + + tokenizer = load_hf_tokenizer(tokenizer_path) + from arealite.impl.dataset.areal import process_areal_dataset + + return process_areal_dataset(dataset, tokenizer=tokenizer) + raise NotImplementedError( + f"Unknown dataset preprocessor type: {config.preprocessor.type}" + ) diff --git a/arealite/api/engine_api.py b/arealite/api/engine_api.py new file mode 100644 index 000000000..362c4d982 --- /dev/null +++ b/arealite/api/engine_api.py @@ -0,0 +1,117 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +import abc +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional + +import torch +import transformers + +from arealite.api.cli_args import EngineConfig, MicroBatchSpec, TrainingArgs +from arealite.api.io_struct import FinetuneSpec +from arealite.api.llm_client_api import LLMClient +from realhf.api.cli_args import ParallelismConfig + + +class SPMDWrapper(abc.ABC): + """A wrapper over the training/inference backends (e.g., FSDP, SGLang). + We following the design of existing libraries, such as Megatron-LM and + pytorch FSDP, which are mostly SPMD-based. + """ + + def __init__(self, args: TrainingArgs, engine_config: EngineConfig): + self.args = args + self.engine_config = engine_config + + def init_distributed(self, config: ParallelismConfig, ft_spec: FinetuneSpec): + """Initialize distributed communication groups and models. + + Models may not be loaded during __init__, but when calling this method. + """ + raise NotImplementedError() + + def train(self, mode: bool = True): + """Set the module in training mode.""" + raise NotImplementedError() + + def eval(self): + """Set the module in evaluation mode.""" + return self.train(False) + + def train_batch( + self, + input_: Dict, + mb_spec: MicroBatchSpec, + loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor], + loss_weight_fn: Callable[[Dict], float], + ) -> Dict: + """Update the model with a batch of data and a loss function.""" + raise NotImplementedError() + + @torch.no_grad() + def eval_batch( + self, + input_: Dict, + mb_spec: MicroBatchSpec, + loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor], + loss_weight_fn: Callable[[Dict], float], + ) -> torch.Tensor | None: + """Evaluate the model using the forward pass and loss function.""" + raise NotImplementedError() + + def forward( + self, + input_: Dict, + mb_spec: MicroBatchSpec, + output_seqlens: List[int] | None = None, + post_hook: Callable[[torch.Tensor, Dict], Any] | None = None, + aggregate_fn: Callable[[List[Any]], Any] = torch.cat, + ) -> Any | None: + """Run the forward pass or inference on the model.""" + raise NotImplementedError() + + def step_lr_scheduler(self): + """Step learning rate scheduler.""" + raise NotImplementedError() + + def save_model_to_hf( + self, + path: str, + tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None, + base_model_path: Optional[str] = None, + ): + raise NotImplementedError() + + def load_model_from_hf(self, path: str): + raise NotImplementedError() + + def save_optimizer_state(self, path: str): + """Save the optimizer state in a folder.""" + raise NotImplementedError() + + def load_optimizer_state(self, path: str): + """Load the optimizer state in a folder.""" + raise NotImplementedError() + + def update_weights_to(self, llm_client: LLMClient): + """Update the weights to the server by sending requests to the client.""" + raise NotImplementedError() + + +@dataclass +class EngineFactory: + args: TrainingArgs + + def make_engine(self, engine_config: EngineConfig) -> SPMDWrapper: + """Create an engine based on the configuration.""" + if engine_config.backend.type == "fsdp": + from arealite.impl.engine.fsdp_wrapper import FSDPEngine + + return FSDPEngine(self.args, engine_config) + elif engine_config.backend.type == "hf": + from arealite.impl.engine.hf_wrapper import HFEngine + + return HFEngine(self.args, engine_config) + else: + raise ValueError(f"Unsupported engine type: {engine_config.backend.type}") diff --git a/arealite/api/io_struct.py b/arealite/api/io_struct.py new file mode 100644 index 000000000..cfff165a0 --- /dev/null +++ b/arealite/api/io_struct.py @@ -0,0 +1,234 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +import enum +import itertools +import re +import uuid +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional + +import torch +from gymnasium.core import ActType, ObsType + +from arealite.api.cli_args import GenerationHyperparameters + +if TYPE_CHECKING: + from arealite.api.llm_client_api import LLMClient + + +@dataclass +class LLMServerInfo: + server_id: str + host: str + port: int + status: str = "healthy" + last_heartbeat: float = 0 + load: float = 0.0 + version: int = 0 + + +@dataclass +class LLMRequest: + rid: str = field(default_factory=lambda: str(uuid.uuid4())) + text: Optional[str] = None + input_ids: List[int] = field(default_factory=list) + gconfig: GenerationHyperparameters = field( + default_factory=GenerationHyperparameters + ) + metadata: Dict[str, Any] = field(default_factory=dict) + model_id: Optional[str] = None + + +@dataclass +class LLMResponse: + # outputs + completion: Any + input_tokens: List[int] = field(default_factory=list) + output_tokens: List[int] = field(default_factory=list) + output_logprobs: List[float] = field(default_factory=list) + output_versions: List[int] = field(default_factory=list) + stop_reason: Literal["length", "stop", "interrupt"] = "stop" + + # statistics + latency: float = float("inf") + ttft: float = float("inf") # Time to first token + itl: List[float] = field(default_factory=list) # List of inter-token latencies + + +@dataclass +class AgentInferInput: + obs: ObsType + llm_client: "LLMClient" + gconfig: GenerationHyperparameters + + +@dataclass +class AgentInferOutput: + action: ActType + llm_req: LLMRequest + llm_resp: LLMResponse + + +@dataclass +class TrajStats: + start_time: float = 0.0 + total_reward: float = 0.0 + episode_length: int = 0 + info: Dict = field(default_factory=dict) + + +@dataclass +class Trajectory: + prompt: Dict[str, Any] + data: Dict[str, torch.Tensor] + stats: TrajStats + + def to_json_compatible(self): + return { + "prompt": self.prompt, + "data": {k: v.cpu().numpy().tolist() for k, v in self.data.items()}, + "stats": { + "start_time": self.stats.start_time, + "total_reward": self.stats.total_reward, + "episode_length": self.stats.episode_length, + "info": self.stats.info, + }, + } + + @classmethod + def from_json_compatible(cls, data: Dict[str, Any]) -> "Trajectory": + return cls( + prompt=data["prompt"], + data={k: torch.tensor(v) for k, v in data["data"].items()}, + stats=TrajStats( + start_time=data["stats"]["start_time"], + total_reward=data["stats"]["total_reward"], + episode_length=data["stats"]["episode_length"], + info=data["stats"]["info"], + ), + ) + + +@dataclass +class FinetuneSpec: + total_train_epochs: int + dataset_size: int + train_batch_size: int + + @property + def total_train_steps(self): + # assuming drop_last + return self.total_train_epochs * (self.dataset_size // self.train_batch_size) + + +@dataclass +class WeightUpdateGroupMeta: + master_address: str + master_port: int + rank_offset: int + world_size: int + group_name: str = "weight_update_group" + backend: str = "nccl" + + +@dataclass +class WeightMeta: + param_name: str + shape: List[str] + dtype: str + group_name: str = "weight_update_group" + + +class AllocationType(enum.Enum): + DECOUPLED_vLLM = 1 + DECOUPLED_SGLANG = 2 + + +@dataclass +class AllocationMode: + type_: AllocationType + parallel_strat: None | Dict[str, Dict[str, int]] + + @property + def gen_tp_size(self) -> int: + return self.parallel_strat["gen"]["t"] + + @property + def gen_pp_size(self) -> int: + return self.parallel_strat["gen"]["p"] + + @property + def gen_dp_size(self) -> int: + return self.parallel_strat["gen"]["d"] + + @property + def gen_world_size(self) -> int: + return self.gen_dp_size * self.gen_pp_size * self.gen_tp_size + + @property + def train_tp_size(self) -> int: + return self.parallel_strat["*"]["t"] + + @property + def train_pp_size(self) -> int: + return self.parallel_strat["*"]["p"] + + @property + def train_dp_size(self) -> int: + return self.parallel_strat["*"]["d"] + + @property + def train_world_size(self) -> int: + return self.train_dp_size * self.train_pp_size * self.train_tp_size + + @classmethod + def from_str(cls, allocation_mode: str): + alloc_decoupled = AllocationMode.extract_decoupled_alloc(allocation_mode) + if "vllm" in allocation_mode: + return cls(AllocationType.DECOUPLED_vLLM, alloc_decoupled) + elif "sglang" in allocation_mode: + return cls(AllocationType.DECOUPLED_SGLANG, alloc_decoupled) + raise NotImplementedError(f"Failed to parse allocation: {allocation_mode}") + + @staticmethod + def extract_3d_alloc(allocation_mode: str) -> Dict | None: + for x, y, z in itertools.permutations(["d", "t", "p"]): + pattern = rf"{x}(\d+){y}(\d+){z}(\d+)" + m = re.match(pattern, allocation_mode) + if not m: + continue + a, b, c = map(int, m.groups()) + # to be consistent with the key-value pattern + return { + "*": { + x: a, + y: b, + z: c, + } + } + + @staticmethod + def extract_decoupled_alloc(allocation_mode: str) -> Dict | None: + pattern = re.compile( + r"(?:(?:vllm|sglang)\.(.+?)\+(.+))|(?:(.+?)\+(?:vllm|sglang)\.(.+))" + ) + m = pattern.match(allocation_mode) + if not m: + return + if m.group(1): + gen_alloc = m.group(1) + other_alloc = m.group(2) + else: + gen_alloc = m.group(4) + other_alloc = m.group(3) + gen_alloc = AllocationMode.extract_3d_alloc(gen_alloc) + if not gen_alloc: + return + other_alloc = AllocationMode.extract_3d_alloc( + other_alloc + ) or AllocationMode.extract_key_value_alloc(other_alloc) + if not other_alloc: + return + other_alloc.update({"gen": gen_alloc["*"]}) + return other_alloc diff --git a/arealite/api/llm_client_api.py b/arealite/api/llm_client_api.py new file mode 100644 index 000000000..ec05fa362 --- /dev/null +++ b/arealite/api/llm_client_api.py @@ -0,0 +1,142 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +import abc +import asyncio +import random +import time +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import aiohttp +import requests +import transformers + +from arealite.api.cli_args import LLMClientConfig, TrainingArgs +from arealite.api.io_struct import ( + LLMRequest, + LLMResponse, + LLMServerInfo, + WeightMeta, + WeightUpdateGroupMeta, +) +from arealite.api.llm_server_api import LLMServiceRegistry +from realhf.api.core.data_api import load_hf_tokenizer + + +class LLMClient(abc.ABC): + def __init__(self, args: TrainingArgs, client_config: LLMClientConfig): + self.args = args + self.client_config = client_config + + self.registry = LLMServiceRegistry(args.experiment_name, args.trial_name) + self.tokenizer: transformers.PreTrainedTokenizerFast = load_hf_tokenizer( + args.rollout.model_path + ) + + def select_server(self): + """Get an available healthy server.""" + servers = self.get_healthy_servers() + min_load = min([server.load for server in servers]) + servers = [server for server in servers if server.load == min_load] + return random.choice(servers) + + def get_healthy_servers(self): + servers = self.registry.get_healthy_servers() + if not servers: + raise RuntimeError("No healthy SGLang servers available") + return servers + + def wait_until_servers_ready(self): + while len(self.registry.get_healthy_servers()) == 0: + time.sleep(10) + + async def arequest_with_retry( + self, + endpoint: str, + payload: Optional[Dict[str, Any]] = None, + method: str = "POST", + max_retries: Optional[int] = None, + timeout: Optional[float] = None, + retry_delay: float = 1.0, + target_server: Optional[LLMServerInfo] = None, + ) -> tuple[aiohttp.ClientResponse, LLMServerInfo]: + timeout = timeout or self.client_config.request_timeout + last_exception = None + max_retries = max_retries or self.client_config.request_retries + + # Try with retries + for _ in range(max_retries): + if target_server is None: + server_info = self.select_server() + else: + server_info = target_server + base_url = f"http://{server_info.host}:{server_info.port}" + url = f"{base_url}{endpoint}" + + for attempt in range(max_retries): + try: + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout( + total=timeout, + sock_connect=30, + sock_read=timeout, + ) + ) as session: + if method.upper() == "GET": + response = await session.get(url) + elif method.upper() == "POST": + response = await session.post(url, json=payload) + elif method.upper() == "PUT": + response = await session.put(url, json=payload) + elif method.upper() == "DELETE": + response = await session.delete(url) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + + response.raise_for_status() + return response, server_info + + except ( + aiohttp.ClientError, + aiohttp.ClientResponseError, + asyncio.TimeoutError, + ) as e: + last_exception = e + if attempt < max_retries - 1: + await asyncio.sleep(retry_delay) + continue + raise RuntimeError( + f"Failed after {max_retries} retries each. " f"Last error: {last_exception}" + ) + + async def agenerate(self, req: LLMRequest) -> LLMResponse: + raise NotImplementedError() + + async def aupdate_weights_from_disk(self, server_info: LLMServerInfo, path: str): + raise NotImplementedError() + + async def ainit_weight_update_group( + self, server_info: LLMServerInfo, group_meta: WeightUpdateGroupMeta + ): + raise NotImplementedError() + + async def aupdate_weights_from_distributed( + self, server_info: LLMServerInfo, weight_meta: WeightMeta + ): + raise NotImplementedError() + + +@dataclass +class LLMClientFactory: + """Factory class to create LLMClient instances.""" + + args: TrainingArgs + + def make_client(self, config: LLMClientConfig) -> LLMClient: + """Create an instance of LLMClient based on the specified type.""" + if self.args.rollout.server_backend == "sglang": + from arealite.system.sglang_client import SGLangClient + + return SGLangClient(self.args, config) + raise ValueError(f"Unknown LLMClient type: {self.args.rollout.server_backend}") diff --git a/arealite/api/llm_server_api.py b/arealite/api/llm_server_api.py new file mode 100644 index 000000000..edb7674d7 --- /dev/null +++ b/arealite/api/llm_server_api.py @@ -0,0 +1,265 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +import json +import subprocess +import sys +import threading +import time +import traceback +import uuid +from dataclasses import asdict, dataclass +from datetime import datetime +from typing import List, Optional + +from arealite.api.cli_args import LLMServiceConfig, TrainingArgs +from arealite.api.io_struct import LLMServerInfo +from realhf.base import logging, name_resolve, names + +logger = logging.getLogger("LLM Server") + + +class LLMServiceRegistry: + """A registry class for dynamic server discovery.""" + + def __init__(self, expr_name: str, trial_name: str): + self.expr_name = expr_name + self.trial_name = trial_name + self.heartbeat_timeout = 30 + + def get_server_key(self, server_id: str) -> str: + return names.gen_server(self.expr_name, self.trial_name, server_id) + + def register_server(self, server_info: LLMServerInfo): + server_info.last_heartbeat = datetime.now().timestamp() + key = self.get_server_key(server_info.server_id) + name_resolve.add( + key, + json.dumps(asdict(server_info)), + keepalive_ttl=self.heartbeat_timeout, + replace=False, + ) + + def unregister_server(self, server_id: str): + try: + name_resolve.delete(self.get_server_key(server_id)) + except name_resolve.NameEntryNotFoundError: + pass + + def update_heartbeat( + self, server_id: str, status: str, load: float = 0.0, version: int = 0 + ): + try: + key = self.get_server_key(server_id) + server_data = name_resolve.get(key) + server_info = LLMServerInfo(**json.loads(server_data)) + server_info.last_heartbeat = datetime.now().timestamp() + server_info.load = load + server_info.status = status + server_info.version = version + name_resolve.add( + key, + json.dumps(asdict(server_info)), + keepalive_ttl=self.heartbeat_timeout, + replace=True, + ) + except (name_resolve.NameEntryNotFoundError, json.JSONDecodeError): + pass + + def get_healthy_servers(self) -> List[LLMServerInfo]: + servers = [] + current_time = time.time() + try: + root = names.gen_server_root(self.expr_name, self.trial_name) + server_infos = name_resolve.get_subtree(root) + for server_data in server_infos: + try: + server_info = LLMServerInfo(**json.loads(server_data)) + if ( + current_time - server_info.last_heartbeat + < self.heartbeat_timeout + and server_info.status == "healthy" + ): + servers.append(server_info) + except (json.JSONDecodeError, TypeError): + continue + except name_resolve.NameEntryNotFoundError: + pass + return servers + + +class LLMServer: + def __init__(self, args: TrainingArgs, service_config: LLMServiceConfig): + self.args = args + self.server_id = str(uuid.uuid4()) + self.registry = LLMServiceRegistry(args.experiment_name, args.trial_name) + self.running = False + self.load = 0.0 + self.process: Optional[subprocess.Popen] = None + self.service_config = service_config + + def launch_server(self) -> Optional[LLMServerInfo]: + """Launch the LLM server subprocess. Returns server info or None if failed.""" + raise NotImplementedError() + + def check_health(self) -> bool: + """Check if the server is healthy.""" + raise NotImplementedError() + + def start(self): + """Main entry point - start server and run until exit""" + try: + self._startup() + self._run() + except Exception as e: + logger.error(f"Server error: {e}") + logger.error(traceback.format_exc()) + self._graceful_exit(1) + + def _startup(self): + """Initialize and start the server""" + self.running = True + + # Launch server process + server_info = self.launch_server() + if not server_info or not self.process: + raise RuntimeError("Failed to launch server") + + logger.info(f"Server {self.server_id} starting") + + # Wait for server to be ready + if not self._wait_for_ready(): + raise RuntimeError( + f"Server failed to become ready in {self.service_config.startup_timeout}s" + ) + + # Register with service registry + self.registry.register_server(server_info) + + # Start health monitoring + health_thread = threading.Thread(target=self._health_monitor, daemon=True) + health_thread.start() + + logger.info( + f"Server {self.server_id} ready and registered at http://{server_info.host}:{server_info.port}" + ) + + def _wait_for_ready(self) -> bool: + """Wait for server to become healthy""" + start_time = time.time() + while time.time() - start_time < self.service_config.startup_timeout: + if not self.running or (self.process and self.process.poll() is not None): + return False + if self.check_health(): + return True + time.sleep(2) + return False + + def _run(self): + """Main server loop""" + try: + while self.running: + # Check if subprocess died + if self.process and self.process.poll() is not None: + logger.error( + f"Server process died (code: {self.process.returncode})" + ) + self._graceful_exit(1) + time.sleep(1) + except KeyboardInterrupt: + logger.info("Keyboard interrupt received") + self._graceful_exit(0) + + def _health_monitor(self): + """Monitor server health and exit if unhealthy""" + failures = 0 + max_failures = self.service_config.max_unhealth_count + + while self.running: + try: + # Check process first + if self.process and self.process.poll() is not None: + logger.error("Server process died") + self._graceful_exit(1) + break + + # Check health + if self.check_health(): + failures = 0 + self.registry.update_heartbeat(self.server_id, "healthy", self.load) + else: + failures += 1 + logger.warning(f"Health check failed ({failures}/{max_failures})") + + if failures >= max_failures: + logger.error("Too many health check failures") + self.registry.update_heartbeat( + self.server_id, "unhealthy", self.load + ) + if self.service_config.graceful_shutdown_on_unhealthy: + self._graceful_exit(1) + break + + except Exception as e: + logger.error(f"Health monitor error: {e}") + logger.error(traceback.format_exc()) + failures += 1 + if ( + failures >= max_failures + and self.service_config.graceful_shutdown_on_unhealthy + ): + self._graceful_exit(1) + break + + time.sleep(self.service_config.health_check_interval) + + def _graceful_exit(self, exit_code: int): + """Clean shutdown and exit""" + if not self.running: + return + + logger.info(f"Graceful shutdown initiated (exit code: {exit_code})") + self.running = False + + # Cleanup registry + try: + self.registry.unregister_server(self.server_id) + except Exception as e: + logger.warning(f"Registry cleanup failed: {e}") + logger.warning(traceback.format_exc()) + + # Stop process + if self.process and self.process.poll() is None: + try: + self.process.terminate() + self.process.wait(timeout=5) + logger.info("Server terminated gracefully") + except subprocess.TimeoutExpired: + logger.warning("Force killing server") + try: + self.process.kill() + self.process.wait() + except (ProcessLookupError, OSError): + pass + except Exception as e: + logger.error(f"Process cleanup failed: {e}") + logger.error(traceback.format_exc()) + + if exit_code != 0: + sys.exit(exit_code) + + +@dataclass +class LLMServerFactory: + args: TrainingArgs + + def make_server(self, server_config: LLMServiceConfig) -> LLMServer: + """Create an LLM server instance based on the configuration.""" + if self.args.rollout.server_backend == "sglang": + from arealite.system.sglang_server import SGLangServer + + return SGLangServer(self.args, server_config) + else: + raise ValueError( + f"Unsupported server backend: {self.args.rollout.server_backend}" + ) diff --git a/arealite/api/rollout_api.py b/arealite/api/rollout_api.py new file mode 100644 index 000000000..619c30a31 --- /dev/null +++ b/arealite/api/rollout_api.py @@ -0,0 +1,144 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +import abc +import functools +from dataclasses import dataclass +from typing import Any, Callable, Optional, SupportsFloat + +from gymnasium import Env +from gymnasium.core import ActType, ObsType +from gymnasium.utils import seeding + +from arealite.api.cli_args import ( + GenerationHyperparameters, + RolloutCollectorConfig, + TrainingArgs, +) +from arealite.api.io_struct import AgentInferInput, AgentInferOutput, Trajectory +from arealite.api.llm_client_api import LLMClient + + +class Agent(abc.ABC): + def __init__(self, args: TrainingArgs): + self.args = args + + async def aact(self, inp: AgentInferInput) -> AgentInferOutput: + """Async version of act. Given an observation, return an action and data used for RL training.""" + raise NotImplementedError() + + async def areset(self) -> None: + """Async version of reset. Resets the agent's memory.""" + raise NotImplementedError() + + +# Re-export the gymnasium environment class +class Environment(abc.ABC, Env): + def __init__(self, args: TrainingArgs): + self.args = args + + @abc.abstractmethod + def step( + self, action: ActType + ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: + raise NotImplementedError() + + @abc.abstractmethod + def reset( + self, + *, + seed: int | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[ObsType, dict[str, Any]]: # type: ignore + # Initialize the RNG if the seed is manually passed + if seed is not None: + self._np_random, self._np_random_seed = seeding.np_random(seed) + + +class RolloutCollector(abc.ABC): + + def __init__( + self, + args: TrainingArgs, + config: RolloutCollectorConfig, + agent: Agent | None = None, + env: Environment | None = None, + reward_func: Callable | None = None, + ): + self.args = args + self.config = config + + # Used in agentic scenarios + self.agent = agent + self.env = env + + # Used in RLVR + self.reward_func = reward_func + + async def arun_episode( + self, + llm_client: LLMClient, + gconfig: GenerationHyperparameters, + env_option: Optional[Any] = None, + seed: Optional[int] = None, + ) -> Trajectory: + """Async version of run_episode. Run a single episode and return the trajectory.""" + raise NotImplementedError() + + +@dataclass +class RolloutCollectorFactory: + args: TrainingArgs + + def make_collector(self, config: RolloutCollectorConfig) -> RolloutCollector: + if config.type == "rlvr": + from arealite.impl.rlvr.rlvr_collector import RlvrCollector + + rlvr_config = config.rlvr + assert rlvr_config is not None + if rlvr_config.reward_type == "areal-math": + from arealite.impl.rlvr.rewards.areal_math import math_reward + + reward_fn = functools.partial( + math_reward, dataset_path=rlvr_config.solution_path + ) + elif rlvr_config.reward_type == "areal-code": + from arealite.impl.rlvr.rewards.areal_code import code_reward + + reward_fn = functools.partial( + code_reward, dataset_path=rlvr_config.solution_path + ) + elif rlvr_config.reward_type == "gsm8k": + from arealite.impl.rlvr.rewards.gsm8k import ( + gsm8k_reward_fn as reward_fn, + ) + else: + raise NotImplementedError( + f"Unknown reward type: {rlvr_config.reward_type}" + ) + + return RlvrCollector( + self.args, + config=config, + reward_fn=reward_fn, + ) + if config.type == "math_code_single_step": + from arealite.impl.agentic.math_code_single_step import ( + MathCodeAgent, + MathCodeSingleStepCollector, + MathCodeSingleStepEnv, + ) + + agent = MathCodeAgent(self.args) + env = MathCodeSingleStepEnv( + self.args, + solution_path=config.math_code_single_step.solution_path, + ) + + return MathCodeSingleStepCollector( + self.args, + config=config, + agent=agent, + env=env, + ) + raise NotImplementedError(f"Unknown agent type: {config.type}") diff --git a/arealite/api/trainer_api.py b/arealite/api/trainer_api.py new file mode 100644 index 000000000..b1d836a4f --- /dev/null +++ b/arealite/api/trainer_api.py @@ -0,0 +1,133 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +import abc +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Union + +import torch.distributed as dist +from datasets import Dataset +from torchdata.stateful_dataloader import StatefulDataLoader + +from arealite.api.cli_args import TrainerConfig, TrainingArgs +from realhf.base import constants + +if TYPE_CHECKING: + from arealite.system.rollout_controller import RolloutController + +# 4. use huggingface.trainerstate +# TODO: how to do checkpointing? + +# follow the signature of transformers.Trainer if possible + + +class Trainer(abc.ABC): + def __init__( + self, + args: TrainingArgs, + trainer_config: TrainerConfig, + train_dataset: Dataset, + valid_dataset: Optional[Dataset] = None, + rollout_controller: Optional["RolloutController"] = None, + ): + self.args = args + self.trainer_config = trainer_config + + self.train_dataset = train_dataset + self.valid_dataset = valid_dataset + + self.rollout_controller = rollout_controller + + self.train_dataloader = None + self.valid_dataloader = None + + def create_train_dataloader(self): + cfg = self.args.train_dataset + if dist.is_initialized(): + batch_size = cfg.batch_size // dist.get_world_size() + else: + batch_size = cfg.batch_size + self.train_dataloader = StatefulDataLoader( + dataset=self.train_dataset, + batch_size=batch_size, + shuffle=cfg.shuffle, + pin_memory=cfg.pin_memory, + num_workers=cfg.num_workers, + drop_last=True, + collate_fn=lambda x: x, + ) + + def create_valid_dataloader(self): + if self.args.valid_dataset is None: + return + cfg = self.args.valid_dataset + if dist.is_initialized(): + batch_size = cfg.batch_size // dist.get_world_size() + else: + batch_size = cfg.batch_size + self.valid_dataloader = StatefulDataLoader( + dataset=self.valid_dataset, + batch_size=batch_size, + shuffle=cfg.shuffle, + pin_memory=cfg.pin_memory, + num_workers=cfg.num_workers, + drop_last=True, + collate_fn=lambda x: x, + ) + + @property + def local_train_batch_size(self): + if not dist.is_initialized(): + return self.args.train_dataset.batch_size + return self.args.train_dataset.batch_size // dist.get_world_size() + + # TODO: check HF trainer signature + def train(self, resume_from_checkpoint: Optional[Union[str, bool]] = None): + raise NotImplementedError() + + def get_save_checkpoint_path( + self, epoch: int, step: int, globalstep: int, name: str = "model" + ): + path = os.path.join( + constants.get_save_path(self.args), + name, + f"epoch{epoch}epochstep{step}globalstep{globalstep}", + ) + os.makedirs(path, exist_ok=True) + return path + + +@dataclass +class TrainerFactory: + args: TrainingArgs + + def make_trainer( + self, + config: TrainerConfig, + train_dataset: Dataset, + valid_dataset: Optional[Dataset] = None, + rollout_controller: Optional["RolloutController"] = None, + ) -> Trainer: + if config.type == "grpo": + from arealite.impl.trainer.grpo import SpmdGRPOTrainer + + return SpmdGRPOTrainer( + self.args, + config, + train_dataset=train_dataset, + valid_dataset=valid_dataset, + rollout_controller=rollout_controller, + ) + elif config.type == "sft": + from arealite.impl.trainer.sft import SFTTrainer + + return SFTTrainer( + self.args, + config, + train_dataset=train_dataset, + valid_dataset=valid_dataset, + rollout_controller=rollout_controller, + ) + else: + raise NotImplementedError(f"Unknown trainer type: {config.type}") diff --git a/arealite/cli/launch_server.py b/arealite/cli/launch_server.py new file mode 100644 index 000000000..2da84620b --- /dev/null +++ b/arealite/cli/launch_server.py @@ -0,0 +1,20 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +import sys + +from arealite.api.cli_args import TrainingArgs, prepare_training_args +from arealite.api.llm_server_api import LLMServerFactory +from realhf.base import seeding + + +def main(): + """Main entry point for launching the LLM server.""" + cfg: TrainingArgs = prepare_training_args(sys.argv[1:])[0] + seeding.set_random_seed(cfg.seed, "llm_server") + server = LLMServerFactory(cfg).make_server(cfg.rollout.llm_service) + server.start() + + +if __name__ == "__main__": + main() diff --git a/arealite/cli/launch_trainer.py b/arealite/cli/launch_trainer.py new file mode 100644 index 000000000..982ecac55 --- /dev/null +++ b/arealite/cli/launch_trainer.py @@ -0,0 +1,58 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +import os +import sys + +import torch.distributed as dist +from torch.distributed.elastic.multiprocessing.errors import record + +from arealite.api.cli_args import TrainingArgs, prepare_training_args +from arealite.api.dataset_api import DatasetFactory +from arealite.api.rollout_api import RolloutCollectorFactory +from arealite.api.trainer_api import TrainerFactory +from arealite.system.rollout_controller import RolloutController +from realhf.base import seeding + + +@record +def main(): + """Main entry point for launching the trainer.""" + cfg: TrainingArgs = prepare_training_args(sys.argv[1:])[0] + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + seeding.set_random_seed(cfg.seed, f"trainer{rank}") + + # Initialize the global pytorch distributed communication group. + dist.init_process_group("nccl") + + # Load and split dataset + dataset_factory = DatasetFactory(cfg) + train_dataset = dataset_factory.make_dataset(cfg.train_dataset, rank, world_size) + valid_dataset = None + if cfg.valid_dataset is not None: + valid_dataset = dataset_factory.make_dataset( + cfg.valid_dataset, rank, world_size + ) + + # Create rollout controller for online training and evaluation. + rollout_controller = None + if cfg.rollout is not None: + rollout_factory = RolloutCollectorFactory(cfg) + collector = rollout_factory.make_collector(cfg.rollout.collector) + rollout_controller = RolloutController(cfg, cfg.rollout, collector=collector) + + # If trainer is given, run RL or offline training. + if cfg.trainer is not None: + trainer_factory = TrainerFactory(cfg) + trainer = trainer_factory.make_trainer( + cfg.trainer, + train_dataset=train_dataset, + valid_dataset=valid_dataset, + rollout_controller=rollout_controller, + ) + trainer.train() + + +if __name__ == "__main__": + main() diff --git a/arealite/config/async_grpo.yaml b/arealite/config/async_grpo.yaml new file mode 100644 index 000000000..8f5e094e4 --- /dev/null +++ b/arealite/config/async_grpo.yaml @@ -0,0 +1,171 @@ +# Basic experiment info +experiment_name: gsm8k-test +trial_name: my-trial-3 +seed: 1 +mode: local +wandb: + mode: disabled + entity: null + project: null + name: null + job_type: null + group: null + notes: null + tags: null + config: null +tensorboard: + path: null + +exp_ctrl: + total_train_epochs: 5 + save_freq_epochs: 1 + save_freq_steps: null + save_freq_secs: null + ckpt_freq_epochs: null + ckpt_freq_steps: null + ckpt_freq_secs: 600 + eval_freq_epochs: null + eval_freq_steps: null + eval_freq_secs: null + benchmark_steps: null + benchmark_n_seqs: null + +# whether to allow persistent servers +shutdown_server_on_exit: true + +# Allocation and parallelism +allocation_mode: sglang.d4p1t1+d4p1t1 +n_nodes: 1 +n_gpus_per_node: 8 + +# Cluster configuration +ray_temp_path: /tmp/ray +cluster: + cluster_name: local + fileroot: /tmp/arealite/ + n_nodes: 1 + n_gpus_per_node: 8 + name_resolve: + type: nfs + nfs_record_root: /tmp/arealite/name_resolve/ + +# Datasets +train_dataset: + path: json + name: null + split: train + data_files: /storage/openpsi/users/xushusheng.xss/training_data/boba_106k_0319.jsonl + batch_size: 32 + shuffle: True + preprocessor: + type: areal + +valid_dataset: null + +# Rollout config +rollout: + collector: + type: rlvr + rlvr: + reward_type: areal-math + solution_path: /storage/openpsi/users/xushusheng.xss/training_data/boba_106k_0319.jsonl + num_workers: 1 + max_concurrent_rollouts: null + max_head_offpolicyness: 0 + filter_reward_lb: -10000 + filter_reward_ub: 10000 + server_backend: sglang + model_path: /storage/openpsi/models/Qwen__Qwen3-1.7B/ + gconfig: + n_samples: 16 + max_new_tokens: 512 + min_new_tokens: 0 + top_p: 1.0 + top_k: 1000000 + temperature: 1.0 + llm_client: + schedule_policy: round_robin + request_timeout: 3600 + request_retries: 3 + llm_service: + served_model_name: null + health_check_interval: 5 + startup_timeout: 300 + max_unhealth_count: 3 + graceful_shutdown_on_unhealthy: true + sglang: + dtype: "bfloat16" + enable_mixed_chunk: false + enable_torch_compile: false + torch_compile_max_bs: 32 + cuda_graph_max_bs: null + cuda_graph_bs: null + triton_attention_reduce_in_fp32: false + triton_attention_num_kv_splits: 8 + num_continuous_decode_steps: 1 + attention_backend: "flashinfer" + sampling_backend: null + context_length: 32768 + mem_fraction_static: 0.9 + max_running_requests: null + chunked_prefill_size: -1 + max_prefill_tokens: 32768 + schedule_policy: "lpm" + schedule_conservativeness: 1.0 + cpu_offload_gb: 0 + kv_cache_dtype: "auto" + log_level: "warning" + log_level_http: "warning" + log_requests: false + log_requests_level: 0 + show_time_cost: false + enable_metrics: true + decode_log_interval: 1 + +# Trainer +trainer: + type: grpo + grpo: + async_training: true + actor: + path: /storage/openpsi/models/Qwen__Qwen3-1.7B/ + init_from_scratch: false + gradient_checkpointing: false + bf16: true + optimizer: + type: adam + lr: 1.0e-6 + weight_decay: 0.05 + beta1: 0.9 + beta2: 0.999 + eps: 1.0e-08 + min_lr_ratio: 0.0 + lr_scheduler_type: constant + warmup_steps_proportion: 0.001 + initial_loss_scale: 4294967296.0 + min_loss_scale: 1.0 + loss_scale_window: 5.0 + hysteresis: 2 + gradient_clipping: 1.0 + backend: + type: fsdp + ref: null + mb_spec: + max_tokens_per_mb: 10240 + # Algorithm + group_adv_norm: False + ppo_n_minibatches: 4 + eps_clip: 0.2 + c_clip: null + reward_scaling: 10.0 + reward_bias: -0.5 + max_reward_clip: 20.0 + mask_no_eos_with_zero: false + discount: 1.0 + gae_lambda: 1.0 + adv_norm: true + kl_ctl: 0.0 + recompute_logprob: true + use_decoupled_loss: true + behav_imp_weight_cap: null + diff --git a/arealite/config/sft.yaml b/arealite/config/sft.yaml new file mode 100644 index 000000000..2fd531e10 --- /dev/null +++ b/arealite/config/sft.yaml @@ -0,0 +1,104 @@ +# Basic experiment info +experiment_name: test-sft +trial_name: test-trial +seed: 1 +mode: ray +n_nodes: 1 +n_gpus_per_node: 8 + +wandb: + mode: disabled + entity: null + project: null + name: null + job_type: null + group: null + notes: null + tags: null + config: null + +tensorboard: + path: null + +exp_ctrl: + total_train_epochs: 2 + save_freq_epochs: 1 + save_freq_steps: null + save_freq_secs: null + ckpt_freq_epochs: null + ckpt_freq_steps: null + ckpt_freq_secs: 600 + eval_freq_epochs: 1 + eval_freq_steps: null + eval_freq_secs: null + benchmark_steps: null + benchmark_n_seqs: null + +ray_temp_path: /tmp/ray +cluster: + cluster_name: local + fileroot: /tmp/arealite/ + n_nodes: 32 + n_gpus_per_node: 8 + name_resolve: + type: nfs + nfs_record_root: /tmp/arealite/nfs_record_root/ + +train_dataset: + path: openai/gsm8k + preprocessor: + type: gsm8k_sft + name: main + split: train + data_files: null + batch_size: 128 + shuffle: true + pin_memory: true + num_workers: 4 + +valid_dataset: + path: openai/gsm8k + preprocessor: + type: gsm8k_sft + name: main + split: test + data_files: null + batch_size: 128 + shuffle: true + pin_memory: true + num_workers: 4 + +trainer: + type: sft + sft: + model: + path: /storage/openpsi/models/Qwen__Qwen3-1.7B/ + init_from_scratch: false + gradient_checkpointing: false + bf16: false + optimizer: + type: adam + lr: 2.0e-05 + weight_decay: 0.05 + beta1: 0.9 + beta2: 0.95 + eps: 1.0e-05 + min_lr_ratio: 0.0 + lr_scheduler_type: constant + warmup_steps_proportion: 0.001 + initial_loss_scale: 4294967296.0 + min_loss_scale: 1.0 + loss_scale_window: 5.0 + hysteresis: 2 + gradient_clipping: 1.0 + backend: + type: fsdp + fsdp: + wrap_policy: + transformer_layer_cls_to_wrap: null + offload_params: false + mb_spec: + n_mbs: 1 + max_tokens_per_mb: 10240 + +rollout: null \ No newline at end of file diff --git a/arealite/impl/agentic/math_code_single_step.py b/arealite/impl/agentic/math_code_single_step.py new file mode 100644 index 000000000..2be4c815a --- /dev/null +++ b/arealite/impl/agentic/math_code_single_step.py @@ -0,0 +1,232 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +import os +import re +import uuid +from dataclasses import dataclass +from datetime import datetime +from functools import lru_cache +from typing import Any, List, Optional, Tuple + +import torch + +from arealite.api.cli_args import GenerationHyperparameters, TrainingArgs +from arealite.api.io_struct import ( + AgentInferInput, + AgentInferOutput, + LLMRequest, + Trajectory, + TrajStats, +) +from arealite.api.llm_client_api import LLMClient +from arealite.api.rollout_api import Agent, Environment, RolloutCollector +from arealite.utils import pad_sequences_to_tensors +from functioncall.code.local_verify import code_verify as local_code_verify +from functioncall.code.verify import code_verify +from functioncall.math.verify import math_verify +from realhf.impl.dataset.math_code_dataset import load_metadata +from realhf.impl.dataset.math_parser import parse_lines_in_parallel + +ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else False +math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel +code_verify_call = code_verify if ENABLE_FUNCTION_CALL else local_code_verify + + +@lru_cache(maxsize=128) +def _load_metadata_cached(dataset_path: str): + """Cached version of load_metadata to avoid reloading metadata each time.""" + return load_metadata(dataset_path) + + +def extract_code(text, min_length=20): + """Extract code blocks from text.""" + code_pattern = r"(?i)```(?:python|py|cpp|CPP)?\s*\n?(.*?)\n?```" + code_blocks = re.findall(code_pattern, text, re.DOTALL) + valid_blocks = [] + for block in code_blocks: + clean_block = block.strip() + if len(clean_block) < min_length: + continue + valid_blocks.append(clean_block) + + if not valid_blocks: + return None + # return the last code block + return valid_blocks[-1] + + +@dataclass +class MathCodeAction: + query_id: str + answer: str + + +@dataclass +class MathCodeObs: + query_id: str + prompt_ids: List[int] + + +class MathCodeSingleStepEnv(Environment): + """Math and Code single-step verification environment.""" + + def __init__(self, args: TrainingArgs, solution_path: str): + super().__init__(args) + self.id2info, _ = _load_metadata_cached(solution_path) + + def reset( + self, seed: Optional[int] = None, options: Optional[dict] = None + ) -> Tuple[Any, dict]: + """Reset the environment.""" + super().reset(seed=seed) + try: + prompt_ids = options["input_ids"] + query_id = options["query_id"] + except KeyError: + raise RuntimeError("`input_ids` and `query_id` must be set in env options.") + # Return dummy observation and info + return MathCodeObs(query_id=query_id, prompt_ids=prompt_ids), {} + + def step( + self, action: MathCodeAction + ) -> Tuple[MathCodeObs, float, bool, bool, dict]: + """Execute one step in the environment.""" + query_id = action.query_id + answer = action.answer + + query_id = query_id.split("@")[0] + cur_task = self.id2info[query_id]["task"] + + if cur_task == "math": + # Run math verification + format_reward = math_verify_call(self.id2info, [answer], [query_id])[0] + elif cur_task == "code": + # Extract code blocks and run code verification + extracted_answer = extract_code(answer) + format_reward = code_verify_call( + self.id2info, [extracted_answer], [query_id] + )[0] + else: + raise NotImplementedError(f"Task type '{cur_task}' not implemented") + + # Return: observation, reward, terminated, truncated, info + terminated = True # Single step environment always terminates + truncated = False + info = {"task": cur_task, "query_id": query_id} + + return ( + None, + format_reward, + terminated, + truncated, + info, + ) + + +class MathCodeAgent(Agent): + + async def aact(self, inp: AgentInferInput) -> AgentInferOutput: + """Async version of act. Given an observation, return an action.""" + # Extract information from observation + obs: MathCodeObs = inp.obs + query_id = obs.query_id + prompt_ids = obs.prompt_ids + + # Create LLM request + llm_req = LLMRequest( + rid=str(query_id) + "-" + str(uuid.uuid4()), + input_ids=prompt_ids, + gconfig=inp.gconfig, + ) + + # Generate response using async LLM client + llm_resp = await inp.llm_client.agenerate(llm_req) + + # Extract answers from completion + answer = llm_resp.completion + + return AgentInferOutput( + action=MathCodeAction(query_id=query_id, answer=answer), + llm_req=llm_req, + llm_resp=llm_resp, + ) + + def reset(self): + """Resets the agent's memory.""" + pass # Stateless agent, no memory to reset + + async def areset(self): + """Async version of reset. Resets the agent's memory.""" + pass # Stateless agent, no memory to reset + + +class MathCodeSingleStepCollector(RolloutCollector): + + async def arun_episode( + self, + llm_client: LLMClient, + gconfig: GenerationHyperparameters, + env_option: Optional[Any] = None, + seed: Optional[int] = None, + ) -> Trajectory: + """Async version of run_episode. Run a single episode and return the trajectory.""" + # Reset the environment and the agent's memory. + obs, _ = self.env.reset(options=env_option, seed=seed) + await self.agent.areset() + + data = [] + rewards = [] + tik = datetime.now().timestamp() + ret = 0.0 + ep_len = 0 + + done = False + # Episode loop. + while not done: + # Take an action by sending a request to generation server. + agent_infer_in = AgentInferInput( + obs=obs, gconfig=gconfig, llm_client=llm_client + ) + agent_infer_out = await self.agent.aact(agent_infer_in) + action = agent_infer_out.action + + # Advance one step in the environment. + nex_obs, reward, terminated, truncated, _ = self.env.step(action) + + # Collect the step data. + resp = agent_infer_out.llm_resp + input_len = len(resp.input_tokens) + output_len = len(resp.output_tokens) + + input_ids = resp.input_tokens + resp.output_tokens + prompt_mask = [1] * input_len + [0] * output_len + logprobs = [0.0] * input_len + resp.output_logprobs + versions = [-1] * input_len + resp.output_versions + + d = dict( + input_ids=torch.tensor(input_ids, dtype=torch.long), + prompt_mask=torch.tensor(prompt_mask, dtype=torch.bool), + logprobs=torch.tensor(logprobs, dtype=torch.float32), + versions=torch.tensor(versions, dtype=torch.long), + ) + data.append(d) + rewards.append(reward) + + ret += float(reward) + ep_len += 1 + + # Prepare information for the next step. + done = terminated or truncated + obs = nex_obs + + return Trajectory( + prompt=env_option, + data=dict(rewards=torch.tensor(rewards), **pad_sequences_to_tensors(data)), + stats=TrajStats( + start_time=tik, + total_reward=ret, + episode_length=ep_len, + info={}, + ), + ) diff --git a/arealite/impl/dataset/areal.py b/arealite/impl/dataset/areal.py new file mode 100644 index 000000000..543a7848a --- /dev/null +++ b/arealite/impl/dataset/areal.py @@ -0,0 +1,7 @@ +from datasets import Dataset + + +def process_areal_dataset(dataset: Dataset, tokenizer): + return dataset.map( + lambda x: tokenizer(x["prompt"], return_attention_mask=False), batched=True + ) diff --git a/arealite/impl/dataset/gsm8k.py b/arealite/impl/dataset/gsm8k.py new file mode 100644 index 000000000..0c69c156f --- /dev/null +++ b/arealite/impl/dataset/gsm8k.py @@ -0,0 +1,46 @@ +from datasets import Dataset + + +def process_gsm8k_rl_dataset(dataset: Dataset, tokenizer, reward_mode): + def process_example(example, idx): + # Add query_id column + example["query_id"] = str(idx) + example["prompt"] = example["question"] + + # used by the reward function + example["method"] = reward_mode + return example + + dataset = dataset.map( + lambda example, idx: process_example(example, idx), + with_indices=True, + ) + return dataset.map( + lambda x: tokenizer(x["question"], return_attention_mask=False), batched=True + ) + + +def process_gsm8k_sft_dataset(dataset: Dataset, tokenizer): + def process_example(example, idx): + # Add query_id column + example["query_id"] = str(idx) + example["prompt"] = example["question"] + example["seq"] = example["prompt"] + example["answer"] + tokenizer.eos_token + return example + + dataset = dataset.map( + lambda example, idx: process_example(example, idx), + with_indices=True, + ) + + def _tokenize(example): + example["prompt"] = tokenizer(example["prompt"], return_attention_mask=False)[ + "input_ids" + ] + example["seq"] = tokenizer(example["seq"], return_attention_mask=False)[ + "input_ids" + ] + return example + + dataset = dataset.map(lambda x: _tokenize(x), batched=True) + return dataset diff --git a/arealite/impl/engine/fsdp_wrapper.py b/arealite/impl/engine/fsdp_wrapper.py new file mode 100644 index 000000000..557e5312e --- /dev/null +++ b/arealite/impl/engine/fsdp_wrapper.py @@ -0,0 +1,553 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +import asyncio +import functools +import math +import os +from typing import Any, Callable, Dict, List, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +import transformers +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import StateDictType +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + PreTrainedModel, + get_constant_schedule_with_warmup, + get_linear_schedule_with_warmup, +) + +from arealite.api.cli_args import EngineConfig, FSDPConfig, MicroBatchSpec, TrainingArgs +from arealite.api.engine_api import SPMDWrapper +from arealite.api.io_struct import FinetuneSpec +from arealite.api.llm_client_api import LLMClient +from arealite.utils import ( + get_state_dict_from_repo_id_or_path, + recorder_list, + split_dict_tensor_with_cu_seqlens, + unpack_sequence, +) +from realhf.api.cli_args import ParallelismConfig +from realhf.base import constants +from realhf.base.pkg_version import is_version_greater_or_equal + +if is_version_greater_or_equal("torch", "2.6.0"): + from torch.distributed.fsdp import ( + CPUOffloadPolicy, + FSDPModule, + MixedPrecisionPolicy, + fully_shard, + ) +elif is_version_greater_or_equal("torch", "2.4.0"): + from torch.distributed._composable.fsdp import ( + CPUOffloadPolicy, + FSDPModule, + MixedPrecisionPolicy, + fully_shard, + ) +else: + fully_shard, MixedPrecisionPolicy, FSDPModule, CPUOffloadPolicy = ( + None, + None, + None, + None, + ) + +from torch.distributed.device_mesh import init_device_mesh + + +def fsdp2_clip_grad_norm_( + parameters, max_norm, norm_type=2.0, error_if_nonfinite=False, foreach=None +): + """torch.nn.utils.clip_grad_norm_ cann't run on cpu parameter DTensor""" + from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + else: + # prevent generators from being exhausted + parameters = list(parameters) + grads = [p.grad for p in parameters if p.grad is not None] + total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach) + total_norm = total_norm.to(torch.cuda.current_device(), non_blocking=True) + _clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) + return total_norm + + +def create_fsdp_device_mesh(shard_size, world_size): + if shard_size < 0 or shard_size >= world_size: + device_mesh = init_device_mesh( + "cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",) + ) + else: + device_mesh = init_device_mesh( + "cuda", + mesh_shape=(world_size // shard_size, shard_size), + mesh_dim_names=("ddp", "fsdp"), + ) + return device_mesh + + +def apply_fsdp2(model, fsdp_kwargs, wrap_policy): + """model: AutoModelForCausalLM""" + assert ( + CPUOffloadPolicy is not None + ), "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" + + default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", list()) + fsdp_transformer_layer_cls_to_wrap = ( + wrap_policy.transformer_layer_cls_to_wrap if wrap_policy is not None else list() + ) + if not fsdp_transformer_layer_cls_to_wrap: + fsdp_transformer_layer_cls_to_wrap = default_transformer_cls_names_to_wrap + + if isinstance(fsdp_transformer_layer_cls_to_wrap, str): + fsdp_transformer_layer_cls_to_wrap = [fsdp_transformer_layer_cls_to_wrap] + + assert ( + len(fsdp_transformer_layer_cls_to_wrap) > 0 + and fsdp_transformer_layer_cls_to_wrap[0] is not None + ) + + modules = [] + for name, module in model.named_modules(): + if module.__class__.__name__ in fsdp_transformer_layer_cls_to_wrap or ( + isinstance(module, nn.Embedding) and not model.config.tie_word_embeddings + ): + modules.append(module) + + for idx, module in enumerate(modules): + fully_shard(module, **fsdp_kwargs) + fully_shard( + model, **fsdp_kwargs + ) # fsdp2 will not reshard_after_forward for root module + + +def fsdp2_load_full_state_dict( + model: PreTrainedModel, + full_state: dict, + cpu_offload=None, + tie_word_embeddings=False, +): + """ + Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the + parameters from rank 0 to all other ranks. This function modifies the model in-place. + + Args: + model (`torch.nn.Module`): The model to load the state dict into + full_state (`dict`): The full state dict to load, can only be on rank 0 + """ + from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + set_model_state_dict, + ) + + device = torch.cuda.current_device() + model = model.to(device=device, non_blocking=True) + cpu_offload = cpu_offload is not None + options = StateDictOptions( + full_state_dict=True, + cpu_offload=cpu_offload, + broadcast_from_rank0=True, + strict=not tie_word_embeddings, + ) + set_model_state_dict(model, full_state, options=options) + + if tie_word_embeddings: + model.tie_weights() + + # rotary_emb is not in state_dict, so we need to broadcast it manually + for name, buf in model.named_buffers(): + dist.broadcast(buf, src=0) + + if cpu_offload: + model.to("cpu", non_blocking=True) + for buf in model.buffers(): + buf.data = buf.data.to(device) + + +def get_cosine_schedule_with_warmup( + optimizer: torch.optim.Optimizer, + num_warmup_steps: int, + num_training_steps: int, + min_lr_ratio: float = 0.0, + num_cycles: float = 0.5, + last_epoch: int = -1, +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + Args: + optimizer (:class:`~torch.optim.Optimizer`): + The optimizer for which to schedule the learning rate. + num_warmup_steps (:obj:`int`): + The number of steps for the warmup phase. + num_training_steps (:obj:`int`): + The total number of training steps. + min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0): + The minimum lr ratio w.r.t the maximum. + num_cycles (:obj:`float`, `optional`, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (:obj:`int`, `optional`, defaults to -1): + The index of the last epoch when resuming training. + Return: + :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + assert min_lr_ratio >= 0 and min_lr_ratio <= 1.0 + coef = (1 - min_lr_ratio) * 0.5 + intercept = (1 + min_lr_ratio) * 0.5 + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return min_lr_ratio + (1.0 - min_lr_ratio) * ( + float(current_step) / float(max(1, num_warmup_steps)) + ) + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) + x = math.cos(math.pi * float(num_cycles) * 2.0 * progress) + return max(min_lr_ratio, x * coef + intercept) + + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch) + + +class FSDPEngine(SPMDWrapper): + """Simplified FSDP engine for transformer models.""" + + def __init__(self, args: TrainingArgs, engine_config: EngineConfig): + super().__init__(args, engine_config) + assert is_version_greater_or_equal( + "torch", "2.4.0" + ), f"arealite only supports FSDP2, which requires torch>=2.4.0" + + self.fsdp_config = engine_config.backend.fsdp + if self.fsdp_config is None: + self.fsdp_config = FSDPConfig() + self.optimizer_config = engine_config.optimizer + + self.model = None + self.optimizer = None + self.model_config = None + self.device_mesh = None + self.cpu_offload = None + + self.world_size = int(os.environ["WORLD_SIZE"]) + + def train(self, mode: bool = True): + """Set the module in training mode.""" + assert self.model is not None + self.model.train(mode=mode) + return self + + def init_distributed(self, config: ParallelismConfig, ft_spec: FinetuneSpec): + """Initialize distributed communication and model.""" + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + + dtype = torch.bfloat16 if self.engine_config.bf16 else torch.float16 + self.model_config = AutoConfig.from_pretrained( + pretrained_model_name_or_path=self.engine_config.path, + trust_remote_code=True, + ) + with torch.device("cuda"): + # initialize scratch model from config + model = AutoModelForCausalLM.from_config( + self.model_config, + torch_dtype=dtype, + attn_implementation="flash_attention_2", + ) + + # Simple auto wrap policy + # TODO: fix wrap policy + mixed_precision_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + cast_forward_inputs=True, + ) + device_mesh = create_fsdp_device_mesh(self.world_size, self.world_size) + self.device_mesh = device_mesh + # sharding_strategy = ShardingStrategy.FULL_SHARD + self.cpu_offload = ( + CPUOffloadPolicy() if self.fsdp_config.offload_params else None + ) + + fsdp_kwargs = { + "mesh": device_mesh, + "mp_policy": mixed_precision_policy, + "offload_policy": self.cpu_offload, + "reshard_after_forward": True, + } + + # Wrap with FSDP2 + apply_fsdp2(model, fsdp_kwargs, self.fsdp_config.wrap_policy) + + self.model = model + + # Set up optimizer + if self.optimizer_config is not None: + assert ( + self.optimizer_config.type == "adam" + ), "Only AdamW optimizer is supported in this engine." + lr = self.optimizer_config.lr + weight_decay = self.optimizer_config.weight_decay + beta1 = self.optimizer_config.beta1 + beta2 = self.optimizer_config.beta2 + eps = self.optimizer_config.eps + + self.optimizer = torch.optim.AdamW( + self.model.parameters(), + lr=lr, + weight_decay=weight_decay, + betas=(beta1, beta2), + eps=eps, + ) + total_train_steps = ft_spec.total_train_steps + num_warmup_steps = int( + self.optimizer_config.warmup_steps_proportion * total_train_steps + ) + + if self.optimizer_config.lr_scheduler_type == "cosine": + self.lr_scheduler = get_cosine_schedule_with_warmup( + self.optimizer, + num_warmup_steps, + total_train_steps, + min_lr_ratio=self.optimizer_config.min_lr_ratio, + ) + elif self.optimizer_config.lr_scheduler_type == "linear": + self.lr_scheduler = get_linear_schedule_with_warmup( + self.optimizer, + num_warmup_steps, + total_train_steps, + ) + elif self.optimizer_config.lr_scheduler_type == "constant": + self.lr_scheduler = get_constant_schedule_with_warmup( + self.optimizer, + num_warmup_steps, + ) + else: + raise ValueError( + f"Unknown lr scheduler type {self.optimizer_config.lr_scheduler_type}" + ) + + def train(self, mode: bool = True): + self.model.train(mode) + return self + + def train_batch( + self, + input_: Dict, + mb_spec: MicroBatchSpec, + loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor], + loss_weight_fn: Callable[[Dict], float], + ) -> Dict: + """Train on a batch using gradient accumulation.""" + # self._initialize_fsdp_train() + assert self.optimizer is not None + assert self.optimizer_config is not None + assert self.lr_scheduler is not None + + self.optimizer.zero_grad() + + mb_inputs = split_dict_tensor_with_cu_seqlens(input_, mb_spec).mbs + + total_loss_weight = torch.tensor( + sum([loss_weight_fn(mb) for mb in mb_inputs]), dtype=torch.float32 + ) + assert total_loss_weight != 0 + dist.all_reduce(total_loss_weight) + + # Process microbatches with gradient accumulation + for i, mb_input in enumerate(mb_inputs): + outputs = self.model(**mb_input) + + loss = loss_fn(outputs.logits, mb_input) + loss_scale = loss_weight_fn(mb_input) / total_loss_weight + + # Scale loss for accumulation + # Revert gradient averaging across dp ranks + loss_scale *= self.world_size + + loss *= loss_scale + loss.backward() + + grad_norm = fsdp2_clip_grad_norm_( + self.model.parameters(), max_norm=self.optimizer_config.gradient_clipping + ) + if not torch.isfinite(grad_norm): + self.optimizer.zero_grad() + update_successful = False + else: + self.optimizer.step() + update_successful = True + + current_lr = self.lr_scheduler.get_last_lr()[0] + # Optimizer step + self.optimizer.step() + return dict( + update_successful=float(update_successful), + grad_norm=float(grad_norm) if grad_norm is not None else float("nan"), + lr=current_lr, + ) + + def step_lr_scheduler(self): + assert self.lr_scheduler is not None + self.lr_scheduler.step() + + @torch.no_grad() + def eval_batch( + self, + input_: Dict, + mb_spec: MicroBatchSpec, + loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor], + loss_weight_fn: Callable[[Dict], float], + ) -> torch.Tensor | None: + """Evaluate on a batch.""" + mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec) + total_loss_weight = torch.tensor( + sum([loss_weight_fn(mb) for mb in mb_splits.mbs]), dtype=torch.float32 + ) + assert total_loss_weight != 0 + + total_loss = 0.0 + total_weight = 0.0 + + for mb_input in mb_splits.mbs: + outputs = self.model(**mb_input) + loss = loss_fn(outputs.logits, mb_input) + + # Simple weight calculation (could be improved) + loss_scale = loss_weight_fn(mb_input) / total_loss_weight + total_loss += loss.item() * loss_scale + total_weight += loss_scale + + return torch.tensor(total_loss / total_weight) + + @torch.no_grad() + def forward( + self, + input_: Dict, + mb_spec: MicroBatchSpec, + output_seqlens: List[int] | None = None, + post_hook: Callable[[torch.Tensor, Dict], Any] | None = None, + aggregate_fn: Callable[[List[Any]], Any] = functools.partial(torch.cat, dim=1), + ) -> Any | None: + """Forward pass with optional post-processing.""" + mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec) + if output_seqlens is None: + cu_seqlens = input_["cu_seqlens"] + output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist() + + results = [] + for mb_input in mb_splits.mbs: + outputs = self.model(**mb_input) + if post_hook: + result = post_hook(outputs.logits, mb_input) + results.append(result) + else: + results.append(outputs.logits) + + res = aggregate_fn(results) + output_seqlens = [output_seqlens[i] for i in mb_splits.forward_indices] + unpacked = unpack_sequence(res, lens=output_seqlens, dim=1) + return aggregate_fn(recorder_list(unpacked, mb_splits.backward_indices)) + + def get_hf_model_state_dict(self) -> Dict[str, torch.Tensor]: + """Get model state dict for saving.""" + if self.model is None: + raise RuntimeError("Model not initialized") + + with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT): + return self.model.state_dict() + + def save_model_to_hf( + self, + path: str, + tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None, + base_model_path: Optional[str] = None, + ): + """Save model in HuggingFace format.""" + if self.model is None: + raise RuntimeError("Model not initialized") + + os.makedirs(path, exist_ok=True) + + # FSDP2 checkpoint saving + from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, + get_model_state_dict, + ) + + # Get full state dict with FSDP2 + options = StateDictOptions(full_state_dict=True, cpu_offload=True) + state_dict = get_model_state_dict(self.model, options=options) + + # save huggingface model + if dist.get_rank() == 0: + os.makedirs(path, exist_ok=True) + self.model.save_pretrained(path, state_dict=state_dict) + self.model_config.save_pretrained(path) + if tokenizer is not None: + tokenizer.save_pretrained(path) + + dist.barrier() + + def load_model_from_hf(self, path: str): + """Load model from HuggingFace format.""" + if dist.get_rank() == 0: + full_state = get_state_dict_from_repo_id_or_path(path) + else: + full_state = {} + + fsdp2_load_full_state_dict( + self.model, + full_state, + self.cpu_offload, + tie_word_embeddings=self.model_config.tie_word_embeddings, + ) + + def save_optimizer_state(self, path: str): + """Save optimizer state.""" + if self.optimizer is None: + raise RuntimeError("Optimizer not initialized") + + os.makedirs(path, exist_ok=True) + torch.save(self.optimizer.state_dict(), os.path.join(path, "optimizer.pt")) + + def load_optimizer_state(self, path: str): + """Load optimizer state.""" + if self.optimizer is None: + raise RuntimeError("Optimizer not initialized") + + optimizer_path = os.path.join(path, "optimizer.pt") + if os.path.exists(optimizer_path): + self.optimizer.load_state_dict( + torch.load(optimizer_path, map_location="cpu") + ) + else: + raise RuntimeError(f"Optimizer state file not found: {optimizer_path}") + + async def aupdate_weights_to(self, llm_client: LLMClient): + """Async method to update weights to all healthy servers.""" + path = constants.get_param_realloc_path(self.args) + self.save_model_to_hf(path) + tasks = [ + llm_client.aupdate_weights_from_disk(server_info=server_info, path=path) + for server_info in llm_client.get_healthy_servers() + ] + await asyncio.gather(*tasks) + + def update_weights_to(self, llm_client: LLMClient): + """Update the weights to the server by sending requests to the client.""" + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(self.aupdate_weights_to(llm_client)) + finally: + loop.close() diff --git a/arealite/impl/engine/hf_wrapper.py b/arealite/impl/engine/hf_wrapper.py new file mode 100644 index 000000000..135f21176 --- /dev/null +++ b/arealite/impl/engine/hf_wrapper.py @@ -0,0 +1,315 @@ +import asyncio +import functools +import math +import os +from typing import Any, Callable, Dict, List, Optional + +import torch +import torch.distributed as dist +import transformers +from transformers import AutoConfig, AutoModelForCausalLM + +from arealite.api.cli_args import ( + EngineConfig, + MicroBatchSpec, + ParallelismConfig, + TrainingArgs, +) +from arealite.api.engine_api import SPMDWrapper +from arealite.api.io_struct import FinetuneSpec +from arealite.api.llm_client_api import LLMClient +from arealite.utils import ( + get_state_dict_from_repo_id_or_path, + recorder_list, + split_dict_tensor_with_cu_seqlens, + unpack_sequence, +) +from realhf.base import constants + + +def get_cosine_schedule_with_warmup( + optimizer: torch.optim.Optimizer, + num_warmup_steps: int, + num_training_steps: int, + min_lr_ratio: float = 0.0, + num_cycles: float = 0.5, + last_epoch: int = -1, +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + Args: + optimizer (:class:`~torch.optim.Optimizer`): + The optimizer for which to schedule the learning rate. + num_warmup_steps (:obj:`int`): + The number of steps for the warmup phase. + num_training_steps (:obj:`int`): + The total number of training steps. + min_lr_ratio (:obj:`float`, `optional`, defaults to 0.0): + The minimum lr ratio w.r.t the maximum. + num_cycles (:obj:`float`, `optional`, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (:obj:`int`, `optional`, defaults to -1): + The index of the last epoch when resuming training. + Return: + :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + assert min_lr_ratio >= 0 and min_lr_ratio <= 1.0 + coef = (1 - min_lr_ratio) * 0.5 + intercept = (1 + min_lr_ratio) * 0.5 + + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) + x = math.cos(math.pi * float(num_cycles) * 2.0 * progress) + return max(0.0, x * coef + intercept) + + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch) + + +class HFEngine(SPMDWrapper): + """Simplified HF engine for transformer models.""" + + def __init__(self, args: TrainingArgs, engine_config: EngineConfig): + super().__init__(args, engine_config) + + self.model = None + self.optimizer = None + self.model_config = None + + self.weight_update_group_initialized = False + + def init_distributed(self, config: ParallelismConfig, ft_spec: FinetuneSpec): + """Initialize model in single node.""" + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") + if dist.get_world_size() > 1: + raise RuntimeError( + "Distributed training is not supported in this engine. " + "Please use FSDP for distributed training." + ) + torch.cuda.set_device("cuda:0") + + dtype = torch.bfloat16 if self.engine_config.bf16 else torch.float16 + self.model_config = AutoConfig.from_pretrained( + pretrained_model_name_or_path=self.engine_config.path, + trust_remote_code=True, + ) + with torch.device("cuda"): + # initialize scratch model from config + model = AutoModelForCausalLM.from_config( + self.model_config, + torch_dtype=dtype, + attn_implementation="flash_attention_2", + ) + + model = model.cuda() + + self.model = model + + # Set up optimizer + optimizer_config = self.engine_config.optimizer + if optimizer_config is not None: + assert ( + optimizer_config.type == "adam" + ), "Only AdamW optimizer is supported in this engine." + lr = optimizer_config.lr + weight_decay = optimizer_config.weight_decay + beta1 = optimizer_config.beta1 + beta2 = optimizer_config.beta2 + eps = optimizer_config.eps + + self.optimizer = torch.optim.AdamW( + self.model.parameters(), + lr=lr, + weight_decay=weight_decay, + betas=(beta1, beta2), + eps=eps, + ) + total_train_steps = ft_spec.total_train_steps + num_warmup_steps = int( + optimizer_config.warmup_steps_proportion * total_train_steps + ) + + self.lr_scheduler = get_cosine_schedule_with_warmup( + self.optimizer, + num_warmup_steps, + total_train_steps, + min_lr_ratio=optimizer_config.min_lr_ratio, + ) + + def train(self, mode: bool = True): + """Set the module in training mode.""" + return self.model.train(mode) + + def train_batch( + self, + input_: Dict, + mb_spec: MicroBatchSpec, + loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor], + loss_weight_fn: Callable[[Dict], float], + ) -> Dict: + """Train on a batch using gradient accumulation.""" + assert self.optimizer is not None + assert self.lr_scheduler is not None + + self.optimizer.zero_grad() + mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec) + total_loss_weight = torch.tensor( + sum([loss_weight_fn(mb) for mb in mb_splits.mbs]), dtype=torch.float32 + ) + assert total_loss_weight != 0 + + for mb_input in mb_splits.mbs: + outputs = self.model(**mb_input) + loss = loss_fn(outputs.logits, mb_input) + loss_scale = loss_weight_fn(mb_input) / total_loss_weight + loss *= loss_scale + loss.backward() + + grad_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), + self.engine_config.optimizer.gradient_clipping, + norm_type=2.0, + error_if_nonfinite=False, + foreach=None, + ) + current_lr = self.lr_scheduler.get_last_lr()[0] + # Optimizer step + self.optimizer.step() + + return { + "grad_norm": grad_norm, + "lr": current_lr, + } + + @torch.no_grad() + def eval_batch( + self, + input_: Dict, + mb_spec: MicroBatchSpec, + loss_fn: Callable[[torch.Tensor, Dict], torch.Tensor], + loss_weight_fn: Callable[[Dict], float], + ) -> torch.Tensor | None: + """Evaluate on a batch.""" + mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec) + total_loss_weight = torch.tensor( + sum([loss_weight_fn(mb) for mb in mb_splits.mbs]), dtype=torch.float32 + ) + assert total_loss_weight != 0 + + total_loss = 0.0 + total_weight = 0.0 + + for mb_input in mb_splits.mbs: + outputs = self.model(**mb_input) + loss = loss_fn(outputs.logits, mb_input) + + # Simple weight calculation (could be improved) + loss_scale = loss_weight_fn(mb_input) / total_loss_weight + total_loss += loss.item() * loss_scale + total_weight += loss_scale + + return torch.tensor(total_loss / total_weight) + + @torch.no_grad() + def forward( + self, + input_: Dict, + mb_spec: MicroBatchSpec, + output_seqlens: List[int] | None = None, + post_hook: Callable[[torch.Tensor, Dict], Any] | None = None, + aggregate_fn: Callable[[List[Any]], Any] = functools.partial(torch.cat, dim=1), + ) -> Any | None: + """Forward pass with optional post-processing.""" + mb_splits = split_dict_tensor_with_cu_seqlens(input_, mb_spec) + if output_seqlens is None: + cu_seqlens = input_["cu_seqlens"] + output_seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist() + + results = [] + for mb_input in mb_splits.mbs: + outputs = self.model(**mb_input) + if post_hook: + result = post_hook(outputs.logits, mb_input) + results.append(result) + else: + results.append(outputs.logits) + + res = aggregate_fn(results) + output_seqlens = [output_seqlens[i] for i in mb_splits.forward_indices] + unpacked = unpack_sequence(res, lens=output_seqlens, dim=1) + return aggregate_fn(recorder_list(unpacked, mb_splits.backward_indices)) + + def step_lr_scheduler(self): + """Step the learning rate scheduler.""" + return self.lr_scheduler.step() + + def save_model_to_hf( + self, + path: str, + tokenizer: Optional[transformers.PreTrainedTokenizerFast] = None, + base_model_path: Optional[str] = None, + ): + """Save model in HuggingFace format.""" + if self.model is None: + raise RuntimeError("Model not initialized") + + os.makedirs(path, exist_ok=True) + + state_dict = {k: v.cpu() for k, v in self.model.state_dict().items()} + self.model.save_pretrained(path, state_dict=state_dict) + self.model_config.save_pretrained(path) + if tokenizer is not None: + tokenizer.save_pretrained(path) + + def load_model_from_hf(self, path: str): + """Load model from HuggingFace format.""" + full_state = get_state_dict_from_repo_id_or_path(path) + self.model.load_state_dict( + full_state, strict=not self.model_config.tie_word_embeddings + ) + if self.model_config.tie_word_embeddings: + self.model.tie_weights() + + def save_optimizer_state(self, path: str): + """Save optimizer state.""" + if self.optimizer is None: + raise RuntimeError("Optimizer not initialized") + + os.makedirs(path, exist_ok=True) + torch.save(self.optimizer.state_dict(), os.path.join(path, "optimizer.pt")) + + def load_optimizer_state(self, path: str): + """Load optimizer state.""" + if self.optimizer is None: + raise RuntimeError("Optimizer not initialized") + + optimizer_path = os.path.join(path, "optimizer.pt") + if os.path.exists(optimizer_path): + self.optimizer.load_state_dict( + torch.load(optimizer_path, map_location="cpu") + ) + else: + raise RuntimeError(f"Optimizer state file not found: {optimizer_path}") + + async def aupdate_weights_to(self, llm_client: LLMClient): + path = constants.get_param_realloc_path(self.args) + self.save_model_to_hf(path) + tasks = [ + llm_client.aupdate_weights_from_disk(server_info=server_info, path=path) + for server_info in llm_client.get_healthy_servers() + ] + await asyncio.gather(*tasks) + + def update_weights_to(self, llm_client: LLMClient): + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(self.aupdate_weights_to(llm_client)) + finally: + loop.close() diff --git a/arealite/impl/rlvr/rewards/areal_code.py b/arealite/impl/rlvr/rewards/areal_code.py new file mode 100644 index 000000000..52405e609 --- /dev/null +++ b/arealite/impl/rlvr/rewards/areal_code.py @@ -0,0 +1,47 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +import re +from functools import lru_cache +from typing import List + +from functioncall.code.local_verify import code_verify +from realhf.impl.dataset.math_code_dataset import load_metadata + + +@lru_cache(maxsize=1) +def _load_metadata(dataset_path: str): + """Cached version of load_metadata to avoid reloading metadata each time.""" + return load_metadata(dataset_path) + + +def extract_code(text, min_length=20): + """Extract code blocks from text.""" + code_pattern = r"(?i)```(?:python|py|cpp|CPP)?\s*\n?(.*?)\n?```" + code_blocks = re.findall(code_pattern, text, re.DOTALL) + valid_blocks = [] + for block in code_blocks: + clean_block = block.strip() + if len(clean_block) < min_length: + continue + valid_blocks.append(clean_block) + + if not valid_blocks: + return None + # return the last code block + return valid_blocks[-1] + + +def code_reward( + query_id: str, + prompt: str, + completion: str, + prompt_ids: List[int], + completion_ids: List[int], + dataset_path: str, + **kwargs, +) -> float: + id2info, _ = _load_metadata(dataset_path) + return code_verify( + id2info=id2info, generateds=[extract_code(completion)], query_ids=[query_id] + )[0] diff --git a/arealite/impl/rlvr/rewards/areal_math.py b/arealite/impl/rlvr/rewards/areal_math.py new file mode 100644 index 000000000..e745c1398 --- /dev/null +++ b/arealite/impl/rlvr/rewards/areal_math.py @@ -0,0 +1,27 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +from functools import lru_cache +from typing import List + +from realhf.impl.dataset.math_code_dataset import load_metadata +from realhf.impl.dataset.math_parser import parse_line + + +@lru_cache(maxsize=1) +def _load_metadata(dataset_path: str): + """Cached version of load_metadata to avoid reloading metadata each time.""" + return load_metadata(dataset_path) + + +def math_reward( + query_id: str, + prompt: str, + completion: str, + prompt_ids: List[int], + completion_ids: List[int], + dataset_path: str, + **kwargs, +) -> float: + id2info, _ = _load_metadata(dataset_path) + return parse_line(id2info=id2info, generated=completion, query_id=query_id) diff --git a/arealite/impl/rlvr/rewards/gsm8k.py b/arealite/impl/rlvr/rewards/gsm8k.py new file mode 100644 index 000000000..5675b9d0e --- /dev/null +++ b/arealite/impl/rlvr/rewards/gsm8k.py @@ -0,0 +1,83 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +# Modified from verl. +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import List + + +def extract_solution(solution_str, method="strict"): + assert method in ["strict", "flexible"] + + if method == "strict": + # this also tests the formatting of the model + solutions = re.findall("#### (\\-?[0-9\\.\\,]+)", solution_str) + if len(solutions) == 0: + final_answer = None + else: + # take the last solution + final_answer = solutions[-1].replace(",", "").replace("$", "") + elif method == "flexible": + answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str) + final_answer = None + if len(answer) == 0: + # no reward is there is no answer + pass + else: + invalid_str = ["", "."] + # find the last number that is not '.' + for final_answer in reversed(answer): + if final_answer not in invalid_str: + break + return final_answer + + +def compute_score( + solution_str, ground_truth, method="strict", format_score=0.0, score=1.0 +): + """The scoring function for GSM8k. + + Reference: Trung, Luong, et al. "Reft: Reasoning with reinforced fine-tuning." Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2024. + + Args: + solution_str: the solution text + ground_truth: the ground truth + method: the method to extract the solution, choices are 'strict' and 'flexible' + format_score: the score for the format + score: the score for the correct answer + """ + answer = extract_solution(solution_str=solution_str, method=method) + if answer is None: + return 0 + else: + if answer == ground_truth: + return score + else: + return format_score + + +def gsm8k_reward_fn( + query_id: str, + prompt: str, + completion: str, + prompt_ids: List[int], + completion_ids: List[int], + answer: str, + method: str, + **kwargs, +) -> float: + return compute_score(completion, extract_solution(answer), method=method) diff --git a/arealite/impl/rlvr/rlvr_collector.py b/arealite/impl/rlvr/rlvr_collector.py new file mode 100644 index 000000000..26581c35b --- /dev/null +++ b/arealite/impl/rlvr/rlvr_collector.py @@ -0,0 +1,91 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +from datetime import datetime +from typing import Any, Callable, Dict, Optional + +import torch + +from arealite.api.cli_args import ( + GenerationHyperparameters, + RolloutCollectorConfig, + TrainingArgs, +) +from arealite.api.io_struct import LLMRequest, Trajectory, TrajStats +from arealite.api.llm_client_api import LLMClient +from arealite.api.rollout_api import RolloutCollector +from realhf.base import logging + +logger = logging.getLogger(__file__) + + +class RlvrCollector(RolloutCollector): + def __init__( + self, + args: TrainingArgs, + config: RolloutCollectorConfig, + reward_fn: Callable, + ): + super().__init__(args, config, None, None) + self.reward_fn = reward_fn + + async def arun_episode( + self, + llm_client: LLMClient, + gconfig: GenerationHyperparameters, + env_option: Optional[Dict[str, Any]] = None, + seed: Optional[int] = None, + ) -> Trajectory: + """Async version of run_episode. Run a single episode and return the trajectory.""" + tik = datetime.now().timestamp() + + prompt_ids = env_option["input_ids"] + query_id = env_option["query_id"] + req = LLMRequest(input_ids=prompt_ids, gconfig=gconfig) + + # Use async LLM client + resp = await llm_client.agenerate(req) + + # Run reward computation in executor to avoid blocking + reward_kwargs = env_option.copy() + reward_kwargs.pop("query_id") + reward_kwargs.pop("prompt") + reward = self.reward_fn( + query_id=query_id, + prompt=req.text, + completion=resp.completion, + prompt_ids=prompt_ids, + completion_ids=resp.output_tokens, + **reward_kwargs, + ) + + input_len = len(resp.input_tokens) + output_len = len(resp.output_tokens) + + input_ids = resp.input_tokens + resp.output_tokens + prompt_mask = [1] * input_len + [0] * output_len + logprobs = [0.0] * input_len + resp.output_logprobs + versions = [-1] * input_len + resp.output_versions + + # logger.info( + # f"Prompt: {req.text}, reward: {reward}\nCompletion: {resp.completion}" + # ) + + return Trajectory( + prompt=env_option, + data=dict( + # unsqueeze to add an additional batch dimension + input_ids=torch.tensor(input_ids).unsqueeze(0), + prompt_mask=torch.tensor(prompt_mask).unsqueeze(0), + logprobs=torch.tensor(logprobs).unsqueeze(0), + versions=torch.tensor(versions).unsqueeze(0), + # reward + rewards=torch.tensor([reward]), + ), + stats=TrajStats( + start_time=tik, + total_reward=reward, + episode_length=1, + info={}, + ), + ) diff --git a/arealite/impl/trainer/grpo.py b/arealite/impl/trainer/grpo.py new file mode 100644 index 000000000..7922fea4a --- /dev/null +++ b/arealite/impl/trainer/grpo.py @@ -0,0 +1,530 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +import functools +import os +import time +from typing import Dict, List, Optional + +import torch +import torch.distributed as dist +from datasets import Dataset + +from arealite import ppo_functional +from arealite.api.cli_args import ( + GRPOTrainerConfig, + MicroBatchSpec, + TrainerConfig, + TrainingArgs, +) +from arealite.api.engine_api import EngineFactory +from arealite.api.io_struct import FinetuneSpec, Trajectory +from arealite.api.llm_client_api import LLMClientFactory +from arealite.api.trainer_api import Trainer +from arealite.system.rollout_controller import RolloutController +from arealite.utils import ( + calc_entropy, + close_wandb_tensorboard, + compute_varlen_position_indices, + concat_padded_tensors, + gather_logprobs, + init_stats_logging, + log_wandb_tensorboard, + masked_normalization, + record_timing, + split_dict_tensor_with_cu_seqlens, + to_device, + unpad_input, +) +from realhf.api.core.data_api import load_hf_tokenizer, tabulate_stats +from realhf.base import constants, logging, name_resolve, names, stats_tracker, timeutil + +logger = logging.getLogger("GRPO Trainer", "system") + + +class SpmdGRPOTrainer(Trainer): + + def __init__( + self, + args: TrainingArgs, + trainer_config: TrainerConfig, + train_dataset: Dataset, + valid_dataset: Optional[Dataset] = None, + rollout_controller: Optional[RolloutController] = None, + ): + super().__init__( + args, + trainer_config, + train_dataset, + valid_dataset, + rollout_controller, + ) + if self.rollout_controller is None: + raise ValueError("GRPO Trainer requires a rollout controller.") + + assert trainer_config.grpo is not None + self.config: GRPOTrainerConfig = trainer_config.grpo + assert args.rollout is not None + assert self.config.actor is not None + + # Create actor model + engine_factory = EngineFactory(args) + self.actor = engine_factory.make_engine(self.config.actor) + + self.actor_tokenizer = load_hf_tokenizer(self.config.actor.path) + self.gconfig = args.rollout.gconfig + + # Create reference model is specified + self.ref = None + if self.config.ref is not None: + self.ref = engine_factory.make_engine(self.config.ref) + + # Create a client to generate responses and update weights + client_factory = LLMClientFactory(args) + self.llm_client = client_factory.make_client(args.rollout.llm_client) + + # Algorithm related attributes + self.kl_ctl = self.config.kl_ctl + self.discount = self.config.discount + self.gae_lambda = self.config.gae_lambda + self.adv_norm = self.config.adv_norm + self.max_reward_clip = self.config.max_reward_clip + self.group_adv_norm = self.config.group_adv_norm + self.group_size = args.rollout.gconfig.n_samples + self.max_head_offpolicyness = args.rollout.max_head_offpolicyness + self.reward_bias = self.config.reward_bias + self.reward_scaling = self.config.reward_scaling + self.max_reward_clip = self.config.max_reward_clip + + self.save_ctl = timeutil.EpochStepTimeFreqCtl( + freq_epoch=self.args.exp_ctrl.save_freq_epochs, + freq_step=self.args.exp_ctrl.save_freq_steps, + freq_sec=self.args.exp_ctrl.save_freq_secs, + ) + self.eval_ctl = timeutil.EpochStepTimeFreqCtl( + freq_epoch=self.args.exp_ctrl.eval_freq_epochs, + freq_step=self.args.exp_ctrl.eval_freq_steps, + freq_sec=self.args.exp_ctrl.eval_freq_steps, + ) + self.summary_writer = init_stats_logging(args) + + def train(self, resume_from_checkpoint=None): + # TODO: handle recover + self.create_train_dataloader() + assert self.rollout_controller is not None + assert self.train_dataloader is not None + + total_epochs = self.args.exp_ctrl.total_train_epochs + steps_per_epoch = len(self.train_dataloader) + ft_spec = FinetuneSpec( + total_train_epochs=total_epochs, + dataset_size=len(self.train_dataset), + train_batch_size=self.args.train_dataset.batch_size, + ) + + # Setting up models. + self.actor.init_distributed(None, ft_spec) + self.actor.load_model_from_hf(self.config.actor.path) + self.actor.eval() + if self.ref is not None: + self.ref.init_distributed(None, ft_spec) + self.ref.load_model_from_hf(self.config.ref.path) + self.ref.eval() + self.llm_client.wait_until_servers_ready() + self.actor.update_weights_to(self.llm_client) + + # Start rollout for asynchronous RL. + if self.config.async_training: + self.rollout_controller.start_generate_loop() + + # Main RL training loop. + total_epochs = self.args.exp_ctrl.total_train_epochs + if dist.is_initialized(): + world_size = dist.get_world_size() + else: + world_size = 1 + global_step = 0 + warmup_steps = self.max_head_offpolicyness + 1 + assert steps_per_epoch >= warmup_steps + start_time = time.monotonic() + for epoch in range(total_epochs): + for step, data in enumerate(self.train_dataloader): + timing_stats = {} + with record_timing("timeperf/rollout", timing_stats): + if self.config.async_training: + self.rollout_controller.submit(data) + # Submitted data will not actually be sent for rollout. + # The rollout controller over-subscribe the data to + # ensure that there are enough data being generated. + if epoch == 0 and step < warmup_steps: + continue + # Wait until enough trajectories has been collected. + trajs = self.rollout_controller.prepare_batch( + batch_size=self.args.train_dataset.batch_size // world_size + ) + else: + # Run batched rollout by submitting requests to LLM servers + trajs = self.rollout_controller.generate_batch( + batch_size=len(data), + env_options=data, + ) + + with record_timing("timeperf/train_step", timing_stats): + # Run RL training and update weights. + mb_stats = self._train_step(trajs) + self.actor.step_lr_scheduler() + + with record_timing("timeperf/sync_weights", timing_stats): + # Synchronize weights to the client. + self.actor.update_weights_to(self.llm_client) + # Update model version + name = names.model_version( + self.args.experiment_name, self.args.trial_name, "actor" + ) + name_resolve.add(name, str(global_step + 1), replace=True) + + if self.save_ctl.check( + epochs=int(step == steps_per_epoch - 1), steps=1 + ): + if dist.get_rank() == 0: + logger.info("Saving model ...") + with record_timing("timeperf/save", timing_stats): + save_path = os.path.join( + constants.get_save_path(self.args), "actor" + ) + self.actor.save_model_to_hf( + save_path, + tokenizer=self.actor_tokenizer, + base_model_path=self.config.actor.path, + ) + + assert len(mb_stats) == self.config.ppo_n_minibatches + log_step = self.config.ppo_n_minibatches * global_step + for i, stats in enumerate(mb_stats): + log_wandb_tensorboard(log_step + i, stats, self.summary_writer) + log_wandb_tensorboard(log_step, timing_stats, self.summary_writer) + + if dist.get_rank() == 0: + logger.info( + f"Epoch {epoch+1}/{total_epochs} " + f"Step {step+1}/{steps_per_epoch} " + f"Train step {global_step + 1}/{total_epochs * steps_per_epoch - warmup_steps} done." + ) + logger.info( + f"Detailed time stats: \n{tabulate_stats(timing_stats, floatfmt='.2f')}" + ) + for i, stats in enumerate(mb_stats): + logger.info( + f"GRPO training stats ({i + 1}/{len(mb_stats)}):\n{tabulate_stats(stats)}" + ) + + global_step += 1 + + if dist.get_rank() == 0: + logger.info( + f"Training completes! Total time elapsed {time.monotonic() - start_time:.2f}." + ) + if self.config.async_training: + self.rollout_controller.stop_generate_loop() + + close_wandb_tensorboard(self.summary_writer) + + def _train_step(self, trajs: List[Trajectory]): + rollout = concat_padded_tensors([traj.data for traj in trajs]) + rollout = to_device(rollout, torch.cuda.current_device()) + + # Marks which sequence does not has an EOS token, i.e., + # generation is truncated by the configured maximum generation length + batch_tokens = rollout["input_ids"] + seq_no_eos_mask = ( + batch_tokens[:, -1] != self.actor_tokenizer.eos_token_id + ).logical_and(batch_tokens[:, -1] != self.actor_tokenizer.pad_token_id) + + # Remove padding to use flash-attn + attn_mask = rollout["attention_mask"] + input_ids, _, cu_seqlens, max_seqlen = unpad_input( + rollout["input_ids"], attn_mask + ) + position_ids = compute_varlen_position_indices(input_ids.shape[0], cu_seqlens) + + # Transformer forward input data + model_inputs = dict( + input_ids=input_ids.unsqueeze(0), + attention_mask=None, + position_ids=position_ids.unsqueeze(0), + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + use_cache=False, + ) + old_logp, *_ = unpad_input(rollout["logprobs"], attn_mask) + prompt_mask, *_ = unpad_input(rollout["prompt_mask"], attn_mask) + # Shift logprobs and mask for computing loss. + loss_mask = prompt_mask.logical_not() + loss_mask = torch.roll(loss_mask, shifts=-1) + old_logp = torch.roll(old_logp, shifts=-1) + + input_ids = model_inputs["input_ids"].squeeze(0) + n_seqs = seq_no_eos_mask.shape[0] + assert n_seqs == self.local_train_batch_size * self.group_size, ( + n_seqs, + self.group_size, + self.local_train_batch_size, + ) + + # Run reference model forward + def calc_logprobs(logits, input_data): + logits = logits.squeeze(0).float() + labels = torch.roll(input_data["input_ids"].squeeze(0), shifts=-1) + logits /= self.gconfig.temperature + logprobs = gather_logprobs(logits, labels) + return logprobs.unsqueeze(0) + + if self.ref is not None and self.config.kl_ctl != 0.0: + ref_logp = self.ref.forward( + model_inputs, + mb_spec=self.config.mb_spec, + post_hook=calc_logprobs, + ).squeeze(0) + else: + ref_logp = torch.zeros_like(input_ids, dtype=torch.float32) + + # Recompute logprobs using the current actor model. + prox_logp = None + if self.config.recompute_logprob: + _logp = self.actor.forward( + model_inputs, + mb_spec=self.config.mb_spec, + post_hook=calc_logprobs, + ).squeeze(0) + if self.config.use_decoupled_loss: + prox_logp = _logp + else: + # Overwrite the logp returned by the inference engine + old_logp = _logp + + # Compute rewards using the reward function in synchronous RLVR pipeline. + reward_score = rollout["rewards"] + reward_score = (reward_score + self.reward_bias) * self.reward_scaling + reward_score = torch.clip(reward_score, max=self.max_reward_clip) + if self.config.group_reward_norm: + for i in range(n_seqs // self.group_size): + s = slice(i * self.group_size, (i + 1) * self.group_size) + r = reward_score[s] + reward_score[s] = (r - r.mean()) / (r.std() + 1e-9) + + # Apply the mask to log probabilities. + ref_logp *= loss_mask + old_logp *= loss_mask + + # Compute KL-regularized rewards and GAEs. + cu_seqlens = model_inputs["cu_seqlens"] + seq_no_eos_mask = seq_no_eos_mask + kl_rewards, rewards = ppo_functional.get_packed_rewards( + kl_ctl=self.kl_ctl, + clip_reward_value=self.max_reward_clip, + log_probs=old_logp, + ref_log_probs=ref_logp, + reward_score=reward_score, + cu_seqlens=cu_seqlens, + seq_no_eos_mask=seq_no_eos_mask, + mask_no_eos_with_zero=self.config.mask_no_eos_with_zero, + ) + advantages, _ = ppo_functional.get_packed_advantages_and_returns( + gamma=self.discount, + lam=self.gae_lambda, + values=torch.zeros( + input_ids.shape[0] + n_seqs, + device=input_ids.device, + dtype=torch.float32, + ), + rewards=rewards, + short1cu_seqlens=cu_seqlens, + seq_no_eos_mask=seq_no_eos_mask, + ) + + # Optionally perform advantage normalization. + if self.adv_norm: + if self.group_adv_norm: + n_samples = len(cu_seqlens) - 1 + assert n_samples % self.group_size == 0 + adv_list = [] + for i in range(0, n_samples, self.group_size): + adv_list.append( + masked_normalization( + advantages[cu_seqlens[i] : cu_seqlens[i + self.group_size]], + loss_mask[cu_seqlens[i] : cu_seqlens[i + self.group_size]], + all_reduce=False, + ) + ) + advantages = torch.cat(adv_list, 0) + else: + advantages = masked_normalization(advantages, loss_mask) + + # Prepare data to be splitted into mini-batches. + global_batch = dict( + **model_inputs, + old_logp=old_logp, + advantages=advantages, + loss_mask=loss_mask, + prox_logp=prox_logp, + ) + input_lens = model_inputs["cu_seqlens"][1:] - model_inputs["cu_seqlens"][:-1] + + all_stats = [] + with stats_tracker.scope("actor"): + ########## Logging code starts ########## + result_denominators = { + "correct_n_seqs": (reward_score > 0).bool(), + "incorrect_n_seqs": (reward_score <= 0).bool(), + } + global_denominators = dict( + n_seqs=torch.ones_like(reward_score, dtype=torch.bool), + n_tokens=torch.ones_like(loss_mask, dtype=torch.bool), + n_valid_tokens=loss_mask.bool(), + **result_denominators, + ) + stats_tracker.denominator(**global_denominators) + stats_tracker.stat( + correct_seq_len=input_lens.float(), denominator="correct_n_seqs" + ) + stats_tracker.stat( + incorrect_seq_len=input_lens.float(), denominator="incorrect_n_seqs" + ) + + stats = dict( + advantages=advantages, + kl_rewards=kl_rewards, + final_reward=rewards, + ) + stats_tracker.stat(**stats, denominator="n_valid_tokens") + + prompt_lens = [] + for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): + prompt_lens.append(prompt_mask[s:e].sum()) + prompt_lens = torch.tensor(prompt_lens, device=reward_score.device) + seq_stats = dict( + no_eos_ratios=seq_no_eos_mask.float(), + task_reward=reward_score, + prompt_len=prompt_lens.float(), + seq_len=input_lens.float(), + ) + stats_tracker.stat(**seq_stats, denominator="n_seqs") + scalars = dict( + mask_no_eos_with_zero=self.config.mask_no_eos_with_zero, + eps_clip=self.config.eps_clip, + use_prox_logp=prox_logp is not None, + ) + if self.config.c_clip is not None: + scalars["c_clip"] = self.config.c_clip + scalars["use_dual_clip"] = 1 + else: + scalars["use_dual_clip"] = 0 + if self.config.behav_imp_weight_cap is not None: + scalars["behav_imp_weight_cap"] = self.config.behav_imp_weight_cap + stats_tracker.scalar(**scalars) + + global_stats = stats_tracker.export() + for k in global_denominators: + global_stats.pop(f"actor/{k}") + ########## Logging code ends ########## + + mb_inputs = split_dict_tensor_with_cu_seqlens( + global_batch, + mb_spec=MicroBatchSpec(n_mbs=self.config.ppo_n_minibatches), + ) + for mb in mb_inputs.mbs: + model_inputs = {k: mb[k] for k in model_inputs} + train_stat = self.actor.train_batch( + mb, + loss_fn=functools.partial( + grpo_loss_fn, + temperature=self.gconfig.temperature, + eps_clip=self.config.eps_clip, + c_clip=self.config.c_clip, + behav_imp_weight_cap=self.config.behav_imp_weight_cap, + ), + mb_spec=self.config.mb_spec, + loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(), + ) + stats_tracker.scalar(**train_stat) + all_stats.append(stats_tracker.export()) + all_stats[0].update(global_stats) + return all_stats + + +def grpo_loss_fn( + logits: torch.Tensor, + input_data: Dict, + temperature: float, + eps_clip: float, + c_clip: float | None, + behav_imp_weight_cap: float | None, +): + """Loss function for actor step, all inputs should be splitted into + pipeline micro batches, returns loss and logging stats.""" + input_ids = input_data["input_ids"].squeeze(0) + cu_seqlens = input_data["cu_seqlens"] + old_logp = input_data["old_logp"] + advantages = input_data["advantages"] + loss_mask = input_data["loss_mask"] + prox_logp = input_data["prox_logp"] + + logits = logits.squeeze(0).float() + logits /= temperature + logprobs = gather_logprobs(logits, torch.roll(input_ids, shifts=-1)) + loss, stat = ppo_functional.actor_loss_fn( + logprobs=logprobs, + old_logprobs=old_logp, + advantages=advantages, + eps_clip=eps_clip, + loss_mask=loss_mask, + c_clip=c_clip, + proximal_logprobs=prox_logp, + behav_imp_weight_cap=behav_imp_weight_cap, + ) + + entropy = calc_entropy(logits=logits, cu_seqlens=cu_seqlens) + + # Log training statistics + stats_tracker.denominator( + n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device), + n_valid_tokens=loss_mask.bool(), + clipped_tokens=stat["clip_mask"], + dual_clipped_tokens=stat["dual_clip_mask"], + ) + + stats_tracker.stat( + importance_weight=stat["importance_weight"], + approx_kl=stat["approx_kl"], + new_logp=logprobs.detach(), + old_logp=old_logp, + entropy=entropy.float(), + actor_loss=stat["loss"], + clip_ratio=stat["clip_mask"].float(), + dual_clip_ratio=stat["dual_clip_mask"].float(), + denominator="n_valid_tokens", + ) + if "behave_imp_weight" in stat: + stats_tracker.denominator(unclipped_behave_tokens=stat["behave_mask"]) + stats_tracker.stat( + behave_imp_weight=stat["behave_imp_weight"], + behave_approx_kl=stat["behave_approx_kl"], + denominator="unclipped_behave_tokens", + ) + vocab_min_logits = logits.detach().min(-1).values.float() + vocab_max_logits = logits.detach().max(-1).values.float() + stats_tracker.stat( + vocab_min_logits=vocab_min_logits, + vocab_max_logits=vocab_max_logits, + denominator="n_tokens", + ) + + clip_mask = stat["clip_mask"] + clipped_new_logp = torch.where(clip_mask, logprobs.detach(), 0.0) + clipped_old_logp = torch.where(clip_mask, old_logp, 0.0) + stats_tracker.stat( + clipped_new_logp=clipped_new_logp, + clipped_old_logp=clipped_old_logp, + denominator="clipped_tokens", + ) + return loss diff --git a/arealite/impl/trainer/sft.py b/arealite/impl/trainer/sft.py new file mode 100644 index 000000000..44bb0a546 --- /dev/null +++ b/arealite/impl/trainer/sft.py @@ -0,0 +1,269 @@ +import time +from typing import Any, Dict, List, Optional + +import torch +import torch.distributed as dist +import torch.utils.data +from datasets import Dataset + +from arealite.api.cli_args import TrainerConfig, TrainingArgs +from arealite.api.engine_api import EngineFactory +from arealite.api.trainer_api import Trainer +from arealite.system.rollout_controller import RolloutController +from arealite.utils import ( + close_wandb_tensorboard, + compute_varlen_position_indices, + gather_logprobs, + init_stats_logging, + list_of_dict2dict_of_list, + log_wandb_tensorboard, + record_timing, +) +from realhf.api.core.data_api import load_hf_tokenizer, tabulate_stats +from realhf.api.core.model_api import FinetuneSpec +from realhf.base import logging, stats_tracker, timeutil + +logger = logging.getLogger("SFT Trainer") + + +def compute_packed_sft_loss( + logits: torch.Tensor, + input_: Dict[str, torch.Tensor], +) -> torch.Tensor: + packed_input_ids: torch.Tensor = input_["input_ids"].squeeze(dim=0) + cu_seqlens: torch.Tensor = input_["cu_seqlens"] + input_lens: torch.Tensor = cu_seqlens[1:] - cu_seqlens[:-1] + cu_seqlens = torch.nn.functional.pad(input_lens.cumsum(0), (1, 0)).int() + prompt_mask = input_["prompt_mask"].squeeze(dim=0) + logits = logits.squeeze(dim=0).float() + + logprobs = gather_logprobs(logits, torch.roll(packed_input_ids, shifts=-1)) + logprobs = torch.where(prompt_mask, 0, logprobs) + + loss = -logprobs.sum() / prompt_mask.logical_not().count_nonzero() + + with torch.no_grad(): + seqlogp = torch.zeros( + cu_seqlens.shape[0] - 1, device=logits.device, dtype=torch.float64 + ) + for i in range(cu_seqlens.shape[0] - 1): + m = prompt_mask[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1] + logp = logprobs[cu_seqlens[i] - i : cu_seqlens[i + 1] - i - 1] + assert cu_seqlens[i + 1] - i - 1 <= logprobs.shape[0], ( + cu_seqlens, + logprobs.shape, + ) + seqlogp[i] = torch.where(m, 0.0, logp.detach()).sum() / ( + m.numel() - m.count_nonzero() + ) + + ## Loggin stats + stats_tracker.denominator( + n_seqs=torch.ones( + cu_seqlens.shape[0] - 1, dtype=torch.bool, device=logprobs.device + ), + n_tokens=torch.ones(logits.shape[0], dtype=torch.bool, device=logits.device), + n_valid_tokens=prompt_mask.logical_not(), + prompt_tokens=prompt_mask, + ) + stats_tracker.stat(ppl=(-seqlogp).exp().float(), denominator="n_seqs") + stats_tracker.stat(loss=-logprobs.detach(), denominator="n_valid_tokens") + vocab_min_logits = logits.detach().min(-1).values.float() + vocab_max_logits = logits.detach().max(-1).values.float() + stats_tracker.stat( + vocab_min_logits=vocab_min_logits, + vocab_max_logits=vocab_max_logits, + denominator="n_tokens", + ) + + return loss + + +class SFTTrainer(Trainer): + + def __init__( + self, + args: TrainingArgs, + trainer_config: TrainerConfig, + train_dataset: Dataset, + valid_dataset: Optional[Dataset] = None, + rollout_controller: Optional[RolloutController] = None, + ): + super().__init__( + args, trainer_config, train_dataset, valid_dataset, rollout_controller + ) + + self.config = config = trainer_config.sft + assert config is not None + + engine_factory = EngineFactory(args) + self.model = engine_factory.make_engine(config.model) + self.tokenizer = load_hf_tokenizer(config.model.path) + + self.mb_spec = config.mb_spec + + self.save_ctl = timeutil.EpochStepTimeFreqCtl( + freq_epoch=self.args.exp_ctrl.save_freq_epochs, + freq_step=self.args.exp_ctrl.save_freq_steps, + freq_sec=self.args.exp_ctrl.save_freq_secs, + ) + self.eval_ctl = timeutil.EpochStepTimeFreqCtl( + freq_epoch=self.args.exp_ctrl.eval_freq_epochs, + freq_step=self.args.exp_ctrl.eval_freq_steps, + freq_sec=self.args.exp_ctrl.eval_freq_steps, + ) + self.summary_writer = init_stats_logging(args) + + def _tokenize(self, strs: List[str]): + # tokenize strings into unpadded tokens with lengths. + return self.tokenizer( + strs, + padding=False, + truncation=True, + return_length=True, + max_length=self.mb_spec.max_tokens_per_mb, + return_attention_mask=False, + ) + + def _get_packed_input(self, data: List[Dict[str, Any]]): + data: Dict[str, List[Any]] = list_of_dict2dict_of_list(data) + + tokenized_seqs = data["seq"] + tokenized_prompts = data["prompt"] + prompt_lens = [len(prompt) for prompt in tokenized_prompts] + input_lens = [len(prompt) for prompt in tokenized_seqs] + + input_lens = torch.tensor(input_lens, dtype=torch.int) + input_ids = [torch.tensor(seq, dtype=torch.long) for seq in tokenized_seqs] + + prompt_mask = [] + for input_len, prompt_len in zip(input_lens, prompt_lens): + assert input_len >= prompt_len, (input_len, prompt_len) + pm = [1] * prompt_len + [0] * (input_len - prompt_len) + prompt_mask.append(torch.tensor(pm, dtype=torch.bool)) + + cu_seqlens = torch.nn.functional.pad( + input_lens.cumsum(0, dtype=torch.int), (1, 0) + ) + max_seqlen = int(torch.max(input_lens).item()) + packed_input_ids = torch.cat(input_ids, dim=0) + prompt_mask = torch.cat(prompt_mask, dim=0) + total_seqlen = int(cu_seqlens[-1].item()) + position_ids = compute_varlen_position_indices(total_seqlen, cu_seqlens) + + return dict( + input_ids=packed_input_ids.unsqueeze(0).cuda(), + attention_mask=None, + position_ids=position_ids.unsqueeze(0).cuda(), + prompt_mask=prompt_mask.unsqueeze(0).cuda(), + cu_seqlens=cu_seqlens.cuda(), + max_seqlen=max_seqlen, + use_cache=False, + ) + + def train(self, resume_from_checkpoint=None): + self.create_train_dataloader() + + total_epochs = self.args.exp_ctrl.total_train_epochs + steps_per_epoch = len(self.train_dataloader) + ft_spec = FinetuneSpec( + total_train_epochs=steps_per_epoch, + dataset_size=len(self.train_dataset), + train_batch_size=self.args.train_dataset.batch_size, + ) + + self.model.init_distributed(None, ft_spec) + self.model.load_model_from_hf(self.config.model.path) + self.model.train() + + if dist.get_rank() == 0: + logger.info(f"total_epochs={total_epochs} step_per_epoch={steps_per_epoch}") + global_step = 0 + start_time = time.monotonic() + for epoch in range(total_epochs): + for step, data in enumerate(self.train_dataloader): + timing_stats = {} + with record_timing("timeperf/data_processing", timing_stats): + packed_input_data = self._get_packed_input(data) + + with record_timing("timeperf/train_step", timing_stats): + with stats_tracker.scope("sft"): + stats = self.model.train_batch( + input_=packed_input_data, + loss_fn=compute_packed_sft_loss, + loss_weight_fn=lambda x: x["prompt_mask"] + .logical_not() + .count_nonzero(), + mb_spec=self.mb_spec, + ) + self.model.step_lr_scheduler() + stats_tracker.scalar(**stats) + + if self.save_ctl.check( + epochs=int(step == steps_per_epoch - 1), steps=1 + ): + if dist.get_rank() == 0: + logger.info("Saving model ...") + + with record_timing("timeperf/save", timing_stats): + save_path = self.get_save_checkpoint_path( + epoch, step, global_step + ) + self.model.save_model_to_hf(save_path, self.tokenizer) + + if self.eval_ctl.check( + epochs=int(step == steps_per_epoch - 1), steps=1 + ): + if dist.get_rank() == 0: + logger.info("Running evaluation ...") + with record_timing("timeperf/eval", timing_stats): + self._eval(global_step) + + training_stats = stats_tracker.export() + training_stats.update(timing_stats) + log_wandb_tensorboard(global_step, training_stats, self.summary_writer) + + if dist.get_rank() == 0: + logger.info( + f"Epoch {epoch} Step {step} GlobalStep {global_step} done. Detailed time stats:" + f"\n{tabulate_stats(timing_stats, floatfmt='.2f')}" + ) + global_step += 1 + + if dist.get_rank() == 0: + logger.info( + f"Training completes! Total time elapsed {time.monotonic() - start_time:.2f}." + ) + + close_wandb_tensorboard(self.summary_writer) + + def _eval(self, global_step): + self.create_valid_dataloader() + if self.valid_dataloader is None: + return + + self.eval_data_generator = iter(self.valid_dataloader) + n_steps = len(self.valid_dataloader) + + losses = [] + + start_time = time.monotonic() + for step in range(n_steps): + data = next(self.eval_data_generator) + packed_input_data = self._get_packed_input(data) + with stats_tracker.scope("sft-eval"): + avg_loss = self.model.eval_batch( + input_=packed_input_data, + loss_fn=compute_packed_sft_loss, + loss_weight_fn=lambda x: x["prompt_mask"] + .logical_not() + .count_nonzero(), + mb_spec=self.mb_spec, + ) + losses.append(avg_loss) + val_loss = torch.mean(torch.stack(losses)) + + logger.info( + f"Global step: {global_step} evaluation time cost {time.monotonic() - start_time:.2f} " + f"val_loss={val_loss:.4f}" + ) diff --git a/arealite/launcher/with_ray.py b/arealite/launcher/with_ray.py new file mode 100644 index 000000000..e69de29bb diff --git a/arealite/launcher/with_scheduler.py b/arealite/launcher/with_scheduler.py new file mode 100644 index 000000000..f114bffb1 --- /dev/null +++ b/arealite/launcher/with_scheduler.py @@ -0,0 +1,87 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +import os +import sys + +from arealite.api.cli_args import prepare_training_args +from arealite.api.io_struct import AllocationMode +from arealite.api.llm_server_api import LLMServiceRegistry +from realhf.base import constants, name_resolve, names +from realhf.scheduler.client import JobException, JobState +from realhf.scheduler.client import make as make_scheduler + + +def main(): + cfg, config_file = prepare_training_args(sys.argv[1:]) + if cfg.shutdown_server_on_exit: + name_resolve.clear_subtree( + names.trial_root( + experiment_name=cfg.experiment_name, trial_name=cfg.trial_name + ) + ) + + # Launch inference and training jobs + alloc_mode = AllocationMode.from_str(cfg.allocation_mode) + assert cfg.mode == "local" + scheduler = make_scheduler(cfg) + BASE_ENVIRONS = constants.get_env_vars(cfg) + for k, v in BASE_ENVIRONS.items(): + os.environ[k] = v + + # discover existing servers + existing_servers = LLMServiceRegistry( + cfg.experiment_name, cfg.trial_name + ).get_healthy_servers() + # Launch LLM servers. + if len(existing_servers) == 0: + n_gpus_per_instance = alloc_mode.gen_pp_size * alloc_mode.gen_tp_size + servers_to_launch = alloc_mode.gen_dp_size - len(existing_servers) + scheduler.submit_array( + worker_type="llm_server", + cmd=f"python3 arealite/cli/launch_server.py --config {str(config_file)}", + count=servers_to_launch, + cpu=cfg.cpu_per_inf_proc * n_gpus_per_instance, + gpu=n_gpus_per_instance, + mem=cfg.mem_per_inf_proc * n_gpus_per_instance, + env_vars=BASE_ENVIRONS, + container_image=cfg.cluster.gpu_infer_image, + ) + # Launch trainers. + scheduler.submit( + worker_type="trainer", + cmd=f"torchrun --nnodes 1 --nproc-per-node {alloc_mode.train_world_size} arealite/cli/launch_trainer.py --config {str(config_file)}", + cpu=cfg.cpu_per_train_proc * alloc_mode.train_world_size, + gpu=alloc_mode.train_world_size, + mem=cfg.cpu_per_train_proc * cfg.mem_per_train_proc, + container_image=cfg.cluster.gpu_image, + nodelist=cfg.nodelist, + exclude=cfg.exclude, + env_vars=BASE_ENVIRONS, + hostfile=False, + multiprog=False, + ) + + # Waiting for the job. + try: + scheduler.wait( + check_status=( + JobState.CANCELLED, + JobState.FAILED, + JobState.NOT_FOUND, + JobState.COMPLETED, + ), + remove_status=(), + ) + except (KeyboardInterrupt, JobException, TimeoutError): + kill_signal = ( + "SIGKILL" if cfg.mode == "slurm" else "SIGTERM" + ) # use sigkill to terminate slurm jobs + if cfg.shutdown_server_on_exit: + scheduler.stop_all(kill_signal) + else: + scheduler.stop("trainer") + + +if __name__ == "__main__": + main() diff --git a/arealite/ppo_functional.py b/arealite/ppo_functional.py new file mode 100644 index 000000000..446d6b71f --- /dev/null +++ b/arealite/ppo_functional.py @@ -0,0 +1,195 @@ +import functools +from typing import Dict, Optional, Tuple + +import torch +import torch.distributed + +from realhf.base import pkg_version + + +def actor_loss_fn( + logprobs: torch.Tensor, + old_logprobs: torch.Tensor, + advantages: torch.Tensor, + eps_clip: float, + loss_mask: torch.Tensor, + c_clip: Optional[float] = None, + proximal_logprobs: Optional[torch.Tensor] = None, + behav_imp_weight_cap: Optional[float] = None, +) -> Tuple[torch.Tensor, Dict]: + denorm_logprobs = ( + proximal_logprobs if proximal_logprobs is not None else old_logprobs + ) + loss_mask_count = loss_mask.count_nonzero() or 1 + ratio = torch.where(loss_mask, torch.exp(logprobs - denorm_logprobs), 0) + clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip) + pg_loss1 = -advantages * ratio + pg_loss2 = -advantages * clipped_ratio + clip_mask = pg_loss1.detach() < pg_loss2.detach() + pg_loss = torch.max(pg_loss1, pg_loss2) + if c_clip is not None: + assert c_clip > 1.0, c_clip + pg_loss3 = torch.sign(advantages) * c_clip * advantages + dual_clip_mask = pg_loss3.detach() < pg_loss.detach() + pg_loss = torch.min(pg_loss, pg_loss3) + else: + dual_clip_mask = torch.zeros_like(clip_mask) + if proximal_logprobs is not None: + behav_kl = proximal_logprobs - old_logprobs + behav_imp_weight = behav_kl.exp() + behav_mask = ( + (behav_imp_weight <= behav_imp_weight_cap).logical_and(loss_mask) + if behav_imp_weight_cap is not None + else loss_mask + ) + behav_kl = torch.where(behav_mask, behav_kl, 0.0) + behav_imp_weight = torch.where(behav_mask, behav_imp_weight, 0.0) + pg_loss = pg_loss * behav_imp_weight + logging_loss = pg_loss.detach() + pg_loss = torch.where(loss_mask, pg_loss, 0).sum() / loss_mask_count + clip_mask.logical_and_(loss_mask) + dual_clip_mask.logical_and_(loss_mask) + stat = dict( + loss=logging_loss, + importance_weight=ratio.detach(), + approx_kl=(logprobs - denorm_logprobs).detach(), + clip_mask=clip_mask, + dual_clip_mask=dual_clip_mask, + ) + if proximal_logprobs is not None: + stat["behave_imp_weight"] = behav_imp_weight + stat["behave_approx_kl"] = behav_kl + stat["behave_mask"] = behav_mask + return pg_loss, stat + + +def _huber_loss(x: torch.Tensor, y: torch.Tensor, delta: float): + diff = torch.abs(x - y) + return torch.where(diff < delta, 0.5 * diff**2, delta * (diff - 0.5 * delta)) + + +def _mse_loss(x: torch.Tensor, y: torch.Tensor): + return 0.5 * (x - y) ** 2 + + +def critic_loss_fn( + value: torch.Tensor, + old_value: torch.Tensor, + target_value: torch.Tensor, + value_eps_clip: float, + loss_mask: torch.Tensor, + loss_fn_type: str = "mse", +) -> Tuple[torch.Tensor, Dict]: + if loss_fn_type == "huber": + loss_fn = functools.partial(_huber_loss, delta=10.0) + elif loss_fn_type == "mse": + loss_fn = _mse_loss + else: + raise NotImplementedError(f"Unknown loss fn type: {loss_fn_type}") + value_loss_original = loss_fn(value, target_value) + value_clipped = old_value + (value - old_value).clamp( + -value_eps_clip, value_eps_clip + ) + value_loss_clipped = loss_fn(value_clipped, target_value) + value_loss = torch.max(value_loss_original, value_loss_clipped) + with torch.no_grad(): + clip_mask = value_loss_clipped.detach() > value_loss_original.detach() + clip_mask.logical_and_(loss_mask) + stat = dict(clip_mask=clip_mask, loss=value_loss.detach()) + value_loss = torch.where(loss_mask, value_loss, 0).sum() / loss_mask.count_nonzero() + return value_loss, stat + + +@torch.no_grad() +def get_packed_rewards( + kl_ctl: float, + clip_reward_value: float, + log_probs: torch.Tensor, + ref_log_probs: torch.Tensor, + reward_score: torch.Tensor, + cu_seqlens: torch.Tensor, + seq_no_eos_mask: torch.Tensor, + mask_no_eos_with_zero: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + tot_rewards = -kl_ctl * (log_probs - ref_log_probs) + tot_rewards[cu_seqlens[1:] - 1] = 0 + kl_rewards = tot_rewards.clone() + reward_score = reward_score.clip(-clip_reward_value, clip_reward_value) + indices = torch.clip(cu_seqlens[1:] - 2, min=0) + if mask_no_eos_with_zero: + tot_rewards[indices] += torch.where(seq_no_eos_mask, 0, reward_score) + else: + tot_rewards[indices] += reward_score + return kl_rewards, tot_rewards + + +def pygae1d_nolp_misalign( + rewards: torch.Tensor, + values: torch.Tensor, + cu_seqlens_: torch.Tensor, + bootstrap: torch.Tensor, + gamma: float, + lam: float, +) -> Tuple[torch.Tensor, torch.Tensor]: + cu_seqlens = cu_seqlens_.clone() + cu_seqlens[1:] += torch.ones_like(cu_seqlens_[1:]).cumsum(0) + bs = cu_seqlens_.shape[0] - 1 + assert values.shape[0] == rewards.shape[0] + bs + advantages_reversed = [] + returns_reversed = [] + for i in reversed(range(bs)): + v_offset = cu_seqlens[i] + r_offset, r_end = cu_seqlens_[i], cu_seqlens_[i + 1] + assert cu_seqlens[i + 1] - v_offset - 1 == r_end - r_offset + lastgaelam = 0 + for t in reversed(range(r_end - r_offset)): + nextvalues = values[v_offset + t + 1] + if t == r_end - r_offset - 1: + nextvalues *= bootstrap[i] + delta = rewards[r_offset + t] + gamma * nextvalues - values[v_offset + t] + lastgaelam = delta + gamma * lam * lastgaelam + advantages_reversed.append(lastgaelam) + returns_reversed.append(lastgaelam + values[v_offset + t]) + advantages = torch.stack(advantages_reversed[::-1]) + returns = torch.stack(returns_reversed[::-1]) + return advantages, returns + + +def cugae1d_nolp_misalign_func( + rewards: torch.Tensor, + values: torch.Tensor, + cu_seqlens: torch.Tensor, + truncate: torch.Tensor, + gamma: float, + lam: float, +) -> Tuple[torch.Tensor, torch.Tensor]: + if pkg_version.is_available("cugae"): + from cugae import cugae1d_nolp_misalign_func as gae_1d_nolp_misalign + else: + from realhf._C.cugae import gae_1d_nolp_misalign + assert len(rewards.shape) == len(values.shape) == len(cu_seqlens.shape) == 1 + assert cu_seqlens[0] == 0 and cu_seqlens[-1] == rewards.shape[0] + return gae_1d_nolp_misalign(rewards, values, cu_seqlens, truncate, gamma, lam) + + +@torch.no_grad() +def get_packed_advantages_and_returns( + gamma: float, + lam: float, + values: torch.Tensor, + rewards: torch.Tensor, + short1cu_seqlens: torch.Tensor, + seq_no_eos_mask: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + if rewards.get_device() == -1: + return pygae1d_nolp_misalign( + rewards, values, short1cu_seqlens, seq_no_eos_mask, gamma, lam + ) + try: + return cugae1d_nolp_misalign_func( + rewards, values, short1cu_seqlens.int(), seq_no_eos_mask.bool(), gamma, lam + ) + except ModuleNotFoundError: + return pygae1d_nolp_misalign( + rewards, values, short1cu_seqlens, seq_no_eos_mask, gamma, lam + ) diff --git a/arealite/system/rollout_controller.py b/arealite/system/rollout_controller.py new file mode 100644 index 000000000..1f496e9ca --- /dev/null +++ b/arealite/system/rollout_controller.py @@ -0,0 +1,228 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +import asyncio +import threading +import time +import traceback +from queue import Empty as QueueEmpty +from typing import Any, List, Optional + +import numpy as np + +# NOTE: the start method of mp should be fork rather than spawn +import torch.multiprocessing as mp + +from arealite.api.cli_args import RolloutConfig, TrainingArgs +from arealite.api.io_struct import Trajectory +from arealite.api.llm_client_api import LLMClientFactory +from arealite.api.rollout_api import RolloutCollector +from arealite.system.rollout_worker import RolloutWorker +from realhf.base import datapack, logging, network +from realhf.system.push_pull_stream import ZMQJsonPuller, ZMQJsonPusher + +logger = logging.getLogger("Rollout Controller") + + +class RolloutController: + def __init__( + self, + args: TrainingArgs, + config: RolloutConfig, + collector: RolloutCollector, + ): + self.args = args + self.config = config + self.gconfig = config.gconfig + self.collector = collector + + # Process-based execution + self._exiting = mp.Event() + self._lock = mp.Lock() + self._buffer: List[List[Trajectory]] = [] + self._version = 0 + + # Worker processes for asynchronous rollout + self._worker_processes: List[mp.Process] = [] + + self.llm_client = LLMClientFactory(args).make_client(config.llm_client) + + # PushPull communication for data to workers + self._data_pusher = None + self._data_pusher_port = None + self._puller = None + self._puller_port = None + self._collector_thread = None + + ################### User Interfaces Start ################# + + def generate_batch( + self, + batch_size: int, + env_options: Optional[List[Any]] = None, + seeds: Optional[List[int]] = None, + ) -> List[Trajectory]: + """Run episodes in batch using the collector directly (for compatibility).""" + if env_options is None: + env_options = [None] * batch_size + else: + assert len(env_options) == batch_size + if seeds is None: + seeds = [None] * batch_size + else: + assert len(seeds) == batch_size + + async def run_parallel_gen(): + worker = RolloutWorker( + worker_id=0, + args=self.args, + config=self.config, + llm_client=self.llm_client, + ) + tasks = [ + worker._run_grouped_episode_async(None, env_option, seed) + for env_option, seed in zip(env_options, seeds) + ] + results = await asyncio.gather(*tasks) + return sum([r[1] for r in results], []) + + return asyncio.run(run_parallel_gen()) + + def start_generate_loop(self): + """Start worker processes that run generation loops.""" + logger.info("Starting worker processes...") + + # Start background thread to collect data from workers + self._puller_port = network.find_free_port( + experiment_name=self.args.experiment_name, trial_name=self.args.trial_name + ) + self._collector_thread = threading.Thread( + target=self._collect_from_workers, daemon=True + ) + self._collector_thread.start() + + # Start worker processes + self._data_pusher_port = network.find_free_port( + experiment_name=self.args.experiment_name, trial_name=self.args.trial_name + ) + self._data_pusher = ZMQJsonPusher( + host="localhost", port=self._data_pusher_port, bind=True + ) + logger.info(f"RolloutController sending data on port {self._data_pusher_port}") + + num_workers = self.config.num_workers + for worker_id in range(num_workers): + process = mp.Process( + target=_run_worker_process, + args=( + worker_id, + self.args, + self.config, + self._puller_port, + self._data_pusher_port, + ), + ) + process.start() + self._worker_processes.append(process) + logger.info(f"Started worker process {worker_id}") + + def submit(self, data): + """Submit data to worker processes for processing.""" + if self._data_pusher is None: + raise RuntimeError( + "Data pusher not initialized. Call start_generate_loop() first." + ) + + # Convert data to JSON-compatible format + assert isinstance(data, list) + for d in data: + self._data_pusher.push(d) + logger.debug(f"Submitted {len(data)} data to workers") + + def prepare_batch(self, batch_size: int) -> List[Trajectory]: + """Prepare and wait for a batch of trajectories.""" + buf_size = -1 + while buf_size < batch_size: + with self._lock: + buf_size = len(self._buffer) + time.sleep(0.1) + with self._lock: + self._buffer = sorted( + self._buffer, key=lambda x: np.mean([xx.stats.start_time for xx in x]) + ) + data, self._buffer = self._buffer[:batch_size], self._buffer[batch_size:] + return datapack.flat2d(data) + + def stop_generate_loop(self): + """Stop worker processes and cleanup.""" + logger.info("Stopping worker processes...") + self._exiting.set() + + # Stop worker processes gracefully first, then forcefully if needed + for i, process in enumerate(self._worker_processes): + if process.is_alive(): + logger.info(f"Terminating worker process {i}...") + try: + process.terminate() + process.join(timeout=1.0) + except Exception: + process.kill() + self._worker_processes.clear() + + if self._collector_thread is not None: + # Wait for the thread to finish (with optional timeout) + self._collector_thread.join(timeout=1.0) + + # Close communication channels + if self._puller: + self._puller.close() + if self._data_pusher: + self._data_pusher.close() + logger.info("Cleanup completed") + + ################## User Interfaces End ################## + + def _collect_from_workers(self): + """Background thread to collect trajectories from workers.""" + # Find a free port + self._puller = ZMQJsonPuller(host="localhost", port=self._puller_port) + logger.info(f"RolloutController listening on port {self._puller_port}") + + while not self._exiting.is_set(): + try: + # Pull data from workers + data = self._puller.pull(timeout_ms=100) + # Convert back to Trajectory objects + trajs = [ + Trajectory.from_json_compatible(traj_data) + for traj_data in data["trajs"] + ] + # Add to buffer + with self._lock: + self._buffer.append(trajs) + logger.debug( + f"Received {len(trajs)} trajectories from worker {data['worker_id']}" + ) + except QueueEmpty: + # No data available, continue + time.sleep(0.1) + continue + except Exception as e: + if not self._exiting.is_set(): + logger.error(f"Error in collector thread: {e}") + logger.error(traceback.format_exc()) + break + + +def _run_worker_process(worker_id: int, args, config, puller_port, data_pusher_port): + worker = RolloutWorker( + worker_id=worker_id, + args=args, + config=config, + pusher_host="localhost", + pusher_port=puller_port, + data_puller_host="localhost", + data_puller_port=data_pusher_port, + ) + logger.info(f"Worker {worker_id} starting generation loop...") + worker.run_generation_loop() diff --git a/arealite/system/rollout_worker.py b/arealite/system/rollout_worker.py new file mode 100644 index 000000000..e982ce182 --- /dev/null +++ b/arealite/system/rollout_worker.py @@ -0,0 +1,234 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +import asyncio +import queue +from typing import Any, Dict, List, Optional + +import numpy as np +import torch.distributed as dist + +from arealite.api.cli_args import RolloutConfig, TrainingArgs +from arealite.api.io_struct import Trajectory +from arealite.api.llm_client_api import LLMClient, LLMClientFactory +from arealite.api.rollout_api import RolloutCollectorFactory +from realhf.base import logging, name_resolve, names +from realhf.base.monitor import RolloutStat +from realhf.system.push_pull_stream import ZMQJsonPuller, ZMQJsonPusher + +logger = logging.getLogger("RolloutWorker") + +ROLLOUT_POLL_WAIT_TIME = 0.4 + + +class RolloutWorker: + """Standalone rollout worker that runs continuous generation loop.""" + + def __init__( + self, + worker_id: int, + args: TrainingArgs, + config: RolloutConfig, + llm_client: LLMClient | None = None, + pusher_host: Optional[str] = "localhost", + pusher_port: Optional[int] = 5555, + data_puller_host: Optional[str] = "localhost", + data_puller_port: Optional[int] = 5556, + ): + self.worker_id = worker_id + self.args = args + self.config = config + self.gconfig = config.gconfig + + # For staleness control + self.train_batch_size = args.train_dataset.batch_size + self.max_concurrent_rollouts = ( + config.max_concurrent_rollouts or self.train_batch_size + ) + + self.pusher_host = pusher_host + self.pusher_port = pusher_port + self.data_puller_host = data_puller_host + self.data_puller_port = data_puller_port + + self._shutdown = False + self.pusher = None + self.data_puller = None + + if llm_client is None: + llm_client = LLMClientFactory(args).make_client(config.llm_client) + self.llm_client = llm_client + + def _cleanup(self): + """Clean up resources.""" + if self.pusher: + self.pusher.close() + if self.data_puller: + self.data_puller.close() + + def run_generation_loop(self): + """Run the continuous generation loop like the original _generate_loop.""" + try: + asyncio.run(self._generate_loop()) + finally: + self._cleanup() + + async def _run_grouped_episode_async( + self, rid: int, data: Any, seed: Optional[int] = None + ): + """Run grouped episode asynchronously.""" + tasks = [] + for _ in range(self.gconfig.n_samples): + # Create collector + factory = RolloutCollectorFactory(self.args) + collector = factory.make_collector(self.config.collector) + tasks += [ + collector.arun_episode( + llm_client=self.llm_client, + gconfig=self.gconfig.new(n_samples=1), + env_option=data, + seed=seed, + ) + ] + trajs = await asyncio.gather(*tasks) + return rid, trajs + + def _get_model_version(self) -> int: + name = names.model_version( + self.args.experiment_name, + self.args.trial_name, + "actor", + ) + try: + return int(name_resolve.get(name)) + except name_resolve.NameEntryNotFoundError: + return 0 + + async def _generate_loop(self): + """Main generation loop - similar to original RolloutController._generate_loop.""" + data = None + + # Communication with main process + self.pusher = ZMQJsonPusher(host=self.pusher_host, port=self.pusher_port) + self.data_puller = ZMQJsonPuller( + host=self.data_puller_host, + port=self.data_puller_port, + bind=False, + ) + + rollout_stat = RolloutStat() + rollout_tasks: Dict[int, asyncio.Task] = {} + rid = 0 + + try: + while not self._shutdown: + # Load next data from controller + if data is None: + try: + data = self.data_puller.pull(timeout_ms=50) + logger.debug(f"Get data from puller: {data}") + except queue.Empty: + logger.debug(f"No data from puller stream.") + + # Check capacity + if dist.is_initialized(): + world_size = dist.get_world_size() + else: + world_size = 1 + + cannot_rollout_reason = [] + capacity = max(1, self.max_concurrent_rollouts // world_size) + can_rollout = len(rollout_tasks) < capacity + if not can_rollout: + cannot_rollout_reason.append( + f"Exceeding capacity: # running tasks {len(rollout_tasks)} >= capacity {capacity}" + ) + + # Staleness control + version = self._get_model_version() + ofp = self.config.max_head_offpolicyness + sample_cnt = rollout_stat.accepted + rollout_stat.running + expected_version = sample_cnt // self.train_batch_size + not_staled = expected_version <= ofp + version + can_rollout &= not_staled + if not not_staled: + cannot_rollout_reason.append( + f"Staled: expected version ({expected_version}) = " + f"global sample cnt ({sample_cnt}) // batch size ({self.train_batch_size}), " + f"current latest version {version}, " + f"offpolicyness {self.config.max_head_offpolicyness}." + ) + + if not can_rollout: + logger.debug( + f"Worker {self.worker_id}: Cannot submit new rollouts. " + + "\n".join(cannot_rollout_reason) + ) + + # Create new rollout task + if can_rollout and data is not None: + task = asyncio.create_task( + self._run_grouped_episode_async(rid, data) + ) + rollout_tasks[rid] = task + + rollout_stat.submitted += 1 + rollout_stat.running += 1 + logger.debug( + f"Worker {self.worker_id}: Submit rollout rid {rid}. " + f"Submit: {rollout_stat.submitted}, " + f"running: {rollout_stat.running}, " + f"accepted: {rollout_stat.accepted}." + ) + + rid += 1 + data = None + + # Wait for rollout completion + tasks = list(rollout_tasks.values()) + done = [] + if tasks: + done, _ = await asyncio.wait( + tasks, + timeout=ROLLOUT_POLL_WAIT_TIME, + return_when=asyncio.FIRST_COMPLETED, + ) + else: + await asyncio.sleep(ROLLOUT_POLL_WAIT_TIME) + + # Collect done results + for task in done: + task_rid, trajs = await task + trajs: List[Trajectory] + rollout_tasks.pop(task_rid) + rollout_stat.running -= 1 + + # Filter data according to episodic return + ret = np.mean([traj.stats.total_reward for traj in trajs]) + accepted = ret >= self.config.filter_reward_lb + accepted &= ret <= self.config.filter_reward_ub + + if accepted: + # Push trajectories to main process + trajectory_data = { + "worker_id": self.worker_id, + "trajs": [traj.to_json_compatible() for traj in trajs], + } + self.pusher.push(trajectory_data) + rollout_stat.accepted += 1 + + logger.debug( + f"Worker {self.worker_id}: Finish rollout {task_rid}. " + f"Submit: {rollout_stat.submitted}, " + f"running: {rollout_stat.running}, " + f"accepted: {rollout_stat.accepted}." + ) + finally: + # Cancel remaining tasks + for task in rollout_tasks.values(): + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass diff --git a/arealite/system/sglang_client.py b/arealite/system/sglang_client.py new file mode 100644 index 000000000..3ff1e134d --- /dev/null +++ b/arealite/system/sglang_client.py @@ -0,0 +1,165 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +import time + +from arealite.api.io_struct import LLMRequest, LLMResponse, LLMServerInfo +from arealite.api.llm_client_api import LLMClient +from realhf.base import logging, pkg_version + +logger = logging.getLogger(__name__) + +if pkg_version.is_available("sglang"): + if pkg_version.is_version_greater_or_equal("sglang", "0.4.4"): + SGLANG_TOKEN_OUTPUT_IDENTIFIER = "output_ids" + else: + SGLANG_TOKEN_OUTPUT_IDENTIFIER = "token_ids" + + +class SGLangClient(LLMClient): + """SGLang implementation of LLMClient.""" + + async def agenerate(self, req: LLMRequest) -> LLMResponse: + """Async version of generate using aiohttp.""" + + # Convert messages to prompt + if not req.text: + assert req.input_ids is not None + req.text = self.tokenizer.decode(req.input_ids) + + # Prepare request payload + gconfig = req.gconfig + stop_token_ids = gconfig.stop_token_ids + if self.tokenizer.eos_token_id not in stop_token_ids: + stop_token_ids.append(self.tokenizer.eos_token_id) + if self.tokenizer.pad_token_id not in stop_token_ids: + stop_token_ids.append(self.tokenizer.pad_token_id) + + assert gconfig.n_samples == 1 + sample_params = { + "top_p": gconfig.top_p, + "top_k": gconfig.top_k, + "max_new_tokens": gconfig.max_new_tokens, + "temperature": 0.0 if gconfig.greedy else gconfig.temperature, + "stop_token_ids": stop_token_ids, + } + + payload = { + "rid": req.rid, + "text": req.text, + "sampling_params": sample_params, + "return_logprob": True, + "stream": False, + } + + # Make request + start_time = time.perf_counter() + accumulated_output_tokens = [] + accumulated_output_logprobs = [] + accumulated_versions = [] + + # Deal with rollout interruption + completion = "" + stop_reason = "length" + + while ( + stop_reason != "stop" + and len(accumulated_output_tokens) < gconfig.max_new_tokens + ): + # loop until the generation is complete + response, server_info = await self.arequest_with_retry( + endpoint="/generate", + payload=payload, + method="POST", + max_retries=3, + timeout=self.client_config.request_timeout, + ) + result = await response.json() + + # Parse response + completion += result["text"] + meta_info = result["meta_info"] + output_tokens = [x[1] for x in meta_info["output_token_logprobs"]] + output_logprobs = [x[0] for x in meta_info["output_token_logprobs"]] + + # Update accumulated outputs + accumulated_output_tokens.extend(output_tokens) + accumulated_output_logprobs.extend(output_logprobs) + accumulated_versions.extend([server_info.version] * len(output_tokens)) + + # Check if generation is complete + finish_reason = meta_info["finish_reason"] + stop_reason = finish_reason["type"] + + payload["text"] += completion + + latency = time.perf_counter() - start_time + + return LLMResponse( + completion=completion, + input_tokens=req.input_ids, + output_tokens=accumulated_output_tokens, + output_logprobs=accumulated_output_logprobs, + output_versions=accumulated_versions, + stop_reason=stop_reason, + latency=latency, + ttft=latency, # Simplified for non-streaming + ) + + async def aupdate_weights_from_disk(self, server_info: LLMServerInfo, path: str): + server_url = f"http://{server_info.host}:{server_info.port}" + response, _ = await self.arequest_with_retry( + endpoint="/update_weights_from_disk", + payload=dict(model_path=path, allow_interrupt=True), + method="POST", + max_retries=3, + timeout=self.client_config.request_timeout, + target_server=server_info, + ) + res = await response.json() + assert res["success"] + if "num_paused_requests" in res: + logger.info( + f"{res['num_paused_requests']} requests are interrupted " + f"during updating weights for server {server_url}" + ) + self.registry.update_heartbeat( + server_info.server_id, "healthy", version=server_info.version + 1 + ) + + async def ainit_weight_update_group(self, server_info, group_meta): + payload = dict( + master_address=group_meta.master_address, + master_port=group_meta.master_port, + rank_offset=group_meta.rank_offset, + world_size=group_meta.world_size, + group_name=group_meta.group_name, + backend=group_meta.backend, + ) + response, _ = await self.arequest_with_retry( + endpoint="/init_weights_update_group", + payload=payload, + method="POST", + max_retries=3, + timeout=self.client_config.request_timeout, + target_server=server_info, + ) + res = await response.json() + assert res["success"], res["message"] + + async def aupdate_weights_from_distributed(self, server_info, weight_meta): + payload = dict( + name=weight_meta.param_name, + dtype=weight_meta.dtype, + shape=weight_meta.shape, + ) + response, _ = await self.arequest_with_retry( + endpoint="/update_weights_from_distributed", + payload=payload, + method="POST", + max_retries=3, + timeout=self.client_config.request_timeout, + target_server=server_info, + ) + res = await response.json() + assert res["success"], res["message"] diff --git a/arealite/system/sglang_server.py b/arealite/system/sglang_server.py new file mode 100644 index 000000000..a620bd267 --- /dev/null +++ b/arealite/system/sglang_server.py @@ -0,0 +1,166 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +import os +import subprocess +import sys +from pathlib import Path + +import requests + +from arealite.api.cli_args import LLMServiceConfig, SGLangConfig +from arealite.api.io_struct import AllocationMode, LLMServerInfo +from arealite.api.llm_server_api import LLMServer +from realhf.base import gpu_utils, logging, network, pkg_version + +logger = logging.getLogger(__name__) + + +def apply_sglang_path(): + """Apply SGLang patch if available.""" + p = Path(os.path.dirname(__file__)) + patch_path = str( + p.parent.parent.parent + / "patch" + / "sglang" + / f"v{pkg_version.get_version('sglang')}.patch" + ) + + target_path = "" + try: + sglang_meta = subprocess.check_output( + "python3 -m pip show sglang", shell=True + ).decode("ascii") + for line in sglang_meta.split("\n"): + line = line.strip() + if line.startswith("Editable project location: "): + target_path = str(Path(line.split(": ")[1]).parent) + + if target_path and Path(patch_path).exists(): + proc = subprocess.Popen( + ["git", "apply", patch_path], + cwd=target_path, + stderr=sys.stdout, + stdout=sys.stdout, + ) + proc.wait() + logger.info(f"Applied SGLang patch at {target_path}") + except (subprocess.CalledProcessError, FileNotFoundError): + pass + + +class SGLangServer(LLMServer): + """SGLang implementation of LLMServer.""" + + def __init__(self, args, service_config: LLMServiceConfig): + super().__init__(args, service_config) + self.server_info: LLMServerInfo | None = None + self.base_gpu_id = 0 + self.config = args.rollout.sglang + + self.alloc_mode = AllocationMode.from_str(args.allocation_mode) + + def _resolve_base_gpu_id(self): + # Determine GPU configuration + import ray + + tp_size = self.alloc_mode.gen_tp_size + pp_size = self.alloc_mode.gen_pp_size + mp_size = tp_size * pp_size + if ray.is_initialized(): + self.base_gpu_id = 0 + elif "CUDA_VISIBLE_DEVICES" in os.environ: + if len(os.environ["CUDA_VISIBLE_DEVICES"]) == 1: + self.base_gpu_id = int(os.environ["CUDA_VISIBLE_DEVICES"]) + elif len(os.environ["CUDA_VISIBLE_DEVICES"]) == mp_size: + self.base_gpu_id = int(os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0]) + else: + logger.warning( + f"Unknown how to resolve cuda visible devices: {os.environ['CUDA_VISIBLE_DEVICES']}, " + f"setting base_gpu_id to 0." + ) + self.base_gpu_id = 0 + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( + map(str, range(gpu_utils.gpu_count())) + ) + elif "RANK" in os.environ and "WORLD_SIZE" in os.environ: + # torchrun + self.base_gpu_id = int(os.environ["RANK"]) % gpu_utils.gpu_count() + elif gpu_utils.gpu_count() == mp_size: + self.base_gpu_id = 0 + else: + logger.warning("Unknown GPU configuration, setting base_gpu_id to 0. ") + self.base_gpu_id = 0 + + def launch_server(self) -> LLMServerInfo | None: + # Apply SGLang patch + apply_sglang_path() + self._resolve_base_gpu_id() + # Get host and ports + host_ip = network.gethostip() + host = "localhost" if not self.config.enable_metrics else host_ip + ports = network.find_multiple_free_ports( + 2, + low=10000, + high=60000, + experiment_name=self.registry.expr_name, + trial_name=self.registry.trial_name, + ) + server_port = ports[0] + nccl_port = ports[1] + # Build command + tp_size = self.alloc_mode.gen_tp_size + cmd = SGLangConfig.build_cmd( + sglang_config=self.config, + model_path=self.args.rollout.model_path, + tp_size=tp_size, + base_gpu_id=self.base_gpu_id, + dist_init_addr=f"{host}:{nccl_port}", + served_model_name=self.service_config.served_model_name, + skip_tokenizer_init=False, + ) + # Launch process + full_command = f"{cmd} --port {server_port}" + full_command = full_command.replace("\\\n", " ").replace("\\", " ") + self.process = subprocess.Popen( + full_command.split(), + text=True, + stdout=sys.stdout, + stderr=sys.stdout, + ) + # Create server info + self.server_info = LLMServerInfo( + server_id=self.server_id, + host=host, + port=server_port, + status="starting", + version=0, + ) + return self.server_info + + def check_health(self) -> bool: + """Check if the SGLang server is healthy.""" + if not self.server_info or not self.process: + return False + + # Check if process is still running + if self.process.poll() is not None: + return False + + try: + # Check server endpoint + base_url = f"http://{self.server_info.host}:{self.server_info.port}" + response = requests.get( + f"{base_url}/metrics", + timeout=30, + ) + if response.status_code != 200: + return False + # Update server load + for line in response.text.split("\n"): + if line.startswith("sglang:num_running_reqs"): + self.load = float(line.split(" ")[1]) + break + return True + except requests.exceptions.RequestException: + return False diff --git a/arealite/tests/data/rlvr_code_dataset.jsonl b/arealite/tests/data/rlvr_code_dataset.jsonl new file mode 100644 index 000000000..a6d1ed9a7 --- /dev/null +++ b/arealite/tests/data/rlvr_code_dataset.jsonl @@ -0,0 +1,5 @@ +{"prompt": "<\uff5cUser\uff5c>Takahashi has A cookies, and Aoki has B cookies.\nTakahashi will do the following action K times:\n - If Takahashi has one or more cookies, eat one of his cookies.\n - Otherwise, if Aoki has one or more cookies, eat one of Aoki's cookies.\n - If they both have no cookies, do nothing.\nIn the end, how many cookies will Takahashi and Aoki have, respectively?\n\n-----Constraints-----\n - 0 \\leq A \\leq 10^{12}\n - 0 \\leq B \\leq 10^{12}\n - 0 \\leq K \\leq 10^{12}\n - All values in input are integers.\n\n-----Input-----\nInput is given from Standard Input in the following format:\nA B K\n\n-----Output-----\nPrint the numbers of Takahashi's and Aoki's cookies after K actions.\n\n-----Sample Input-----\n2 3 3\n\n-----Sample Output-----\n0 2\n\nTakahashi will do the following:\n - He has two cookies, so he eats one of them.\n - Now he has one cookie left, and he eats it.\n - Now he has no cookies left, but Aoki has three, so Takahashi eats one of them.\nThus, in the end, Takahashi will have 0 cookies, and Aoki will have 2.\n<\uff5cAssistant\uff5c>\n", "question": "Takahashi has A cookies, and Aoki has B cookies.\nTakahashi will do the following action K times:\n - If Takahashi has one or more cookies, eat one of his cookies.\n - Otherwise, if Aoki has one or more cookies, eat one of Aoki's cookies.\n - If they both have no cookies, do nothing.\nIn the end, how many cookies will Takahashi and Aoki have, respectively?\n\n-----Constraints-----\n - 0 \\leq A \\leq 10^{12}\n - 0 \\leq B \\leq 10^{12}\n - 0 \\leq K \\leq 10^{12}\n - All values in input are integers.\n\n-----Input-----\nInput is given from Standard Input in the following format:\nA B K\n\n-----Output-----\nPrint the numbers of Takahashi's and Aoki's cookies after K actions.\n\n-----Sample Input-----\n2 3 3\n\n-----Sample Output-----\n0 2\n\nTakahashi will do the following:\n - He has two cookies, so he eats one of them.\n - Now he has one cookie left, and he eats it.\n - Now he has no cookies left, but Aoki has three, so Takahashi eats one of them.\nThus, in the end, Takahashi will have 0 cookies, and Aoki will have 2.", "query_id": "deepcoder_taco_754161b853b1e0b1c02ee9667cd50f8c", "id": "deepcoder_taco_754161b853b1e0b1c02ee9667cd50f8c", "starter_code": null, "input_output": "{\"inputs\": [\"2 3 3\\n\", \"500000000000 500000000000 1000000000000\\n\", \"500000000000 500000000001 1000000000000\\n\", \"500000000000 499999999999 1000000000000\\n\", \"0 0 0\\n\", \"0 0 1000000000000\\n\", \"0 1000000000000 0\\n\", \"0 1000000000000 1000000000000\\n\", \"1000000000000 0 0\\n\", \"1000000000000 0 1000000000000\\n\", \"1000000000000 1000000000000 0\\n\", \"1000000000000 1000000000000 1000000000000\\n\", \"999664720736 99150401673 9177110689\\n\", \"999553244087 9473760141 99451169880\\n\", \"99818575601 999284522381 9657141929\\n\", \"99374288514 9551434405 999519154734\\n\", \"9352360840 999532174388 99550343731\\n\", \"9839285289 99663130322 999454076321\\n\", \"500000000000 274588616851 1000000000000\", \"2 3 1\", \"2 3 2\", \"2 3 0\", \"2 6 0\", \"4 6 0\", \"8 6 0\", \"10 6 0\", \"498881727339 49011419508 0000000100000\", \"0 6 0\", \"498881727339 49011419508 0000001100000\", \"0 9 0\", \"498881727339 50858735181 0000001100000\", \"498881727339 64133142411 0000001100000\", \"498881727339 86551688904 0000001100000\", \"498881727339 86551688904 0000000100000\", \"498881727339 88952939337 0000000100000\", \"498881727339 88952939337 0000010100000\", \"623031139973 88952939337 0000010100000\", \"623031139973 53371298929 0000010100000\", \"488984103352 53371298929 0000010100000\", \"488984103352 28853843669 0000010100000\", \"488984103352 44514766879 0000010100000\", \"488984103352 24648370780 0000010100000\", \"488984103352 11555398095 0000010100000\", \"551061056476 11555398095 0000010100000\", \"551061056476 2744039490 0000010100000\", \"551061056476 2744039490 0000010100100\", \"551061056476 2744039490 0010010100100\", \"690068676694 2744039490 0010010100100\", \"810189303696 2744039490 0010010100100\", \"810189303696 304656088 0010010100100\", \"810189303696 304656088 0010010101100\", \"810189303696 559952381 0010010101100\", \"878339967763 559952381 0010010101100\", \"878339967763 559952381 0010010101101\", \"1138880556104 559952381 0010010101101\", \"1138880556104 516416634 0010010101101\", \"637172716773 516416634 0010010101101\", \"1127076073437 516416634 0010010101101\", \"1127076073437 516416634 0010010100101\", \"1127076073437 516416634 0010010100001\", \"1485644197907 516416634 0010010100001\", \"40795666163 516416634 0010010100001\", \"40795666163 581054487 0010010100001\", \"40795666163 581054487 0010010110001\", \"22565061889 581054487 0010010110001\", \"22565061889 581054487 0010110110001\", \"18064166595 581054487 0010110110001\", \"9875757521 933260677 0010110110000\", \"1547869 73650225 0000000111100\", \"204065 73650225 0000000111100\", \"362104 73650225 0000000111100\", \"362104 24404642 0000000111100\", \"362104 16911655 0000000111100\", \"283244830608 274588616851 1000000000000\", \"498881727339 274588616851 1000000000000\", \"498881727339 92988479429 1000000000000\", \"498881727339 92988479429 1010000000000\", \"498881727339 92988479429 1000000100000\", \"498881727339 49011419508 1000000100000\", \"0 0 0\", \"-1 0 0\", \"-1 0 1\", \"-2 0 1\", \"0 0 1\", \"0 -1 1\", \"0 -1 2\", \"4188594464 581054487 0010110110001\", \"7578113331 581054487 0010110110001\", \"7578113331 581054487 0010110110000\", \"7578113331 933260677 0010110110000\", \"9875757521 933260677 1010110110000\", \"9875757521 933260677 1010110100000\", \"9875757521 933260677 1110110100000\", \"9875757521 933260677 1110110100001\", \"8901212864 933260677 1110110100001\", \"8901212864 933260677 1110010100001\", \"9957043226 933260677 1110010100001\", \"9957043226 933260677 1110110100001\", \"13071060683 933260677 1110110100001\", \"13071060683 933260677 1010110100001\", \"24844196728 933260677 1010110100001\", \"24844196728 933260677 1010010100001\", \"46282745761 933260677 1010010100001\", \"46282745761 933260677 1010010100011\", \"46282745761 1063903834 1010010100011\", \"46282745761 1953749145 1010010100011\", \"46282745761 3894888655 1010010100011\", \"82916368507 3894888655 1010010100011\", \"125508515642 3894888655 1010010100011\", \"250272688906 3894888655 1010010100011\", \"250272688906 5812721198 1010010100011\", \"250272688906 9468746842 1010010100011\", \"143395449423 9468746842 1010010100011\", \"223869749260 9468746842 1010010100011\", \"223869749260 13514692621 1010010100011\", \"223869749260 13514692621 1010010101011\", \"223869749260 13514692621 1011010101011\", \"223869749260 13514692621 1011010101111\", \"500000000000 500000000000 1000000000000\", \"2 3 3\"], \"outputs\": [\"0 2\\n\", \"0 0\\n\", \"0 1\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 1000000000000\\n\", \"0 0\\n\", \"1000000000000 0\\n\", \"0 0\\n\", \"1000000000000 1000000000000\\n\", \"0 1000000000000\\n\", \"990487610047 99150401673\\n\", \"900102074207 9473760141\\n\", \"90161433672 999284522381\\n\", \"0 0\\n\", \"0 909334191497\\n\", \"0 0\\n\", \"0 0\\n\", \"1 3\\n\", \"0 3\\n\", \"2 3\\n\", \"2 6\\n\", \"4 6\\n\", \"8 6\\n\", \"10 6\\n\", \"498881627339 49011419508\\n\", \"0 6\\n\", \"498880627339 49011419508\\n\", \"0 9\\n\", \"498880627339 50858735181\\n\", \"498880627339 64133142411\\n\", \"498880627339 86551688904\\n\", \"498881627339 86551688904\\n\", \"498881627339 88952939337\\n\", \"498871627339 88952939337\\n\", \"623021039973 88952939337\\n\", \"623021039973 53371298929\\n\", \"488974003352 53371298929\\n\", \"488974003352 28853843669\\n\", \"488974003352 44514766879\\n\", \"488974003352 24648370780\\n\", \"488974003352 11555398095\\n\", \"551050956476 11555398095\\n\", \"551050956476 2744039490\\n\", \"551050956376 2744039490\\n\", \"541050956376 2744039490\\n\", \"680058576594 2744039490\\n\", \"800179203596 2744039490\\n\", \"800179203596 304656088\\n\", \"800179202596 304656088\\n\", \"800179202596 559952381\\n\", \"868329866663 559952381\\n\", \"868329866662 559952381\\n\", \"1128870455003 559952381\\n\", \"1128870455003 516416634\\n\", \"627162615672 516416634\\n\", \"1117065972336 516416634\\n\", \"1117065973336 516416634\\n\", \"1117065973436 516416634\\n\", \"1475634097906 516416634\\n\", \"30785566162 516416634\\n\", \"30785566162 581054487\\n\", \"30785556162 581054487\\n\", \"12554951888 581054487\\n\", \"12454951888 581054487\\n\", \"7954056594 581054487\\n\", \"0 698908198\\n\", \"1436769 73650225\\n\", \"92965 73650225\\n\", \"251004 73650225\\n\", \"251004 24404642\\n\", \"251004 16911655\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\\n\", \"0 0\", \"0 2\"], \"remote\": false}", "task": "code", "language": "PYTHON", "solutions": ["(a, b, k) = map(int, input().split())\nprint(max(a - k, 0), max(b - max(k - a, 0), 0))\n", "(a, b, k) = map(int, input().split())\nnum = min(a, k)\na -= num\nk -= num\nprint(a, max(b - k, 0))\n", "(A, B, K) = map(int, input().split())\nprint(max(0, A - K), max(0, min(B, A + B - K)))\n", "(A, B, K) = map(int, input().split())\nt = max(0, A - K)\na = max(0, A + B - K - t)\nprint(t, a)\n", "(A, B, K) = map(int, input().split())\na = max(A - K, 0)\nK -= A - a\nb = max(B - K, 0)\nprint(a, b)\n", "(a, b, k) = map(int, input().split())\nif a >= k:\n\tprint(a - k, b)\nelse:\n\tprint(0, max(b - (k - a), 0))\n", "(a, b, k) = map(int, input().split())\nif a < k:\n\tb -= k - a\n\ta = 0\n\tif b < 0:\n\t\tb = 0\nelif a - k > 0:\n\ta -= k\nelse:\n\ta = 0\nprint(a, b)\n", "a = input()\na = a.split()\nk = int(a[2])\nb = int(a[1])\na = int(a[0])\nd = b\nc = a - k\nif c < 0:\n\tc = 0\nif c == 0:\n\td = b - k + a\nif d < 0:\n\td = 0\nprint(c, d)\n", "(a, b, k) = map(int, input().split())\nc = max(a - k, 0)\nk -= a - c\nd = max(b - k, 0)\nprint(c, d)\n", "(a, b, c) = map(int, input().split())\naa = max(0, a - c)\nprint(aa, max(0, a + b - aa - c))\n", "(a, b, n) = map(int, input().split())\nprint(max(a - n, 0), max(b - max(n - a, 0), 0))\n", "(a, b, k) = map(int, input().split())\nprint(str(max(a - k, 0)) + ' ' + str(max(min(b, a + b - k), 0)))\n", "(a, b, k) = map(int, input().split())\nta = a - min(a, k)\nk -= a - ta\nb -= min(b, k)\nprint(ta, b)\n", "(a, b, k) = map(int, input().split())\nn = min(a, k)\nk -= n\nprint(a - n, max(0, b - k))\n", "(t, a, k) = map(int, input().split())\nT = max(t - k, 0)\nk -= t - T\nA = max(a - k, 0)\nprint(T, A)\n", "(A, B, K) = map(int, input().split())\nk = max(K - A, 0)\nprint(max(A - K, 0), max(B - k, 0))\n", "(a, b, k) = map(int, input().split())\nremain = a - k\ntakahashi = remain\naoki = b\nif remain < 0:\n\ttakahashi = 0\n\taoki = b + remain\nprint(str(takahashi) + ' ' + str(aoki if aoki > 0 else 0))\n", "(a, b, k) = map(int, input().split())\nA = max(a - k, 0)\nB = max(b - max(k - a, 0), 0)\nprint(A, B)\n", "(a, b, c) = input().split()\na = int(a)\nb = int(b)\nc = int(c)\nif a >= c:\n\tprint(a - c, b)\nif a < c <= a + b:\n\tprint(0, a + b - c)\nif c > a + b:\n\tprint(0, 0)\n", "(a, b, k) = list(map(int, input().split()))\nprint(max(0, a - k), max(0, min(b, a + b - k)))\n", "def slove():\n\t(A, B, K) = map(int, input().split())\n\tif A > K:\n\t\tprint(A - K, B)\n\telse:\n\t\tprint(0, max(0, -K + (A + B)))\n\ndef __starting_point():\n\tslove()\n__starting_point()\n", "(a, b, k) = map(int, input().split())\neat1 = min(a, k)\nk -= eat1\neat2 = min(b, k)\nprint(a - eat1, b - eat2)\n", "(a, b, k) = (int(i) for i in input().split())\ne = min(a, k)\na -= e\nk -= e\nb -= min(b, k)\nprint(a, b)\n", "(a, b, k) = (int(x) for x in input().split())\nc = a\nif a + b <= k:\n\ta = 0\n\tb = 0\nelif a <= k:\n\ta = 0\n\tb = b + c - k\nelif a > k:\n\ta = a - k\nc = str(a) + ' ' + str(b)\nprint(c)\n", "(a, b, k) = map(int, input().split())\nta = max(0, a - k)\nk -= a - ta\ntb = max(0, b - k)\nprint(ta, tb)\n", "(t, a, k) = map(int, input().split())\nt = t - k\nif t < 0:\n\ta += t\n\tt = 0\nif a < 0:\n\ta = 0\nprint(t, a)\n", "(a, b, k) = map(int, input().split())\nm = min(a, k)\na -= m\nk -= m\nb -= k\nprint(a, max(b, 0))\n", "(A, B, K) = map(int, input().split())\nx = min(A, K)\ny = min(B, K - x)\nprint(A - x, B - y)\n", "(A, B, K) = list(map(int, input().split()))\nk = min(A, K)\nprint(A - k, max(0, B - (K - k)))\n", "(A, B, K) = map(int, input().split())\nprint(max(0, A - K), max(0, B - abs(max(0, K - A))))\n", "import numpy as np\n(A, B, K) = map(int, input().split())\nif K > A:\n\ta = 0\n\tB = max(B - (K - A), 0)\nelse:\n\ta = A - K\nprint(a, B)\n", "[A, B, K] = list(map(int, input().split()))\nX = max(0, A - K)\nY = max(0, min(B, B - (K - A)))\nprint(X, Y)\n", "(A, B, K) = map(int, input().split())\nprint(A - min(A, K), B - min(B, K - min(A, K)))\n", "(a, b, k) = map(int, input().split())\ny = min(a, k)\nz = min(b, k - y)\nprint(a - y, b - z)\n", "(a, b, k) = map(int, input().split())\nt = min(a, k)\na -= t\nk -= t\nu = min(b, k)\nb -= u\nprint(a, b)\n", "(A, B, K) = map(int, input().split())\nprint(max(0, A - K), end=' ')\nK = max(0, K - A)\nprint(max(0, B - K))\n", "(a, b, k) = map(int, input().split())\nra = max(a - k, 0)\nif k > a:\n\tb = max(b - (k - a), 0)\nprint(ra, b)\n", "(a, b, k) = map(int, input().split())\ntakahashi = [0, a - k]\naoki = [0, b + min(takahashi)]\nprint('{takahashi} {aoki}'.format(takahashi=max(takahashi), aoki=max(aoki)))\n", "(a, b, k) = map(int, input().split())\naa = max(0, a - k)\nb = b + min(0, a - k)\nprint(aa, max(b, 0))\n", "(a, b, k) = map(int, input().split())\nx = min(a, k)\na -= x\nk -= x\ny = min(b, k)\nb -= y\nprint(a, b)\n", "(a, b, c) = map(int, input().split())\nx = max(0, a - c)\nbb = max(0, c - a)\nprint(x, max(0, b - bb))\n", "(A, B, K) = map(int, input().split())\nT = max(0, A - K)\nprint(T, max(0, B - K + A - T))\n", "(A, B, K) = map(int, input().split())\nd = A - K\nif A + B >= K:\n\tif d < 0:\n\t\tprint(0, B + d)\n\telse:\n\t\tprint(d, B)\nelse:\n\tprint(0, 0)\n", "(a, b, k) = map(int, input().split())\naa = max(a - k, 0)\nrest = max(k - a, 0)\nbb = max(b - rest, 0)\nprint(f'{aa} {bb}')\n", "(a, b, c) = map(int, input().split())\nrem = min(a, c)\na -= rem\nc -= rem\nrem = min(b, c)\nb -= rem\nprint(a, b)\n", "(a, b, k) = [int(i) for i in input().split()]\nif a <= k:\n\tb = b - k + a if b - k + a >= 0 else 0\n\ta = 0\nelse:\n\ta -= k\nprint(a, b)\n", "(A, B, K) = map(int, input().split())\nrA = max(0, A - K)\nrB = max(0, B - (K - (A - rA)))\nprint(rA, rB)\n", "(A, B, K) = map(int, input().split())\ntemp = A\nif K >= A:\n\tK = K - A\n\tA = 0\nelse:\n\tA = A - K\n\tK = 0\nif B >= K:\n\tB = B - K\nelse:\n\tB = 0\nprint('{} {}'.format(A, B))\n", "(a, b, k) = [int(x) for x in input().split()]\nif a > k:\n\tprint(a - k, b)\nelse:\n\tprint(0, max(0, a + b - k))\n", "(a, b, k) = map(int, input().split())\ntaka = a - k\nif taka < 0:\n\taoki = b - abs(taka)\n\ttaka = 0\n\tif aoki < 0:\n\t\taoki = 0\nelse:\n\taoki = b\nprint(taka, aoki)\n", "(a, b, n) = map(int, input().split())\nif a >= n:\n\ta -= n\n\tprint(a, b)\nelif a < n:\n\tn -= a\n\tif n >= b:\n\t\tprint(0, 0)\n\telse:\n\t\tprint(0, b - n)\n", "(a, b, k) = map(int, input().split())\nprint(' '.join(map(str, [max(0, a - k), max(0, b - max(0, k - a))])))\n", "(T, A, K) = map(int, input().split())\nS = T\nT -= min(T, K)\nK = K - (S - T)\nA -= min(A, K)\nprint(T, A)\n", "import sys\n(a, b, k) = map(int, input().split())\nif a >= k:\n\tprint(a - k, b)\n\treturn\nif b >= k - a:\n\tprint(0, b - (k - a))\n\treturn\nprint(0, 0)\n", "(A, B, K) = map(int, input().split())\n\ndef sum():\n\tnonlocal A, B, K\n\tif A >= K:\n\t\treturn (A - K, B)\n\telse:\n\t\ta = K - A\n\t\tA = 0\n\t\tif B >= a:\n\t\t\treturn (A, B - a)\n\t\telse:\n\t\t\treturn (A, 0)\nprint(*sum())\n", "(A, B, K) = map(int, input().split())\n(A, rem) = (max(A - K, 0), max(K - A, 0))\nB = max(B - rem, 0)\nprint(A, B)\n"], "verify": 1, "wrong_type": null} +{"prompt": "<\uff5cUser\uff5c>Write a program which prints $n$-th fibonacci number for a given integer $n$. The $n$-th fibonacci number is defined by the following recursive formula:\n\n\\begin{equation*} fib(n)= \\left \\\\{ \\begin{array}{ll} 1 & (n = 0) \\\\\\ 1 & (n = 1) \\\\\\ fib(n - 1) + fib(n - 2) & \\\\\\ \\end{array} \\right. \\end{equation*}\n\nConstraints\n\n* $0 \\leq n \\leq 44$\n\nInput\n\nAn integer $n$ is given.\n\nExample\n\nInput\n\n3\n\n\nOutput\n\n3\n<\uff5cAssistant\uff5c>\n", "question": "Write a program which prints $n$-th fibonacci number for a given integer $n$. The $n$-th fibonacci number is defined by the following recursive formula:\n\n\\begin{equation*} fib(n)= \\left \\\\{ \\begin{array}{ll} 1 & (n = 0) \\\\\\ 1 & (n = 1) \\\\\\ fib(n - 1) + fib(n - 2) & \\\\\\ \\end{array} \\right. \\end{equation*}\n\nConstraints\n\n* $0 \\leq n \\leq 44$\n\nInput\n\nAn integer $n$ is given.\n\nExample\n\nInput\n\n3\n\n\nOutput\n\n3", "query_id": "deepcoder_taco_d46ca8a56178d77918f5a864bdf4de7c", "id": "deepcoder_taco_d46ca8a56178d77918f5a864bdf4de7c", "starter_code": null, "input_output": "{\"inputs\": [\"1\", \"2\", \"4\", \"8\", \"6\", \"15\", \"7\", \"24\", \"14\", \"5\", \"13\", \"11\", \"22\", \"9\", \"29\", \"12\", \"10\", \"41\", \"16\", \"30\", \"21\", \"28\", \"17\", \"34\", \"18\", \"38\", \"31\", \"44\", \"36\", \"32\", \"33\", \"23\", \"19\", \"27\", \"20\", \"37\", \"25\", \"26\", \"42\", \"43\", \"35\", \"39\", \"40\", \"001\", \"6\", \"12\", \"17\", \"24\", \"31\", \"5\", \"2\", \"4\", \"15\", \"30\", \"8\", \"1\", \"10\", \"11\", \"16\", \"9\", \"7\", \"13\", \"27\", \"25\", \"14\", \"19\", \"44\", \"22\", \"23\", \"43\", \"21\", \"36\", \"20\", \"18\", \"29\", \"38\", \"32\", \"35\", \"33\", \"40\", \"28\", \"41\", \"26\", \"34\", \"39\", \"42\", \"37\", \"001\", \"3\"], \"outputs\": [\"1\\n\", \"2\\n\", \"5\\n\", \"34\\n\", \"13\\n\", \"987\\n\", \"21\\n\", \"75025\\n\", \"610\\n\", \"8\\n\", \"377\\n\", \"144\\n\", \"28657\\n\", \"55\\n\", \"832040\\n\", \"233\\n\", \"89\\n\", \"267914296\\n\", \"1597\\n\", \"1346269\\n\", \"17711\\n\", \"514229\\n\", \"2584\\n\", \"9227465\\n\", \"4181\\n\", \"63245986\\n\", \"2178309\\n\", \"1134903170\\n\", \"24157817\\n\", \"3524578\\n\", \"5702887\\n\", \"46368\\n\", \"6765\\n\", \"317811\\n\", \"10946\\n\", \"39088169\\n\", \"121393\\n\", \"196418\\n\", \"433494437\\n\", \"701408733\\n\", \"14930352\\n\", \"102334155\\n\", \"165580141\\n\", \"1\\n\", \"13\\n\", \"233\\n\", \"2584\\n\", \"75025\\n\", \"2178309\\n\", \"8\\n\", \"2\\n\", \"5\\n\", \"987\\n\", \"1346269\\n\", \"34\\n\", \"1\\n\", \"89\\n\", \"144\\n\", \"1597\\n\", \"55\\n\", \"21\\n\", \"377\\n\", \"317811\\n\", \"121393\\n\", \"610\\n\", \"6765\\n\", \"1134903170\\n\", \"28657\\n\", \"46368\\n\", \"701408733\\n\", \"17711\\n\", \"24157817\\n\", \"10946\\n\", \"4181\\n\", \"832040\\n\", \"63245986\\n\", \"3524578\\n\", \"14930352\\n\", \"5702887\\n\", \"165580141\\n\", \"514229\\n\", \"267914296\\n\", \"196418\\n\", \"9227465\\n\", \"102334155\\n\", \"433494437\\n\", \"39088169\\n\", \"1\\n\", \"3\"], \"remote\": false}", "task": "code", "language": "PYTHON", "solutions": ["n = int(input())\nnum = [1, 1]\nfor i in range(2, 45):\n\tf = num[i - 1] + num[i - 2]\n\tnum.append(f)\nprint(num[n])\n", "n = int(input())\nfib_b = [1, 1]\nif n > 1:\n\tfor i in range(2, n + 1):\n\t\tfib_b.append(fib_b[i - 1] + fib_b[i - 2])\nprint(fib_b[n])\n", "n = int(input())\na = b = 1\nwhile n:\n\t(a, b) = (b, a + b)\n\tn -= 1\nprint(a)\n", "n = int(input())\nfib = [1, 1]\nfor i in range(2, n + 1):\n\tfib.append(fib[-1] + fib[-2])\nprint(fib[-1])\n", "n = int(input())\na1 = 1\na = 1\ni = 1\nwhile i < n:\n\t(a1, a) = (a, a1 + a)\n\ti += 1\nprint(a)\n", "n = int(input())\ndp = [0] * (n + 1)\ndp[0] = 1\ndp[1] = 1\nfor i in range(2, n + 1):\n\tdp[i] = dp[i - 1] + dp[i - 2]\nprint(dp[-1])\n", "n = int(input())\nfib = [1, 1, 2]\nfor i in range(3, n + 1):\n\ttmp = fib[i - 1] + fib[-2]\n\tfib.append(tmp)\nprint(fib[n])\n", "n = int(input())\ndp = [1, 1]\nif n > 1:\n\tfor i in range(2, n + 1):\n\t\tdp.append(dp[i - 1] + dp[i - 2])\nprint(dp[n])\n", "a = [1, 1]\n\ndef fib(n):\n\ttry:\n\t\treturn a[n]\n\texcept:\n\t\ta.append(fib(n - 2) + fib(n - 1))\n\t\treturn a[n]\nn = int(input())\nprint(fib(n))\n", "N = int(input())\ndp = [1] * (N + 1)\nfor n in range(2, N + 1):\n\tdp[n] = dp[n - 1] + dp[n - 2]\nprint(dp[N])\n", "n = int(input())\nfib = [-1] * (n + 1)\n(fib[0], fib[1]) = (1, 1)\nfor i in range(2, n + 1):\n\tfib[i] = fib[i - 1] + fib[i - 2]\nprint(fib[-1])\n", "n = int(input())\narr = [1, 1]\nfor i in range(2, 45):\n\tarr.append(arr[i - 1] + arr[i - 2])\nprint(arr[n])\n", "n = int(input())\n\ndef fib(n):\n\ta = 1\n\tb = 1\n\tfor _ in range(n):\n\t\t(a, b) = (b, a + b)\n\treturn a\nprint(fib(n))\n", "def fb(num):\n\t(a, b) = (1, 0)\n\tfor _ in range(num):\n\t\t(a, b) = (a + b, a)\n\treturn b\nn = int(input())\nprint(fb(n + 1))\n", "def fib(n):\n\tfib = [1, 1]\n\tfor i in range(2, n + 1):\n\t\tfib.append(fib[i - 2] + fib[i - 1])\n\treturn fib[n]\nn = int(input())\nprint(fib(n))\n", "l = [1, 1]\nn = int(input())\nfor i in range(n):\n\tfib = [l[i] + l[i + 1]]\n\tl = l + fib\nprint(l[n])\n", "(a, b) = (1, 1)\nfor i in range(int(input())):\n\t(a, b) = (b, a + b)\nprint(a)\n", "n = int(input()) + 1\ndp = [1] * n\nfor i in range(n):\n\tif i < 2:\n\t\tcontinue\n\tdp[i] = dp[i - 1] + dp[i - 2]\nprint(dp[n - 1])\n", "n = int(input())\nif n <= 1:\n\tprint(1)\n\texit()\nx = [1, 1]\nfor i in range(n - 1):\n\tx.append(x[-1] + x[-2])\nprint(x[-1])\n", "a = 1\nb = 1\nlist = []\nn = int(input())\nlist.append(a)\nlist.append(b)\nfor i in range(n):\n\tlist.append(a + b)\n\td = b\n\tb += a\n\ta = d\nprint(list[n])\n", "N = int(input())\ndp = [0] * 50\nfor i in range(N + 1):\n\tif i < 2:\n\t\tdp[i] = 1\n\telse:\n\t\tdp[i] = dp[i - 1] + dp[i - 2]\nprint(dp[N])\n", "n = int(input())\nn1 = 1\nn2 = 1\nfor i in range(n):\n\t(n1, n2) = (n2, n1 + n2)\nprint(n1)\n", "n = int(input())\nl = [1, 1]\nfor i in range(n - 1):\n\tl.append(l[-1] + l[-2])\nprint(l[n])\n", "fib = [0] * 45\nfib[0] = 1\nfib[1] = 1\nn = int(input())\nfor i in range(2, 45):\n\tfib[i] = fib[i - 1] + fib[i - 2]\nprint(fib[n])\n", "n = int(input())\nDP = [0] * 45\n(DP[0], DP[1]) = (1, 1)\nfor i in range(2, 45):\n\tDP[i] = DP[i - 1] + DP[i - 2]\nprint(DP[n])\n", "n = int(input())\nfib = [0] * (n + 1)\n(fib[0], fib[1]) = (1, 1)\nfor i in range(2, n + 1):\n\tfib[i] = fib[i - 1] + fib[i - 2]\nprint(fib[n])\n", "fib = [0] * 45\nfib[0] = fib[1] = 1\nfor i in range(2, 45):\n\tfib[i] = fib[i - 1] + fib[i - 2]\nprint(fib[int(input())])\n", "N = int(input())\nD = [0] * (N + 1)\nD[0] = 1\nD[1] = 1\nfor i in range(2, N + 1):\n\tD[i] = D[i - 1] + D[i - 2]\nprint(D[N])\n", "n = int(input())\nF = [0] * (n + 1)\nF[0] = 1\nF[1] = 1\nfor i in range(2, len(F)):\n\tF[i] = F[i - 1] + F[i - 2]\nprint(F[n])\n", "N = int(input())\ndp = [0] * (N + 1)\ndp[0] = 1\ndp[1] = 1\nfor i in range(2, N + 1):\n\tdp[i] = dp[i - 1] + dp[i - 2]\nans = dp[N]\nprint(ans)\n", "n = int(input())\ndp = [0] * (n + 1)\nfor i in range(n + 1):\n\tif i <= 1:\n\t\tdp[i] = 1\n\t\tcontinue\n\tdp[i] = dp[i - 1] + dp[i - 2]\nprint(dp[-1])\n", "f = [1, 1] + [0] * (int(input()) - 1)\nfor i in range(2, len(f)):\n\tf[i] = f[i - 2] + f[i - 1]\nprint(f[-1])\n", "n = int(input())\nf = [0] * 45\nf[0] = 1\nf[1] = 1\nfor i in range(2, 45):\n\tf[i] = f[i - 1] + f[i - 2]\nprint(f[n])\n", "n = int(input())\nfi = [1] * 50\nfor i in range(2, n + 1):\n\tfi[i] = fi[i - 1] + fi[i - 2]\nprint(fi[n])\n", "n = int(input())\ndp = [0 for i in range(1000)]\ndp[0] = 1\ndp[1] = 1\nfor i in range(2, n + 1):\n\tdp[i] = dp[i - 1] + dp[i - 2]\nprint(dp[n])\n", "n = int(input())\nfirst = 1\nsecond = 1\nfor i in range(n - 1):\n\t(second, first) = (first + second, second)\nprint(second)\n", "N = int(input())\ndp = [0] * (N + 1)\ndp[0] = 1\ndp[1] = 1\nfor i in range(2, N + 1):\n\tdp[i] = dp[i - 1] + dp[i - 2]\nprint(dp[N])\n", "n = int(input())\nf = [1, 1]\nfor i in range(2, n + 1):\n\tnf = f[i - 2] + f[i - 1]\n\tf.append(nf)\nprint(f[n])\n", "n = int(input())\ndp = [0] * (n + 10)\ndp[0] = 1\ndp[1] = 1\nfor i in range(2, n + 1):\n\tdp[i] = dp[i - 1] + dp[i - 2]\nprint(dp[n])\n", "n = int(input())\nfib = [1] * 45\nfor i in range(2, n + 1):\n\tfib[i] = fib[i - 1] + fib[i - 2]\nprint(fib[n])\n", "n = int(input())\ndp = [1, 1]\nfor i in range(n):\n\tdp.append(dp[i] + dp[i + 1])\nprint(dp[n])\n", "fib = [0 for i in range(45)]\nfib[0] = 1\nfib[1] = 1\nfor i in range(2, 45):\n\tfib[i] = fib[i - 1] + fib[i - 2]\na = int(input())\nprint(fib[a])\n", "def Fibonacci(n, a1, a2):\n\tif n < 1:\n\t\treturn a1\n\treturn Fibonacci(n - 1, a1 + a2, a1)\nprint(Fibonacci(int(input()), 1, 0))\n", "def fibonacci(n):\n\t(a, b) = (1, 0)\n\tfor _ in range(0, n):\n\t\t(a, b) = (b, a + b)\n\treturn b\nn = int(input())\nn += 1\nprint(fibonacci(n))\n", "n = int(input())\nf = [0] * 45\nf[0] = f[1] = 1\nfor i in range(2, 45):\n\tf[i] = f[i - 2] + f[i - 1]\nprint(f[n])\n", "(a, b) = (1, 1)\nn = int(input())\nfor i in range(n):\n\t(a, b) = (b, a + b)\nprint(a)\n", "n = int(input())\nx = 0\ny = 1\nif n == 0 or n == 1:\n\tprint(1)\nelse:\n\tfor i in range(n):\n\t\t(x, y) = (y, x + y)\n\t\ti += 1\n\tprint(y)\n", "n = int(input())\na = b = 1\nfor i in range(n):\n\t(a, b) = (b, a + b)\n\tn -= 1\nprint(a)\n", "n = int(input())\nfib = [0 for i in range(n + 1)]\nfib[0] = 1\nfib[1] = 1\nfor i in range(2, n + 1):\n\tfib[i] = fib[i - 1] + fib[i - 2]\nprint(fib[n])\n", "n = int(input())\nA = []\nA.append(1)\nA.append(1)\nif n > 1:\n\tfor i in range(n - 1):\n\t\tA.append(A[i] + A[i + 1])\nprint(A[n])\n", "def fib(n):\n\tn1 = n2 = tmp = 1\n\tfor _ in range(n - 1):\n\t\ttmp = n1 + n2\n\t\t(n1, n2) = (n2, tmp)\n\treturn tmp\nprint(fib(int(input())))\n", "n = int(input())\ndp = [1, 1]\nfor i in range(2, n + 1):\n\tdp = [dp[-1], dp[-2] + dp[-1]]\nprint(dp[-1])\n", "n = int(input())\na = 1\nb = 1\nif n == 0 or n == 1:\n\tprint(1)\nelse:\n\tfor i in range(n):\n\t\t(a, b) = (b, a + b)\n\tprint(a)\n", "N = int(input())\ndp = {0: 1, 1: 1}\nfor i in range(1, N):\n\tdp[i + 1] = dp[i] + dp[i - 1]\nprint(dp[N])\n", "N = int(input())\nL = [0 for i in range(N + 1)]\nL[0] = 1\nL[1] = 1\nfor i in range(2, N + 1):\n\tL[i] = L[i - 1] + L[i - 2]\nprint(L[N])\n", "n = int(input())\nL = [0] * (n + 1)\nL[0] = 1\nL[1] = 1\nfor i in range(n - 1):\n\tL[i + 2] = L[i + 1] + L[i]\nprint(L[n])\n", "F = [1, 1] + [0 for i in range(44)]\nn = int(input())\nfor i in range(2, n + 1):\n\tF[i] = F[i - 1] + F[i - 2]\nprint(F[n])\n", "n = int(input())\nmemo = [1] * (n + 1)\nfor i in range(2, n + 1):\n\tmemo[i] = memo[i - 1] + memo[i - 2]\nprint(memo[n])\n", "n = int(input())\ndp = [0] * 45\ndp[:2] = [1, 1]\nfor i in range(2, n + 1):\n\tdp[i] = dp[i - 2] + dp[i - 1]\nprint(dp[n])\n", "def fib(n):\n\tx = [1, 1]\n\tfor i in range(n + 1):\n\t\ty = x[i] + x[i + 1]\n\t\tx.append(y)\n\tprint(x[n])\na = int(input())\nfib(a)\n", "n = int(input())\nif n < 2:\n\tans = 1\nelse:\n\ta = [0] * (n + 1)\n\ta[0] = a[1] = 1\n\tfor i in range(2, n + 1):\n\t\ta[i] = a[i - 1] + a[i - 2]\n\tans = a[n]\nprint(ans)\n", "n = int(input())\ndp = [None] * (n + 1)\ndp[0] = 1\ndp[1] = 1\nfor i in range(n - 1):\n\tdp[i + 2] = dp[i] + dp[i + 1]\nprint(dp[n])\n", "y = [1, 1]\nn = int(input())\nfor i in range(0, n):\n\ta = y[i] + y[i + 1]\n\ty.append(a)\nprint(y[n])\n", "def fibonacci(n):\n\ta = [1] * 2\n\tfor i in range(2, n + 1):\n\t\t(a[0], a[1]) = (a[1], a[0] + a[1])\n\treturn a[1]\nn = int(input())\nprint(fibonacci(n))\n", "n = int(input())\npair = [1, 1]\nfor i in range(n - 1):\n\tpair[i % 2] = sum(pair)\nprint(pair[n % 2])\n", "n = int(input())\na = 1\nb = 1\nlist = []\nlist.append(a)\nlist.append(b)\nfor i in range(n):\n\tlist.append(a + b)\n\td = b\n\tb += a\n\ta = d\nprint(list[n])\n", "n = int(input())\nfib = [-1] * (n + 1)\nfib[0] = 1\nfib[1] = 1\nfor i in range(2, n + 1):\n\tfib[i] = fib[i - 1] + fib[i - 2]\nprint(fib[n])\n", "n = int(input())\nfib = [1, 1]\nfor i in range(2, n + 1):\n\ta = fib[i - 1] + fib[i - 2]\n\tfib.append(a)\nprint(fib[n])\n", "Fib = [1, 1]\nfor i in range(2, 45):\n\tFib.append(Fib[i - 1] + Fib[i - 2])\nn = int(input())\nprint(Fib[n])\n", "n = int(input())\na = 1\nb = 1\ni = 0\nwhile i < n:\n\t(a, b) = (b, a + b)\n\ti += 1\nprint(a)\n", "fib = [0 for i in range(45)]\nfib[0] = 1\nfib[1] = 1\nfor i in range(2, 45):\n\tfib[i] = fib[i - 1] + fib[i - 2]\nn = int(input())\nprint(fib[n])\n", "n = int(input())\nfib = [1, 1]\nfor i in range(2, n + 1):\n\tfib.append(fib[i - 2] + fib[i - 1])\nprint(fib[n])\n", "def resolve():\n\tn = int(input())\n\tA = [1, 1]\n\tfor i in range(1, 45):\n\t\tA.append(A[i - 1] + A[i])\n\tprint(A[n])\nresolve()\n", "n = int(input())\n\ndef Fib(n):\n\t(a, b) = (0, 1)\n\tfor i in range(n):\n\t\t(a, b) = (b, a + b)\n\treturn b\nprint(Fib(n))\n", "N = int(input())\nfib = [1, 1]\nfor i in range(50):\n\tfib.append(fib[-1] + fib[-2])\nprint(fib[N])\n", "n = int(input())\nn_L = [0] * 45\nn_L[0] = 1\nn_L[1] = 1\nfor i in range(2, n + 1):\n\tn_L[i] = n_L[i - 1] + n_L[i - 2]\nprint(n_L[n])\n", "N = int(input())\none = 1\ntwo = 1\nfor _ in range(1, N):\n\tthree = one + two\n\t(one, two) = (two, three)\nprint(two)\n", "n = int(input())\nfib = [1, 1]\nfor i in range(2, n + 1):\n\tfib.append(fib[i - 1] + fib[i - 2])\nprint(fib.pop())\n", "N = int(input())\nDP = []\nDP.append(1)\nDP.append(1)\nfor i in range(2, N + 1):\n\tDP.append(DP[i - 1] + DP[i - 2])\nprint(DP[N])\n", "(p, n) = ((1 + 5 ** 0.5) / 2, int(input()) + 1)\nprint(int((p ** n - (1 - p) ** n) / 5 ** 0.5))\n", "x = [1, 1]\nn = int(input())\nfor i in range(n):\n\ta = [x[i] + x[i + 1]]\n\tx = x + a\nprint(x[n])\n", "n = int(input())\nn += 1\ndp = [1 for i in range(n)]\nfor i in range(2, n):\n\tdp[i] = dp[i - 1] + dp[i - 2]\nprint(dp[-1])\n", "n = int(input())\nf = [0] * (n + 3)\n(f[0], f[1]) = (1, 1)\nfor i in range(2, n + 1):\n\tf[i] = f[i - 1] + f[i - 2]\nprint(f[n])\n", "n = int(input())\na = [1, 1]\nif n < 2:\n\tpass\nelse:\n\tfor i in range(2, n + 1):\n\t\ta.append(a[i - 1] + a[i - 2])\nprint(a[n])\n", "n = int(input())\ndp = [0] * (n + 1)\ndp[0] = 1\ndp[1] = 1\nfor i in range(2, n + 1):\n\tdp[i] = dp[i - 1] + dp[i - 2]\nans = dp[n]\nprint(ans)\n", "n = int(input())\nnum = [1, 1]\nfor i in range(43):\n\tb = num[-2] + num[-1]\n\tnum.append(b)\nprint(num[n])\n", "N = int(input())\na = b = 1\nwhile N:\n\t(a, b) = (b, a + b)\n\tN -= 1\nprint(a)\n", "n = int(input())\ndp = [0 for _ in range(46)]\ndp[0] = 1\ndp[1] = 1\nfor i in range(2, 46):\n\tdp[i] = dp[i - 1] + dp[i - 2]\nprint(dp[n])\n", "lis = []\nlis.append(1)\nlis.append(1)\nfor i in range(2, 45):\n\tlis.append(lis[i - 1] + lis[i - 2])\nn = int(input())\nprint(lis[n])\n", "fib = [1 for i in range(45)]\nfor i in range(2, 45):\n\tfib[i] = fib[i - 1] + fib[i - 2]\ni = int(input())\nprint(fib[i])\n", "n = int(input())\nA = [1] * (n + 1)\nfor i in range(2, n + 1):\n\tA[i] = A[i - 1] + A[i - 2]\nprint(A[-1])\n", "n = int(input())\nmap = [0] * (n + 1)\nmap[0] = 1\nmap[1] = 1\nfor i in range(n - 1):\n\tmap[i + 2] = map[i] + map[i + 1]\nprint(map[-1])\n", "n = int(input())\nfib = [0] * (n + 1)\nfib[0] = 1\nif n > 0:\n\tfib[1] = 1\nfor i in range(2, n + 1):\n\tfib[i] = fib[i - 1] + fib[i - 2]\nprint(fib[n])\n", "n = int(input())\ndp = [0] * (n + 1)\ndp[0] = 1\nfor i in range(1, n + 1):\n\tif i == 1:\n\t\tdp[1] = 1\n\telse:\n\t\tdp[i] = dp[i - 1] + dp[i - 2]\nprint(dp[n])\n", "ii = lambda : int(input())\nn = ii()\nw = [-1] * 50\nw[0] = 1\nw[1] = 1\nfor i in range(2, 45):\n\tw[i] = w[i - 2] + w[i - 1]\nprint(w[n])\n", "n = int(input())\ndp = [0] * (n + 1)\ndp[0] = 1\ndp[1] = 1\nfor i in range(1, n):\n\ti += 1\n\tdp[i] = dp[i - 1] + dp[i - 2]\nprint(dp[n])\n", "def fib2(n):\n\t(a1, a2) = (1, 0)\n\twhile n > 0:\n\t\t(a1, a2) = (a1 + a2, a1)\n\t\tn -= 1\n\treturn a1\nn = int(input())\nprint(fib2(n))\n", "n = int(input())\n\ndef fib(n):\n\t(a, b) = (1, 1)\n\tfor i in range(n):\n\t\t(a, b) = (b, a + b)\n\treturn a\nprint(fib(n))\n", "n = int(input())\nF = [None for i in range(n + 1)]\n(F[0], F[1]) = (1, 1)\nfor i in range(2, n + 1):\n\tF[i] = F[i - 2] + F[i - 1]\nanswer = F[n]\nprint(answer)\n", "a = int(input())\nn = [0] * (a + 1)\nn[0] = 1\nn[1] = 1\nfor i in range(2, a + 1):\n\tn[i] = n[i - 1] + n[i - 2]\nprint(n[a])\n", "n = int(input())\nif n < 2:\n\tprint(1)\nelse:\n\t(a, b) = (1, 1)\n\tfor k in range(n - 1):\n\t\t(a, b) = (b, a + b)\n\tprint(b)\n", "n = int(input())\ndp = [0] * (n + 1)\ndp[0] = 1\ndp[1] = 1\nfor i in range(2, n + 1):\n\tdp[i] = dp[i - 1] + dp[i - 2]\nprint(dp[n])\n", "n = int(input())\nfib = []\nfib.append(1)\nfib.append(1)\nfor i in range(2, n + 1):\n\tfib.append(fib[i - 1] + fib[i - 2])\nprint(fib[n])\n", "n = int(input())\nif n == 0 or n == 1:\n\tprint(1)\nelse:\n\ta = [1, 1]\n\tfor i in range(n - 1):\n\t\ta.append(a[i] + a[i + 1])\n\tprint(a[-1])\n", "n = int(input()) + 1\na = [0] * n\na[0] = 1\na[1] = 1\nfor i in range(2, n):\n\ta[i] = a[i - 1] + a[i - 2]\nprint(a[-1])\n", "a = b = 1\nfor i in range(int(input())):\n\t(a, b) = (b, a + b)\nprint(a)\n", "n = int(input())\na = 1\nb = 1\nfor i in range(n - 1):\n\tc = b\n\tb += a\n\ta = c\nprint(b)\n", "n = int(input())\ndp = [0 for _ in range(n + 1)]\n(dp[0], dp[1]) = (1, 1)\nfor i in range(2, n + 1):\n\tdp[i] = dp[i - 1] + dp[i - 2]\nprint(dp[n])\n", "F = [1, 1]\nfor i in range(2, int(input()) + 1):\n\tF.append(F[i - 2] + F[i - 1])\nprint(F[-1])\n", "n = int(input())\nF = [1, 1]\nfor i in range(2, n + 1):\n\tF.append(F[i - 1] + F[i - 2])\nprint(F[n])\n", "n = int(input())\ndp = [0 for _ in range(45)]\n(dp[0], dp[1]) = (1, 1)\nfor i in range(2, n + 1):\n\tdp[i] = dp[i - 1] + dp[i - 2]\nprint(dp[n])\n", "n = int(input())\ndp = [1] * 50\nfor i in range(n):\n\tdp[i + 2] = dp[i + 1] + dp[i]\nprint(dp[n])\n", "fib = [0] * 45\nfib[0] = 1\nfib[1] = 1\nfor i in range(2, 45):\n\tfib[i] = fib[i - 1] + fib[i - 2]\nn = int(input())\nprint(fib[n])\n", "n = int(input())\nF = [1] * 50\nfor i in range(2, n + 1):\n\tF[i] = F[i - 1] + F[i - 2]\nprint(F[n])\n", "N = int(input())\ndp = [1] * (N + 1)\nfor i in range(2, N + 1):\n\tdp[i] = dp[i - 2] + dp[i - 1]\nprint(dp[N])\n", "N = int(input())\nDP = [0 for _ in range(N + 1)]\nDP[0] = 1\nDP[1] = 1\nfor i in range(2, N + 1):\n\tDP[i] = DP[i - 1] + DP[i - 2]\nprint(DP[N])\n", "N = int(input())\nfib = [None] * (N + 1)\nfib[0] = 1\nfib[1] = 1\nfor i in range(2, N + 1):\n\tfib[i] = fib[i - 1] + fib[i - 2]\nprint(fib[N])\n", "n = int(input())\ndp = [0] * (n + 1)\ndp[0] = 1\ndp[1] = 1\nfor i in range(2, n + 1, 1):\n\tdp[i] = dp[i - 2] + dp[i - 1]\nprint(dp[n])\n", "def fib(n):\n\t(f0, f1) = (1, 1)\n\tfor i in range(n - 1):\n\t\t(f1, f0) = (f1 + f0, f1)\n\treturn f1\nn = int(input())\nans = fib(n)\nprint(ans)\n", "n = int(input())\ndp = [-1] * (n + 1)\n(dp[0], dp[1]) = (1, 1)\nfor i in range(2, n + 1):\n\tdp[i] = dp[i - 1] + dp[i - 2]\nprint(dp[-1])\n", "import sys\ninput = sys.stdin.readline\nn = int(input())\nfib = [1] * 50\nfor i in range(2, n + 1):\n\tfib[i] = fib[i - 1] + fib[i - 2]\nprint(fib[n])\n", "def g(a, b, n):\n\tif n == 1:\n\t\treturn a\n\telse:\n\t\treturn g(a + b, a, n - 1)\nprint(g(1, 1, int(input())))\n", "n = int(input())\ndp = [0] * 50\ndp[0] = 1\ndp[1] = 1\nfor i in range(2, 46):\n\tdp[i] = dp[i - 1] + dp[i - 2]\nprint(dp[n])\n", "a = [0] * 45\na[0] = 1\na[1] = 1\nfor i in range(2, 45):\n\ta[i] = a[i - 1] + a[i - 2]\nn = int(input())\nprint(a[n])\n", "n = int(input())\na = list(range(n + 1))\na[0] = 1\ni = 2\nwhile i <= n:\n\ta[i] = a[i - 1] + a[i - 2]\n\ti += 1\nprint(a[n])\n", "n = int(input())\na = b = 1\nwhile range(n):\n\t(a, b) = (b, a + b)\n\tn -= 1\nprint(a)\n", "n = int(input())\n\ndef func(n):\n\tfib = [1, 1]\n\tfor i in range(2, n):\n\t\tfib.append(fib[i - 2] + fib[i - 1])\n\treturn fib[n - 1]\nprint(func(n + 1))\n", "a = 1\nb = 1\nc = []\nn = int(input())\nc.append(a)\nc.append(b)\nfor i in range(n):\n\tc.append(a + b)\n\td = b\n\tb += a\n\ta = d\nprint(c[n])\n", "def g(a, b, n):\n\tif n == 1:\n\t\treturn a\n\telse:\n\t\treturn g(a + b, a, n - 1)\n(a, b) = (1, 1)\nprint(g(a, b, int(input())))\n", "n = int(input())\nx = 1\ny = 1\na = 0\nfor i in range(n):\n\ta = x\n\tx = y\n\ty = a + y\nprint(x)\n", "n = int(input())\na = 1\nb = 1\nwhile n:\n\t(a, b) = (b, a + b)\n\tn -= 1\nprint(a)\n", "n = int(input())\n(a, b) = (1, 1)\nif n <= 1:\n\tprint(1)\nelse:\n\tfor i in range(n - 1):\n\t\tc = a + b\n\t\ta = b\n\t\tb = c\n\tprint(c)\n", "n = int(input()) - 1\nfib = [1, 2]\nfor i in range(2, n + 1):\n\tfib.append(fib[i - 1] + fib[i - 2])\nprint(fib[n])\n", "c = int(input())\n(a, b) = (1, 1)\nwhile c - 1:\n\t(a, b) = (a + b, a)\n\tc -= 1\nprint(a)\n", "a = b = 1\nfor _ in range(int(input())):\n\t(a, b) = (b, a + b)\nprint(a)\n", "n = int(input())\nfib = [1, 1]\nfor i in range(2, n + 1):\n\tfib.append(fib[i - 1] + fib[i - 2])\nprint(fib[n])\n", "num = int(input())\narr = [0] * 45\narr[0] = 1\narr[1] = 1\nfor i in range(2, 45):\n\tarr[i] = arr[i - 1] + arr[i - 2]\nprint(arr[num])\n", "n = int(input())\ndp = [0] * (n + 1)\n(dp[0], dp[1]) = (1, 1)\nfor i in range(2, n + 1):\n\tdp[i] = dp[i - 1] + dp[i - 2]\nprint(dp[n])\n", "n = int(input())\ndp = [1] * 45\nfor i in range(2, n + 1):\n\tdp[i] = dp[i - 1] + dp[i - 2]\nprint(dp[n])\n", "N = int(input())\n(a, b) = (1, 1)\nn = 0\nwhile n <= N:\n\tif n > 1:\n\t\t(a, b) = (b, a + b)\n\tn += 1\nprint(b)\n", "n = int(input())\ntable = [0] * 45\ntable[0] = 1\ntable[1] = 1\nfor i in range(2, 45):\n\ttable[i] = table[i - 1] + table[i - 2]\nprint(table[n])\n", "def fib(a, b, c):\n\tif c == 0:\n\t\treturn a\n\treturn fib(a + b, a, c - 1)\nn = int(input().rstrip())\nprint(fib(0, 1, n + 1))\n", "n = int(input())\np = [1, 1]\nif n <= 1:\n\tprint(1)\n\texit(0)\nfor i in range(2, n + 1):\n\tp.append(p[i - 2] + p[i - 1])\nprint(p[n])\n", "N = int(input())\ndp = [0] * (N + 1)\n(dp[0], dp[1]) = (1, 1)\nfor i in range(2, N + 1):\n\tdp[i] = dp[i - 1] + dp[i - 2]\nprint(dp[N])\n", "n = int(input())\na = [1, 1]\nfor i in range(2, n + 1):\n\ta.append(a[i - 1] + a[i - 2])\nprint(a[n])\n", "f = [1, 1]\nfor _ in [0] * 43:\n\tf += [f[-2] + f[-1]]\nprint(f[int(input())])\n", "n = int(input())\nfib = []\nfib.append(1)\nfib.append(1)\nfor _ in range(n - 1):\n\tfib.append(fib[-2] + fib[-1])\nprint(fib[-1])\n", "a = [1] * 45\nfor i in range(2, len(a)):\n\ta[i] = a[i - 1] + a[i - 2]\nprint(a[int(input())])\n", "f = [1, 1]\nfor _ in range(2, 45):\n\tf += [sum(f[-2:])]\nprint(f[int(input())])\n", "N = int(input())\ndp = [-1] * (N + 1)\ndp[0] = 1\ndp[1] = 1\nfor i in range(2, N + 1):\n\tdp[i] = dp[i - 2] + dp[i - 1]\nprint(dp[-1])\n", "n = int(input())\na = []\na.append(1)\na.append(1)\nmamo = [-1] * 100\nfor i in range(2, n + 1):\n\ta.append(a[i - 1] + a[i - 2])\n\ti += 1\nprint(a[n])\n", "n = int(input())\n\ndef fib(n):\n\ta = 1\n\tb = 1\n\tfor _ in range(n):\n\t\t(a, b) = (b, a + b)\n\treturn a\nprint(fib(n))\n", "n = int(input())\nfib = [0] * (n + 1)\nfib[0] = 1\nfib[1] = 1\nfor ni in range(2, n + 1):\n\tfib[ni] = fib[ni - 1] + fib[ni - 2]\nprint(fib[n])\n", "n = int(input())\ndp = [0] * 45\n(dp[0], dp[1]) = (1, 1)\nfor i in range(2, 44 + 1):\n\tdp[i] += dp[i - 1] + dp[i - 2]\nprint(dp[n])\n", "fib = [0] * 45\nfib[0] = 1\nfib[1] = 1\nfor i in range(2, 45):\n\tfib[i] = fib[i - 1] + fib[i - 2]\nx = int(input())\nprint(fib[x])\n", "n = int(input())\nfib = [1, 1]\nif n >= 2:\n\tfor i in range(2, n + 1):\n\t\tfib.append(fib[i - 1] + fib[i - 2])\nprint(fib[n])\n", "import sys\ninput = sys.stdin.readline\nn = int(input())\nF = [1] * 50\nfor i in range(2, n + 1):\n\tF[i] = F[i - 1] + F[i - 2]\nprint(F[n])\n", "n = int(input())\na = 1\nb = 1\nmlist = []\nmlist.append(a)\nmlist.append(b)\nfor i in range(n):\n\tmlist.append(a + b)\n\td = b\n\tb += a\n\ta = d\nprint(mlist[n])\n", "import sys\nn = int(input())\nif n < 2:\n\tprint(1)\n\tsys.exit()\n(a, b) = (1, 1)\nfor i in range(n - 1):\n\tc = a + b\n\ta = b\n\tb = c\nprint(c)\n", "n = int(input())\nf = [1, 1]\nfor i in range(n - 1):\n\tfib = f[-1] + f[-2]\n\tf.append(fib)\nprint(f[n])\n", "n = int(input())\nA = [0 for _ in range(n + 1)]\nA[0] = A[1] = 1\nfor i in range(2, n + 1):\n\tA[i] = A[i - 1] + A[i - 2]\nprint(A[n])\n", "n = int(input())\nfib = [1] * 2 + [0] * (n - 1)\nfor i in range(2, n + 1):\n\tfib[i] = fib[i - 1] + fib[i - 2]\nprint(fib[n])\n", "inf = 10 ** 9 + 7\nmod = 10 ** 9 + 7\nn = int(input())\ndp = [1 for i in range(n + 1)]\nfor i in range(2, n + 1):\n\tdp[i] = dp[i - 2] + dp[i - 1]\nprint(dp[n])\n", "n = int(input())\nfor i in range(n + 1):\n\tif i == 0 or i == 1:\n\t\ta = b = c = 1\n\telse:\n\t\tc = a + b\n\t\t(a, b) = (b, c)\nprint(c)\n", "def f(n):\n\tf = [1, 1]\n\tfor i in range(2, n):\n\t\tf.append(f[i - 2] + f[i - 1])\n\treturn f[n - 1]\nn = int(input())\nprint(f(n + 1))\n", "n = int(input())\nfib = [1] * (n + 1)\nfor i in range(2, n + 1):\n\tfib[i] = fib[i - 1] + fib[i - 2]\nprint(fib[n])\n", "n = int(input())\nfib = [0] * 45\nfib[0] = 1\nfib[1] = 1\nfor i in range(2, 45):\n\tfib[i] = fib[i - 1] + fib[i - 2]\nprint(fib[n])\n", "n = int(input())\n(x, y) = (1, 1)\nfor i in range(n):\n\t(x, y) = (y, x + y)\nprint(x)\n", "def g(a, b, n):\n\tif n == 0 or n == 1:\n\t\treturn a\n\telse:\n\t\treturn g(a + b, a, n - 1)\n(a, b) = (1, 1)\nprint(g(a, b, int(input())))\n", "N = int(input())\ndp = [0] * (N + 1)\ndp[0] = 1\ndp[1] = 1\nfor i in range(2, N + 1):\n\tdp[i] = dp[i - 1] + dp[i - 2]\nprint(dp[-1])\n", "n = int(input())\nif n == 0 or n == 1:\n\tprint(1)\nelse:\n\tN = 1\n\tm = 1\n\tfor i in range(n - 1):\n\t\tM = N + m\n\t\tm = N\n\t\tN = M\n\tprint(M)\n", "res = [1, 1]\nn = int(input())\nfor i in range(n - 1):\n\tres.append(res[i] + res[i + 1])\nprint(res[n])\n", "fib = [1] * 100\nfor i in range(2, 100):\n\tfib[i] = fib[i - 1] + fib[i - 2]\nprint(fib[int(input())])\n", "n = int(input())\nDP = [0 for _ in range(n + 1)]\nDP[0] = 1\nDP[1] = 1\nfor i in range(2, n + 1):\n\tDP[i] = DP[i - 1] + DP[i - 2]\nprint(DP[n])\n", "a = 1\nb = 1\nlist = []\nn = int(input())\nlist.append(a)\nlist.append(b)\nfor i in range(n):\n\tlist.append(a + b)\n\tc = b\n\tb += a\n\ta = c\nprint(list[n])\n", "def fib(n):\n\tglobal F\n\tif n not in F:\n\t\tF[n] = fib(n - 1) + fib(n - 2)\n\treturn F[n]\nF = {0: 1, 1: 1}\nprint(fib(int(input())))\n", "n = int(input())\nF = [1, 1]\nfor i in range(n - 1):\n\tF.append(F[-2] + F[-1])\nprint(F[n])\n", "n = int(input())\nfib_ls = [0] * (n + 1)\n(fib_ls[0], fib_ls[1]) = (1, 1)\nfor i in range(2, n + 1):\n\tfib_ls[i] = fib_ls[i - 1] + fib_ls[i - 2]\nprint(fib_ls[n])\n", "n = int(input())\nf = [0] * (n + 1)\nf[0] = 1\nf[1] = 1\nfor i in range(2, n + 1):\n\tf[i] = f[i - 1] + f[i - 2]\nprint(f[n])\n", "n = int(input())\n(x, y) = (0, 1)\nfor i in range(1, n + 1):\n\t(x, y) = (y, x + y)\nprint(y)\n"], "verify": 1, "wrong_type": null} +{"prompt": "<\uff5cUser\uff5c>There are N integers written on a blackboard. The i-th integer is A_i, and the greatest common divisor of these integers is 1.\n\nTakahashi and Aoki will play a game using these integers. In this game, starting from Takahashi the two player alternately perform the following operation:\n\n* Select one integer on the blackboard that is not less than 2, and subtract 1 from the integer.\n* Then, divide all the integers on the black board by g, where g is the greatest common divisor of the integers written on the blackboard.\n\n\n\nThe player who is left with only 1s on the blackboard and thus cannot perform the operation, loses the game. Assuming that both players play optimally, determine the winner of the game.\n\nConstraints\n\n* 1 \u2266 N \u2266 10^5\n* 1 \u2266 A_i \u2266 10^9\n* The greatest common divisor of the integers from A_1 through A_N is 1.\n\nInput\n\nThe input is given from Standard Input in the following format:\n\n\nN\nA_1 A_2 \u2026 A_N\n\n\nOutput\n\nIf Takahashi will win, print `First`. If Aoki will win, print `Second`.\n\nExamples\n\nInput\n\n3\n3 6 7\n\n\nOutput\n\nFirst\n\n\nInput\n\n4\n1 2 4 8\n\n\nOutput\n\nFirst\n\n\nInput\n\n5\n7 8 8 8 8\n\n\nOutput\n\nSecond\n<\uff5cAssistant\uff5c>\n", "question": "There are N integers written on a blackboard. The i-th integer is A_i, and the greatest common divisor of these integers is 1.\n\nTakahashi and Aoki will play a game using these integers. In this game, starting from Takahashi the two player alternately perform the following operation:\n\n* Select one integer on the blackboard that is not less than 2, and subtract 1 from the integer.\n* Then, divide all the integers on the black board by g, where g is the greatest common divisor of the integers written on the blackboard.\n\n\n\nThe player who is left with only 1s on the blackboard and thus cannot perform the operation, loses the game. Assuming that both players play optimally, determine the winner of the game.\n\nConstraints\n\n* 1 \u2266 N \u2266 10^5\n* 1 \u2266 A_i \u2266 10^9\n* The greatest common divisor of the integers from A_1 through A_N is 1.\n\nInput\n\nThe input is given from Standard Input in the following format:\n\n\nN\nA_1 A_2 \u2026 A_N\n\n\nOutput\n\nIf Takahashi will win, print `First`. If Aoki will win, print `Second`.\n\nExamples\n\nInput\n\n3\n3 6 7\n\n\nOutput\n\nFirst\n\n\nInput\n\n4\n1 2 4 8\n\n\nOutput\n\nFirst\n\n\nInput\n\n5\n7 8 8 8 8\n\n\nOutput\n\nSecond", "query_id": "deepcoder_taco_b5ba27229515b9362e4ef106dccc7486", "id": "deepcoder_taco_b5ba27229515b9362e4ef106dccc7486", "starter_code": null, "input_output": "{\"inputs\": [\"3\\n3 9 7\", \"4\\n1 2 1 9\", \"4\\n1 2 1 8\", \"5\\n7 2 8 8 8\", \"3\\n3 9 9\", \"3\\n3 1 9\", \"4\\n1 2 1 13\", \"3\\n1 1 9\", \"4\\n1 2 1 22\", \"3\\n0 1 9\", \"4\\n1 1 1 22\", \"3\\n0 1 5\", \"4\\n0 1 1 22\", \"4\\n0 0 1 22\", \"4\\n0 -1 1 22\", \"4\\n0 0 1 1\", \"4\\n0 0 1 2\", \"4\\n1 0 1 2\", \"4\\n1 1 1 2\", \"4\\n1 1 1 4\", \"4\\n2 1 1 4\", \"4\\n2 1 1 1\", \"4\\n2 0 1 1\", \"4\\n2 0 0 1\", \"4\\n2 1 0 1\", \"4\\n2 1 0 0\", \"4\\n2 1 0 -1\", \"4\\n2 1 -1 0\", \"3\\n3 6 4\", \"5\\n7 8 4 8 8\", \"3\\n3 4 7\", \"4\\n1 2 1 5\", \"5\\n7 2 8 8 2\", \"3\\n3 9 2\", \"4\\n1 2 1 11\", \"3\\n3 1 3\", \"4\\n1 1 1 13\", \"3\\n1 2 9\", \"4\\n2 2 1 22\", \"3\\n0 1 15\", \"4\\n1 0 1 22\", \"4\\n-1 0 1 22\", \"4\\n0 -1 2 22\", \"4\\n0 1 1 1\", \"4\\n0 -1 1 2\", \"4\\n1 0 0 2\", \"4\\n1 1 2 2\", \"4\\n1 0 1 4\", \"4\\n3 1 1 4\", \"4\\n-1 1 1 1\", \"4\\n1 1 1 1\", \"4\\n0 0 0 1\", \"4\\n2 1 0 2\", \"4\\n4 1 0 0\", \"4\\n2 1 1 -1\", \"4\\n2 2 -1 0\", \"3\\n2 6 4\", \"5\\n7 8 4 8 11\", \"4\\n1 2 0 5\", \"5\\n7 2 8 12 2\", \"3\\n5 9 2\", \"4\\n1 2 0 11\", \"3\\n5 1 3\", \"4\\n1 1 1 14\", \"3\\n1 4 9\", \"4\\n1 2 1 19\", \"4\\n-1 0 1 33\", \"4\\n-1 0 1 23\", \"4\\n1 -1 2 22\", \"4\\n-1 1 1 2\", \"4\\n-1 -1 1 2\", \"4\\n1 -1 0 2\", \"4\\n1 1 2 4\", \"4\\n1 0 0 4\", \"4\\n2 1 1 8\", \"4\\n-1 1 0 1\", \"4\\n0 2 1 1\", \"4\\n0 -1 0 1\", \"4\\n0 1 0 2\", \"4\\n3 1 0 0\", \"4\\n2 0 -1 0\", \"3\\n4 6 4\", \"5\\n7 8 4 5 11\", \"4\\n1 2 -1 5\", \"5\\n7 2 10 12 2\", \"3\\n5 10 2\", \"3\\n5 2 3\", \"4\\n1 1 1 3\", \"3\\n1 0 9\", \"4\\n1 0 1 19\", \"4\\n-1 1 1 33\", \"4\\n-1 0 1 28\", \"4\\n0 -1 2 3\", \"4\\n-1 1 2 2\", \"4\\n-1 -1 1 0\", \"4\\n1 -1 1 2\", \"4\\n2 1 2 4\", \"4\\n1 -1 0 4\", \"4\\n2 1 1 15\", \"4\\n-1 2 0 1\", \"3\\n3 6 7\", \"4\\n1 2 4 8\", \"5\\n7 8 8 8 8\"], \"outputs\": [\"Second\\n\", \"First\\n\", \"Second\\n\", \"Second\\n\", \"Second\\n\", \"Second\\n\", \"First\\n\", \"Second\\n\", \"Second\\n\", \"First\\n\", \"First\\n\", \"First\\n\", \"Second\\n\", \"First\\n\", \"Second\\n\", \"Second\\n\", \"First\\n\", \"Second\\n\", \"First\\n\", \"First\\n\", \"Second\\n\", \"First\\n\", \"Second\\n\", \"First\\n\", \"Second\\n\", \"First\\n\", \"Second\\n\", \"Second\\n\", \"Second\\n\", \"First\\n\", \"First\\n\", \"First\\n\", \"First\\n\", \"First\\n\", \"First\\n\", \"Second\\n\", \"Second\\n\", \"First\\n\", \"First\\n\", \"First\\n\", \"Second\\n\", \"Second\\n\", \"First\\n\", \"First\\n\", \"Second\\n\", \"First\\n\", \"Second\\n\", \"Second\\n\", \"First\\n\", \"Second\\n\", \"Second\\n\", \"First\\n\", \"First\\n\", \"First\\n\", \"First\\n\", \"First\\n\", \"First\\n\", \"First\\n\", \"Second\\n\", \"First\\n\", \"First\\n\", \"Second\\n\", \"Second\\n\", \"First\\n\", \"First\\n\", \"First\\n\", \"First\\n\", \"First\\n\", \"Second\\n\", \"First\\n\", \"First\\n\", \"Second\\n\", \"Second\\n\", \"First\\n\", \"Second\\n\", \"First\\n\", \"Second\\n\", \"Second\\n\", \"First\\n\", \"Second\\n\", \"First\\n\", \"First\\n\", \"Second\\n\", \"First\\n\", \"Second\\n\", \"Second\\n\", \"First\\n\", \"Second\\n\", \"First\\n\", \"First\\n\", \"Second\\n\", \"Second\\n\", \"Second\\n\", \"Second\\n\", \"First\\n\", \"First\\n\", \"First\\n\", \"Second\\n\", \"First\\n\", \"Second\\n\", \"First\", \"First\", \"Second\"], \"remote\": false}", "task": "code", "language": "PYTHON", "solutions": ["import sys\n\ndef I():\n\treturn int(sys.stdin.readline().rstrip())\n\ndef LI():\n\treturn list(map(int, sys.stdin.readline().rstrip().split()))\nN = I()\nA = LI()\nif N == 1:\n\tprint('Second')\n\texit()\nif N == 2:\n\tprint('First')\n\texit()\nfrom math import gcd\nr = 0\nwhile True:\n\tif (sum(A) - N) % 2 == 1:\n\t\tif r == 0:\n\t\t\tprint('First')\n\t\telse:\n\t\t\tprint('Second')\n\t\tbreak\n\telse:\n\t\ta = 0\n\t\tb = 0\n\t\tfor i in range(N):\n\t\t\tif A[i] % 2 == 1:\n\t\t\t\ta += 1\n\t\t\tif A[i] == 1:\n\t\t\t\tb += 1\n\t\tif a != 1 or b > 0:\n\t\t\tif r == 0:\n\t\t\t\tprint('Second')\n\t\t\telse:\n\t\t\t\tprint('First')\n\t\t\tbreak\n\t\telse:\n\t\t\tg = 0\n\t\t\tfor i in range(N):\n\t\t\t\tif A[i] % 2 == 1:\n\t\t\t\t\tg = gcd(g, A[i] - 1)\n\t\t\t\telse:\n\t\t\t\t\tg = gcd(g, A[i])\n\t\t\tfor i in range(N):\n\t\t\t\tA[i] = A[i] // g\n\t\t\tr = 1 - r\n", "import random\n\ndef tester(N=0):\n\tmaxno1 = 100000.0\n\tmaxno2 = 1000000000.0\n\ts = input()\n\tif s != '':\n\t\treturn s\n\tif N == 0:\n\t\treturn random.randint(2, maxno1)\n\telse:\n\t\tprint('Testing...')\n\t\tprint('N=', N)\n\t\tA = []\n\t\tfor i in range(N):\n\t\t\tA.extend([random.randint(1, maxno2)])\n\t\treturn ' '.join(list(map(str, A)))\nimport copy\n\ndef gcd(a, b):\n\twhile b:\n\t\t(a, b) = (b, a % b)\n\treturn a\n\ndef gcdm(x):\n\tg = x[0]\n\tfor i in range(1, len(x)):\n\t\tif g == 1:\n\t\t\treturn g\n\t\tg = gcd(g, x[i])\n\treturn g\n\ndef playmove(A, i):\n\tA[i] -= 1\n\tg = gcdm(A)\n\treturn [x // g for x in A]\n\ndef noofevens(A):\n\tr = 0\n\tfor i in A:\n\t\tif i % 2 == 0:\n\t\t\tr += 1\n\treturn r\nN = int(tester())\nA = [int(x) for x in tester(N).split()]\nisFirstmove = True\nwhile True:\n\te = noofevens(A)\n\tif e % 2 == 1:\n\t\tif isFirstmove:\n\t\t\tprint('First')\n\t\telse:\n\t\t\tprint('Second')\n\t\tbreak\n\telif N - e > 1:\n\t\tif isFirstmove:\n\t\t\tprint('Second')\n\t\telse:\n\t\t\tprint('First')\n\t\tbreak\n\telse:\n\t\tfor i in range(N):\n\t\t\tif A[i] % 2 == 1:\n\t\t\t\tbreak\n\t\tif A[i] == 1:\n\t\t\tif isFirstmove:\n\t\t\t\tprint('Second')\n\t\t\telse:\n\t\t\t\tprint('First')\n\t\t\tbreak\n\t\telse:\n\t\t\tA = playmove(A, i)\n\t\t\tisFirstmove = isFirstmove != True\n", "import sys\nreadline = sys.stdin.readline\nfrom functools import reduce\n\ndef gcd(a, b):\n\twhile b:\n\t\t(a, b) = (b, a % b)\n\treturn a\n\ndef calc(A):\n\tN = len(A)\n\tif N == 1:\n\t\treturn A[0] % 2 == 0\n\tK = sum((1 for a in A if a % 2 == 0))\n\tif K & 1:\n\t\treturn True\n\tif N - K != 1:\n\t\treturn False\n\tif min(A) == 1:\n\t\treturn False\n\tA = [a - a % 2 for a in A]\n\tg = reduce(gcd, A)\n\tA = [a // g for a in A]\n\treturn not calc(A)\nN = int(readline())\nA = list(map(int, readline().split()))\nprint('First' if calc(A) else 'Second')\n"], "verify": 1, "wrong_type": null} +{"prompt": "<\uff5cUser\uff5c>A game is played on a strip consisting of N cells consecutively numbered from 1 to N.\n\nAlice has her token on cell A. Borys has his token on a different cell B.\n\nPlayers take turns, Alice moves first. The moving player must shift his or her token from its current cell X to the neighboring cell on the left, cell X-1, or on the right, cell X+1. Note that it's disallowed to move the token outside the strip or to the cell with the other player's token. In one turn, the token of the moving player must be shifted exactly once.\n\nThe player who can't make a move loses, and the other player wins.\n\nBoth players want to win. Who wins if they play optimally?\n\nConstraints\n\n* 2 \\leq N \\leq 100\n* 1 \\leq A < B \\leq N\n* All input values are integers.\n\nInput\n\nInput is given from Standard Input in the following format:\n\n\nN A B\n\n\nOutput\n\nPrint `Alice` if Alice wins, `Borys` if Borys wins, and `Draw` if nobody wins.\n\nExamples\n\nInput\n\n5 2 4\n\n\nOutput\n\nAlice\n\n\nInput\n\n2 1 2\n\n\nOutput\n\nBorys\n\n\nInput\n\n58 23 42\n\n\nOutput\n\nBorys\n<\uff5cAssistant\uff5c>\n", "question": "A game is played on a strip consisting of N cells consecutively numbered from 1 to N.\n\nAlice has her token on cell A. Borys has his token on a different cell B.\n\nPlayers take turns, Alice moves first. The moving player must shift his or her token from its current cell X to the neighboring cell on the left, cell X-1, or on the right, cell X+1. Note that it's disallowed to move the token outside the strip or to the cell with the other player's token. In one turn, the token of the moving player must be shifted exactly once.\n\nThe player who can't make a move loses, and the other player wins.\n\nBoth players want to win. Who wins if they play optimally?\n\nConstraints\n\n* 2 \\leq N \\leq 100\n* 1 \\leq A < B \\leq N\n* All input values are integers.\n\nInput\n\nInput is given from Standard Input in the following format:\n\n\nN A B\n\n\nOutput\n\nPrint `Alice` if Alice wins, `Borys` if Borys wins, and `Draw` if nobody wins.\n\nExamples\n\nInput\n\n5 2 4\n\n\nOutput\n\nAlice\n\n\nInput\n\n2 1 2\n\n\nOutput\n\nBorys\n\n\nInput\n\n58 23 42\n\n\nOutput\n\nBorys", "query_id": "deepcoder_taco_b47de59fdd041aad05039e69af13d5e6", "id": "deepcoder_taco_b47de59fdd041aad05039e69af13d5e6", "starter_code": null, "input_output": "{\"inputs\": [\"9 23 42\", \"8 2 4\", \"3 1 2\", \"9 23 18\", \"3 2 2\", \"8 2 3\", \"9 23 23\", \"0 2 2\", \"7 2 3\", \"9 44 23\", \"0 2 1\", \"7 2 5\", \"9 44 0\", \"0 0 1\", \"7 1 5\", \"9 38 0\", \"0 1 1\", \"7 1 4\", \"1 38 0\", \"0 1 0\", \"8 1 4\", \"1 38 1\", \"0 2 0\", \"8 1 7\", \"1 31 0\", \"0 0 0\", \"8 0 7\", \"2 31 0\", \"1 0 0\", \"8 0 3\", \"2 31 -1\", \"2 0 0\", \"8 1 3\", \"2 53 -1\", \"0 0 -1\", \"3 1 3\", \"2 64 -1\", \"0 -1 -1\", \"0 1 3\", \"0 64 -1\", \"-1 -1 -1\", \"-1 1 3\", \"0 64 0\", \"0 -2 -1\", \"-1 1 4\", \"1 64 0\", \"0 1 -1\", \"-1 1 7\", \"1 103 0\", \"1 1 -1\", \"-1 1 13\", \"1 103 1\", \"2 1 -1\", \"-1 1 2\", \"2 103 1\", \"4 1 -1\", \"0 1 2\", \"2 79 1\", \"4 1 0\", \"0 1 4\", \"2 79 2\", \"7 1 0\", \"0 1 6\", \"2 56 2\", \"7 0 0\", \"0 0 6\", \"2 56 4\", \"11 0 0\", \"-1 0 6\", \"2 12 4\", \"11 -1 0\", \"-1 0 8\", \"2 21 4\", \"21 -1 0\", \"-1 0 10\", \"2 21 0\", \"21 -2 0\", \"0 0 10\", \"2 35 0\", \"21 -2 1\", \"1 0 -1\", \"2 35 1\", \"16 -2 1\", \"1 -1 -1\", \"2 30 1\", \"16 -3 1\", \"2 -1 -1\", \"2 11 1\", \"16 -3 2\", \"2 -1 -2\", \"0 11 1\", \"5 -3 2\", \"0 -1 -2\", \"-1 11 1\", \"7 -3 2\", \"0 -1 -3\", \"-1 22 1\", \"13 -3 2\", \"0 0 -3\", \"-1 22 2\", \"58 23 42\", \"2 1 2\", \"5 2 4\"], \"outputs\": [\"Borys\\n\", \"Alice\\n\", \"Borys\\n\", \"Borys\\n\", \"Alice\\n\", \"Borys\\n\", \"Alice\\n\", \"Alice\\n\", \"Borys\\n\", \"Borys\\n\", \"Borys\\n\", \"Borys\\n\", \"Alice\\n\", \"Borys\\n\", \"Alice\\n\", \"Alice\\n\", \"Alice\\n\", \"Borys\\n\", \"Alice\\n\", \"Borys\\n\", \"Borys\\n\", \"Borys\\n\", \"Alice\\n\", \"Alice\\n\", \"Borys\\n\", \"Alice\\n\", \"Borys\\n\", \"Borys\\n\", \"Alice\\n\", \"Borys\\n\", \"Alice\\n\", \"Alice\\n\", \"Alice\\n\", \"Alice\\n\", \"Borys\\n\", \"Alice\\n\", \"Borys\\n\", \"Alice\\n\", \"Alice\\n\", \"Borys\\n\", \"Alice\\n\", \"Alice\\n\", \"Alice\\n\", \"Borys\\n\", \"Borys\\n\", \"Alice\\n\", \"Alice\\n\", \"Alice\\n\", \"Borys\\n\", \"Alice\\n\", \"Alice\\n\", \"Alice\\n\", \"Alice\\n\", \"Borys\\n\", \"Alice\\n\", \"Alice\\n\", \"Borys\\n\", \"Alice\\n\", \"Borys\\n\", \"Borys\\n\", \"Borys\\n\", \"Borys\\n\", \"Borys\\n\", \"Alice\\n\", \"Alice\\n\", \"Alice\\n\", \"Alice\\n\", \"Alice\\n\", \"Alice\\n\", \"Alice\\n\", \"Borys\\n\", \"Alice\\n\", \"Borys\\n\", \"Borys\\n\", \"Alice\\n\", \"Borys\\n\", \"Alice\\n\", \"Alice\\n\", \"Borys\\n\", \"Borys\\n\", \"Borys\\n\", \"Alice\\n\", \"Borys\\n\", \"Alice\\n\", \"Borys\\n\", \"Alice\\n\", \"Alice\\n\", \"Alice\\n\", \"Borys\\n\", \"Borys\\n\", \"Alice\\n\", \"Borys\\n\", \"Borys\\n\", \"Alice\\n\", \"Borys\\n\", \"Alice\\n\", \"Borys\\n\", \"Borys\\n\", \"Borys\\n\", \"Alice\\n\", \"Borys\", \"Borys\", \"Alice\"], \"remote\": false}", "task": "code", "language": "PYTHON", "solutions": ["(n, a, b) = map(int, input().split())\nprint('Alice' if (a - b - 1) % 2 == 1 else 'Borys')\n", "(n, a, b) = map(int, input().split())\nprint('Alice' if abs(a - b) % 2 == 0 else 'Borys')\n", "(N, A, B) = map(int, input().split())\nprint('Borys' if (B - A) % 2 else 'Alice')\n", "print('ABloircyes'[eval(input()[2:].replace(' ', '-')) % 2::2])\n", "(n, a, b) = map(int, input().split())\nans = ['Alice', 'Borys']\nprint(ans[(a - b) % 2])\n", "(N, A, B) = [int(x) for x in input().split()]\nprint('Alice' if (B - A - 1) % 2 == 1 else 'Borys')\n", "(n, a, b) = map(int, input().split())\nprint('ABloircyes'[(b - a) % 2::2])\n", "(_, A, B) = list(map(int, input().split()))\nx = A + B\nprint('Alice' if x % 2 == 0 else 'Borys')\n", "(N, A, B) = map(int, input().split(' '))\nprint(['Alice', 'Borys'][(B - A) % 2])\n", "(a, b, c) = map(int, input().split(' '))\nif (c - b) % 2 == 0:\n\tprint('Alice')\nelse:\n\tprint('Borys')\n", "(N, A, B) = map(int, input().split())\nans = 'Alice'\nif (B - A) % 2 == 1:\n\tans = 'Borys'\nprint(ans)\n", "(N, A, B) = map(int, input().split())\nC = B - A - 1\nif C % 2 != 0:\n\tprint('Alice')\nelse:\n\tprint('Borys')\n", "(n, x, y) = map(int, input().split())\nif (x - y) % 2 == 0:\n\tprint('Alice')\nelse:\n\tprint('Borys')\n", "(a, b, c) = map(int, input().split())\nprint('ABloircyes'[(c - b) % 2::2])\n", "(a, b, c) = [int(i) for i in input().split()]\nif abs(b - c) % 2 == 0:\n\tprint('Alice')\nelse:\n\tprint('Borys')\n", "a = list(map(int, input().split()))\nprint('ABloircyes'[(a[2] - a[1]) % 2::2])\n", "(n, a, b) = (int(x) for x in input().split())\nprint('Alice' if (b - a) % 2 == 0 else 'Borys')\n", "(n, a, b) = list(map(int, input().split()))\nans = 'Borys' if (b - a) % 2 else 'Alice'\nprint(ans)\n", "(_, a, b) = map(int, input().split())\nprint(['Alice', 'Borys'][abs(b - a) % 2 == 1])\n", "(a, b, c) = map(int, input().split())\nprint('Alice' if abs(b - c) % 2 != 1 else 'Borys')\n", "(n, a, b) = [int(i) for i in input().split()]\nprint(['Alice', 'Borys'][(a - b) % 2])\n", "(_, a, b) = map(int, input().split())\nprint(['Alice', 'Borys'][(b - a) % 2])\n", "(N, A, B) = map(int, input().split())\nprint('ABloircyes'[(B - A) % 2::2])\n", "(_, A, B) = map(int, input().split())\nprint('ABloircyes'[(B - A) % 2::2])\n", "(N, A, B) = [int(i) for i in input().split()]\nprint('Alice') if (A - B) % 2 == 0 else print('Borys')\n", "(N, a, b) = map(int, input().split())\nif (b - a - 1) % 2 == 0:\n\tprint('Borys')\nelse:\n\tprint('Alice')\n", "(n, a, b) = list(map(int, input().split()))\nprint(['Alice', 'Borys'][(b - a) % 2])\n", "(_, x, y) = map(int, input().split())\nprint(['Alice', 'Borys'][(x - y) % 2])\n", "(_, A, B) = map(int, input().split())\nprint(['Alice', 'Borys'][(B - A) % 2])\n", "(_, a, b) = map(int, open(0).read().split())\nprint('Borys' if (a - b) % 2 else 'Alice')\n", "(x, y, z) = list(map(int, input().split()))\nk = y - z\nif k % 2 == 1:\n\tprint('Borys')\nelse:\n\tprint('Alice')\n", "print('ABloircyes'[sum(map(int, input().split()[1:])) % 2::2])\n", "(N, A, B) = map(int, input().split())\nprint(('Borys', 'Alice')[(abs(A - B) - 1) % 2 == 1])\n"], "verify": 1, "wrong_type": null} +{"prompt": "<\uff5cUser\uff5c>problem\n\nThere are $ N $ propositions, named $ 1, 2, \\ cdots, N $, respectively. Also, $ M $ information about the propositions is given. The $ i $ th information is \"$ a_i $$\". Given in the form \"b_i $\", which means that $ a_i $ is $ b_i $. (\"If\" is a logical conditional and the transition law holds.) $ For each proposition $ i $ Output all propositions that have the same value as i $ in ascending order. However, proposition $ i $ and proposition $ i $ are always the same value. Proposition $ X $ and proposition $ Y $ have the same value as \"$ if $ X $\". It means \"Y $\" and \"$ X $ if $ Y $\".\n\n\n\noutput\n\nOn the $ i $ line, output all propositions that have the same value as the proposition $ i $, separated by blanks in ascending order. Also, output a line break at the end of each line.\n\nExample\n\nInput\n\n5 2\n1 2\n2 1\n\n\nOutput\n\n1 2\n1 2\n3\n4\n5\n<\uff5cAssistant\uff5c>\n", "question": "problem\n\nThere are $ N $ propositions, named $ 1, 2, \\ cdots, N $, respectively. Also, $ M $ information about the propositions is given. The $ i $ th information is \"$ a_i $$\". Given in the form \"b_i $\", which means that $ a_i $ is $ b_i $. (\"If\" is a logical conditional and the transition law holds.) $ For each proposition $ i $ Output all propositions that have the same value as i $ in ascending order. However, proposition $ i $ and proposition $ i $ are always the same value. Proposition $ X $ and proposition $ Y $ have the same value as \"$ if $ X $\". It means \"Y $\" and \"$ X $ if $ Y $\".\n\n\n\noutput\n\nOn the $ i $ line, output all propositions that have the same value as the proposition $ i $, separated by blanks in ascending order. Also, output a line break at the end of each line.\n\nExample\n\nInput\n\n5 2\n1 2\n2 1\n\n\nOutput\n\n1 2\n1 2\n3\n4\n5", "query_id": "deepcoder_taco_446be192f6149c3de610a3122f476230", "id": "deepcoder_taco_446be192f6149c3de610a3122f476230", "starter_code": null, "input_output": "{\"inputs\": [\"5 2\\n1 2\\n2 2\", \"3 2\\n1 2\\n2 1\", \"3 2\\n1 2\\n3 1\", \"6 2\\n1 2\\n2 1\", \"9 2\\n1 2\\n2 4\", \"6 2\\n1 1\\n2 1\", \"17 2\\n1 2\\n2 4\", \"4 2\\n2 2\\n3 1\", \"26 2\\n2 2\\n4 4\", \"12 2\\n2 2\\n4 4\", \"1 0\\n2 2\\n1 1\", \"17 2\\n1 2\\n2 1\", \"7 1\\n2 1\\n3 0\", \"28 2\\n1 2\\n4 6\", \"18 2\\n1 3\\n2 2\", \"10 2\\n1 2\\n4 6\", \"48 0\\n4 3\\n8 4\", \"94 0\\n4 3\\n0 3\", \"178 0\\n4 3\\n0 3\", \"8 1\\n4 1\\n1 -1\", \"211 0\\n4 2\\n0 3\", \"9 2\\n1 2\\n2 1\", \"42 2\\n2 2\\n4 4\", \"2 0\\n1 3\\n1 1\", \"11 1\\n1 2\\n2 1\", \"2 2\\n1 2\\n2 1\", \"5 2\\n2 3\\n3 2\", \"20 2\\n1 1\\n4 6\", \"14 2\\n1 2\\n2 1\", \"39 0\\n2 2\\n8 0\", \"22 2\\n1 1\\n4 6\", \"24 0\\n4 5\\n8 4\", \"14 0\\n7 1\\n-1 0\", \"15 1\\n1 1\\n0 -1\", \"21 1\\n1 1\\n2 1\", \"38 1\\n2 2\\n6 4\", \"77 0\\n2 2\\n8 0\", \"75 0\\n5 2\\n0 3\", \"13 0\\n2 0\\n6 4\", \"65 0\\n5 2\\n0 3\", \"25 1\\n2 1\\n1 4\", \"40 0\\n4 0\\n14 3\", \"5 2\\n1 2\\n2 4\", \"5 2\\n2 2\\n2 1\", \"3 2\\n1 2\\n2 2\", \"5 2\\n1 2\\n3 1\", \"5 2\\n2 2\\n3 1\", \"17 2\\n1 2\\n4 4\", \"4 1\\n2 2\\n3 1\", \"5 2\\n2 2\\n2 2\", \"5 2\\n2 2\\n1 1\", \"3 0\\n1 2\\n3 1\", \"6 2\\n1 2\\n2 2\", \"17 2\\n1 4\\n2 4\", \"17 2\\n2 2\\n4 4\", \"4 1\\n2 2\\n3 0\", \"5 2\\n2 2\\n3 2\", \"5 0\\n2 2\\n1 1\", \"3 0\\n1 1\\n3 1\", \"6 2\\n1 3\\n2 2\", \"4 1\\n2 1\\n3 0\", \"5 2\\n2 2\\n1 2\", \"6 0\\n2 2\\n1 1\", \"6 2\\n1 3\\n2 3\", \"6 0\\n0 2\\n1 1\", \"6 2\\n1 5\\n2 3\", \"12 2\\n2 2\\n3 4\", \"6 0\\n0 4\\n1 1\", \"6 2\\n2 5\\n2 3\", \"5 2\\n1 2\\n3 2\", \"6 1\\n1 2\\n2 1\", \"3 2\\n1 2\\n2 3\", \"5 1\\n1 2\\n3 1\", \"17 0\\n1 4\\n2 4\", \"17 2\\n1 2\\n4 6\", \"5 2\\n1 2\\n5 1\", \"4 1\\n2 2\\n6 1\", \"4 0\\n1 2\\n3 1\", \"4 2\\n1 2\\n2 2\", \"4 1\\n2 2\\n2 0\", \"5 2\\n2 4\\n3 2\", \"5 0\\n2 4\\n1 1\", \"3 0\\n1 1\\n3 0\", \"9 2\\n1 3\\n2 2\", \"26 2\\n2 2\\n8 4\", \"6 0\\n2 1\\n1 1\", \"6 0\\n0 2\\n1 2\", \"12 2\\n2 2\\n3 1\", \"6 0\\n1 4\\n1 1\", \"6 2\\n2 5\\n1 3\", \"5 0\\n1 2\\n3 2\", \"6 0\\n1 2\\n2 1\", \"3 2\\n1 3\\n2 2\", \"5 1\\n1 3\\n3 1\", \"17 0\\n0 4\\n2 4\", \"4 1\\n2 2\\n6 2\", \"1 0\\n2 2\\n1 0\", \"4 0\\n1 2\\n5 1\", \"7 2\\n1 2\\n2 2\", \"6 0\\n1 2\\n3 1\", \"5 2\\n1 2\\n2 1\"], \"outputs\": [\"1\\n2\\n3\\n4\\n5\\n\", \"1 2\\n1 2\\n3\\n\", \"1\\n2\\n3\\n\", \"1 2\\n1 2\\n3\\n4\\n5\\n6\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n\", \"1\\n2\\n3\\n4\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n18\\n19\\n20\\n21\\n22\\n23\\n24\\n25\\n26\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n\", \"1\\n\", \"1 2\\n1 2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n18\\n19\\n20\\n21\\n22\\n23\\n24\\n25\\n26\\n27\\n28\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n18\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n18\\n19\\n20\\n21\\n22\\n23\\n24\\n25\\n26\\n27\\n28\\n29\\n30\\n31\\n32\\n33\\n34\\n35\\n36\\n37\\n38\\n39\\n40\\n41\\n42\\n43\\n44\\n45\\n46\\n47\\n48\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n18\\n19\\n20\\n21\\n22\\n23\\n24\\n25\\n26\\n27\\n28\\n29\\n30\\n31\\n32\\n33\\n34\\n35\\n36\\n37\\n38\\n39\\n40\\n41\\n42\\n43\\n44\\n45\\n46\\n47\\n48\\n49\\n50\\n51\\n52\\n53\\n54\\n55\\n56\\n57\\n58\\n59\\n60\\n61\\n62\\n63\\n64\\n65\\n66\\n67\\n68\\n69\\n70\\n71\\n72\\n73\\n74\\n75\\n76\\n77\\n78\\n79\\n80\\n81\\n82\\n83\\n84\\n85\\n86\\n87\\n88\\n89\\n90\\n91\\n92\\n93\\n94\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n18\\n19\\n20\\n21\\n22\\n23\\n24\\n25\\n26\\n27\\n28\\n29\\n30\\n31\\n32\\n33\\n34\\n35\\n36\\n37\\n38\\n39\\n40\\n41\\n42\\n43\\n44\\n45\\n46\\n47\\n48\\n49\\n50\\n51\\n52\\n53\\n54\\n55\\n56\\n57\\n58\\n59\\n60\\n61\\n62\\n63\\n64\\n65\\n66\\n67\\n68\\n69\\n70\\n71\\n72\\n73\\n74\\n75\\n76\\n77\\n78\\n79\\n80\\n81\\n82\\n83\\n84\\n85\\n86\\n87\\n88\\n89\\n90\\n91\\n92\\n93\\n94\\n95\\n96\\n97\\n98\\n99\\n100\\n101\\n102\\n103\\n104\\n105\\n106\\n107\\n108\\n109\\n110\\n111\\n112\\n113\\n114\\n115\\n116\\n117\\n118\\n119\\n120\\n121\\n122\\n123\\n124\\n125\\n126\\n127\\n128\\n129\\n130\\n131\\n132\\n133\\n134\\n135\\n136\\n137\\n138\\n139\\n140\\n141\\n142\\n143\\n144\\n145\\n146\\n147\\n148\\n149\\n150\\n151\\n152\\n153\\n154\\n155\\n156\\n157\\n158\\n159\\n160\\n161\\n162\\n163\\n164\\n165\\n166\\n167\\n168\\n169\\n170\\n171\\n172\\n173\\n174\\n175\\n176\\n177\\n178\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n18\\n19\\n20\\n21\\n22\\n23\\n24\\n25\\n26\\n27\\n28\\n29\\n30\\n31\\n32\\n33\\n34\\n35\\n36\\n37\\n38\\n39\\n40\\n41\\n42\\n43\\n44\\n45\\n46\\n47\\n48\\n49\\n50\\n51\\n52\\n53\\n54\\n55\\n56\\n57\\n58\\n59\\n60\\n61\\n62\\n63\\n64\\n65\\n66\\n67\\n68\\n69\\n70\\n71\\n72\\n73\\n74\\n75\\n76\\n77\\n78\\n79\\n80\\n81\\n82\\n83\\n84\\n85\\n86\\n87\\n88\\n89\\n90\\n91\\n92\\n93\\n94\\n95\\n96\\n97\\n98\\n99\\n100\\n101\\n102\\n103\\n104\\n105\\n106\\n107\\n108\\n109\\n110\\n111\\n112\\n113\\n114\\n115\\n116\\n117\\n118\\n119\\n120\\n121\\n122\\n123\\n124\\n125\\n126\\n127\\n128\\n129\\n130\\n131\\n132\\n133\\n134\\n135\\n136\\n137\\n138\\n139\\n140\\n141\\n142\\n143\\n144\\n145\\n146\\n147\\n148\\n149\\n150\\n151\\n152\\n153\\n154\\n155\\n156\\n157\\n158\\n159\\n160\\n161\\n162\\n163\\n164\\n165\\n166\\n167\\n168\\n169\\n170\\n171\\n172\\n173\\n174\\n175\\n176\\n177\\n178\\n179\\n180\\n181\\n182\\n183\\n184\\n185\\n186\\n187\\n188\\n189\\n190\\n191\\n192\\n193\\n194\\n195\\n196\\n197\\n198\\n199\\n200\\n201\\n202\\n203\\n204\\n205\\n206\\n207\\n208\\n209\\n210\\n211\\n\", \"1 2\\n1 2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n18\\n19\\n20\\n21\\n22\\n23\\n24\\n25\\n26\\n27\\n28\\n29\\n30\\n31\\n32\\n33\\n34\\n35\\n36\\n37\\n38\\n39\\n40\\n41\\n42\\n\", \"1\\n2\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n\", \"1 2\\n1 2\\n\", \"1\\n2 3\\n2 3\\n4\\n5\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n18\\n19\\n20\\n\", \"1 2\\n1 2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n18\\n19\\n20\\n21\\n22\\n23\\n24\\n25\\n26\\n27\\n28\\n29\\n30\\n31\\n32\\n33\\n34\\n35\\n36\\n37\\n38\\n39\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n18\\n19\\n20\\n21\\n22\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n18\\n19\\n20\\n21\\n22\\n23\\n24\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n18\\n19\\n20\\n21\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n18\\n19\\n20\\n21\\n22\\n23\\n24\\n25\\n26\\n27\\n28\\n29\\n30\\n31\\n32\\n33\\n34\\n35\\n36\\n37\\n38\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n18\\n19\\n20\\n21\\n22\\n23\\n24\\n25\\n26\\n27\\n28\\n29\\n30\\n31\\n32\\n33\\n34\\n35\\n36\\n37\\n38\\n39\\n40\\n41\\n42\\n43\\n44\\n45\\n46\\n47\\n48\\n49\\n50\\n51\\n52\\n53\\n54\\n55\\n56\\n57\\n58\\n59\\n60\\n61\\n62\\n63\\n64\\n65\\n66\\n67\\n68\\n69\\n70\\n71\\n72\\n73\\n74\\n75\\n76\\n77\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n18\\n19\\n20\\n21\\n22\\n23\\n24\\n25\\n26\\n27\\n28\\n29\\n30\\n31\\n32\\n33\\n34\\n35\\n36\\n37\\n38\\n39\\n40\\n41\\n42\\n43\\n44\\n45\\n46\\n47\\n48\\n49\\n50\\n51\\n52\\n53\\n54\\n55\\n56\\n57\\n58\\n59\\n60\\n61\\n62\\n63\\n64\\n65\\n66\\n67\\n68\\n69\\n70\\n71\\n72\\n73\\n74\\n75\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n18\\n19\\n20\\n21\\n22\\n23\\n24\\n25\\n26\\n27\\n28\\n29\\n30\\n31\\n32\\n33\\n34\\n35\\n36\\n37\\n38\\n39\\n40\\n41\\n42\\n43\\n44\\n45\\n46\\n47\\n48\\n49\\n50\\n51\\n52\\n53\\n54\\n55\\n56\\n57\\n58\\n59\\n60\\n61\\n62\\n63\\n64\\n65\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n18\\n19\\n20\\n21\\n22\\n23\\n24\\n25\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n18\\n19\\n20\\n21\\n22\\n23\\n24\\n25\\n26\\n27\\n28\\n29\\n30\\n31\\n32\\n33\\n34\\n35\\n36\\n37\\n38\\n39\\n40\\n\", \"1\\n2\\n3\\n4\\n5\\n\", \"1\\n2\\n3\\n4\\n5\\n\", \"1\\n2\\n3\\n\", \"1\\n2\\n3\\n4\\n5\\n\", \"1\\n2\\n3\\n4\\n5\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n\", \"1\\n2\\n3\\n4\\n\", \"1\\n2\\n3\\n4\\n5\\n\", \"1\\n2\\n3\\n4\\n5\\n\", \"1\\n2\\n3\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n\", \"1\\n2\\n3\\n4\\n\", \"1\\n2\\n3\\n4\\n5\\n\", \"1\\n2\\n3\\n4\\n5\\n\", \"1\\n2\\n3\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n\", \"1\\n2\\n3\\n4\\n\", \"1\\n2\\n3\\n4\\n5\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n\", \"1\\n2\\n3\\n4\\n5\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n\", \"1\\n2\\n3\\n\", \"1\\n2\\n3\\n4\\n5\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n\", \"1\\n2\\n3\\n4\\n5\\n\", \"1\\n2\\n3\\n4\\n\", \"1\\n2\\n3\\n4\\n\", \"1\\n2\\n3\\n4\\n\", \"1\\n2\\n3\\n4\\n\", \"1\\n2\\n3\\n4\\n5\\n\", \"1\\n2\\n3\\n4\\n5\\n\", \"1\\n2\\n3\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n18\\n19\\n20\\n21\\n22\\n23\\n24\\n25\\n26\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n\", \"1\\n2\\n3\\n4\\n5\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n\", \"1\\n2\\n3\\n\", \"1\\n2\\n3\\n4\\n5\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n10\\n11\\n12\\n13\\n14\\n15\\n16\\n17\\n\", \"1\\n2\\n3\\n4\\n\", \"1\\n\", \"1\\n2\\n3\\n4\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n7\\n\", \"1\\n2\\n3\\n4\\n5\\n6\\n\", \"1 2\\n1 2\\n3\\n4\\n5\"], \"remote\": false}", "task": "code", "language": "PYTHON", "solutions": ["import sys\ninput = sys.stdin.readline\n\ndef inpl():\n\treturn list(map(int, input().split()))\n(N, M) = inpl()\nG = [[] for _ in range(N)]\nrG = [[] for _ in range(N)]\nfor i in range(M):\n\t(a, b) = inpl()\n\tG[a - 1].append(b - 1)\n\trG[b - 1].append(a - 1)\n\ndef SCC(G, rG):\n\tN = len(G)\n\n\tdef dfs(i):\n\t\tnonlocal t, rorder, searched\n\t\tsearched[i] = True\n\t\tfor j in G[i]:\n\t\t\tif not searched[j]:\n\t\t\t\tdfs(j)\n\t\trorder[t] = i\n\t\tt += 1\n\n\tdef rdfs(i):\n\t\tnonlocal t, group, g\n\t\tgroup[i] = g\n\t\tfor j in rG[i]:\n\t\t\tif group[j] == -1:\n\t\t\t\trdfs(j)\n\tt = 0\n\trorder = [-1] * N\n\tsearched = [0] * N\n\tgroup = [-1] * N\n\tfor i in range(N):\n\t\tif not searched[i]:\n\t\t\tdfs(i)\n\tg = 0\n\tfor i in range(N - 1, -1, -1):\n\t\tif group[rorder[i]] == -1:\n\t\t\trdfs(rorder[i])\n\t\t\tg += 1\n\treturn (group, g)\n(group, g) = SCC(G, rG)\nans = [[] for _ in range(g)]\nfor i in range(N):\n\tans[group[i]].append(i + 1)\nfor i in range(N):\n\tprint(*ans[group[i]])\n", "def scc(N, G, RG):\n\torder = []\n\tused = [0] * N\n\tgroup = [None] * N\n\n\tdef dfs(s):\n\t\tused[s] = 1\n\t\tfor t in G[s]:\n\t\t\tif not used[t]:\n\t\t\t\tdfs(t)\n\t\torder.append(s)\n\n\tdef rdfs(s, col):\n\t\tgroup[s] = col\n\t\tused[s] = 1\n\t\tfor t in RG[s]:\n\t\t\tif not used[t]:\n\t\t\t\trdfs(t, col)\n\tfor i in range(N):\n\t\tif not used[i]:\n\t\t\tdfs(i)\n\tused = [0] * N\n\tlabel = 0\n\tfor s in reversed(order):\n\t\tif not used[s]:\n\t\t\trdfs(s, label)\n\t\t\tlabel += 1\n\treturn (label, group)\n(N, M) = map(int, input().split())\nE1 = [[] for _ in range(N + 1)]\nE2 = [[] for _ in range(N + 1)]\nfor _ in range(M):\n\t(a, b) = map(int, input().split())\n\tE1[a].append(b)\n\tE2[b].append(a)\n(label, group) = scc(N + 1, E1, E2)\nfor g in group[1:]:\n\tans = []\n\tfor (n, gg) in enumerate(group[1:], 1):\n\t\tif g == gg:\n\t\t\tans.append(n)\n\tprint(' '.join(map(str, ans)))\n", "import sys\nimport math\nfrom bisect import bisect_right as br\nfrom bisect import bisect_left as bl\nsys.setrecursionlimit(1000000000)\nfrom heapq import heappush, heappop, heappushpop\nfrom collections import defaultdict\nfrom itertools import accumulate\nfrom collections import Counter\nfrom collections import deque\nfrom operator import itemgetter\nfrom itertools import permutations\nmod = 10 ** 9 + 7\ninf = float('inf')\n\ndef I():\n\treturn int(sys.stdin.readline())\n\ndef LI():\n\treturn list(map(int, sys.stdin.readline().split()))\n(n, m) = LI()\ngraph = [[] for _ in range(n)]\nfor _ in range(m):\n\t(a, b) = LI()\n\tgraph[a - 1].append(b - 1)\n\ndef dfs(s, lst, graph, check):\n\tlst.append(s)\n\tcheck[s] = False\n\tfor v in graph[s]:\n\t\tif check[v]:\n\t\t\tdfs(v, lst, graph, check)\n\treturn lst\nL = []\nfor i in range(n):\n\tc = [True] * n\n\tl = dfs(i, [], graph, c)\n\tl.sort()\n\tL.append(l)\nans = [[] for _ in range(n)]\nfor i in range(n):\n\tfor j in L[i]:\n\t\tif i in L[j]:\n\t\t\tans[i].append(j + 1)\nfor i in range(n):\n\tprint(*ans[i])\n"], "verify": 1, "wrong_type": null} diff --git a/arealite/tests/data/rlvr_math_dataset.jsonl b/arealite/tests/data/rlvr_math_dataset.jsonl new file mode 100644 index 000000000..f98e0bbfa --- /dev/null +++ b/arealite/tests/data/rlvr_math_dataset.jsonl @@ -0,0 +1,10 @@ +{"prompt": "<\uff5cUser\uff5c>\nBaron Munchausen told a story. \"There were a whole crowd of us. We reached a crossroads. Then half of our group turned left, a third turned right, and a fifth went straight.\" \"But wait, the Duke remarked, the sum of half, a third, and a fifth isn't equal to one, so you are lying!\" The Baron replied, \"I'm not lying, I'm rounding. For example, there are 17 people. I say that a third turned. Should one person split in your opinion? No, with rounding, six people turned. From whole numbers, the closest to the fraction $17 / 3$ is 6. And if I say that half of the 17 people turned, it means 8 or 9 people.\" It is known that Baron Munchausen never lies. What is the largest number of people that could have been in the crowd?\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c>", "task": "math", "query_id": "00006d8f079c739f", "solutions": ["\\boxed{37}"]} +{"prompt": "<\uff5cUser\uff5c>What is the unit digit of the product\n\n$$\n(5+1)\\left(5^{3}+1\\right)\\left(5^{6}+1\\right)\\left(5^{12}+1\\right) ?\n$$\n\n(a) 0 \n(b) 1 \n(c) 2 \n(d) 5 \n(e) 6\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c>", "task": "math", "query_id": "000316109ea516b3", "solutions": ["\\boxed{e}"]} +{"prompt": "<\uff5cUser\uff5c>Given points \\( A(4,0) \\) and \\( B(2,2) \\) are inside the ellipse \\( \\frac{x^{2}}{25}+\\frac{y^{2}}{9}=1 \\), and \\( M \\) is a point on the ellipse, find the maximum value of \\( |MA| + |MB| \\).\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c>", "task": "math", "query_id": "000adcfa66ee4270", "solutions": ["\\boxed{10+2\\sqrt{10}}"]} +{"prompt": "<\uff5cUser\uff5c>There is a schoolbag containing 12 cards labeled $1, 1, 2, 2, \\cdots, 6, 6$. A person draws one card at a time without replacement. If a card is drawn that has the same number as a previously drawn card, both cards are discarded. The process ends when the person has 3 single cards in hand or all cards in the schoolbag have been drawn. Find the probability that all cards in the schoolbag are drawn.\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c>", "task": "math", "query_id": "001354647264e663", "solutions": ["\\boxed{\\frac{9}{385}}"]} +{"prompt": "<\uff5cUser\uff5c>For the sequence of numbers \\( n_{1}, n_{2}, n_{3}, \\ldots \\), the relation \\( n_{i} = 2 n_{i-1} + a \\) holds for all \\( i > 1 \\). If \\( n_{2} = 5 \\) and \\( n_{8} = 257 \\), what is \\( n_{5} \\)?\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c>", "task": "math", "query_id": "0014142e5f3c28a7", "solutions": ["\\boxed{33}"]} +{"prompt": "<\uff5cUser\uff5c>Three players play tic-tac-toe together. In other words, the three players take turns placing an \"A\", \"B\", and \"C\", respectively, in one of the free spots of a \\(3 \\times 3\\) grid, and the first player to have three of their label in a row, column, or diagonal wins. How many possible final boards are there where the player who goes third wins the game? (Rotations and reflections are considered different boards, but the order of placement does not matter.)\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c>", "task": "math", "query_id": "0017c4e9f72d26eb", "solutions": ["\\boxed{148}"]} +{"prompt": "<\uff5cUser\uff5c>Let \\( a_{1}, a_{2}, \\cdots, a_{2014} \\) be a permutation of the positive integers \\( 1, 2, \\cdots, 2014 \\). Define\n\\[ S_{k} = a_{1} + a_{2} + \\cdots + a_{k} \\quad (k=1, 2, \\cdots, 2014). \\]\n\nWhat is the maximum number of odd numbers among \\( S_{1}, S_{2}, \\cdots, S_{2014} \\)?\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c>", "task": "math", "query_id": "00231541a71983cd", "solutions": ["\\boxed{1511}"]} +{"prompt": "<\uff5cUser\uff5c>\nThe polynomial \\( G(x) \\) with real coefficients takes the value 2022 at exactly five distinct points \\( x_{1}", "task": "math", "query_id": "002ba4c0d1ad1b54", "solutions": ["\\boxed{6}"]} +{"prompt": "<\uff5cUser\uff5c>The square of a natural number has 202 digits. The first 100 digits are 1, followed by 101 digits of 2. Determine the last digit and the number.\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c>", "task": "math", "query_id": "0041f14cc37aee13", "solutions": ["\\boxed{5}"]} +{"prompt": "<\uff5cUser\uff5c>The descriptors 'even', 'factors of 240', 'multiple of 3', 'odd', 'prime' and 'square' are to be placed in some order as row and column headings around a grid in positions \\(a, b, c, d, e,\\) and \\(f\\). The digits 1 through 9 are to be placed in the empty cells inside the grid so that each digit satisfies both the relevant row and column headings.\n(i) Show that it is possible to complete the grid.\n(ii) In how many different ways can the grid be completed?\nPlease reason step by step, and put your final answer within \\boxed{}.<\uff5cAssistant\uff5c>", "task": "math", "query_id": "00514ff45cc98a48", "solutions": ["\\boxed{72}"]} diff --git a/arealite/tests/test_engine.py b/arealite/tests/test_engine.py new file mode 100644 index 000000000..03b1559b8 --- /dev/null +++ b/arealite/tests/test_engine.py @@ -0,0 +1,176 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +"""Test script for HF Engine implementation.""" + +import os +from typing import Dict + +import pytest +import torch +from transformers import AutoTokenizer + +from arealite.api.cli_args import ( + EngineBackendConfig, + EngineConfig, + MicroBatchSpec, + OptimizerConfig, + TrainingArgs, +) +from arealite.api.engine_api import EngineFactory +from arealite.api.io_struct import FinetuneSpec +from arealite.utils import compute_varlen_position_indices +from realhf.impl.model.utils.padding import unpad_input + +VOCAB_SIZE = 100 +MODEL_PATH = "Qwen/Qwen2-0.5B" + + +@pytest.fixture(scope="module") +def mock_input(bs: int = 3, min_seqlen: int = 3, max_seqlen: int = 12) -> Dict: + """Create mock input data for testing.""" + seqlens = torch.randint( + min_seqlen, max_seqlen, (bs,), dtype=torch.int, device="cuda:0" + ) + max_seqlen = int(max(seqlens)) + input_ids = torch.randint( + 0, VOCAB_SIZE, (bs, max_seqlen), dtype=torch.long, device="cuda:0" + ) + + attn_mask = torch.zeros((bs, max_seqlen), dtype=torch.bool, device="cuda:0") + attn_mask[ + torch.arange(0, max_seqlen, device="cuda:0").unsqueeze(0) < seqlens.unsqueeze(1) + ] = 1 + + packed_input_ids, indices, cu_seqlens, max_seqlen = unpad_input( + input_ids, attn_mask + ) + + assert torch.allclose( + cu_seqlens, torch.nn.functional.pad(seqlens.cumsum(0, dtype=torch.int), (1, 0)) + ) + position_ids = compute_varlen_position_indices(int(sum(seqlens)), cu_seqlens) + + return dict( + input_ids=packed_input_ids.unsqueeze(0), + attention_mask=None, + position_ids=position_ids.unsqueeze(0), + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + use_cache=False, + ) + + +def mock_loss_fn(logits: torch.Tensor, input_data: Dict) -> torch.Tensor: + """Mock loss function for testing.""" + return torch.mean(logits) + + +@pytest.fixture(params=["hf", "fsdp"], scope="module") +def backend_type(request): + return request.param + + +@pytest.fixture(scope="module") +def engine(backend_type): + os.environ["WORLD_SIZE"] = "1" + os.environ["RANK"] = "0" + os.environ["LOCAL_RANK"] = "0" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "7777" + + engine_config = EngineConfig( + path=MODEL_PATH, + gradient_checkpointing=False, + optimizer=OptimizerConfig(), + backend=EngineBackendConfig(type=backend_type), + ) + + mock_args = TrainingArgs(n_nodes=1, n_gpus_per_node=1) + + engine_factory = EngineFactory(mock_args) + engine = engine_factory.make_engine(engine_config) + ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2) + engine.init_distributed(None, ft_spec) + engine.load_model_from_hf(engine_config.path) + print("✓ Engine created successfully") + yield engine + + +def test_forward_microbatch(engine, mock_input): + x2 = ( + engine.forward( + input_=mock_input, + mb_spec=MicroBatchSpec(n_mbs=2), + aggregate_fn=lambda x: torch.cat(x, dim=1), + ) + .squeeze(0) + .mean(-1) + ) + x1 = ( + engine.forward( + input_=mock_input, + mb_spec=MicroBatchSpec(n_mbs=1), + aggregate_fn=lambda x: torch.cat(x, dim=1), + ) + .squeeze(0) + .mean(-1) + ) + input_ids = mock_input["input_ids"].squeeze(0) + assert x1.shape[0] == input_ids.shape[0] + assert x2.shape[0] == input_ids.shape[0] + assert torch.allclose(x1, x2, atol=1e-1, rtol=1e-2), (x1 - x2).abs().max().item() + + +def test_eval_batch(engine, mock_input): + eval_result = engine.eval_batch( + input_=mock_input, + mb_spec=MicroBatchSpec(n_mbs=2), + loss_fn=mock_loss_fn, + loss_weight_fn=lambda x: x["cu_seqlens"][-1], + ) + assert isinstance(eval_result, torch.Tensor), "Evaluation should return a tensor" + assert eval_result.is_cuda, "Evaluation tensor should be on CUDA device" + assert eval_result is not None, "Evaluation should return a loss value" + print(f"✓ Evaluation successful, loss: {eval_result.item()}") + + +def test_train_batch(tmp_path_factory, engine, mock_input): + path = tmp_path_factory.mktemp("hf_engine_train_batch") + engine.save_optimizer_state(path) + + train_result = engine.train_batch( + input_=mock_input, + mb_spec=MicroBatchSpec(n_mbs=2), + loss_fn=mock_loss_fn, + loss_weight_fn=lambda x: x["cu_seqlens"][-1], + ) + assert isinstance(train_result, dict), "Training should return a dictionary" + assert train_result["grad_norm"] is not None + assert train_result["lr"] is not None + print("✓ Training successful") + + engine.load_optimizer_state(path) + + +@torch.no_grad() +def test_save_load_weights(tmp_path_factory, engine, mock_input): + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) + path = tmp_path_factory.mktemp("hf_engine_test") + + old = engine.forward( + input_=mock_input, + mb_spec=MicroBatchSpec(n_mbs=1), + ) + engine.save_model_to_hf(path=path, tokenizer=tokenizer) + + for name, param in engine.model.named_parameters(): + param.zero_() + + engine.load_model_from_hf(path=path) + new = engine.forward( + input_=mock_input, + mb_spec=MicroBatchSpec(n_mbs=1), + ) + + assert torch.allclose(old, new) diff --git a/arealite/tests/test_grpo.py b/arealite/tests/test_grpo.py new file mode 100644 index 000000000..24881a2a2 --- /dev/null +++ b/arealite/tests/test_grpo.py @@ -0,0 +1,106 @@ +"""Test script for GRPO Trainer implementation.""" + +import pytest +from datasets import load_dataset + +from arealite.api.cli_args import ( + DatasetConfig, + EngineBackendConfig, + EngineConfig, + GRPOTrainerConfig, + OptimizerConfig, + RLVRConfig, + TrainerConfig, + TrainingArgs, +) +from arealite.api.io_struct import FinetuneSpec +from arealite.api.rollout_api import RolloutCollectorFactory +from arealite.impl.trainer.grpo import SpmdGRPOTrainer +from arealite.system.rollout_controller import RolloutController +from arealite.tests.utils import mock_rollout_output +from realhf.base import constants, name_resolve, seeding + +EXPR_NAME = "test_grpo" +TRIAL_NAME = "test_grpo" +MODEL_PATH = "Qwen/Qwen2-0.5B" + + +@pytest.fixture(scope="module") +def args(): + args = TrainingArgs(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME) + constants.set_experiment_trial_names(args.experiment_name, args.trial_name) + seeding.set_random_seed(args.seed, EXPR_NAME) + args.train_dataset = DatasetConfig( + path="openai/gsm8k", + name="main", + split="train", + batch_size=4, + shuffle=True, + pin_memory=True, + num_workers=1, + ) + args.trainer = TrainerConfig(type="grpo", grpo=GRPOTrainerConfig()) + args.trainer.grpo.actor = EngineConfig( + path=MODEL_PATH, + gradient_checkpointing=False, + optimizer=OptimizerConfig(), + backend=EngineBackendConfig(type="hf"), + ) + args.trainer.grpo.ref = EngineConfig( + path=MODEL_PATH, + gradient_checkpointing=False, + backend=EngineBackendConfig(type="hf"), + ) + args.rollout.model_path = MODEL_PATH + args.rollout.server_backend = "sglang" + args.rollout.collector.rlvr = RLVRConfig(solution_path="nothing") + args.rollout.gconfig.max_new_tokens = 16 + name_resolve.reconfigure(args.cluster.name_resolve) + yield args + name_resolve.reset() + + +@pytest.mark.parametrize("kl_ctl", [0.0, 0.1]) +@pytest.mark.parametrize("bs", [4]) +@pytest.mark.parametrize("n_samples", [2]) +@pytest.mark.parametrize("recompute", [False, True]) +@pytest.mark.parametrize("use_decoupled_loss", [False, True]) +def test_train_step(args, kl_ctl, bs, n_samples, recompute, use_decoupled_loss): + args.rollout.gconfig.n_samples = n_samples + args.trainer.grpo.kl_ctl = kl_ctl + args.trainer.grpo.recompute_logprobs = recompute + args.trainer.grpo.use_decoupled_loss = use_decoupled_loss + args.train_dataset.batch_size = bs + # Create mock rollout controller and trainer + rollout_factory = RolloutCollectorFactory(args) + collector = rollout_factory.make_collector(args.rollout.collector) + rollout_controller = RolloutController(args, args.rollout, collector=collector) + dataset = load_dataset("openai/gsm8k", name="main", split="train").select(range(10)) + + trainer = SpmdGRPOTrainer( + args=args, + trainer_config=args.trainer, + train_dataset=dataset, + rollout_controller=rollout_controller, + ) + ft_spec = FinetuneSpec( + total_train_epochs=1, + dataset_size=100, + train_batch_size=args.train_dataset.batch_size, + ) + trainer.actor.init_distributed(None, ft_spec) + trainer.actor.eval() + if trainer.ref is not None: + trainer.ref.init_distributed(None, ft_spec) + trainer.ref.eval() + + rollout_output = mock_rollout_output(bs, n_samples) + stats_list = trainer._train_step(rollout_output) + + # Verify the output + assert isinstance(stats_list, list) + assert len(stats_list) == args.trainer.grpo.ppo_n_minibatches + for stats in stats_list: + assert isinstance(stats, dict) + for k, v in stats.items(): + assert isinstance(v, float) diff --git a/arealite/tests/test_rollout.py b/arealite/tests/test_rollout.py new file mode 100644 index 000000000..3f6bd43fb --- /dev/null +++ b/arealite/tests/test_rollout.py @@ -0,0 +1,172 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +import json +from datetime import datetime +from pathlib import Path + +import pytest +from datasets import load_dataset + +from arealite.api.cli_args import ( + DatasetPreprocessor, + GenerationHyperparameters, + GSM8KPreprocessor, + MathCodeSingleStepConfig, + RLVRConfig, + RolloutCollectorConfig, + SGLangConfig, + TrainingArgs, +) +from arealite.api.io_struct import Trajectory +from arealite.api.llm_client_api import LLMClientFactory +from arealite.api.llm_server_api import LLMServerFactory +from arealite.api.rollout_api import RolloutCollectorFactory +from realhf.api.core.data_api import load_hf_tokenizer +from realhf.base import name_resolve, seeding + +EXPR_NAME = "test_rollout" +TRIAL_NAME = "test_rollout" +MODEL_PATH = "Qwen/Qwen2-0.5B" + + +@pytest.fixture(scope="module") +def tokenizer(): + yield load_hf_tokenizer(MODEL_PATH) + + +@pytest.fixture(scope="module") +def args(): + args = TrainingArgs(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME) + args.rollout.model_path = MODEL_PATH + seeding.set_random_seed(args.seed, EXPR_NAME) + name_resolve.reconfigure(args.cluster.name_resolve) + yield args + name_resolve.reset() + + +@pytest.fixture(scope="module") +def sglang_server(args): + args.rollout.sglang = SGLangConfig() + server = LLMServerFactory(args).make_server(args.rollout.llm_service) + server._startup() + yield + server._graceful_exit(0) + + +@pytest.mark.parametrize("task", ["math", "code"]) +@pytest.mark.asyncio +async def test_rlvr_rollout(args, sglang_server, tokenizer, task): + jsonl_file = Path(__file__).parent / "data" / f"rlvr_{task}_dataset.jsonl" + args.rollout.server_backend = "sglang" + args.rollout.gconfig = gconfig = GenerationHyperparameters(max_new_tokens=16) + args.rollout.collector = RolloutCollectorConfig( + type="rlvr", + rlvr=RLVRConfig(reward_type=f"areal-{task}", solution_path=jsonl_file), + ) + llm_client = LLMClientFactory(args).make_client(args.rollout.llm_client) + collector = RolloutCollectorFactory(args).make_collector(args.rollout.collector) + + # Test the rollout collector with the provided JSONL data + with open(jsonl_file, "r") as f: + for i, l in enumerate(f.readlines()): + data = json.loads(l) + env_option = dict( + query_id=data["query_id"], + input_ids=tokenizer.encode(data["prompt"]), + prompt=data["prompt"], + ) + res = await collector.arun_episode( + llm_client=llm_client, + gconfig=gconfig, + env_option=env_option, + ) + assert isinstance(res, Trajectory) + assert isinstance(res.data, dict) + assert res.prompt == env_option + shape = res.data["input_ids"].shape + for k in ["prompt_mask", "logprobs", "versions"]: + assert res.data[k].shape == shape + assert res.stats.episode_length == 1 + assert res.stats.total_reward in [0, 1], res.stats.total_reward + assert res.stats.start_time < datetime.now().timestamp() + + +@pytest.mark.asyncio +async def test_gsm8k_rollout(args, sglang_server, tokenizer): + args.rollout.server_backend = "sglang" + args.rollout.gconfig = gconfig = GenerationHyperparameters(max_new_tokens=16) + args.rollout.collector = RolloutCollectorConfig( + type="rlvr", rlvr=RLVRConfig(reward_type="gsm8k") + ) + collector = RolloutCollectorFactory(args).make_collector(args.rollout.collector) + + args.train_dataset.path = "openai/gsm8k" + args.train_dataset.name = "main" + args.train_dataset.split = "train" + args.train_dataset.preprocessor = DatasetPreprocessor( + "gsm8k_rl", gsm8k=GSM8KPreprocessor("strict") + ) + + from arealite.api.dataset_api import DatasetFactory + + llm_client = LLMClientFactory(args).make_client(args.rollout.llm_client) + dataset = ( + DatasetFactory(args) + .make_dataset(args.train_dataset, rank=0, world_size=1) + .select(range(10)) + ) + for i in range(len(dataset)): + env_option = dataset[i] + res = await collector.arun_episode( + llm_client=llm_client, + gconfig=gconfig, + env_option=env_option, + ) + assert isinstance(res, Trajectory) + assert isinstance(res.data, dict) + assert res.prompt == env_option + shape = res.data["input_ids"].shape + for k in ["prompt_mask", "logprobs", "versions"]: + assert res.data[k].shape == shape + assert res.stats.episode_length == 1 + assert res.stats.total_reward in [0, 1], res.stats.total_reward + assert res.stats.start_time < datetime.now().timestamp() + + +@pytest.mark.parametrize("task", ["math", "code"]) +@pytest.mark.asyncio +async def test_math_code_agentic_rollout(args, task, sglang_server, tokenizer): + jsonl_file = Path(__file__).parent / "data" / f"rlvr_{task}_dataset.jsonl" + args.rollout.server_backend = "sglang" + args.rollout.gconfig = gconfig = GenerationHyperparameters(max_new_tokens=16) + args.rollout.collector = RolloutCollectorConfig( + type="math_code_single_step", + math_code_single_step=MathCodeSingleStepConfig(solution_path=jsonl_file), + ) + + collector = RolloutCollectorFactory(args).make_collector(args.rollout.collector) + llm_client = LLMClientFactory(args).make_client(args.rollout.llm_client) + + # Test the rollout collector with the provided JSONL data + with open(jsonl_file, "r") as f: + for i, l in enumerate(f.readlines()): + data = json.loads(l) + env_option = dict( + query_id=data["query_id"], + input_ids=tokenizer.encode(data["prompt"]), + ) + res = await collector.arun_episode( + llm_client=llm_client, + gconfig=gconfig, + env_option=env_option, + ) + assert isinstance(res, Trajectory) + assert isinstance(res.data, dict) + assert res.prompt == env_option + shape = res.data["input_ids"].shape + for k in ["prompt_mask", "logprobs", "versions"]: + assert res.data[k].shape == shape + assert res.stats.episode_length == 1 + assert res.stats.total_reward in [0, 1], res.stats.total_reward + assert res.stats.start_time < datetime.now().timestamp() diff --git a/arealite/tests/test_rollout_controller.py b/arealite/tests/test_rollout_controller.py new file mode 100644 index 000000000..636f34933 --- /dev/null +++ b/arealite/tests/test_rollout_controller.py @@ -0,0 +1,209 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +import time +from copy import deepcopy +from pathlib import Path + +import pytest +import torch.multiprocessing as mp +from datasets import load_dataset +from torchdata.stateful_dataloader import StatefulDataLoader + +from arealite.api.cli_args import RLVRConfig, SGLangConfig, TrainingArgs +from arealite.api.io_struct import Trajectory +from arealite.api.llm_server_api import LLMServerFactory +from arealite.api.rollout_api import RolloutCollectorFactory +from arealite.system.rollout_controller import RolloutController +from arealite.tests.utils import mock_rollout_output +from arealite.utils import concat_padded_tensors +from realhf.api.core.data_api import load_hf_tokenizer +from realhf.base import name_resolve, names, seeding + +EXPR_NAME = "test_rollout_controller" +TRIAL_NAME = "test_rollout_controller" +MODEL_PATH = "Qwen/Qwen2-0.5B" + + +@pytest.fixture(scope="module") +def args(): + args = TrainingArgs(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME) + seeding.set_random_seed(args.seed, EXPR_NAME) + args.rollout.model_path = MODEL_PATH + args.rollout.llm_client.tokenizer_path = MODEL_PATH + args.train_dataset.batch_size = 2 + args.rollout.collector.rlvr = RLVRConfig( + solution_path=str(Path(__file__).parent / "data" / f"rlvr_math_dataset.jsonl") + ) + start_method = mp.get_start_method() + mp.set_start_method("fork", force=True) + name_resolve.reconfigure(args.cluster.name_resolve) + yield args + name_resolve.reset() + mp.set_start_method(start_method, force=True) + + +@pytest.fixture(scope="module") +def sglang_server(args): + args.rollout.sglang = SGLangConfig() + server = LLMServerFactory(args).make_server(args.rollout.llm_service) + server._startup() + yield + server._graceful_exit(0) + + +@pytest.fixture +def dataloader(args): + dataset = load_dataset( + "json", + split="train", + data_files=str(Path(__file__).parent / "data" / f"rlvr_math_dataset.jsonl"), + ) + tokenizer = load_hf_tokenizer(MODEL_PATH) + dataset = dataset.map(lambda x: tokenizer(x["prompt"]), batched=True) + yield StatefulDataLoader( + dataset, + batch_size=args.train_dataset.batch_size, + collate_fn=lambda x: x, + drop_last=True, + ) + + +@pytest.mark.parametrize("num_workers", [1, 4]) +@pytest.mark.parametrize("n_samples", [1, 2]) +def test_generate_batch(args, sglang_server, dataloader, n_samples, num_workers): + args = deepcopy(args) + args.rollout.num_workers = num_workers + args.rollout.gconfig.n_samples = n_samples + args.rollout.gconfig.max_new_tokens = 16 + rollout_factory = RolloutCollectorFactory(args) + collector = rollout_factory.make_collector(args.rollout.collector) + rollout_controller = RolloutController(args, args.rollout, collector=collector) + + data = next(iter(dataloader)) + batch_size = len(data) + result = rollout_controller.generate_batch(batch_size, env_options=data) + + assert len(result) == batch_size * n_samples + assert all(isinstance(traj, Trajectory) for traj in result) + for traj in result: + shape = traj.data["input_ids"].shape + assert len(shape) == 2 + for v in traj.data.values(): + assert v.shape == shape or len(v.shape) == 1 + data = concat_padded_tensors([traj.data for traj in result]) + assert data["input_ids"].shape[0] == batch_size * n_samples + shape = data["input_ids"].shape + assert len(shape) == 2 + for v in data.values(): + assert v.shape == shape or len(v.shape) == 1 + + +@pytest.mark.parametrize("batch_size", [1, 2, 3]) +@pytest.mark.parametrize("n_samples", [1, 2, 4]) +def test_mock_trajs(batch_size, n_samples): + # Test the consistency with mocked rollout output + result = mock_rollout_output(batch_size, n_samples) + assert len(result) == batch_size * n_samples + assert all(isinstance(traj, Trajectory) for traj in result) + for traj in result: + shape = traj.data["input_ids"].shape + assert len(shape) == 2 + for v in traj.data.values(): + assert v.shape == shape or len(v.shape) == 1 + data = concat_padded_tensors([traj.data for traj in result]) + assert data["input_ids"].shape[0] == batch_size * n_samples + shape = data["input_ids"].shape + assert len(shape) == 2 + for v in data.values(): + assert v.shape == shape or len(v.shape) == 1 + + +@pytest.mark.parametrize("n_samples", [1, 4, 16]) +@pytest.mark.parametrize("num_workers", [1, 2, 5]) +def test_async_rollout(args, sglang_server, dataloader, n_samples, num_workers): + args = deepcopy(args) + args.rollout.gconfig.n_samples = n_samples + args.rollout.gconfig.max_new_tokens = 16 + args.train_dataset.batch_size = 2 + args.rollout.max_concurrent_rollouts = 16 + args.rollout.num_workers = num_workers + rollout_factory = RolloutCollectorFactory(args) + collector = rollout_factory.make_collector(args.rollout.collector) + rollout_controller = RolloutController(args, args.rollout, collector=collector) + + # start loop + rollout_controller.start_generate_loop() + assert hasattr(rollout_controller, "_collector_thread") + assert rollout_controller._collector_thread.is_alive() + + # Submit data to workers + data = next(iter(dataloader)) + rollout_controller.submit(data) + + # wait for batch + batch_size = 2 + result = rollout_controller.prepare_batch(batch_size) + assert len(result) == batch_size * n_samples + assert all(isinstance(traj, Trajectory) for traj in result) + for traj in result: + shape = traj.data["input_ids"].shape + assert len(shape) == 2 + for v in traj.data.values(): + assert v.shape == shape or len(v.shape) == 1 + data = concat_padded_tensors([traj.data for traj in result]) + assert data["input_ids"].shape[0] == batch_size * n_samples + shape = data["input_ids"].shape + assert len(shape) == 2 + for v in data.values(): + assert v.shape == shape or len(v.shape) == 1 + + # exit + rollout_controller.stop_generate_loop() + assert rollout_controller._exiting.is_set() + assert not rollout_controller._collector_thread.is_alive() + assert not rollout_controller._worker_processes + + +@pytest.mark.parametrize("ofp", [1, 2, 4, 16]) +def test_async_staleness_control(args, sglang_server, dataloader, ofp): + args = deepcopy(args) + args.rollout.gconfig.n_samples = 2 + args.rollout.gconfig.max_new_tokens = 4 + args.rollout.max_head_offpolicyness = ofp + args.rollout.max_concurrent_rollouts = 100 + rollout_factory = RolloutCollectorFactory(args) + collector = rollout_factory.make_collector(args.rollout.collector) + rollout_controller = RolloutController(args, args.rollout, collector=collector) + name = names.model_version(args.experiment_name, args.trial_name, "actor") + name_resolve.add(name, str(0), replace=True) + + # start loop + rollout_controller.start_generate_loop() + batch_size = args.train_dataset.batch_size + + gen = iter(dataloader) + rollout_controller.submit(next(gen)) + rollout_controller.submit(next(gen)) + # wait for some time + time.sleep(15) + assert len(rollout_controller._buffer) == min( + batch_size * 2, batch_size * (ofp + 1) + ) + + # Update model version + name = names.model_version(args.experiment_name, args.trial_name, "actor") + name_resolve.add(name, str(1), replace=True) + print("Updated model version", flush=True) + + # submit again + rollout_controller.submit(next(gen)) + rollout_controller.submit(next(gen)) + # wait for some time + time.sleep(15) + assert len(rollout_controller._buffer) == min( + batch_size * 4, batch_size * (ofp + 2) + ) + + # exit + rollout_controller.stop_generate_loop() diff --git a/arealite/tests/test_sft.py b/arealite/tests/test_sft.py new file mode 100644 index 000000000..eaa81cd52 --- /dev/null +++ b/arealite/tests/test_sft.py @@ -0,0 +1,107 @@ +"""Test script for FSDP Engine implementation.""" + +import os +from typing import Dict + +import torch +from datasets import load_dataset + +from arealite.api.cli_args import ( + DatasetConfig, + DatasetPreprocessor, + EngineBackendConfig, + EngineConfig, + OptimizerConfig, + SFTTrainerConfig, + TrainerConfig, + TrainingArgs, +) +from arealite.api.dataset_api import DatasetFactory +from arealite.api.trainer_api import TrainerFactory + + +def mock_loss_fn(logits: torch.Tensor, input_data: Dict) -> torch.Tensor: + """Mock loss function for testing.""" + return torch.mean(logits) + + +def mock_loss_weight_fn(logits: torch.Tensor, input_data: Dict) -> float: + """Mock loss weight function for testing.""" + return float(input_data["attention_mask"].sum()) + + +def test_sft(): + """Test SFTTrainer""" + # environment variables for torch distributed + os.environ["WORLD_SIZE"] = "1" + os.environ["RANK"] = "0" + os.environ["LOCAL_RANK"] = "0" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "7777" + + train_dataset = DatasetConfig( + path="openai/gsm8k", + preprocessor=DatasetPreprocessor("gsm8k_sft"), + name="main", + split="train", + batch_size=8, + shuffle=True, + pin_memory=True, + ) + + valid_dataset = DatasetConfig( + path="openai/gsm8k", + preprocessor=DatasetPreprocessor("gsm8k_sft"), + name="main", + split="test", + batch_size=8, + shuffle=False, + pin_memory=True, + ) + + engine_config = EngineConfig( + path="Qwen/Qwen2-0.5B", + gradient_checkpointing=False, + optimizer=OptimizerConfig(), + backend=EngineBackendConfig(type="hf"), + ) + + sft_config = SFTTrainerConfig( + model=engine_config, + ) + + train_config = TrainerConfig( + type="sft", + sft=sft_config, + ) + + args = TrainingArgs( + experiment_name="test-sft", + trial_name="test", + mode="local", + n_nodes=1, + n_gpus_per_node=1, + train_dataset=train_dataset, + valid_dataset=valid_dataset, + trainer=train_config, + ) + + rollout_controller = None + dataset_factory = DatasetFactory(args) + train_dataset = dataset_factory.make_dataset(args.train_dataset, 0, 1) + train_dataset = train_dataset.select(range(100)) + valid_dataset = None + if args.valid_dataset is not None: + valid_dataset = dataset_factory.make_dataset(args.valid_dataset, 0, 1) + valid_dataset = valid_dataset.select(range(100)) + if args.trainer is not None: + trainer_factory = TrainerFactory(args) + trainer = trainer_factory.make_trainer( + args.trainer, + train_dataset=train_dataset, + valid_dataset=valid_dataset, + rollout_controller=rollout_controller, + ) + trainer.train() + + print("All tests passed!") diff --git a/arealite/tests/test_sglang_client.py b/arealite/tests/test_sglang_client.py new file mode 100644 index 000000000..db037c1d0 --- /dev/null +++ b/arealite/tests/test_sglang_client.py @@ -0,0 +1,112 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +import os +import uuid + +import pytest + +from arealite.api.cli_args import ( + EngineBackendConfig, + EngineConfig, + GenerationHyperparameters, + LLMClientConfig, + OptimizerConfig, + SGLangConfig, + TrainingArgs, +) +from arealite.api.engine_api import EngineFactory +from arealite.api.io_struct import FinetuneSpec, LLMRequest, LLMResponse +from arealite.api.llm_client_api import LLMClient +from arealite.api.llm_server_api import LLMServerFactory +from realhf.base import name_resolve, seeding + +EXPR_NAME = "test_sglang_client" +TRIAL_NAME = "test_sglang_client" +MODEL_PATH = "Qwen/Qwen2-0.5B" + + +@pytest.fixture(scope="module") +def args(): + args = TrainingArgs(experiment_name=EXPR_NAME, trial_name=TRIAL_NAME) + args.rollout.model_path = MODEL_PATH + seeding.set_random_seed(args.seed, EXPR_NAME) + name_resolve.reconfigure(args.cluster.name_resolve) + yield args + name_resolve.reset() + + +@pytest.fixture(scope="module") +def sglang_server(args): + args.rollout.sglang = SGLangConfig(mem_fraction_static=0.3) + server = LLMServerFactory(args).make_server(args.rollout.llm_service) + server._startup() + yield + server._graceful_exit(0) + + +@pytest.fixture(scope="module") +def sglang_client(args, sglang_server): + from arealite.system.sglang_client import SGLangClient + + args.rollout.server_backend = "sglang" + llm_client = LLMClientConfig() + client = SGLangClient(args, client_config=llm_client) + yield client + + +@pytest.mark.asyncio +async def test_sglang_generate(sglang_client): + req = LLMRequest( + rid=str(uuid.uuid4()), + text="hello! how are you today", + gconfig=GenerationHyperparameters(max_new_tokens=16), + ) + resp = await sglang_client.agenerate(req) + assert isinstance(resp, LLMResponse) + assert resp.input_tokens == req.input_ids + assert ( + len(resp.output_logprobs) + == len(resp.output_tokens) + == len(resp.output_versions) + ) + assert isinstance(resp.completion, str) + + +@pytest.mark.asyncio +async def test_sglang_update_weights_from_disk(sglang_client: LLMClient): + servers = sglang_client.get_healthy_servers() + assert len(servers) == 1 + await sglang_client.aupdate_weights_from_disk( + server_info=servers[0], path=MODEL_PATH + ) + + +@pytest.fixture(scope="module") +def engine(sglang_server): + os.environ["WORLD_SIZE"] = "1" + os.environ["RANK"] = "0" + os.environ["LOCAL_RANK"] = "0" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "7777" + engine_config = EngineConfig( + path=MODEL_PATH, + gradient_checkpointing=False, + optimizer=OptimizerConfig(), + backend=EngineBackendConfig(type="fsdp"), + ) + + mock_args = TrainingArgs(n_nodes=1, n_gpus_per_node=1) + + engine_factory = EngineFactory(mock_args) + engine = engine_factory.make_engine(engine_config) + ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=100, train_batch_size=2) + engine.init_distributed(None, ft_spec) + print("✓ Engine created successfully") + yield engine + + +def test_sglang_update_weights_from_distributed( + engine, sglang_server, sglang_client: LLMClient +): + engine.update_weights_to(sglang_client) diff --git a/arealite/tests/utils.py b/arealite/tests/utils.py new file mode 100644 index 000000000..c99b3d2cb --- /dev/null +++ b/arealite/tests/utils.py @@ -0,0 +1,36 @@ +import random + +import torch + +from arealite.api.io_struct import Trajectory, TrajStats + + +def mock_rollout_output(bs, n_samples): + trajs = [] + min_seqlen, max_seqlen = 8, 16 + for _ in range(bs * n_samples): + input_len = random.randint(min_seqlen, max_seqlen) + prompt_len = random.randint(1, min_seqlen - 1) + input_ids = torch.randint(0, 100, (input_len,)) + prompt_mask = torch.tensor([1] * prompt_len + [0] * (input_len - prompt_len)) + logprobs = -torch.randn(input_len).abs() + versions = torch.zeros(input_len) + traj = Trajectory( + prompt=None, + data=dict( + input_ids=input_ids.unsqueeze(0), + prompt_mask=prompt_mask.unsqueeze(0), + logprobs=logprobs.unsqueeze(0), + versions=versions.unsqueeze(0), + rewards=torch.tensor([random.random()]), + ), + stats=TrajStats( + start_time=0, + total_reward=0, + episode_length=1, + info={}, + ), + ) + trajs.append(traj) + + return trajs diff --git a/arealite/utils.py b/arealite/utils.py new file mode 100644 index 000000000..35684ebf5 --- /dev/null +++ b/arealite/utils.py @@ -0,0 +1,517 @@ +# Copyright 2025 Ant Group Inc. +# Licensed under the Apache License, Version 2.0 + +# Pad/unpad operations are modified from flash-attention under BSD-3 license. +# Copyright (c) 2023, Tri Dao. + +import os +import time +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +import wandb +from einops import rearrange, repeat +from tensorboardX import SummaryWriter + +from arealite.api.cli_args import MicroBatchSpec, TrainingArgs +from realhf.base import constants, datapack + + +def recorder_list(xs: List, indices: List[int]) -> List: + assert len(set(indices)) == len(xs) + return [xs[i] for i in indices] + + +def dict_map(x: Dict, fn: Callable) -> Dict: + return {k: fn(v) for k, v in x.items()} + + +def dict_of_list2list_of_dict( + dict_of_lists: Dict[str, List[Any]], +) -> List[Dict[str, Any]]: + if not dict_of_lists: + return [] + keys = list(dict_of_lists.keys()) + length = len(dict_of_lists[keys[0]]) + for key, value_list in dict_of_lists.items(): + if len(value_list) != length: + raise ValueError( + f"All lists must have the same length. Key '{key}' has length {len(value_list)}, expected {length}" + ) + return [{key: dict_of_lists[key][i] for key in keys} for i in range(length)] + + +def list_of_dict2dict_of_list( + list_of_dicts: List[Dict[str, Any]], +) -> Dict[str, List[Any]]: + if not list_of_dicts: + return {} + keys = list(list_of_dicts[0].keys()) + for i, dict_item in enumerate(list_of_dicts): + if set(dict_item.keys()) != set(keys): + raise ValueError( + f"All dictionaries must have the same keys. Dictionary at index {i} has keys {set(dict_item.keys())}, expected {set(keys)}" + ) + return {key: [dict_item[key] for dict_item in list_of_dicts] for key in keys} + + +def pad_sequences_to_tensors( + sequence_list: List[Dict[str, torch.Tensor]], pad_value: float = 0.0 +) -> Dict[str, torch.Tensor]: + if not sequence_list: + return {} + max_length = max(len(seq) for item in sequence_list for seq in item.values()) + result = {} + for key in sequence_list[0].keys(): + padded = [ + torch.nn.functional.pad( + item[key], (0, max_length - len(item[key])), value=pad_value + ) + for item in sequence_list + ] + result[key] = torch.stack(padded) + attention_mask = [ + [1] * len(next(iter(item.values()))) + + [0] * (max_length - len(next(iter(item.values())))) + for item in sequence_list + ] + result["attention_mask"] = torch.tensor(attention_mask, dtype=torch.long) + return result + + +def unpad_input( + hidden_states, attention_mask +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) + ) + return ( + rearrange(hidden_states, "b s ... -> (b s) ...")[indices], + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def pad_input(hidden_states, indices, batch, seqlen): + output = hidden_states.new_zeros(batch * seqlen) + output[indices] = hidden_states + return rearrange(output, "(b s) ... -> b s ...", b=batch) + + +def concat_padded_tensors( + tensor_dicts: List[Dict[str, torch.Tensor]], pad_value: float = 0.0 +) -> Dict[str, torch.Tensor]: + """Concatenate and pad tensors from multiple padded tensor dictionaries.""" + if not tensor_dicts: + return {} + + # Find max sequence length across all dictionaries + lens = [] + for tensor_dict in tensor_dicts: + for key, tensor in tensor_dict.items(): + if key != "attention_mask" and len(tensor.shape) == 2: + lens.append(tensor.shape[1]) + break + max_length = max(lens) + attn_mask = torch.arange(max_length).unsqueeze(0) < torch.tensor(lens).unsqueeze(1) + + result = {} + # Process each key + for key in tensor_dicts[0].keys(): + tensors_to_concat = [] + for tensor_dict in tensor_dicts: + tensor = tensor_dict[key] + # Skip 1D tensors like rewards + if len(tensor.shape) == 1: + tensors_to_concat.append(tensor) + continue + current_length = tensor.shape[1] + if current_length < max_length: + # Pad tensor to max_length + pad_width = max_length - current_length + if key == "attention_mask": + # Pad attention mask with 0s + padding = torch.zeros( + (tensor.shape[0], pad_width), dtype=tensor.dtype + ) + else: + # Pad feature tensors with pad_value + padding = torch.full( + (tensor.shape[0], pad_width), pad_value, dtype=tensor.dtype + ) + tensor = torch.cat([tensor, padding], dim=1) + tensors_to_concat.append(tensor) + + result[key] = torch.cat(tensors_to_concat, dim=0) + if "attention_mask" not in result: + result["attention_mask"] = attn_mask + return result + + +def to_device(data: Dict[str, torch.Tensor | Any], device) -> Dict[str, torch.Tensor]: + """Move tensors in a dictionary to the specified device.""" + return { + key: value.to(device) if torch.is_tensor(value) else value + for key, value in data.items() + } + + +def unpack_sequence( + x: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + lens: Optional[List[int]] = None, + dim: int = 0, +): + """Unpack a sequence tensor into a list of tensors based on cumulative sequence lengths.""" + if lens is not None: + return torch.split(x, lens, dim=dim) + if cu_seqlens is not None: + return torch.split( + x, (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy().tolist(), dim=dim + ) + raise ValueError("Either cu_seqlens or input_lens must be provided.") + + +def allocate_balanced_mbs(mb_spec: MicroBatchSpec, lens: List[int]) -> List[List[int]]: + group_indices = datapack.ffd_allocate( + lens, mb_spec.max_tokens_per_mb, min_groups=mb_spec.n_mbs + ) + group_indices = sorted([sorted(g) for g in group_indices]) + return group_indices + + +def allocate_balanced_mbs_synced( + mb_spec: MicroBatchSpec, + lens: List[int], + group: Optional[dist.ProcessGroup] = None, +) -> List[List[int]]: + group_indices = allocate_balanced_mbs(mb_spec, lens) + if not dist.is_initialized(): + return group_indices + + all_n_mbs = [None for _ in range(dist.get_world_size(group))] + dist.all_gather_object(all_n_mbs, len(group_indices), group=group) + if all(mbs == len(group_indices) for mbs in all_n_mbs): + return group_indices + return allocate_balanced_mbs_synced( + MicroBatchSpec.new(mb_spec, n_mbs=max(all_n_mbs)), lens + ) + + +@dataclass +class MicroBatchSplitResult: + data: Dict[str, Any] + mb_spec: MicroBatchSpec + mbs: List[Dict[str, Any]] + forward_indices: List[int] + backward_indices: List[int] + + +def split_dict_tensor_with_cu_seqlens( + data: Dict[str, torch.Tensor], + mb_spec: MicroBatchSpec, + group: Optional[dist.ProcessGroup] = None, +) -> MicroBatchSplitResult: + assert "cu_seqlens" in data + cu_seqlens = data["cu_seqlens"] + bs = cu_seqlens.shape[0] - 1 + total_lens = int(cu_seqlens[-1]) + input_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu().numpy() + + # check tensor shape, split only 1d tensors with length "total_lens" + to_split = {} + not_to_split = {} + keys_to_unsqueeze = set() + for key, value in data.items(): + if key == "cu_seqlens" or key == "max_seqlen": + continue + if not torch.is_tensor(value): + not_to_split[key] = value + else: + assert value.numel() == total_lens, (key, value.shape) + if value.shape[0] == 1: + keys_to_unsqueeze.add(key) + to_split[key] = value.squeeze() + else: + to_split[key] = value + + # split + group_indices = allocate_balanced_mbs_synced(mb_spec, input_lens, group=group) + splitted_lens = [ + [input_lens[i] for i in group_index] for group_index in group_indices + ] + group_lens = [sum(x) for x in splitted_lens] + + forward_indices = datapack.flat2d(group_indices) + backward_indices = np.zeros(bs, dtype=np.int64) + backward_indices[forward_indices] = np.arange(bs) + + to_split = dict_map(to_split, lambda x: unpack_sequence(x, cu_seqlens=cu_seqlens)) + to_split = dict_map(to_split, lambda x: recorder_list(x, forward_indices)) + to_split = dict_map(to_split, lambda x: torch.cat(x)) + to_split = dict_map(to_split, lambda x: unpack_sequence(x, lens=group_lens)) + mbs = dict_of_list2list_of_dict(to_split) + + results = [] + # organize splitted micro batches + assert len(mbs) == len(splitted_lens), (len(mbs), len(splitted_lens)) + for i, (mb, lens) in enumerate(zip(mbs, splitted_lens)): + mb = { + k: v if k not in keys_to_unsqueeze else v.unsqueeze(0) + for k, v in mb.items() + } + max_seqlen = max(lens) + lens = torch.tensor(lens, device="cuda") + batch_cu_seqlens = torch.nn.functional.pad( + lens.cumsum(0, dtype=torch.int), (1, 0) + ) + results.append( + { + **mb, + **not_to_split, + "max_seqlen": max_seqlen, + "cu_seqlens": batch_cu_seqlens, + } + ) + return MicroBatchSplitResult( + data=data, + mbs=results, + mb_spec=mb_spec, + forward_indices=forward_indices, + backward_indices=backward_indices, + ) + + +@torch.no_grad() +def compute_varlen_position_indices( + total_seqlen: int, + cu_seqlens: torch.Tensor, + seqlen_offsets: Optional[torch.Tensor] = None, +) -> torch.LongTensor: + indexing_t = torch.arange( + total_seqlen, dtype=torch.long, device=cu_seqlens.device + ).unsqueeze_(0) + indexing_t = (cu_seqlens[:-1].unsqueeze(1) <= indexing_t) & ( + indexing_t < cu_seqlens[1:].unsqueeze(1) + ) + indices = indexing_t.cumsum(1) - 1 + if seqlen_offsets is not None: + indices += seqlen_offsets.unsqueeze(1) + return torch.where(indexing_t, indices, 0).sum(0) + + +@torch.compile +@torch.no_grad() +def calc_entropy(logits, cu_seqlens): + probs = torch.nn.functional.softmax(logits.detach().float(), dim=-1) + entropy = -torch.sum(probs * torch.log(probs + 1e-7), dim=-1) + return entropy + + +@torch.no_grad() +def masked_normalization( + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + dim=None, + unbiased=False, + eps=1e-5, + high_precision=True, + all_reduce=True, + reduce_group=None, +): + dtype = torch.float64 if high_precision else torch.float32 + x = x.to(dtype) + if dim is None: + dim = tuple(range(len(x.shape))) + if mask is None: + factor = torch.tensor( + np.prod([x.shape[d] for d in dim]), dtype=dtype, device=x.device + ) + else: + mask = mask.to(dtype) + x = x * mask + factor = mask.sum(dim, keepdim=True) + x_sum = x.sum(dim=dim, keepdim=True) + x_sum_sq = x.square().sum(dim=dim, keepdim=True) + if dist.is_initialized() and all_reduce: + dist.all_reduce(factor, op=dist.ReduceOp.SUM, group=reduce_group) + dist.all_reduce(x_sum, op=dist.ReduceOp.SUM, group=reduce_group) + dist.all_reduce( + x_sum_sq, + op=dist.ReduceOp.SUM, + group=reduce_group, + ) + mean = x_sum / factor + meansq = x_sum_sq / factor + var = meansq - mean**2 + if unbiased: + var *= factor / (factor - 1) + return ((x - mean) / (var.sqrt() + eps)).float() + + +def gather_logprobs(logits: torch.Tensor, labels: torch.Tensor): + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + return log_probs_labels + + +def init_stats_logging(args: TrainingArgs): + """ + Initialize wandb and/or tensorboard according to config. + If torch.distributed is initialized + + Return: + tensorboard SummaryWriter if args.tensorboard.path is not None + """ + if dist.is_initialized() and dist.get_rank() != 0: + return + + # wandb init, connect to remote wandb host + if args.wandb.mode != "disabled": + wandb.login() + wandb.init( + mode=args.wandb.mode, + entity=args.wandb.entity, + project=args.wandb.project or args.experiment_name, + name=args.wandb.name or args.trial_name, + job_type=args.wandb.job_type, + group=args.wandb.group or f"{args.experiment_name}_{args.trial_name}", + notes=args.wandb.notes, + tags=args.wandb.tags, + config=args.wandb.config, + dir=constants.get_log_path(args), + force=True, + id=f"{args.experiment_name}_{args.trial_name}_train", + resume="allow", + settings=wandb.Settings(start_method="fork"), + ) + # tensorboard logging + summary_writer = None + if args.tensorboard.path is not None: + summary_writer = SummaryWriter(log_dir=args.tensorboard.path) + + return summary_writer + + +def log_wandb_tensorboard(step, data, summary_writer=None): + if dist.is_initialized() and dist.get_rank() != 0: + return + + wandb.log(data, step=step) + if summary_writer is not None: + for key, val in data.items(): + summary_writer.add_scalar(f"{key}", val, step) + + +def close_wandb_tensorboard(summary_writer=None): + if dist.is_initialized() and dist.get_rank() != 0: + return + + wandb.finish() + if summary_writer is not None: + summary_writer.close() + + +@contextmanager +def record_timing(name, timing_stats): + start_time = time.perf_counter() + yield + timing_stats[name] = time.perf_counter() - start_time + + +############### Logging related end ############### + + +############### Model load start ############### + + +def get_state_dict_from_repo_id_or_path(repo_id_or_path: str) -> Dict: + """ + Obtain a state dictionary from either a Hugging Face repo ID or a local path. + + Args: + repo_id_or_path (str): Either a Hugging Face repo ID (e.g., 'username/model-name') + or a local path to a directory containing model weights. + + Returns: + Dict: The combined state dictionary from all .safetensors and .bin files. + """ + from safetensors.torch import load_file as safetensors_load + + state_dict = {} + + # Step 1: Identify if the input is a Hugging Face repo ID or local path + try: + from huggingface_hub.utils import HFValidationError, validate_repo_id + + try: + validate_repo_id(repo_id_or_path) + is_hf_repo = True + except HFValidationError: + is_hf_repo = False + except ImportError: + is_hf_repo = False + + if is_hf_repo: + from huggingface_hub import snapshot_download + + # Step 2: Download the repo if it's a Hugging Face repo ID + local_path = snapshot_download( + repo_id=repo_id_or_path, + ) + else: + # Assume it's a local path + local_path = repo_id_or_path + if not os.path.isdir(local_path): + raise ValueError( + f"Local path {local_path} does not exist or is not a directory, " + f"or {local_path} is a huggingface repo id but huggingface_hub is not installed." + ) + + # Step 3: Load all .safetensors and .bin files + file_paths_to_load = [] + for filename in os.listdir(local_path): + filepath = os.path.join(local_path, filename) + if filename.endswith(".safetensors") or filename.endswith(".bin"): + file_paths_to_load.append(filepath) + + def _load(filepath: str): + if filepath.endswith(".safetensors"): + state_dict = safetensors_load(filepath) + elif filepath.endswith(".bin"): + state_dict = torch.load(filepath, map_location="cpu") + else: + raise ValueError(f"{filepath} is not a torch bin or safetensor file.") + return state_dict + + state_dict = {} + + from concurrent.futures import ThreadPoolExecutor, as_completed + + with ThreadPoolExecutor( + max_workers=min(4, max(1, os.cpu_count() // 8)) + ) as executor: + future_to_checkpoint = { + executor.submit(_load, path): path for path in file_paths_to_load + } + + for future in as_completed(future_to_checkpoint): + path = future_to_checkpoint[future] + try: + sd = future.result() + state_dict.update(sd) + except Exception as e: + raise RuntimeError(f"Error loading checkpoint from {path}: {e}") + return state_dict + + +############### Model load end ############### diff --git a/ci/build_env_image.sh b/ci/build_env_image.sh new file mode 100644 index 000000000..383984592 --- /dev/null +++ b/ci/build_env_image.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash + +set -e + +GIT_COMMIT_SHA=${GIT_COMMIT_SHA:?"GIT_COMMIT_SHA is not set"} + +echo "GIT_COMMIT_SHA: $GIT_COMMIT_SHA" + +# If there is already an image named areal-env, skip. +if docker images --format '{{.Repository}}:{{.Tag}}' | grep -q 'areal-env:latest'; then + echo "Image areal-env already exists, skipping build." + exit 0 +fi + +RUN_ID="areal-$GIT_COMMIT_SHA" +cd "/tmp/$RUN_ID" + +if docker ps -a --format '{{.Names}}' | grep -q "$RUN_ID"; then + docker rm -f $RUN_ID +fi + +docker run \ + --name $RUN_ID \ + --gpus all \ + --shm-size=8g \ + -v $(pwd):/workspace \ + -w /workspace \ + nvcr.io/nvidia/pytorch:25.01-py3 \ + bash -c " + python -m pip install --upgrade pip + pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple + pip config unset global.extra-index-url + bash examples/env/scripts/setup-pip-deps.sh + pip uninstall -y transformer-engine + mv ./sglang /sglang + " || { docker rm -f $RUN_ID; exit 1; } + +docker commit $RUN_ID areal-env:latest +docker rm -f $RUN_ID diff --git a/ci/clone_repo.sh b/ci/clone_repo.sh new file mode 100644 index 000000000..c4448822e --- /dev/null +++ b/ci/clone_repo.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash + +set -e + +GIT_REPO_URL=${GIT_REPO_URL:?"GIT_REPO_URL is not set"} +GIT_COMMIT_SHA=${GIT_COMMIT_SHA:?"GIT_COMMIT_SHA is not set"} + +echo "GIT_REPO_URL: $GIT_REPO_URL" +echo "GIT_COMMIT_SHA: $GIT_COMMIT_SHA" + +RUN_ID="areal-$GIT_COMMIT_SHA" +rm -rf "/tmp/$RUN_ID" +mkdir -p "/tmp/$RUN_ID" +cd "/tmp/$RUN_ID" + +git init +git remote add origin "$GIT_REPO_URL" +git fetch --depth 1 origin "$GIT_COMMIT_SHA" +git checkout FETCH_HEAD diff --git a/ci/test_arealite.sh b/ci/test_arealite.sh new file mode 100644 index 000000000..79734eade --- /dev/null +++ b/ci/test_arealite.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash + +set -e + +GIT_COMMIT_SHA=${GIT_COMMIT_SHA:?"GIT_COMMIT_SHA is not set"} + +echo "GIT_COMMIT_SHA: $GIT_COMMIT_SHA" + +RUN_ID="areal-$GIT_COMMIT_SHA" +cd "/tmp/$RUN_ID" + +if docker ps -a --format '{{.Names}}' | grep -q "$RUN_ID"; then + docker rm -f $RUN_ID +fi + +docker run \ + --name $RUN_ID \ + --gpus all \ + --shm-size=8g \ + -v $(pwd):/workspace \ + -w /workspace \ + areal-env:latest \ + bash -c " + mv /sglang ./sglang + HF_ENDPOINT=https://hf-mirror.com python -m pytest -s arealite/ + " || { docker rm -f $RUN_ID; exit 1; } + +docker rm -f $RUN_ID diff --git a/pyproject.toml b/pyproject.toml index 7597e137a..de1cb0d8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,8 @@ dependencies = [ "hydra-core==1.4.0.dev1", "packaging", "tabulate", + "gymnasium>=1.1.1", + "torchdata", # Monitoring and logging "wandb", diff --git a/realhf/api/cli_args.py b/realhf/api/cli_args.py index 6de76c710..d599d1ff1 100644 --- a/realhf/api/cli_args.py +++ b/realhf/api/cli_args.py @@ -303,7 +303,6 @@ class SGLangConfig: schedule_policy: str = "lpm" schedule_conservativeness: float = 1.0 cpu_offload_gb: int = 0 - hybrid_train: bool = False dtype: str = "float16" kv_cache_dtype: str = "auto" @@ -318,15 +317,19 @@ class SGLangConfig: # and update prometheus metrics decode_log_interval: int = 1 + # Not used. + hybrid_train: bool = False + # Use staticmethod to make OmegaConf happy. @staticmethod def build_cmd( sglang_config: "SGLangConfig", model_path, tp_size, - server_index, base_gpu_id, dist_init_addr: Optional[str] = None, + served_model_name: Optional[str] = None, + skip_tokenizer_init: bool = True, ): from realhf.base import constants, network, pkg_version, seeding from realhf.experiments.common.utils import asdict as conf_as_dict @@ -335,6 +338,8 @@ def build_cmd( args.pop("hybrid_train") args["random_seed"] = seeding.get_seed() + if served_model_name is None: + served_model_name = model_path host_ip = network.gethostip() host = "localhost" if not sglang_config.enable_metrics else host_ip args = dict( @@ -346,9 +351,9 @@ def build_cmd( load_format="auto", trust_remote_code=True, device="cuda", - served_model_name=f"{constants.experiment_name()}/{constants.trial_name()}/{model_path}", + served_model_name=served_model_name, is_embedding=False, - skip_tokenizer_init=True, + skip_tokenizer_init=skip_tokenizer_init, # Other runtime options tp_size=tp_size, # Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process @@ -556,6 +561,10 @@ class GenerationHyperparameters: default=1.0, metadata={"help": "Sampling temperature. Higher values increase diversity."}, ) + stop_token_ids: List[int] = field( + default_factory=list, + metadata={"help": "Stop generation when encoutering these token ids."}, + ) # Deprecated parameters use_cuda_graph: bool = field( diff --git a/realhf/api/core/data_api.py b/realhf/api/core/data_api.py index ce6d9bf95..f698a76d6 100644 --- a/realhf/api/core/data_api.py +++ b/realhf/api/core/data_api.py @@ -8,6 +8,7 @@ import random import time from contextlib import contextmanager +from functools import lru_cache # NOTE: We don't sue wildcard importing here because the type # `Sequence` has a very similar name to `SequenceSample`. @@ -47,6 +48,7 @@ RL_TASKS = ["math", "code", "rlhf", "stem"] +@lru_cache(maxsize=8) def load_hf_tokenizer( model_name_or_path: str, fast_tokenizer=True, diff --git a/realhf/base/constants.py b/realhf/base/constants.py index 73962e273..cdf78a4de 100644 --- a/realhf/base/constants.py +++ b/realhf/base/constants.py @@ -120,7 +120,6 @@ def get_param_realloc_path(args: "BaseExperimentConfig"): "REAL_IS_REMOTE": "1", # "NCCL_P2P_DISABLE": "1", # "NCCL_IB_DISABLE": "1", - "TRANSFORMERS_OFFLINE": "1", "TOKENIZERS_PARALLELISM": "true", "PYTORCH_KERNEL_CACHE_PATH": PYTORCH_KERNEL_CACHE_PATH, "TRITON_CACHE_DIR": TRITON_CACHE_PATH, diff --git a/realhf/base/gpu_utils.py b/realhf/base/gpu_utils.py index 53c804e68..f85021e4a 100644 --- a/realhf/base/gpu_utils.py +++ b/realhf/base/gpu_utils.py @@ -27,15 +27,16 @@ def gpu_count(): Ad-hoc to frl cluster. """ + try: + import torch + + torch_cnt = torch.cuda.device_count() + except ImportError: + torch_cnt = 0 if platform.system() == "Darwin": return 0 elif platform.system() == "Windows": - try: - import torch - - return torch.cuda.device_count() - except ImportError: - return 0 + return torch_cnt else: dev_directories = list(os.listdir("/dev/")) for cnt in itertools.count(): @@ -43,7 +44,7 @@ def gpu_count(): continue else: break - return cnt + return cnt or torch_cnt def set_cuda_device(device): diff --git a/realhf/base/names.py b/realhf/base/names.py index f66db5da0..64f0219c6 100644 --- a/realhf/base/names.py +++ b/realhf/base/names.py @@ -93,6 +93,14 @@ def gen_servers(experiment_name, trial_name): return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/gen_servers" +def gen_server(experiment_name, trial_name, server_id): + return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/gen_server/{server_id}" + + +def gen_server_root(experiment_name, trial_name): + return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/gen_server/" + + def used_ports(experiment_name, trial_name, host_name): return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/{host_name}/" diff --git a/realhf/base/network.py b/realhf/base/network.py index 5fac8fdb6..703358541 100644 --- a/realhf/base/network.py +++ b/realhf/base/network.py @@ -30,6 +30,7 @@ def find_free_port( trial_name="port", lockfile_root=constants.PORT_LOCKFILE_ROOT, ): + # TODO: user random sampling instead of bind """Find a free port within the specified range, excluding certain ports.""" ports_name = names.used_ports(experiment_name, trial_name, gethostip()) diff --git a/realhf/base/prologue.py b/realhf/base/prologue.py index 62efe5c35..8ae6bf620 100644 --- a/realhf/base/prologue.py +++ b/realhf/base/prologue.py @@ -64,4 +64,4 @@ def get_trial_name(default_name: str = ""): return trial_name -global_init() +# global_init() diff --git a/realhf/impl/dataset/math_parser.py b/realhf/impl/dataset/math_parser.py index 1655395df..e0a4d537f 100644 --- a/realhf/impl/dataset/math_parser.py +++ b/realhf/impl/dataset/math_parser.py @@ -813,7 +813,7 @@ def loadJson(dataDir): return samples -def parse_line(id2info, prompt_str, generated, query_id): +def parse_line(id2info, generated, query_id): info = id2info[query_id.split("@idx:")[0]] label = 0 diff --git a/realhf/impl/model/interface/ppo_interface.py b/realhf/impl/model/interface/ppo_interface.py index b9955c14b..a5662516a 100644 --- a/realhf/impl/model/interface/ppo_interface.py +++ b/realhf/impl/model/interface/ppo_interface.py @@ -298,15 +298,15 @@ def save(self, model: model_api.Model, save_dir: str): ) @torch.no_grad() - def generate( + def compute_logps( self, - model: model_api.Model, input_: SequenceSample, mb_spec: MicroBatchSpec, ) -> SequenceSample: module = model.module module.eval() + self.engine.forward() # Remap the key `packed_prompts` to `packed_input_ids`, # because the pipe runner only recognizes `packed_input_ids`. diff --git a/realhf/scheduler/local/client.py b/realhf/scheduler/local/client.py index ffc461dad..bd5769d2c 100644 --- a/realhf/scheduler/local/client.py +++ b/realhf/scheduler/local/client.py @@ -76,7 +76,7 @@ class LocalSchedulerClient(SchedulerClient): def log_path_of(self, worker_type) -> str: return os.path.join( get_log_path(self.args), - f"{worker_type}-0", + f"{worker_type}-0.log", ) def __init__(self, args): @@ -90,8 +90,8 @@ def __init__(self, args): ).split(",") self._job_counter: Dict[str, int] = defaultdict(int) - self._job_with_gpu: Dict[str, bool] = defaultdict(int) - self._job_env_vars: Dict[str, Dict] = defaultdict(int) + self._job_gpu_cnt: Dict[str, int] = defaultdict(int) + self._job_env_vars: Dict[str, Dict] = defaultdict(dict) self._job_cmd = {} self._job_states = {} @@ -117,12 +117,12 @@ def submit_array( env_vars = {} self._job_counter[worker_type] += count - if worker_type in self._job_with_gpu: - assert self._job_with_gpu[worker_type] == ( - gpu > 0 + if worker_type in self._job_gpu_cnt: + assert self._job_gpu_cnt[worker_type] == ( + gpu ), "All workers of the same type must either use GPU or not use GPU." else: - self._job_with_gpu[worker_type] = gpu > 0 + self._job_gpu_cnt[worker_type] = gpu if worker_type in self._job_env_vars: assert ( @@ -142,10 +142,10 @@ def submit(self, worker_type, cmd, **kwargs): self.submit_array(worker_type, cmd, count=1, **kwargs) def __commit_all(self): - for worker_type, count, use_gpu, env_vars in zip( + for worker_type, count, gpu, env_vars in zip( self._job_counter.keys(), self._job_counter.values(), - self._job_with_gpu.values(), + self._job_gpu_cnt.values(), self._job_env_vars.values(), ): os.makedirs( @@ -154,12 +154,18 @@ def __commit_all(self): mode=0o775, ) for i in range(count): - if use_gpu: - available_device_id = self._gpu_counter % len(self._cuda_devices) - env_vars["CUDA_VISIBLE_DEVICES"] = str( - self._cuda_devices[available_device_id] + if gpu > 0: + # Allocate GPUs in a round-robin manner + visible_devices = [] + for _ in range(gpu): + available_device_id = self._gpu_counter % len( + self._cuda_devices + ) + self._gpu_counter += 1 + visible_devices.append(available_device_id) + env_vars["CUDA_VISIBLE_DEVICES"] = ",".join( + str(self._cuda_devices[j]) for j in visible_devices ) - self._gpu_counter += 1 cmd = ( " ".join(str(k) + "=" + str(v) for k, v in env_vars.items()) + " stdbuf -oL " @@ -278,11 +284,11 @@ def wait( assert worker_type in self._job_counter self._job_counter[worker_type] -= 1 if self._job_counter[worker_type] <= 0: - assert worker_type in self._job_with_gpu + assert worker_type in self._job_gpu_cnt assert worker_type in self._job_env_vars assert worker_type in self._job_cmd self._job_counter.pop(worker_type) - self._job_with_gpu.pop(worker_type) + self._job_gpu_cnt.pop(worker_type) self._job_env_vars.pop(worker_type) self._job_cmd.pop(worker_type) diff --git a/realhf/system/push_pull_stream.py b/realhf/system/push_pull_stream.py index e5fdb0705..7e15cf29d 100644 --- a/realhf/system/push_pull_stream.py +++ b/realhf/system/push_pull_stream.py @@ -25,14 +25,19 @@ class ZMQJsonPusher: hwm: High-water mark for outgoing messages (default: 1000) """ - def __init__(self, host: str = "localhost", port: int = 5555, hwm: int = 1000): + def __init__( + self, host: str = "localhost", port: int = 5555, hwm: int = 1000, bind=False + ): self.host = host self.port = port self.ctx = zmq.Context.instance() self.socket = self.ctx.socket(zmq.PUSH) self.socket.setsockopt(zmq.SNDHWM, hwm) - self.socket.connect(f"tcp://{self.host}:{self.port}") + if not bind: + self.socket.connect(f"tcp://{self.host}:{self.port}") + else: + self.socket.bind(f"tcp://{self.host}:{self.port}") def push(self, data: JSONType) -> None: """ @@ -77,6 +82,7 @@ def __init__( port: int = 5555, default_timeout_ms: int = 1000, hwm: int = 1000, + bind: bool = True, ): self.host = host self.port = port @@ -86,7 +92,10 @@ def __init__( self.socket = self.ctx.socket(zmq.PULL) self.socket.setsockopt(zmq.RCVHWM, hwm) self.socket.setsockopt(zmq.RCVTIMEO, self.default_timeout_ms) - self.socket.bind(f"tcp://{self.host}:{self.port}") + if bind: + self.socket.bind(f"tcp://{self.host}:{self.port}") + else: + self.socket.connect(f"tcp://{self.host}:{self.port}") self.poller = zmq.Poller() self.poller.register(self.socket, zmq.POLLIN) diff --git a/realhf/utils.py b/realhf/utils.py index 47009aee1..7a2aa41c4 100644 --- a/realhf/utils.py +++ b/realhf/utils.py @@ -35,6 +35,7 @@ def load_hf_or_local_file(path: str) -> str: => /root/.cache/huggingface/hub/models--inclusionAI--AReaL-RL-Data/data/boba_106k_0319.jsonl """ + path = str(path) if path.startswith("hf://") or path.startswith("hf-dataset://"): repo_type = "dataset" if path.startswith("hf-dataset://") else "model" hf_path = path.strip().split("://")[1] diff --git a/requirements.txt b/requirements.txt index 4ffab7e39..bf7552bbb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -69,4 +69,6 @@ word2number Pebble timeout-decorator prettytable -swanlab[dashboard] \ No newline at end of file +gymnasium>=1.1.1 +torchdata +swanlab[dashboard]