diff --git a/examples/penguin/grpo_dapo17k_bytedtsinghua_qwen3_4binstruct_nf.yaml b/examples/penguin/grpo_dapo17k_bytedtsinghua_qwen3_4binstruct_nf.yaml new file mode 100644 index 0000000000..9fba418322 --- /dev/null +++ b/examples/penguin/grpo_dapo17k_bytedtsinghua_qwen3_4binstruct_nf.yaml @@ -0,0 +1,273 @@ +grpo: + max_num_epochs: 1 + num_prompts_per_step: 64 + num_generations_per_prompt: 16 + max_rollout_turns: 1 # for multi-turn rollouts. Math Environments just have 1 turn (answering the question) + max_num_steps: 1000000 + normalize_rewards: true + use_leave_one_out_baseline: true + val_period: 10 + val_at_start: true + overlong_filtering: false + max_val_samples: null # inferred from size of val dataset. for multi evals, repeat val ds via `num_repeats` in `ng_prepare_data`. + val_batch_size: null + seed: 42 + use_dynamic_sampling: false + dynamic_sampling_max_gen_batches: 10 + batch_multiplier: 1 + reward_shaping: + enabled: false + overlong_buffer_length: 128 + overlong_buffer_penalty: 1 + max_response_length: ${policy.max_total_sequence_length} + reward_scaling: + enabled: false + source_min: 0.0 + source_max: 1.0 + target_min: 0.0 + target_max: 1.0 + skip_reference_policy_logprobs_calculation: true + +loss_fn: + reference_policy_kl_penalty: 0 + reference_policy_kl_type: "k3" + kl_input_clamp_value: 20.0 + kl_output_clamp_value: 10.0 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 + ratio_clip_c: null + # (default off) loss formulation improvements (docs/guides/grpo.md#loss) + use_on_policy_kl_approximation: false + truncated_importance_sampling_ratio: null + use_importance_sampling_correction: false + token_level_loss: true + +checkpointing: + enabled: true + checkpoint_dir: "results/grpo" + metric_name: "val:accuracy" + higher_is_better: true + keep_top_k: 3 + save_period: 1 + checkpoint_must_save_by: null + +policy: + model_name: "Qwen/Qwen3-4B-Instruct-2507" + tokenizer: + name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default + chat_template_kwargs: null # can be used to pass kwargs to the chat template, e.g., enable_thinking=true + hf_config_overrides: {} + train_global_batch_size: ${mul:${grpo.num_prompts_per_step}, ${grpo.num_generations_per_prompt}} # Match the total rollouts per step + train_micro_batch_size: 1 + logprob_batch_size: 1 + generation_batch_size: 32 # Only used when generating using HF backend + max_total_sequence_length: 32768 + precision: "bfloat16" + logprob_chunk_size: 1024 + + dtensor_cfg: + _v2: false + enabled: true + cpu_offload: False + sequence_parallel: false + activation_checkpointing: true + tensor_parallel_size: 2 + context_parallel_size: 1 + custom_parallel_plan: null + clear_cache_every_n_steps: null + + megatron_cfg: + enabled: false + # We might want to consider setting this value higher (e.g. to 1) and raising the vllm generation max mem utilization + empty_unused_memory_level: 0 + activation_checkpointing: true + converter_type: "Qwen2ForCausalLM" # Apparently this is comptible with Qwen 3 dense models. + tensor_model_parallel_size: 1 + expert_tensor_parallel_size: 1 + expert_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + num_layers_in_first_pipeline_stage: null + num_layers_in_last_pipeline_stage: null + context_parallel_size: 1 + pipeline_dtype: ${policy.precision} + sequence_parallel: false + freeze_moe_router: true + moe_router_dtype: "fp64" + moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo + moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo + #gives ~20% training perf speedup with sequence packing + apply_rope_fusion: True + defer_fp32_logits: true + moe_permute_fusion: false + bias_activation_fusion: True + + optimizer: + optimizer: "adam" + lr: 5.0e-6 + min_lr: 5.0e-7 + weight_decay: 0.01 + bf16: true + fp16: false + params_dtype: "float32" + + #adam + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_eps: 1e-8 + + #sgd + sgd_momentum: 0.9 + + #distributed optimizer + use_distributed_optimizer: true + use_precision_aware_optimizer: true + + # optimizer cpu offload + optimizer_cpu_offload: false + optimizer_offload_fraction: 0.0 + + clip_grad: ${policy.max_grad_norm} + + scheduler: + start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + weight_decay_incr_style: "constant" + lr_decay_style: "constant" + lr_decay_iters: null + lr_warmup_iters: 13 + lr_warmup_init: 5.0e-7 + + distributed_data_parallel_config: + grad_reduce_in_fp32: false + overlap_grad_reduce: true + overlap_param_gather: true + use_custom_fsdp: false + data_parallel_sharding_strategy: "optim_grads_params" + + env_vars: null + + # See docs/design-docs/sequence-packing-and-dynamic-batching.md + # for more details on dynamic batching and sequence packing. + dynamic_batching: + enabled: False + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + sequence_length_round: 64 + + sequence_packing: + enabled: false + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + + # makes the training sequence length divisible by the tensor parallel size + # this is useful for sequence parallel training + make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size} + max_grad_norm: 1.0 + + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 1.0e-6 + weight_decay: 0.01 + betas: [0.9, 0.999] + eps: 1e-8 + # when using Dtensor, we need to set foreach + # and fused to False + foreach: False + fused: False + + scheduler: + - name: "torch.optim.lr_scheduler.ConstantLR" + kwargs: + factor: 1.0 + total_iters: 10000000000 + - milestones: [] + + generation: + backend: "vllm" + max_new_tokens: ${policy.max_total_sequence_length} + temperature: 1.0 + top_p: 1.0 + top_k: null + stop_token_ids: null + stop_strings: null + vllm_cfg: + async_engine: true + precision: ${policy.precision} + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + enable_expert_parallel: false + expert_parallel_size: 1 + gpu_memory_utilization: 0.8 + max_model_len: ${policy.max_total_sequence_length} + enforce_eager: false + use_deep_gemm: False + num_last_layers_in_bf16: 0 + num_first_layers_in_bf16: 0 + expose_http_server: true + skip_tokenizer_init: false + http_server_serving_chat_kwargs: + # This is the tool parser for Qwen 3 4B Instruct. This needs to be changed for other models. + enable_auto_tools: true + tool_parser: hermes + # Enable the appropriate reasoning parser here. Since this model is an instruct model, we comment it out. + # reasoning_parser: deepseek_r1 + vllm_kwargs: + compilation_config: + # when enforce_eager is False, set ++policy.generation.vllm_kwargs.compilation_config.use_inductor=False for better accuracy, + # with the flag, vllm will use the custom CUDA kernels instead of the Triton kernels generated by torch.compile + # for more details, see convergence issue https://github.com/NVIDIA-NeMo/RL/issues/998 + use_inductor: False + colocated: + # true: generation shares training GPUs + # false: uses dedicated generation resources + enabled: true + # only relevant when enabled is false + resources: + gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1 + num_nodes: null # Decides number of nodes to be dedicated to generation + +data: + train_jsonl_fpath: 3rdparty/Penguin-workspace/Penguin/data/bytedtsinghua_dapo17k/train.jsonl + validation_jsonl_fpath: 3rdparty/Penguin-workspace/Penguin/data/bytedtsinghua_dapo17k/validation.jsonl + shuffle: true + num_workers: 0 + +env: + should_use_penguin: true + should_log_penguin_responses: true # If you have low logging storage, set this to false + penguin: # This is passed into Penguin as the initial_global_config_dict + config_paths: + - responses_api_models/vllm_model/configs/vllm_model_for_training.yaml # Required! And it must be *for_training + - resources_servers/library_judge_math/configs/library_judge_math.yaml + library_judge_math: + resources_servers: + library_judge_math: + judge_model_server: + name: policy_model + should_use_judge: false + +logger: + log_dir: "logs" # Base directory for all logs + num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal + wandb_enabled: true + tensorboard_enabled: false + mlflow_enabled: false # Disable MLflow logging + swanlab_enabled: false + monitor_gpus: true # If true, will monitor GPU usage and log to wandb and/or tensorboard + wandb: + project: "grpo-dev" + name: "grpo-dev-logger" + tensorboard: {} + mlflow: + experiment_name: "grpo-dev" + run_name: "grpo-dev-logger" + gpu_monitoring: + collection_interval: 10 # How often to collect GPU usage metrics (in seconds) + flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) + +cluster: + gpus_per_node: 8 + num_nodes: 8 diff --git a/examples/penguin/run_grpo_penguin.py b/examples/penguin/run_grpo_penguin.py new file mode 100644 index 0000000000..96d33e9528 --- /dev/null +++ b/examples/penguin/run_grpo_penguin.py @@ -0,0 +1,297 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os +import pprint +from itertools import chain, repeat +from typing import Optional + +# Increase the W&B single object size warning threshold. Initially 100_000 (100 KB) -> 10_000_000 (10 MB) +import wandb.util + +wandb.util.VALUE_BYTES_LIMIT = 10_000_000 + +import ray +from omegaconf import OmegaConf +from wandb import Table + +from nemo_rl.algorithms.grpo import ( + ColocatablePolicyInterface, + EnvironmentInterface, + GenerationInterface, + Logger, + MasterConfig, + StatefulDataLoader, + TokenizerType, + _should_use_penguin, + grpo_train, + refit_policy_generation, + setup, +) +from nemo_rl.algorithms.utils import get_tokenizer +from nemo_rl.data.datasets import AllTaskProcessedDataset +from nemo_rl.data.interfaces import DatumSpec +from nemo_rl.distributed.ray_actor_environment_registry import ( + get_actor_python_env, +) +from nemo_rl.distributed.virtual_cluster import init_ray +from nemo_rl.environments.penguin import ( + Penguin, + PenguinConfig, + penguin_example_to_nemo_rl_datum_spec, + setup_penguin_config, +) +from nemo_rl.experience.rollouts import run_async_penguin_rollout +from nemo_rl.models.generation import configure_generation_config +from nemo_rl.utils.config import load_config, parse_hydra_overrides +from nemo_rl.utils.logger import get_next_experiment_dir + +OmegaConf.register_new_resolver("mul", lambda a, b: a * b) + + +def parse_args() -> tuple[argparse.Namespace, list[str]]: + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Run GRPO training with configuration") + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + + # Parse known args for the script + args, overrides = parser.parse_known_args() + + return args, overrides + + +def setup_single_penguin_dataset( + jsonl_fpath: str, tokenizer, num_repeats: Optional[int] = None +): + with open(jsonl_fpath) as f: + penguin_examples = list(map(json.loads, f)) + + print(f"Loaded data at {jsonl_fpath}. Found {len(penguin_examples)} examples") + + if num_repeats: + previous_length = len(penguin_examples) + penguin_examples = list( + chain.from_iterable( + repeat(penguin_example, num_repeats) + for penguin_example in penguin_examples + ) + ) + print( + f"Repeating examples (in a pattern of abc to aabbcc) for {jsonl_fpath} from {previous_length} to {len(penguin_examples)}!" + ) + + nemo_rl_compatible_examples: list[DatumSpec] = [ + penguin_example_to_nemo_rl_datum_spec(penguin_example, idx) + for idx, penguin_example in enumerate(penguin_examples) + ] + + passthrough_task_processor = lambda datum_dict, *args, **kwargs: datum_dict + return AllTaskProcessedDataset( + nemo_rl_compatible_examples, + tokenizer, + None, + passthrough_task_processor, + ) + + +# These types are directly imported from grpo_train since if something about the architecture changes we want to immediately fail. +def collect_trajectories( + policy: ColocatablePolicyInterface, + policy_generation: GenerationInterface, + val_dataloader: StatefulDataLoader, + tokenizer: TokenizerType, + val_task_to_env: dict[str, EnvironmentInterface], + logger: Logger, + master_config: MasterConfig, +) -> None: + """Run trajectory collection.""" + # common config/state items + colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"] + refit_policy_generation(policy, policy_generation, colocated_inference) + + log_filename = "trajectory_collection.jsonl" + + print("\nšŸ” Running trajectory collection...", flush=True) + generation_config = master_config["policy"]["generation"] + for val_batch in val_dataloader: + penguin_rollout_result = run_async_penguin_rollout( + policy_generation=policy_generation, + input_batch=val_batch, + tokenizer=tokenizer, + task_to_env=val_task_to_env, + max_seq_len=None, + generation_config=generation_config, + max_rollout_turns=None, + greedy=False, + ) + + rows_to_log: list[str] = [] + for key, value in penguin_rollout_result.rollout_metrics.items(): + if "full_result" not in key: + continue + + value: Table + data: list[list[str]] = value.data # (n, 1) + rows_to_log.extend(v[0] for v in data) + + logger.log_string_list_as_jsonl(rows_to_log, log_filename) + + # TODO: eventually as trajectory collection use cases exceed 4 hours, we can leverage the dataloader save functionality to resume + # And also leverage the TimeoutChecker functionality as well + + policy_generation.finish_generation() + + +def main() -> None: + """Main entry point.""" + # Parse arguments + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join( + os.path.dirname(__file__), + "grpo_dapo17k_bytedtsinghua_qwen3_4binstruct_nf.yaml", + ) + + config = load_config(args.config) + print(f"Loaded configuration from: {args.config}") + + if overrides: + print(f"Overrides: {overrides}") + config = parse_hydra_overrides(config, overrides) + + config: MasterConfig = OmegaConf.to_container(config, resolve=True) + print("Applied CLI overrides") + + # Get the next experiment directory with incremented ID + config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) + print(f"šŸ“Š Using log directory: {config['logger']['log_dir']}") + if config["checkpointing"]["enabled"]: + print( + f"šŸ“Š Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}" + ) + + # setup tokenizer + tokenizer = get_tokenizer(config["policy"]["tokenizer"]) + assert config["policy"]["generation"] is not None, ( + "A generation config is required for GRPO" + ) + config["policy"]["generation"] = configure_generation_config( + config["policy"]["generation"], tokenizer + ) + + # Penguin specific config setup. + setup_penguin_config(config, tokenizer) + + # We assert here since this is right after the final config has been materialized. + assert _should_use_penguin(config) + + print("\nā–¶ Setting up data...") + train_dataset = setup_single_penguin_dataset( + jsonl_fpath=config["data"]["train_jsonl_fpath"], + tokenizer=tokenizer, + ) + val_dataset = setup_single_penguin_dataset( + jsonl_fpath=config["data"]["validation_jsonl_fpath"], + tokenizer=tokenizer, + ) + + # Validation dataset config setup. + if config["grpo"]["max_val_samples"] is not None: + raise ValueError( + """A non-null `grpo.max_val_samples` parameter is not supported. + +Gym principle is that there is no hidden data pre or post processing from you. What you see is what you get. + +The validation set you pass in will directly be used for validation with no additional preprocessing. If you want to have some number of repetitions, please include that in your dataset, via ``num_repeats``, in your dataset config and `ng_prepare_data` will prepare it accordingly.""" + ) + + print( + f"Setting `grpo.max_val_samples` and `grpo.val_batch_size` to the length of the validation dataset, which is {len(val_dataset)}" + ) + config["grpo"]["max_val_samples"] = len(val_dataset) + config["grpo"]["val_batch_size"] = config["grpo"]["max_val_samples"] + + # Print config + print("Final config:") + pprint.pprint(config) + + init_ray() + + ( + policy, + policy_generation, + cluster, + dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + grpo_state, + master_config, + ) = setup(config, tokenizer, train_dataset, val_dataset) + + is_trajectory_collection = ( + config["env"]["penguin"].pop("is_trajectory_collection", False) or False + ) + penguin_config = PenguinConfig( + model_name=policy_generation.cfg["model_name"], + base_urls=policy_generation.dp_openai_server_base_urls, + initial_global_config_dict=config["env"]["penguin"], + ) + penguin = Penguin.options( + runtime_env={ + "py_executable": get_actor_python_env( + "nemo_rl.environments.penguin.Penguin" + ), + } + ).remote(penguin_config) + # Blocking wait for penguin to spin up + ray.get(penguin.health_check.remote()) + task_to_env = {"penguin": penguin} + val_task_to_env = task_to_env + + if is_trajectory_collection: + collect_trajectories( + policy=policy, + policy_generation=policy_generation, + val_dataloader=val_dataloader, + tokenizer=tokenizer, + val_task_to_env=val_task_to_env, + logger=logger, + master_config=master_config, + ) + else: + grpo_train( + policy, + policy_generation, + dataloader, + val_dataloader, + tokenizer, + loss_fn, + task_to_env, + val_task_to_env, + logger, + checkpointer, + grpo_state, + master_config, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/penguin/run_penguin_single_node_sanity_tests.sh b/examples/penguin/run_penguin_single_node_sanity_tests.sh new file mode 100755 index 0000000000..1337cf3102 --- /dev/null +++ b/examples/penguin/run_penguin_single_node_sanity_tests.sh @@ -0,0 +1,33 @@ +# Fail on errors +set -e + +uv sync --group={build,docs,dev,test} --extra penguin + +# Stop pesky previous Ray servers that may have not been able to spin down from previous users. +uv run ray stop --force +uv run python -c "import ray; ray.shutdown()" + +# The first time I ran this, it took roughly 5 mins to setup the vLLM deps. +# This took me 2-3 mins to run this one test. +# NeMo RL test. This should pass no matter what the Gym setup is. +./tests/run_unit.sh unit/models/generation/test_vllm_generation.py::test_vllm_generate_text + +# NeMo Gym uses an OpenAI compatible endpoint under the hood. This tests the implementation for this server. +./tests/run_unit.sh unit/models/generation/test_vllm_generation.py::test_vllm_http_server + +# NeMo Gym communicates not using token ids, but in OpenAI schema. There are some edge cases we need to handle (e.g. token merging upon retokenization, multiple most efficient retokenizations, etc). +./tests/run_unit.sh unit/models/generation/test_vllm_generation.py::test_VllmAsyncGenerationWorker_replace_prefix_tokens +./tests/run_unit.sh unit/models/generation/test_vllm_generation.py::test_replace_prefix_tokens_empty_model_prefix_returns_template +./tests/run_unit.sh unit/models/generation/test_vllm_generation.py::test_replace_prefix_tokens_missing_eos_in_template_prefix_raises +./tests/run_unit.sh unit/models/generation/test_vllm_generation.py::test_replace_prefix_tokens_tokenizer_without_eos_raises +./tests/run_unit.sh unit/models/generation/test_vllm_generation.py::test_replace_prefix_tokens_uses_last_eos_in_template_prefix +./tests/run_unit.sh unit/models/generation/test_vllm_generation.py::test_vllm_http_server_correct_merged_tokens_matches_baseline + +# NeMo RL test. This should pass no matter what the Gym setup is. +./tests/run_unit.sh unit/environments/test_math_environment.py::test_math_env_step_basic + +# NeMo Gym integrates directly into NeMo RL as an Environment since that is the cleanest way. This tests the NeMo Gym integration logic and correctness. +./tests/run_unit.sh unit/environments/test_penguin.py::test_penguin_sanity + +# NeMo Gym uses a separate rollout loop inside grpo_train in NeMo RL. This tests the e2e rollout functionality and correctness. +./tests/run_unit.sh unit/experience/test_rollouts.py::test_run_async_penguin_rollout diff --git a/nemo_rl/environments/penguin.py b/nemo_rl/environments/penguin.py index a53c3d89b9..1f7462a866 100644 --- a/nemo_rl/environments/penguin.py +++ b/nemo_rl/environments/penguin.py @@ -16,10 +16,12 @@ import ray import torch +from transformers import PreTrainedTokenizerBase from nemo_rl.data.interfaces import DatumSpec from nemo_rl.distributed.virtual_cluster import _get_free_port_local, _get_node_ip_local from nemo_rl.environments.interfaces import EnvironmentInterface +from nemo_rl.utils.timer import Timer class PenguinConfig(TypedDict): @@ -46,7 +48,9 @@ def __init__(self, cfg: PenguinConfig): RELATIVE_PATH = "nemo_rl/environments/penguin.py" assert __file__.endswith(RELATIVE_PATH) - initial_global_config_dict = self.cfg["initial_global_config_dict"] + initial_global_config_dict = ( + self.cfg.get("initial_global_config_dict") or dict() + ) # Policy information initial_global_config_dict["policy_model_name"] = self.cfg["model_name"] initial_global_config_dict["policy_api_key"] = ( @@ -54,13 +58,13 @@ def __init__(self, cfg: PenguinConfig): ) initial_global_config_dict["policy_base_url"] = self.cfg["base_urls"] - initial_global_config_dict["global_aiohttp_connector_limit_per_host"] = ( - initial_global_config_dict.get("global_aiohttp_connector_limit_per_host") - or 1024 + initial_global_config_dict.setdefault( + "global_aiohttp_connector_limit_per_host", 16_384 ) - initial_global_config_dict["global_aiohttp_connector_limit"] = ( - initial_global_config_dict["global_aiohttp_connector_limit_per_host"] - * len(self.cfg["base_urls"]) + initial_global_config_dict.setdefault("global_aiohttp_connector_limit", 65_536) + print( + f"""Set global_aiohttp_connector_limit_per_host={initial_global_config_dict["global_aiohttp_connector_limit_per_host"]} and global_aiohttp_connector_limit={initial_global_config_dict["global_aiohttp_connector_limit"]}. +Depending on your data shape, you may want to change these values.""" ) # Get Ray head node address if Ray is initialized @@ -73,11 +77,6 @@ def __init__(self, cfg: PenguinConfig): initial_global_config_dict["ray_head_node_address"] = ray_context.gcs_address print(f"Ray head node address: {ray_context.gcs_address}") - print( - f"""Set `global_aiohttp_connector_limit_per_host` to a flat {initial_global_config_dict["global_aiohttp_connector_limit_per_host"]}. -Since there are {len(self.cfg["base_urls"])} data-parallel vLLM worker instances, the `global_aiohttp_connector_limit` has been set to {len(self.cfg["base_urls"])} * {initial_global_config_dict["global_aiohttp_connector_limit_per_host"]} = {initial_global_config_dict["global_aiohttp_connector_limit"]}.""" - ) - # Head server initial_global_config_dict[HEAD_SERVER_KEY_NAME] = { "host": "0.0.0.0", @@ -104,17 +103,43 @@ def __init__(self, cfg: PenguinConfig): def health_check(self) -> bool: return True - async def run_rollouts(self, penguin_examples: list[dict]) -> list[dict]: - penguin_results = await self.rch.run_examples( + async def run_rollouts( + self, + penguin_examples: list[dict], + tokenizer: PreTrainedTokenizerBase, + timer_prefix: str, + ) -> list[dict]: + timer = Timer() + + penguin_result_iterator = self.rch.run_examples( examples=penguin_examples, head_server_config=self.head_server_config ) - nemo_rl_results = list( - map(self._postprocess_penguin_to_nemo_rl_result, penguin_results) + timer.start("_run_rollouts_total") + nemo_rl_results = [] + for task in penguin_result_iterator: + with timer.time(label=f"{timer_prefix}/await_results"): + penguin_result = await task + + with timer.time(label=f"{timer_prefix}/postprocess_results"): + nemo_rl_result = self._postprocess_penguin_to_nemo_rl_result( + penguin_result, tokenizer + ) + + nemo_rl_results.append(nemo_rl_result) + + timer.stop("_run_rollouts_total") + timing_metrics = timer.get_timing_metrics("sum") + total_time = timing_metrics.pop("_run_rollouts_total") + timing_metrics[f"{timer_prefix}/postprocess_results_pct"] = ( + 100 * timing_metrics[f"{timer_prefix}/postprocess_results"] / total_time ) - return nemo_rl_results - def _postprocess_penguin_to_nemo_rl_result(self, penguin_result: dict) -> dict: + return nemo_rl_results, timing_metrics + + def _postprocess_penguin_to_nemo_rl_result( + self, penguin_result: dict, tokenizer: PreTrainedTokenizerBase + ) -> dict: nemo_rl_message_log = [] seen_token_ids: List[int] = [] for output_item_dict in penguin_result["response"]["output"]: @@ -138,23 +163,34 @@ def _postprocess_penguin_to_nemo_rl_result(self, penguin_result: dict) -> dict: { "role": "user", "content": "", - "token_ids": output_item_dict["prompt_token_ids"][ - len(seen_token_ids) : - ], + "token_ids": torch.tensor( + output_item_dict["prompt_token_ids"][len(seen_token_ids) :] + ), } ) nemo_rl_message_log.append( { "role": "assistant", "content": "", - "token_ids": output_item_dict["generation_token_ids"], - "generation_logprobs": output_item_dict["generation_log_probs"], + "token_ids": torch.tensor(output_item_dict["generation_token_ids"]), + "generation_logprobs": torch.tensor( + output_item_dict["generation_log_probs"] + ), } ) seen_token_ids.extend(nemo_rl_message_log[-2]["token_ids"]) seen_token_ids.extend(nemo_rl_message_log[-1]["token_ids"]) + # We pop to remove larger tensors from logging. + output_item_dict["prompt_str"] = tokenizer.decode( + output_item_dict.pop("prompt_token_ids") + ) + output_item_dict["generation_str"] = tokenizer.decode( + output_item_dict.pop("generation_token_ids") + ) + output_item_dict.pop("generation_log_probs") + return { "message_log": nemo_rl_message_log, "input_message_log": nemo_rl_message_log[:1], diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py index ba0ba9ce95..b8b378542c 100644 --- a/nemo_rl/experience/rollouts.py +++ b/nemo_rl/experience/rollouts.py @@ -48,6 +48,7 @@ GenerationInterface, GenerationOutputSpec, ) +from nemo_rl.utils.timer import Timer TokenizerType = PreTrainedTokenizerBase @@ -934,14 +935,6 @@ async def run_single_sample_with_error_handling(i, sample_state): return asyncio.run(_async_rollout_implementation()) -def _tensorize_by_key(message_logs: list, key: str): - if not message_logs or key not in message_logs[0]: - return - - for m in message_logs: - m[key] = torch.tensor(m[key]) - - @dataclass class AsyncPenguinRolloutResult: input_ids: torch.Tensor @@ -995,6 +988,10 @@ def run_async_penguin_rollout( "Top k is not supported in the generation config in Penguin path!" ) + timer = Timer() + timer_prefix = "timing/rollout" + timer.start(f"{timer_prefix}/total") + for row in penguin_rows: # We may need better handling here. The max tokens set here would be the max new generated tokens, not the total max tokens. # Currently, we just rely on the underlying vLLM engine to do the truncation for us using the max model seq len set in the config. @@ -1007,116 +1004,109 @@ def run_async_penguin_rollout( # Max new tokens, just like max_seq_len above is ignored and we rely on the underlying vLLM engine for truncation. # generation_config["max_new_tokens"] - penguin_environment = task_to_env["penguin"] - results = ray.get(penguin_environment.run_rollouts.remote(penguin_rows)) - - # Tensorize all token ids - for r in results: - _tensorize_by_key(r["input_message_log"], "token_ids") - _tensorize_by_key(r["message_log"], "token_ids") - _tensorize_by_key( - [m for m in r["message_log"] if m["role"] == "assistant"], - "generation_logprobs", + with timer.time(f"{timer_prefix}/run_rollouts"): + penguin_environment = task_to_env["penguin"] + results, rollout_loop_timing_metrics = ray.get( + penguin_environment.run_rollouts.remote( + penguin_rows, tokenizer, timer_prefix + ) ) # Prepare for the rollout metrics calculation below. Not strictly necessary here, but good to have parity with `run_async_multi_turn_rollout` - batch_size = len(penguin_rows) - max_total_tokens_per_sample = policy_generation.cfg["vllm_cfg"]["max_model_len"] - all_sample_metrics = [ - { - "total_reward": r["full_result"]["reward"], - "assistant_tokens": sum( - len(m["token_ids"]) - for m in r["message_log"] - if m["role"] == "assistant" - ), - "total_tokens": sum(len(m["token_ids"]) for m in r["message_log"]), - "turn_count": sum(1 for m in r["message_log"] if m["role"] == "user"), - "hit_max_tokens": sum(len(m["token_ids"]) for m in r["message_log"]) - == max_total_tokens_per_sample, - } - for r in results - ] + with timer.time(f"{timer_prefix}/prepare_for_metrics_calculation"): + batch_size = len(penguin_rows) + max_total_tokens_per_sample = policy_generation.cfg["vllm_cfg"]["max_model_len"] + all_sample_metrics = [ + { + "total_reward": r["full_result"]["reward"], + "assistant_tokens": sum( + len(m["token_ids"]) + for m in r["message_log"] + if m["role"] == "assistant" + ), + "total_tokens": sum(len(m["token_ids"]) for m in r["message_log"]), + "turn_count": sum(1 for m in r["message_log"] if m["role"] == "user"), + "hit_max_tokens": sum(len(m["token_ids"]) for m in r["message_log"]) + == max_total_tokens_per_sample, + } + for r in results + ] # Aggregate metrics across all samples - rollout_metrics = { - **_calculate_single_metric( - [m["turn_count"] for m in all_sample_metrics], - batch_size, - "turns_per_sample", - ), - **_calculate_single_metric( - [m["total_tokens"] for m in all_sample_metrics], - batch_size, - "total_tokens_per_sample", - ), - **_calculate_single_metric( - [m["assistant_tokens"] for m in all_sample_metrics], - batch_size, - "gen_tokens_per_sample", - ), - **_calculate_single_metric( - [m["total_reward"] for m in all_sample_metrics], batch_size, "total_reward" - ), - "natural_termination_rate": sum( - not m["hit_max_tokens"] for m in all_sample_metrics - ) - / batch_size, - "truncation_rate": sum(m["hit_max_tokens"] for m in all_sample_metrics) - / batch_size, - # TODO enable this metric. We don't have a clear handle on which tokens are user or tool role. - # We would probably need to re-tokenize the messages post-hoc to kind of figure this out. - # "mean_env_tokens_per_sample": sum( - # m["env_tokens"] for m in all_sample_metrics - # ) - # / batch_size, - } + with timer.time(f"{timer_prefix}/aggregate_metrics"): + rollout_metrics = { + **rollout_loop_timing_metrics, + **_calculate_single_metric( + [m["turn_count"] for m in all_sample_metrics], + batch_size, + "turns_per_sample", + ), + **_calculate_single_metric( + [m["total_tokens"] for m in all_sample_metrics], + batch_size, + "total_tokens_per_sample", + ), + **_calculate_single_metric( + [m["assistant_tokens"] for m in all_sample_metrics], + batch_size, + "gen_tokens_per_sample", + ), + **_calculate_single_metric( + [m["total_reward"] for m in all_sample_metrics], + batch_size, + "total_reward", + ), + "natural_termination_rate": sum( + not m["hit_max_tokens"] for m in all_sample_metrics + ) + / batch_size, + "truncation_rate": sum(m["hit_max_tokens"] for m in all_sample_metrics) + / batch_size, + # TODO enable this metric. We don't have a clear handle on which tokens are user or tool role. + # We would probably need to re-tokenize the messages post-hoc to kind of figure this out. + # "mean_env_tokens_per_sample": sum( + # m["env_tokens"] for m in all_sample_metrics + # ) + # / batch_size, + } # Per-agent misc metrics - agent_to_results: dict[str, list[dict]] = defaultdict(list) - for penguin_row, result in zip(penguin_rows, results): - agent_name = penguin_row["agent_ref"]["name"] - agent_to_results[agent_name].append(result["full_result"]) - - per_agent_metrics = {} - for agent_name, agent_results in agent_to_results.items(): - keys = agent_results[0].keys() - for key in keys: - values = [] - for r in agent_results: - if isinstance(r.get(key), (bool, int, float)): - values.append(float(r[key])) - - if values: - per_agent_metrics.update( - _calculate_single_metric( - values, len(agent_results), f"{agent_name}/{key}" + with timer.time(f"{timer_prefix}/per_agent_misc_metrics"): + agent_to_results: dict[str, list[dict]] = defaultdict(list) + for penguin_row, result in zip(penguin_rows, results): + agent_name = penguin_row["agent_ref"]["name"] + agent_to_results[agent_name].append(result["full_result"]) + + per_agent_metrics = {} + for agent_name, agent_results in agent_to_results.items(): + keys = agent_results[0].keys() + for key in keys: + values = [ + float(r[key]) + for r in agent_results + if isinstance(r.get(key), (bool, int, float)) + ] + if values: + per_agent_metrics.update( + _calculate_single_metric( + values, len(agent_results), f"{agent_name}/{key}" + ) ) - ) - # Log the full result - to_log = [] - for r in agent_results: - r = copy.deepcopy(r) - # Remove tokens from logging - for output_item in r["response"]["output"]: - output_item.pop("prompt_token_ids", None) - output_item.pop("generation_token_ids", None) - output_item.pop("generation_log_probs", None) - - r = json.dumps(r, separators=((",", ":"))) - to_log.append([r]) - - per_agent_metrics[f"{agent_name}/full_result"] = Table( - data=to_log, columns=["Full result"] - ) + # Log the full result + to_log = [[json.dumps(r, separators=((",", ":")))] for r in agent_results] + per_agent_metrics[f"{agent_name}/full_result"] = Table( + data=to_log, columns=["Full result"] + ) - rollout_metrics.update(per_agent_metrics) + rollout_metrics.update(per_agent_metrics) # Necessary for downstream nemo rl logging/printing. rollout_metrics["mean_gen_tokens_per_sample"] = rollout_metrics[ "gen_tokens_per_sample/mean" ] + timer.stop(f"{timer_prefix}/total") + rollout_metrics.update(timer.get_timing_metrics("sum")) # Convert LLMMessageLogType to FlatMessagesType for generation input_batch_for_input_ids = BatchedDataDict[DatumSpec]( diff --git a/tests/unit/environments/test_penguin.py b/tests/unit/environments/test_penguin.py index 7f4afa0958..abc238d145 100644 --- a/tests/unit/environments/test_penguin.py +++ b/tests/unit/environments/test_penguin.py @@ -142,7 +142,12 @@ def penguin_sanity_test_data(): not PENGUIN_INSTALLED, reason="Skipping Penguin test since Penguin is not installed!", ) -def test_penguin_sanity(penguin, penguin_sanity_test_data, penguin_vllm_generation): +def test_penguin_sanity( + penguin, + penguin_sanity_test_data, + penguin_vllm_generation, + penguin_tokenizer, # noqa: F811 +): """Test basic functionality of MathEnvironment step with simple messages.""" # We need to match NeMo RL generation config params before sending to Penguin @@ -154,11 +159,21 @@ def test_penguin_sanity(penguin, penguin_sanity_test_data, penguin_vllm_generati ] example["responses_create_params"]["top_p"] = generation_config["top_p"] - actual_result = ray.get( - penguin.run_rollouts.remote(penguin_sanity_test_data["input"]) + actual_result, _ = ray.get( + penguin.run_rollouts.remote( + penguin_sanity_test_data["input"], penguin_tokenizer, "" + ) ) expected_result = penguin_sanity_test_data["expected_output"] + # These are tensors originally and we swap them back to a list for comparison below + for d in actual_result: + for message in d["input_message_log"]: + message["token_ids"] = message["token_ids"].tolist() + # Right now, we don't need to swap the token ids in the message log since they pointto the same underlying dictionary as above. + # for message in d["message_log"][:1]: + # message["token_ids"] = message["token_ids"].tolist() + def _standardize_single_result(d: dict): d = deepcopy(d) d.pop("full_result", None) @@ -170,6 +185,10 @@ def _standardize_single_result(d: dict): message["token_ids"] = [] if "generation_logprobs" in message: message["generation_logprobs"] = [] + if "prompt_str" in message: + message["prompt_str"] = "dummy prompt_str" + if "generation_str" in message: + message["generation_str"] = "dummy generation_str" return d diff --git a/tests/unit/experience/test_rollouts.py b/tests/unit/experience/test_rollouts.py index 0aa44e68cf..da515ba10f 100644 --- a/tests/unit/experience/test_rollouts.py +++ b/tests/unit/experience/test_rollouts.py @@ -782,6 +782,14 @@ def test_run_async_penguin_rollout( }, "rollout_metrics": { # core metrics + "timing/rollout/total": 0.0, + "timing/rollout/run_rollouts": 0.0, + "timing/rollout/await_results": 0.0, + "timing/rollout/postprocess_results": 0.0, + "timing/rollout/postprocess_results_pct": 0.0, + "timing/rollout/prepare_for_metrics_calculation": 0.0, + "timing/rollout/aggregate_metrics": 0.0, + "timing/rollout/per_agent_misc_metrics": 0.0, "mean_gen_tokens_per_sample": None, "turns_per_sample/mean": 2.0, "turns_per_sample/max": 2,