diff --git a/.gitignore b/.gitignore index 2628699139..38276ac019 100644 --- a/.gitignore +++ b/.gitignore @@ -1,156 +1,29 @@ -# log and data files -*.model -*.pkl -*.pt -.DS_Store -.hydra -.bash_history.local +# Adding to .gitignore helps reduce the size of your working_dir -# Byte-compiled / optimized / DLL files +.git +*.out +*.log +*.tar +*.tar.gz +.venv +venv __pycache__/ -*.py[cod] -*$py.class -**.pyc - -# C extensions -*.so - -# Distribution / packaging -.idea -.Python -wandb +_build/ build/ -develop-eggs/ +apidocs/ dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -#parts/ -sdist/ -var/ -wheels/ -pip-wheel-metadata/ -share/python-wheels/ *.egg-info/ -.installed.cfg -*.egg -MANIFEST +*.vscode/ -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ +# Test .coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -.hypothesis/ -.pytest_cache/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/build - -# PyBuilder -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# Override Jupyter in Github Language states for more accurate estimate of repo code. -# Reference: https://github.com/github/linguist/blob/master/docs/overrides.md#generated-code -*.ipynb linguist-generated - -# IPython -profile_default/ -ipython_config.py - -# pyenv -.python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don’t work, or not -# install all needed dependencies. -#Pipfile.lock - -# celery beat schedule file -celerybeat-schedule - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -venv/ -env.bak/ -venv.bak/ - -# VSCode project settins -.vscode/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site -/docs/html -/docs/docs_zh/zh - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# Emacs backup files -*~ - -*.tar.gz - -# Test data. -tests/.data -tests/data - -# outputs folder -wandb -# Checkpoints, config files and temporary files created in tutorials. -.hydra/ +# Cache +uv_cache/ +hf_home/ +*logs/ +datasets/ +docker/ +wandb/ +checkpoints/ +results/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000..ba48b680ae --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,141 @@ +# Contributing To Nemo-Reinforcer + +Thanks for your interest in contributing to Nemo-Reinforcer! + +## Setting Up + +### Development Environment + +1. **Build and run the Docker container**: +```bash +docker buildx build -t nemo-reinforcer -f Dockerfile . +# Run the container with your local nemo-reinforcer directory mounted +docker run -it --gpus all -v /path/to/nemo-reinforcer:/workspace/nemo-reinforcer nemo-reinforcer +``` + +2. **Install the package in development mode**: +```bash +cd /workspace/nemo-reinforcer +pip install -e . +``` + +## Making Changes + +### Workflow: Clone and Branch (No Fork Required) + +#### Before You Start: Install pre-commit + +From the [`nemo-reinforcer` root directory](.), run: +```bash +python3 -m pip install pre-commit +pre-commit install +``` + +Pre-commit checks (using `ruff`) will help ensure your code follows our formatting and style guidelines. + +We follow a direct clone and branch workflow for now: + +1. Clone the repository directly: + ```bash + git clone https://github.com/NVIDIA/nemo__placeholder + cd nemo-reinforcer + ``` + +2. Create a new branch for your changes: + ```bash + git checkout -b your-feature-name + ``` + +3. Make your changes and commit them: + ```bash + git add . + git commit --signoff -m "Your descriptive commit message" + ``` + +We require signing commits with `--signoff` (or `-s` for short). See [Signing Your Work](#signing-your-work) for details. + +4. Push your branch to the repository: + ```bash + git push origin feature/your-feature-name + ``` + +5. Create a pull request from your branch to the `main` branch. + +### Design Documentation Requirement + +**Important**: All new key features (ex: enabling a new parallelization technique, enabling a new RL algorithm) must include documentation update (either a new doc or updating an existing one). This document update should: + +- Explain the motivation and purpose of the feature +- Outline the technical approach and architecture +- Provide clear usage examples and instructions for users +- Document internal implementation details where appropriate + +This ensures that all significant changes are well-thought-out and properly documented for future reference. Comprehensive documentation serves two critical purposes: + +1. **User Adoption**: Helps users understand how to effectively use the library's features in their projects +2. **Developer Extensibility**: Enables developers to understand the internal architecture and implementation details, making it easier to modify, extend, or adapt the code for their specific use cases + +Quality documentation is essential for both the usability of Nemo-Reinforcer and its ability to be customized by the community. + +## Code Quality + +- Follow the existing code style and conventions +- Write tests for new features +- Update documentation to reflect your changes +- Ensure all tests pass before submitting a PR +- Do not add arbitrary defaults for configs, be as explicit as possible. + + +## Signing Your Work + +* We require that all contributors "sign-off" on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license. + + * Any contribution which contains commits that are not Signed-Off will not be accepted. + +* To sign off on a commit you simply use the `--signoff` (or `-s`) option when committing your changes: + ```bash + $ git commit -s -m "Add cool feature." + ``` + This will append the following to your commit message: + ``` + Signed-off-by: Your Name + ``` + +* Full text of the DCO: + + ``` + Developer Certificate of Origin + Version 1.1 + + Copyright (C) 2004, 2006 The Linux Foundation and its contributors. + + Everyone is permitted to copy and distribute verbatim copies of this + license document, but changing it is not allowed. + + + Developer's Certificate of Origin 1.1 + + By making a contribution to this project, I certify that: + + (a) The contribution was created in whole or in part by me and I + have the right to submit it under the open source license + indicated in the file; or + + (b) The contribution is based upon previous work that, to the best + of my knowledge, is covered under an appropriate open source + license and I have the right under that license to submit that + work with modifications, whether created in whole or in part + by me, under the same open source license (unless I am + permitted to submit under a different license), as indicated + in the file; or + + (c) The contribution was provided directly to me by some other + person who certified (a), (b) or (c) and I have not modified + it. + + (d) I understand and agree that this project and the contribution + are public and that a record of the contribution (including all + personal information I submit with it, including my sign-off) is + maintained indefinitely and may be redistributed consistent with + this project or the open source license(s) involved. + ``` diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000..261eeb9e9f --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index 7ce63a5cf4..4668a90ede 100644 --- a/README.md +++ b/README.md @@ -1 +1,55 @@ -# _placeholder +# Nemo-Reinforcer: A Scalable and Efficient Post-Training Library for Models Ranging from 1 GPU to 1000s, and from Tiny to >100B Parameters + + +- [Nemo-Reinforcer: A Scalable and Efficient Post-Training Library for Models Ranging from 1 GPU to 1000s, and from Tiny to \>100B Parameters](#nemo-reinforcer-a-scalable-and-efficient-post-training-library-for-models-ranging-from-1-gpu-to-1000s-and-from-tiny-to-100b-parameters) + - [Features](#features) + - [Installation](#installation) + - [Cluster Start](#cluster-start) + +**Nemo-Reinforcer** is a scalable and efficient post-training library designed for models ranging from 1 GPU to thousands, and from tiny to over 100 billion parameters. + +What you can expect: + +- **Seamless integration with HuggingFace** for ease of use, allowing users to leverage a wide range of pre-trained models and tools. +- **High-performance implementation with Megatron core**, supporting various parallelism techniques for large models (>100B) and large context lengths. +- **Efficient resource management using Ray**, enabling scalable and flexible deployment across different hardware configurations. +- **Flexibility** with a modular design that allows easy integration and customization. +- **Comprehensive documentation** that is both detailed and user-friendly, with practical examples. + +## Features + +_βœ… Available now | πŸ”œ Coming in v0.2_ + +- βœ… **Fast Generation** - vLLM backend for optimized inference +- βœ… **HuggingFace Integration** - Works with 1-8B models (Qwen1.5, Llama) +- βœ… **Distributed Training** - FSDP support and Ray-based infrastructure +- βœ… **Environment Support** - Support for multi-environment training. +- βœ… **Learning Algorithms** - GRPO (Group Relative Policy Optimization) and SFT (Supervised Fine-Tuning) +- βœ… **Worker Isolation** - Process isolation between RL Actors (no worries about global state) +- πŸ”œ **Larger Model Support** - Native PyTorch support for models up to 70B parameters +- πŸ”œ **Advanced Parallelism** - FSDP2, TP, SP, and sequence packing for efficient training +- πŸ”œ **Environment Isolation** - Dependency isolation between components +- πŸ”œ **DPO Algorithm** - Direct Preference Optimization for alignment + +## Installation + +```sh +# For faster setup we use `uv` +pip install uv + +# Specify a virtual env that uses Python 3.12 +uv venv -p python3.12.9 .venv +# Install NeMo-Reinforcer with vllm +uv pip install -e . +# Install NeMo-Reinforcer with dev/test dependencies +uv pip install -e '.[dev,test]' + +# Use uv run to launch any runs. +# Note that it is recommended to not activate the venv and instead use `uv run` since +# it ensures consistent environment usage across different shells and sessions. +uv run python examples/run_grpo.py +``` + +## Cluster Start + +Please visit [Cluster Start](docs/cluster.md) for how to get started on Slurm or Kubernetes. diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000000..9031b0a9be --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,7 @@ +ARG BASE_IMAGE=anyscale/ray:2.43.0-py312-cu125 +FROM ${BASE_IMAGE} + +RUN sudo apt-get update && sudo apt-get install -y jq + +RUN pip install --no-cache-dir uv +RUN echo "unset RAY_RUNTIME_ENV_HOOK" >> /home/ray/.bashrc diff --git a/docs/adding_new_models.md b/docs/adding_new_models.md new file mode 100644 index 0000000000..e80fbb1a79 --- /dev/null +++ b/docs/adding_new_models.md @@ -0,0 +1,121 @@ +# Adding New Models + +This guide outlines how to integrate and validate a new model within **NeMo-Reinforcer**. Each new model must pass a standard set of compatibility tests before being considered ready to be used in RL pipelines. + +## Importance of Log Probability Consistency in Training and Inference + +In on-policy RL, we sample tokens (actions) from the latest version of the policy, meaning the sampling distribution of token probabilities produced by the inference framework must closely match those from the training framework. If the inference framework produces significantly different probabilities, we effectively sample from a different distribution, leading to errors in the loss estimation. + +As an example, we would see errors in naive KL estimation: + +$$\text{KL} = \mathbb{E}_{x \sim \pi}[\pi(x) - \pi_{\text{ref}}(x)]$$ + +When summed/integrated, replacing the $x \sim \pi$ with $x \sim \pi_{\text{wrong}}$ leads to an error of: + +$$\sum_{x} \left( \pi(x) - \pi_{\text{ref}}(x) \right) \left( \pi_{\text{wrong}}(x) - \pi(x) \right)$$ + +So, to verify correctness, we calculate + +$$ +\frac{1}{n}\sum_{i=1}^{n}\exp\left(\left\|\text{lp_train_fwk}_i - \text{lp_infer_fwk}_i\right\|\right) +$$ + +as a measure of multiplicative probability error for sampled tokens. Note that this is not exhaustive (the sampling framework could lack distribution support and we wouldn't catch it here) + +## Understanding Discrepancies Between Backends + +When validating models across different backends, you may encounter discrepancies in log probabilities. These differences can stem from various sources with effects ranging from negligible to significant: + +- **Numerical precision differences**: Training and inference backends may differ in precision formats (FP32, FP16, BF16, FP8). + - Training may use mixed precision while the inference backend may not + - High-precision training with FP8 inference may not be numerically stable for certain models + - Differences can occur at the layer level, with some layers in FP32 while others use lower precision + +- **Implementation variations**: Subtle differences in how layer implementations like softmax, layer normalization, or attention mechanisms are implemented. + - Attention/Norm layers (which could be fused) in TransformerEngine may not be bit-wise identical to implementations in inference backends + - Inference backends may re-implement kernels (e.g., for SSM layers) leading to differences + - Softmax in training frameworks may be calculated differently than in inference backends for numerical stability + +- **KV/Prefill cache handling**: Differences in how key-value/prefill caches are managed during autoregressive generation. + - In some cases, disabling the inference backend cache can resolve discrepancies + +- **Parallelism effects**: Parallelisms like Tensor parallelism may introduce small variations + +- **Inherent non-determinism**: Some neural network operations are inherently non-deterministic (e.g., `torch.cumsum`) + +- **Prefill/Decoding kernel mismatch**: Different kernels for prefill and decoding phases may produce different log probabilities. + - Training frameworks typically use prefill kernels, while inference backends may use both prefill kernels and specialized decoding kernels + +- **Imperfect Refit**: Weight conversion from the training framework to the inference backend may be incomplete or data formats may be incorrect + - If weights are reshaped or reordered incorrectly, generations tend to be very wrong + - In some cases, if some weights in the inference backend are not refit after each training step, the error between training and inference log probabilities can diverge as training progresses + +- **Batch size**: In some cases, `batch_size>1` may produce larger errors than `batch_size=1` + +When investigating discrepancies beyond the acceptable threshold, focus on these areas and determine whether the differences appear systematically or only in specific contexts. + + +--- + +## 1. Hugging Face–Based Models + +### Validation Workflow + +When validating Hugging Face-based models, perform the following checks: + +- **Compare log probabilities** + Ensure the generation log probabilities from inference backends like **vLLM** match those computed by HuggingFace. This comparison helps diagnose potential mismatches. + +- **Test parallelism** + Verify consistency with other parallelism settings. + +- **Variance** + Repeat tests multiple times (e.g., 10 runs) to confirm that behavior is deterministic or within acceptable variance. + +- **Check sequence lengths** + Perform inference on sequence lengths of 100, 1,000, and 10,000 tokens. + Ensure the model behaves consistently at each length. + +- **Use real and dummy data** + - **Real data:** Tokenize and generate from actual text samples. + - **Dummy data:** Simple numeric sequences to test basic generation. + +- **Vary sampling parameters** + Test both greedy and sampling generation modes. + Adjust temperature and top-p to confirm output consistency across backends. + +- **Test different batch sizes** + Try with batch sizes of 1, 8, and 32 to ensure consistent behavior across different batch configurations. + +--- + +## 2. Megatron Models + +### Additional Validation + +- **Compare Megatron outputs** + Ensure the Megatron forward pass aligns with HuggingFace and the generation log probabilities from inference backends like **vLLM**. + +- **Parallel settings** + Match the same parallelism configurations used for the HuggingFace-based tests. + Confirm outputs remain consistent across repeated runs. + +--- + +## 3. Expected Error Threshold + +When comparing log probabilities between training and inference backends, we use an error threshold of `1.05` to determine acceptable variance (for equal precision). An error of `1.0` indicates a perfect match, and values exceeding `1.05` require further investigation. + +When validating your model, you should analyze the results across different configurations. Your analysis should include: + +| Sequence Length | Data Type | Generation Method | Batch Size | HF vs VLLM | Megatron vs VLLM | +|-----------------|------------|-------------------|------------|------------|------------------| +| 100 | Real | Greedy | 1 | 1.02 | 1.01 | +| 100 | Real | Sampling | 8 | 1.03 | 1.02 | +| 100 | Synthetic | Greedy | 1 | 1.01 | 1.02 | +| 1,000 | Real | Greedy | 32 | 1.04 | 1.03 | +| ... | ... | ... | ... | ... | ... | + +--- + +By following these validation steps and ensuring your model's outputs remain consistent across backends, you can confirm that your new model meets **NeMo-Reinforcer**'s requirements. diff --git a/docs/assets/val-log.png b/docs/assets/val-log.png new file mode 100644 index 0000000000..bda6618b8c Binary files /dev/null and b/docs/assets/val-log.png differ diff --git a/docs/autodoc2_docstrings_parser.py b/docs/autodoc2_docstrings_parser.py new file mode 100644 index 0000000000..d550c3dedb --- /dev/null +++ b/docs/autodoc2_docstrings_parser.py @@ -0,0 +1,30 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from docutils import nodes +from myst_parser.parsers.sphinx_ import MystParser +from sphinx.ext.napoleon.docstring import GoogleDocstring + + +class NapoleonParser(MystParser): + def parse(self, input_string: str, document: nodes.document) -> None: + # Get the Sphinx configuration + config = document.settings.env.config + + # Process with Google style + google_parsed = str(GoogleDocstring(input_string, config)) + + return super().parse(google_parsed, document) + + +Parser = NapoleonParser diff --git a/docs/cluster.md b/docs/cluster.md new file mode 100644 index 0000000000..5d2f5cb4b8 --- /dev/null +++ b/docs/cluster.md @@ -0,0 +1,91 @@ +# Cluster start + +- [Cluster start](#cluster-start) + - [Slurm](#slurm) + - [Batched Job Submission](#batched-job-submission) + - [Interactive Launching](#interactive-launching) + - [Kubernetes](#kubernetes) + +## Slurm + +:::{tip} +It is important to set `UV_CACHE_DIR` to a directory that can be read from all workers before +running any `uv run` command. This ensures a fast startup time since all workers can re-use the same cache. + +```sh +export UV_CACHE_DIR=/path/that/all/workers/can/access/uv_cache +``` +::: + +### Batched Job Submission + +```sh +# Run from the root of NeMo-Reinforcer repo +NUM_ACTOR_NODES=1 # Total nodes requested are $NUM_ACTOR_NODES + 1 (+1 for head node) + +COMMAND="bash -c 'uv pip install -e .; uv run ./examples/run_grpo.py'" \ +RAY_DEDUP_LOGS=0 \ +UV_CACHE_DIR=YOUR_UV_CACHE_DIR \ +CONTAINER=YOUR_CONTAINER \ +MOUNTS="$PWD:$PWD" \ +sbatch \ + --nodes=$((NUM_ACTOR_NODES + 1)) \ + --account=YOUR_ACCOUNT \ + --job-name=YOUR_JOBNAME \ + --partition=YOUR_PARTITION \ + --time=1:0:0 \ + --gres=gpu:8 \ + ray.sub +``` + +Notes: +* Some clusters may or may not need `--gres=gpu:8` to be added to the `sbatch` command. +* Setting `UV_CACHE_DIR` to a shared directory accessible by all worker nodes is critical for performance. Without this, the `uv` package manager will need to synchronize dependencies separately for each worker, which can significantly increase startup times and create unnecessary network traffic. + +Which will print the `SLURM_JOB_ID`: +```text +Submitted batch job 1980204 +``` +Make note of the the job submission number. Once the job begins you can track it's process in the driver logs which you can `tail`: +```sh +tail -f 1980204-logs/ray-driver.log +``` + +### Interactive Launching +To run interactively, launch the same command as the [Batched Job Submission](#batched-job-submission) except omit the `COMMAND` line: +```sh +# Run from the root of NeMo-Reinforcer repo +NUM_ACTOR_NODES=1 # Total nodes requested are $NUM_ACTOR_NODES + 1 (+1 for head node) + +RAY_DEDUP_LOGS=0 \ +UV_CACHE_DIR=YOUR_UV_CACHE_DIR \ +CONTAINER=YOUR_CONTAINER \ +MOUNTS="$PWD:$PWD" \ +sbatch \ + --nodes=$((NUM_ACTOR_NODES + 1)) \ + --account=YOUR_ACCOUNT \ + --job-name=YOUR_JOBNAME \ + --partition=YOUR_PARTITION \ + --time=1:0:0 \ + --gres=gpu:8 \ + ray.sub +``` +Which will print the `SLURM_JOB_ID`: +```text +Submitted batch job 1980204 +``` +Once the ray cluster is up, a script should be created to attach to the ray head node, +which you can use launch experiments. +```sh +bash 1980204-attach.sh +``` +Now that you are on the head node, you can launch the command like so: +```sh +uv venv -p python3.12.9 .venv +uv pip install -e . +uv run ./examples/run_grpo.py +``` + +## Kubernetes + +TBD diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000000..e800a2595d --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,76 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +import os +import sys + +project = "NeMo-Reinforcer" +copyright = "2025, NVIDIA Corporation" +author = "NVIDIA Corporation" +release = "0.0.1" + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +extensions = [ + "myst_parser", # For our markdown docs + "autodoc2", # Generates API docs + "sphinx.ext.viewcode", # For adding a link to view source code in docs + "sphinx.ext.doctest", # Allows testing in docstrings + "sphinx.ext.napoleon", # For google style docstrings + "sphinx_copybutton", # For copy button in code blocks +] + +templates_path = ["_templates"] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] + +# -- Options for MyST Parser (Markdown) -------------------------------------- +# MyST Parser settings +myst_enable_extensions = [ + "dollarmath", # Enables dollar math for inline math + "amsmath", # Enables LaTeX math for display mode + "colon_fence", # Enables code blocks using ::: delimiters instead of ``` + "deflist", # Supports definition lists with term: definition format + "fieldlist", # Enables field lists for metadata like :author: Name + "tasklist", # Adds support for GitHub-style task lists with [ ] and [x] +] +myst_heading_anchors = 3 # Generates anchor links for headings up to level 3 + +# -- Options for Autodoc2 --------------------------------------------------- +sys.path.insert(0, os.path.abspath("..")) + +autodoc2_packages = [ + "../nemo_reinforcer", # Path to your package relative to conf.py +] +autodoc2_render_plugin = "myst" # Use MyST for rendering docstrings +autodoc2_output_dir = "apidocs" # Output directory for autodoc2 (relative to docs/) +# This is a workaround that uses the parser located in autodoc2_docstrings_parser.py to allow autodoc2 to +# render google style docstrings. +# Related Issue: https://github.com/sphinx-extensions2/sphinx-autodoc2/issues/33 +autodoc2_docstring_parser_regexes = [ + (r".*", "autodoc2_docstrings_parser"), +] + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_theme = "nvidia_sphinx_theme" diff --git a/docs/design_docs/chat_datasets.md b/docs/design_docs/chat_datasets.md new file mode 100644 index 0000000000..43e2801fdc --- /dev/null +++ b/docs/design_docs/chat_datasets.md @@ -0,0 +1,61 @@ +# Data Format + +## HuggingFace Chat Datasets + +HuggingFace chat datasets are expected to have the following structure: Each example in the dataset should be a dictionary with a `messages` key. `messages` should be a list of dictionaries, each with a `role` and `content` key. `role` is typically one of `system`, `user`, and `assistant`. For example: + +```json +{ + "messages": [ + { + "role": "system", + "content": "This is a helpful system message." + }, + { + "role": "user", + "content": "This is a user's question" + }, + { + "role": "assistant", + "content": "This is the assistant's response." + } + ] +} +``` + +### Chat Templates + +Formatting the data in this way allows us to take advantage of HuggingFace tokenizers' `apply_chat_template` functionality to combine the messages. Chat templates can be used to add special tokens or task-specific information to each example in the dataset. Refer to the [HuggingFace apply_chat_template documentation](https://huggingface.co/docs/transformers/main/en/chat_templating#applychattemplate) for details. + +By default, `apply_chat_template` attempts to apply the `chat_template` associated with the tokenizer. However, in some cases, users might want to specify their own chat template. Also, note that many tokenizers do not have associated `chat_template`s, in which case an explicit chat template is required. Users can specify an explicit chat template string using Jinja format and can pass that string to `apply_chat_template`. +The following is an example using a simple template which prepends a role header to each turn: + +```{testcode} +from transformers import AutoTokenizer + +example_template = "{% for message in messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{{ content }}{% endfor %}" + +example_input = [ + { + 'role': 'user', + 'content': 'Hello!' + }, + { + 'role': 'assistant', + 'content': 'Hi there!' + } +] +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") +output = tokenizer.apply_chat_template(example_input, chat_template=example_template, tokenize=False) + +## this is the output string we expect +expected_output = '<|start_header_id|>user<|end_header_id|>\n\nHello!<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi there!<|eot_id|>' +assert output == expected_output +``` + + +```{testoutput} +:hide: +``` + +For more details on creating chat templates, refer to the [HuggingFace documentation](https://huggingface.co/docs/transformers/v4.34.0/en/chat_templating#how-do-i-create-a-chat-template). \ No newline at end of file diff --git a/docs/design_docs/design_and_philosophy.md b/docs/design_docs/design_and_philosophy.md new file mode 100644 index 0000000000..ba1c2e28c9 --- /dev/null +++ b/docs/design_docs/design_and_philosophy.md @@ -0,0 +1,95 @@ +# Design and Philosophy +In this section, we will describe the problems this library aims to solve and motivate/dicuss the Reinforcer APIs. + +## Motivation +Online RL requires coordinating a lot of different pieces of software/models +- Policy Model/Training Framework +- Fast inference Framework (vLLM, SGLANG, TRT-LLM) +- Reward Environments, Critics, etc. + +We refer to each of these pieces of software as an **RL Actor**. + +[TODO @sahilj Diagram] + +Fundamentally, we need to be able to do 4 things between these RL Actors: +- Resource them (provide GPUs/CPUs) +- Isolate them + - RL Actors may each set global variables or have conflicting dependencies, so they each need to live in an isolated process environment with configurable dependencies +- Coordinate them (control) +- Communicate between them (data) + +## Design + +We create composable and hackable abstractions for each layer of the tasks above +- Resourcing -> {py:class}`RayVirtualCluster ` +- Isolation -> {py:class}`RayWorkerGroup ` +- Coordination -> A Single-Process Controller using Ray +- Communication -> Data flows through one of the following: + - the single controller + - a communication scheme set-up by the controller such as + - NCCL Collectives + - Multiprocess Queues + +By creating a common interface for these 4 tasks, **RL algorithm code looks the same from 1 GPU to 1000 GPUs and does not care about the implementation of each RL Actor (Megatron, HF, Grad student with pen and paper)** + +### {py:class}`RayVirtualCluster ` +VirtualCluster provides a basic abstraction on top of Ray Placement Groups that allow you to section off a part of your compute resources for WorkerGroups to run on as though they had their own cluster. They support running just one WorkerGroup on each VirtualCluster, or *colocation*, where multiple WorkerGroups share resources (i.e running policy training(hf) and generation(vllm) on the same GPUs in-turn). + +Minimally, it has has the following core API: +```python +class RayVirtualCluster: +""" + Creates a virtual distributed cluster using Ray placement groups. + + This class simplifies distributed training setup by: + - Creating placement groups that represent logical compute nodes + - Allocating GPU and CPU resources for distributed workers + - Managing communication between distributed processes + + - Bundle: A resource allocation unit (ex: 4 GPUs on a single node) + - Worker: A process that performs computation (model training/inference) + - Node: A physical or virtual machine containing multiple bundles +""" + def __init__(self, bundle_ct_per_node_list: List[int], {other args}): + """ + Initialize a virtual cluster using Ray placement groups. + + Args: + bundle_ct_per_node_list: List specifying GPU bundles per node + (e.g., [2,2] creates 2 nodes with 2 GPU bundles each) + """ + def get_placement_groups(self): + """ + Returns a list of placement groups that have at least one bundle, filtering out empty nodes. + This represents the "virtual cluster" - only nodes that are actually being used. + + Returns: + List of placement groups that have at least one bundle + """ +``` + +### {py:class}`RayWorkerGroup ` +All work is done by "Worker Processes"(Ray Actors) that run on a small unit of resources (usually 1 CPU or 1 CPU+GPU). These workers are managed by *RayWorkerGroup* +```python +class RayWorkerGroup: + """ + Manages a group of distributed Ray worker/actor processes that execute tasks in parallel. + + This class creates and manages Ray actor instances that run on resources + allocated by a RayVirtualCluster. It handles: + - Worker creation and placement on specific GPU resources + - Setting up distributed training environment variables (rank, world size, etc.) + - Executing methods across all workers in parallel + - Collecting and aggregating results + - Support for tied worker groups where multiple workers process the same data + """ +``` +[TODO @sahilj Diagram] + + + +### Single-Controller & Execution Diagram + +## Walking through an implementation of GRPO + + diff --git a/docs/design_docs/generation.md b/docs/design_docs/generation.md new file mode 100644 index 0000000000..e9fa3ee2ee --- /dev/null +++ b/docs/design_docs/generation.md @@ -0,0 +1,137 @@ +# Generation Module + +This doc explains the token generation interface and various backends for the NeMo Reinforcer framework. The generation system is designed with a unified interface that allows different backends (like VLLM, HuggingFace, SGLang, TRT-LLM) to provide token generation capabilities while adhering to the same API. + +## Generation Interface + +The core of the generation system is defined in `interfaces.py`, which establishes an abstract interface that all generation backends must implement. This ensures consistency across different implementations and makes it easy to swap backends without changing the calling code. + +### Key Components + +1. **GenerationConfig**: A TypedDict that defines the configuration for generation: + ```python + class GenerationConfig(TypedDict): + """Configuration for generation.""" + backend: str # The backend to use (e.g., "vllm", "hf") + max_new_tokens: int # Maximum number of tokens to generate + temperature: float # Sampling temperature + top_p: float # Top-p sampling parameter + top_k: int # Top-k sampling parameter + model_name: str # Name or path of the model + ``` + +2. **GenerationDatumSpec**: A TypedDict that defines the input data format: + ```python + class GenerationDatumSpec(TypedDict): + input_ids: torch.Tensor # Input token IDs + attention_mask: torch.Tensor # Attention mask + __extra__: Any # Additional data specific to the backend + ``` + +3. **GenerationOutputSpec**: A TypedDict that defines output data format: + ```python + class GenerationOutputSpec(TypedDict): + output_ids: torch.Tensor + generation_lengths: torch.Tensor # Length of just the generated response part + unpadded_sequence_lengths: torch.Tensor # Length of full valid sequence (input + generated response) + logprobs: torch.Tensor + __extra__: Any # Additional output data specific to the backend + ``` + +4. **GenerationInterface**: An abstract base class that all generation backends must implement: + ```python + class GenerationInterface(ABC): + """Abstract base class defining the interface for RL policies.""" + + @abstractmethod + def generate( + self, data: BatchedDataDict["GenerationDatumSpec"], greedy: bool + ) -> BatchedDataDict["GenerationOutputSpec"]: + pass + + @abstractmethod + def prepare_for_generation(self, *args, **kwargs): + pass + + @abstractmethod + def finish_generation(self, *args, **kwargs): + pass + ``` + +A key thing to note about generation backends is that the generation backend takes in tokens and gives out tokens without dealing with the tokenizer. By ensuring that only tokens are communicated we eliminate the possibility of having different tokenizers (different versions/specs etc) for training and generation framework. + +## VLLM Backend + +The VLLM backend (`models/generation/vllm.py`) implements the {py:class}`GenerationInterface ` to provide efficient text generation using the VLLM library, which is optimized for large language models. + +### VllmGeneration Class + +The {py:class}`VllmGeneration ` class is the main implementation of the {py:class}`GenerationInterface ` for VLLM. It: + +1. Sets up VLLM workers in a distributed environment using Ray +2. Manages the lifecycle of these workers (initialization, generation, shutdown) +3. Distributes inputs to workers and collects outputs +4. Handles weight updates and synchronization + +### VllmGenerationWorker + +The {py:class}`VllmGenerationWorker ` is a Ray actor that: + +1. Initializes and manages a VLLM model instance +2. Performs the actual generation on a GPU +3. Supports dynamic weight updates through IPC handles +4. Implements sleep/wake mechanisms for efficient resource utilization + +### Custom VLLM Extensions + +The {py:class}`UpdatableVllmInternalWorker ` class in `vllm_backend.py` extends the VLLM worker with additional capabilities: + +1. Reporting device IDs to allow mapping of workers to specific GPUs +2. Updating weights from IPC handles for efficient weight sharing +3. Checking if weights have been updated correctly + +## Usage Example + +To use a generation backend: + +```python +from nemo_reinforcer.models.generation.vllm import VllmGeneration, VllmConfig +from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict + +# Set up the configuration +config = VllmConfig( + backend="vllm", + model_name="Qwen/Qwen2.5-1.5B", + max_new_tokens=100, + temperature=0.7, + top_p=1, + vllm_cfg={ + "tensor_parallel_size": 1, + "gpu_memory_utilization": 0.8 + } +) + +# Initialize the cluster and generation backend +cluster = RayVirtualCluster(...) +generator = VllmGeneration(cluster, config) + +# Prepare input data +input_data = BatchedDataDict(...) + +# Generate text +generator.prepare_for_generation() +output = generator.generate(input_data, greedy=False) +generator.finish_generation() +``` + +## Extending with New Backends + +To add a new generation backend: + +1. Create a new class that implements {py:class}`GenerationInterface ` +2. Implement the required methods: {py:method}`generate `, {py:method}`prepare_for_generation `, and {py:method}`finish_generation ` +3. Ensure your implementation works with the standard {py:class}`GenerationConfig ` and {py:class}`GenerationDatumSpec ` structures +4. Register your backend with the system (if needed) to make it accessible + +This modular design allows for easy extension with new backends while maintaining a consistent interface for the rest of the system. diff --git a/docs/design_docs/logger.md b/docs/design_docs/logger.md new file mode 100644 index 0000000000..cf55c442e4 --- /dev/null +++ b/docs/design_docs/logger.md @@ -0,0 +1,80 @@ +# Logger + +## Requirements: + +* Tracking distributed metrics with specified reductions (mean, max, etc) +* Tracking distributed timing with (usually) 'max' reduction across ranks +* Logging: + * WandB + * Tensorboard + +## Overall Design + +Since there is a single controller, the single process running the main training loop will gather the metrics and do the logging. + +To handle multiple logger backends, we will have a {py:class}`LoggerInterface ` interface that the {py:class}`TensorboardLogger ` and {py:class}`WandbLogger ` will implement: + +```python +class LoggerInterface(ABC): + """Abstract base class for logger backends.""" + + @abstractmethod + def log_metrics(self, metrics: Dict[str, Any], step: int, prefix: Optional[str]: "") -> None: + """Log a dictionary of metrics.""" + pass + + @abstractmethod + def log_hyperparams(self, params: dict[str, Any]) -> None: + """Log dictionary of hyperparameters.""" + pass +``` + +A {py:class}`Logger ` wrapper class will also implement {py:class}`LoggerInterface ` and will contain a list of loggers it delegates to when writing logs. This will be the main class the user uses in the training loop. Usage example: + +```python +# Initialize logger with both wandb and tensorboard enabled +logging_config = { + "wandb_enabled": True, + "tensorboard_enabled": False, + + "wandb": { + "project": "grpo-dev", + "name": "grpo-dev-logging", + }, + "tensorboard": { + "log_dir": "logs", + }, +} +logger = Logger( + cfg=logger_config, +) + +# Log metrics, will go to both wandb and tensorboard +logger.log_metrics({ + "loss": 0.123, +}, step=10) +``` + +## Validation Pretty Logging + +The logger supports pretty-formatted logging of validation samples to help visualize model outputs during training. This feature is controlled by the `num_val_samples_to_print` configuration parameter: + +```python +logger: + wandb_enabled: false + tensorboard_enabled: false + num_val_samples_to_print: 10 +``` + +When `num_val_samples_to_print` is set to a value greater than 0, the logger will generate well-formatted text outputs for the specified number of validation samples. This is particularly useful for: + +1. Quickly inspecting model generation quality during training +2. Comparing inputs and outputs side-by-side +3. Tracking validation sample performance over time + +### Example Output + +When enabled, the pretty logging will generate formatted text similar to: + +![Validation Pretty Logging Example](../assets/val-log.png) + diff --git a/docs/design_docs/padding.md b/docs/design_docs/padding.md new file mode 100644 index 0000000000..d5949cf3b5 --- /dev/null +++ b/docs/design_docs/padding.md @@ -0,0 +1,98 @@ +# Padding in NeMo Reinforcer + +## Overview + +This document explains padding in NeMo Reinforcer and why consistent padding is critical for the framework. + +## Padding Approach + +NeMo Reinforcer uses **right padding** for all tensor operations, where padding tokens are added to the right/end of sequences: + +``` +[101, 2054, 2003, 0, 0] # Length 3 +[101, 2054, 2003, 2001, 1996] # Length 5 (no padding needed) +[101, 2054, 0, 0, 0] # Length 2 +``` + +This approach: +1. **Naturally aligns with LLM processing**: Tokens are processed from left to right +2. **Keeps meaningful tokens contiguous**: All valid tokens appear at the beginning of tensors +3. **Simplifies indexing and operations**: Valid token boundaries are easily defined with a single length value + +## Right-Padded Generation Example + +Input (right-padded) β†’ Generation β†’ Final (right-padded): +``` +[101, 2054, 2003, 0, 0] # Original input (length 3) + ↓ +[101, 2054, 2003, 2001, 1996, 4568, 7899, 0] # After generation +|-- input --| |----- generation -----| |pad| +``` + +Corresponding logprobs: +``` +[ 0, 0, 0, -1.2, -0.8, -1.5, -2.1, 0] +|-- zeros for input --| |- gen logprobs -| |pad| +``` + +## Verifying Right Padding + +NeMo Reinforcer provides utilities to verify correct padding: + +```{testcode} +import torch +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.models.generation.interfaces import verify_right_padding + +# For input data (BatchedDataDict containing input_ids and input_lengths) +input_data = BatchedDataDict({ + "input_ids": torch.tensor([ + [101, 2054, 2003, 0, 0], # Example input sequence + [101, 2054, 0, 0, 0] # Another input sequence + ]), + "input_lengths": torch.tensor([3, 2]) # Length of each sequence +}) + +# Check if input data is properly right-padded +is_right_padded, error_msg = verify_right_padding(input_data, pad_value=0) + +# For generation output data (BatchedDataDict containing output_ids and generation_lengths) +output_data = BatchedDataDict({ + "output_ids": torch.tensor([ + [101, 2054, 2003, 2001, 1996, 0, 0], # Example output sequence + [101, 2054, 2001, 4568, 0, 0, 0] # Another output sequence + ]), + "generation_lengths": torch.tensor([2, 2]), # Length of generated response + "unpadded_sequence_lengths": torch.tensor([5, 4]) # Total valid tokens +}) + +# Check if output data is properly right-padded +is_right_padded, error_msg = verify_right_padding(output_data, pad_value=0) + +if not is_right_padded: + print(f"Padding error: {error_msg}") +``` + + +```{testoutput} +:hide: +``` + +The {py:class}`verify_right_padding() ` function checks that: +1. All padding (zeros or padding token provided by the user) appears after valid tokens +2. The padding starts at the position specified by the length tensor + +The function automatically detects whether you're passing input or output data: +- For input data: Requires `input_ids` and `input_lengths` fields +- For output data: Requires `output_ids` and either `generation_lengths` or `unpadded_sequence_lengths` + + +## Best Practices + +1. **Always Use Right Padding**: All components expect this format + +2. **Track Length Tensors**: Include appropriate length tensors with your data + +3. **Verify Padding**: Use {py:class}`verify_right_padding() ` when in doubt + +4. **Mask Padding in Operations**: Use lengths to exclude padding tokens from loss calculations diff --git a/docs/design_docs/uv.md b/docs/design_docs/uv.md new file mode 100644 index 0000000000..d4a88d99f7 --- /dev/null +++ b/docs/design_docs/uv.md @@ -0,0 +1,78 @@ +# `uv` in NeMo-Reinforcer + +Using `uv` for Dependency Management in NeMo-Reinforcer + +## Overview + +`uv` is an incredible tool that simplifies our workflow and is blazingly fast because it's written in Rust. This document outlines why we've adopted `uv` for package management in our repository, particularly for NeMo Reinforcer, and how it helps us manage dependencies across Ray clusters. + +## Why `uv`? + +### Speed and Efficiency + +- Written in Rust, making it significantly faster than traditional Python package managers +- Optimized caching mechanisms that reduce redundant downloads and installations +- Quick environment creation and switching, enabling rapid development cycles + +### Isolated Environments + +- Creates fully isolated Python environments, preventing dependency conflicts between system packages and project-specific packages +- Avoids nuanced dependency situations where a Python script might accidentally use both virtualenv dependencies and system dependencies +- Ensures consistent behavior across different machines and deployment environments + +### Dependency Management in Ray Clusters + +- Enables management of heterogeneous Python environments across a Ray cluster +- Provides flexibility for each actor (worker) to use the specific Python dependencies it requires +- Simplifies propagation of environments to worker nodes without manual setup on each node + +### Container-Free Flexibility + +- Frees us from having to publish many containers for different dependency combinations +- Allows us to define different [dependency groups](https://docs.astral.sh/uv/concepts/projects/dependencies/#dependency-groups) and [extras](https://docs.astral.sh/uv/concepts/projects/dependencies/#optional-dependencies) and select which ones we need dynamically +- Reduces infrastructure complexity and maintenance overhead + +## Implementation in NeMo Reinforcer + +### Worker Configuration + +In our codebase, workers (classes decorated with `@ray.remote`, e.g., `HFPolicyWorker`) define a `DEFAULT_PY_EXECUTABLE` which specifies what dependencies the worker needs. This allows different parts of our application to have their own tailored environments. + +### Supported Python Executables + +We provide several predefined Python executable configurations in {py:class}`PY_EXECUTABLES `: + +```python +# --with-editable .: speeds up the install slightly since editable installs don't require full copies +# --cache-dir $UV_CACHE_DIR: caching isn't propagated by default. This will set it if the user has set it. +class PY_EXECUTABLES: + # This uses the .venv created by `uv`. This is the fastest option, but provides no isolation between workers. + DEFAULT_VENV = f"{os.environ['VIRTUAL_ENV']}/bin/python" + + # TODO: Debug high run-to-run variance latency with these options + # Use NeMo-Reinforcer direct dependencies and nothing from system + DEFAULT = f"uv run --isolated --with-editable . {uv_cache_flag}" + # Use none of NeMo-Reinforcer's dependencies or the system. Useful for workers that only need standard python packages. + BARE_BONES = f"uv run --isolated --no-project --with-editable . {uv_cache_flag}" +``` + +At the moment we **highly recommend** {py:class}`DEFAULT_ENV ` as it results in the fastest bringup of your workload if you are using the `transformers` library and `vllm`. + +### Customization + +If you need a different Python executable configuration, you can override the default one by passing your own in {py:class}`RayWorkerBuilder.__call__ `. This provides flexibility for special use cases without modifying the core configurations. + +## How It Works + +When a Ray job is started: + +1. The driver process runs in the `uv` environment specified at launch +2. Ray detects this environment and propagates it to worker processes +3. Each worker can specify its own environment through `py_executable` in its runtime environment +4. `uv` efficiently sets up these environments on each worker, using caching to minimize setup time + +This approach ensures consistent environments across the cluster while allowing for worker-specific customization when needed. + +## Conclusion + +Using `uv` for dependency management in NeMo Reinforcer provides us with a fast, flexible, and reliable way to handle Python dependencies across distributed Ray clusters. It eliminates many of the traditional pain points of dependency management in distributed systems while enabling heterogeneous environments that can be tailored to specific workloads. diff --git a/docs/docker.md b/docs/docker.md new file mode 100644 index 0000000000..e270372666 --- /dev/null +++ b/docs/docker.md @@ -0,0 +1,7 @@ +# Building Docker Image + +## Docker Build +```sh +cd docker/ +docker buildx build -t nemo-reinforcer -f Dockerfile . +``` diff --git a/docs/documentation.md b/docs/documentation.md new file mode 100644 index 0000000000..668056b13b --- /dev/null +++ b/docs/documentation.md @@ -0,0 +1,74 @@ +# Documentation Development + +- [Documentation Development](#documentation-development) + - [Building](#building) + - [Live Building](#live-building) + - [Running Tests in Python Docstrings](#running-tests-in-python-docstrings) + - [Writing Tests in Python Docstrings](#writing-tests-in-python-docstrings) + + +## Building + +The following sections describe how to set up and build the NeMo-Reinforcer documentation. + +Switch to the documentation source folder and generate HTML output. + +```sh +cd docs/ +uv run --extra docs sphinx-build . _build/html +``` + +* The resulting HTML files are generated in a `_build/html` folder that is created under the project `docs/` folder. +* The generated python API docs are placed in `apidocs` under the `docs/` folder. + +## Live Building + +When writing documentation it can be helpful to serve the documentation and have it update live while you edit. + +To do so run: + +```sh +cd docs/ +uv run --extra docs sphinx-autobuild . _build/html --port 12345 --host 0.0.0.0 +``` + +Open a web browser and go to `http://${HOST_WHERE_SPHINX_COMMAND_RUN}:12345` to view the output. + + +## Running Tests in Python Docstrings + +We also run tests in our python docstrings. You can run them with: + +```sh +cd docs/ +uv run --extra docs sphinx-build -b doctest . _build/doctest +``` + +## Writing Tests in Python Docstrings + +Any code in triple backtick blocks with the `{doctest}` directive will be tested. The format follows Python's doctest module syntax, where `>>>` indicates Python input and the following line shows the expected output. Here's an example: + +```python +def add(x: int, y: int) -> int: + """ + Adds two integers together. + + Args: + x (int): The first integer to add. + y (int): The second integer to add. + + Returns: + int: The sum of x and y. + + Examples: + ```{doctest} + >>> from nemo_reinforcer.made_up_package import add + >>> add(1, 2) + 3 + ``` + + """ + return x + y +``` + + diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md new file mode 100644 index 0000000000..0925548d0d --- /dev/null +++ b/docs/guides/grpo.md @@ -0,0 +1,3 @@ +# GRPO + +placeholder TBD diff --git a/docs/guides/sft.md b/docs/guides/sft.md new file mode 100644 index 0000000000..534c6b1702 --- /dev/null +++ b/docs/guides/sft.md @@ -0,0 +1,65 @@ +# Supervised Fine-tuning in Reinforcer + +## Launch an SFT Run + +The script [examples/run_sft.py](../../examples/run_sft.py) can be used to launch an experiment. This script can either be launched locally or via Slurm. For details on how to set up Ray and launch a job using Slurm, refer to the [cluster documentation](../cluster.md). + +Be sure to launch the job using `uv`. The command to launch an SFT job is as follows: +```bash +uv run examples/run_sft.py --config --output-dir +``` +If not specified, `config` will default to [examples/configs/sft.yaml](../../examples/configs/sft.yaml) and `output-dir` will default to `./outputs`. + +## Configuration + +Reinforcer allows users to configure experiments using `yaml` config files. An example SFT configuration file can be found [here](../../examples/configs/sft.yaml). + +To override a value in the config, either update the value in the `yaml` file directly, or pass the override via the command line. For example: + +```bash +python examples/run_sft.py \ + data.max_input_seq_length=8192 \ + logger.wandb.name="sft-dev-sl-8192" +``` + +## Datasets + +SFT datasets in Reinforcer are encapsulated using classes. Each SFT data class is expected to have the following attributes: + - `formatted_ds`: The dictionary of formatted datasets. This dictionary should contain `train` and `validation` splits, and each split should conform to the format described below. + - `task_spec`: The `TaskDataSpec` for this dataset. This should specify the name you choose for this dataset as well as the `custom_template` for this dataset. More on custom templates below. + +SFT datasets are expected to follow the HuggingFace chat format. Refer to the [chat dataset document](../design_docs/chat_datasets.md) for details. If your data is not in the correct format, simply write a preprocessing script to convert the data into this format. [data/hf_datasets/squad.py](../../nemo_reinforcer/data/hf_datasets/squad.py) has an example: + +```python +def format_squad(data): + return { + "messages": [ + { + "role": "system", + "content": data["context"], + }, + { + "role": "user", + "content": data["question"], + }, + { + "role": "assistant", + "content": data["answers"]["text"][0], + }, + ] + } +``` + +Reinforcer SFT uses HuggingFace chat templates to format the individual examples. If you would like to use a custom template, create a string template in [jinja format](https://huggingface.co/docs/transformers/v4.34.0/en/chat_templating#how-do-i-create-a-chat-template) and pass it to the dataset's `TaskDataSpec`. For example, + +```python +custom_template = ( + "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer: '}}{%- elif message['role'] == 'assistant' %}{{message['content'].strip()}}{%- endif %}{% endfor %}" +) +task_spec = TaskDataSpec( + task_name="squad", + custom_template=custom_template, +) +``` + +By default, NeMo-Reinforcer has support for `Squad` and `OpenAssistant` datasets. If you would like to use a custom dataset, create a new dataset class with the expected attributes. diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000000..221b9d31a2 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,49 @@ +```{include} ../README.md +:relative-docs: docs/ +``` + +```{toctree} +:caption: πŸ–₯️ Environment Start +:hidden: + +local_workstation.md +cluster.md + +``` + +```{toctree} +:caption: πŸ“š Guides +:hidden: + +guides/sft.md +guides/grpo.md +``` + +```{toctree} +:caption: 🐳 Containers +:hidden: + +docker.md +``` + +```{toctree} +:caption: πŸ› οΈ Development +:hidden: + +adding_new_models.md +testing.md +documentation.md +apidocs/index.rst +``` + +```{toctree} +:caption: πŸ“ Design Docs +:hidden: + +design_docs/design_and_philosophy.md +design_docs/padding.md +design_docs/logger.md +design_docs/uv.md +design_docs/chat_datasets.md +design_docs/generation.md +``` diff --git a/docs/local_workstation.md b/docs/local_workstation.md new file mode 100644 index 0000000000..3e252694a0 --- /dev/null +++ b/docs/local_workstation.md @@ -0,0 +1,25 @@ +# Local Workstation + +## Launching Locally + +When launching examples locally with `uv`, {py:class}`init_ray() ` will first attempt to connect to an existing cluster. If none is found, it will start a local one and connect to it using all available GPU and CPU resources on your node. + +To launch a job outside of a container, simply run: + +```sh +uv run examples/run_grpo.py +``` + +In the logs, you will see that Ray has started a local cluster instance, along with details on the resources made available to it: +``` +2025-03-17 13:37:45,360 INFO worker.py:1841 -- Started a local Ray instance. +... +INFO:nemo_reinforcer.distributed.virtual_cluster:Started local cluster with: {'node:__internal_head__': 1.0, 'CPU': 24.0, 'object_store_memory': 80448493977.0, 'accelerator_type:RTX': 1.0, 'memory': 177713152615.0, 'GPU': 1.0, 'node:10.0.0.1': 1.0} +``` + +To control the GPUs ray uses locally more granularly, please use `CUDA_VISIBLE_DEVICES`: + +```sh +# Use the 0th and 3rd indexed GPU (for a total of 2 GPUs) +CUDA_VISIBLE_DEVICES=0,3 uv run examples/run_grpo.py +``` diff --git a/docs/testing.md b/docs/testing.md new file mode 100644 index 0000000000..ec184ab0c0 --- /dev/null +++ b/docs/testing.md @@ -0,0 +1,24 @@ +# Testing NeMo-Reinforcer + +## Unit Tests + +```sh +uv pip install -e '.[test]' +uv run bash tests/run_unit.sh +``` + +### Run Unit Tests Hermetic + +If your local environment does not have all the necessary dependencies (e.g., `gcc`, `nvcc`) +or there is concern that something in your environment may be misconfigured, you can also run +the tests in docker with this script: + +```sh +CONTAINER=... bash tests/run_unit_in_docker.sh +``` + +The `CONTAINER` can be built by following the instructions [here](docker.md). + +## Functional tests + +TBD diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/examples/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/configs/base.yaml b/examples/configs/base.yaml new file mode 100644 index 0000000000..ded5e8255c --- /dev/null +++ b/examples/configs/base.yaml @@ -0,0 +1,51 @@ +# Base configuration with common settings +policy: + model_name: "meta-llama/Llama-3.2-1B-Instruct" + train_global_batch_size: 32 + train_micro_batch_size: 4 + generation_batch_size: 32 + learning_rate: 5.0e-6 + logprob_batch_size: 4 + max_total_sequence_length: 8192 + + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.1 + end_factor: 1.0 + total_iters: 50 + - name: "torch.optim.lr_scheduler.ConstantLR" + kwargs: + factor: 1.0 + total_iters: 10000000000 + - milestones: [50] + + generation: + backend: "vllm" + max_new_tokens: ${policy.max_total_sequence_length} + temperature: 1.0 + top_p: 1.0 + top_k: null + vllm_cfg: + tensor_parallel_size: 1 + gpu_memory_utilization: 0.7 + max_model_len: ${policy.max_total_sequence_length} + +data: + max_input_seq_length: ${policy.max_total_sequence_length} + prompt_file: "examples/prompts/cot.txt" + system_prompt_file: null + +logger: + log_dir: "logs" # Base directory for all logs + num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal + wandb_enabled: false + tensorboard_enabled: false + wandb: + project: "grpo-dev" + name: "grpo-dev-logger" + tensorboard: {} + +cluster: + gpus_per_node: 2 + num_nodes: 1 \ No newline at end of file diff --git a/examples/configs/grpo.yaml b/examples/configs/grpo.yaml new file mode 100644 index 0000000000..4436421009 --- /dev/null +++ b/examples/configs/grpo.yaml @@ -0,0 +1,56 @@ +# GRPO Algorithm Configuration +defaults: "base.yaml" + +grpo: + num_prompts_per_step: 8 + num_generations_per_prompt: 8 + num_steps: 100 + normalize_rewards: true + use_leave_one_out_baseline: true + val_period: 10 + val_at_start: true + max_val_samples: 16 + val_batch_size: 16 + +loss_fn: + reference_policy_kl_penalty: 0.01 + ratio_eps: 0.2 + +checkpointing: + enabled: true + checkpoint_dir: "results/grpo" + metric_name: "val_reward" + higher_is_better: true + keep_top_k: 3 + save_period: 10 + +policy: + model_name: "meta-llama/Llama-3.2-1B-Instruct" + train_global_batch_size: 32 + train_micro_batch_size: 4 + generation_batch_size: 32 + logprob_batch_size: 4 + max_total_sequence_length: 1024 + + generation: + backend: "vllm" # "vllm" or "hf"(to use the hf training framework's generation) + max_new_tokens: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len below + temperature: 1.0 + # Don't change since vllm logprobs in V0 runtime are after sampling and in V1 runtime are before sampling. + top_p: 1.0 + top_k: null # disable + vllm_cfg: + tensor_parallel_size: 1 + gpu_memory_utilization: 0.7 + max_model_len: ${policy.max_total_sequence_length} + +data: + max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len + prompt_file: "examples/prompts/cot.txt" + system_prompt_file: null + dataset_name: "datasets/Eurus-2-RL-Data/Eurus-2-RL-Data-math_train.jsonl" + val_dataset_name: "datasets/Eurus-2-RL-Data/Eurus-2-RL-Data-math_val.jsonl" + +env: + math: + num_workers: 8 diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml new file mode 100644 index 0000000000..abe1b7ed98 --- /dev/null +++ b/examples/configs/sft.yaml @@ -0,0 +1,42 @@ +# SFT Algorithm Configuration +sft: + num_steps: 100 + #val_period: 10 + #val_at_start: true + #checkpoint_dir: "results/sft" + +policy: + model_name: "meta-llama/Llama-3.2-1B-Instruct" + train_global_batch_size: 8 + train_micro_batch_size: 2 + learning_rate: 5.0e-6 + max_total_sequence_length: 1024 + + scheduler: + - name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.1 + end_factor: 1.0 + total_iters: 100 + - name: "torch.optim.lr_scheduler.CosineAnnealingLR" + kwargs: + T_max: 100 + - milestones: [50] + +data: + max_input_seq_length: ${policy.max_total_sequence_length} + dataset_name: "open_assistant" + +logger: + log_dir: "logs" # Base directory for all logs + wandb_enabled: true + tensorboard_enabled: false + wandb: + project: "sft-dev" + name: "sft-dev-logger" + tensorboard: + log_dir: "tb_logs" + +cluster: + gpus_per_node: 1 + num_nodes: 1 diff --git a/examples/prompts/cot.txt b/examples/prompts/cot.txt new file mode 100644 index 0000000000..c5e97ff50a --- /dev/null +++ b/examples/prompts/cot.txt @@ -0,0 +1,4 @@ +Think step-by-step to solve the following problem. Output your answer inside of \\boxed{{}} tags.: +{} + +Let's think step-by-step \ No newline at end of file diff --git a/examples/run_grpo_math.py b/examples/run_grpo_math.py new file mode 100644 index 0000000000..b3fadbb563 --- /dev/null +++ b/examples/run_grpo_math.py @@ -0,0 +1,228 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import pprint + +from omegaconf import OmegaConf +from typing import Dict, Any + +from datasets import load_dataset +from transformers import AutoTokenizer +from collections import defaultdict + +from nemo_reinforcer.algorithms.grpo import MasterConfig, grpo_train, setup +from nemo_reinforcer.distributed.virtual_cluster import init_ray +from nemo_reinforcer.utils.config import load_config +from nemo_reinforcer.utils.logger import get_next_experiment_dir +from nemo_reinforcer.data.interfaces import TaskDataSpec, DatumSpec, LLMMessageLogType +from nemo_reinforcer.data import DataConfig +from nemo_reinforcer.models.policy import PolicyConfig +from nemo_reinforcer.data.datasets import AllTaskProcessedDataset, rl_collate_fn +from nemo_reinforcer.environments.math_environment import MathEnvironment + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Run GRPO training with configuration") + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + + # Parse known args for the script + args, remaining = parser.parse_known_args() + + # Convert remaining args to OmegaConf format + overrides = OmegaConf.from_dotlist(remaining) + + return args, overrides + + +# =============================================================================== +# Math Data Processor +# =============================================================================== + + +# this processor expects the datum_dict to have a 'problem' key and an 'expected_answer' key +def math_data_processor( + datum_dict: Dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer, + max_seq_length: int, + idx: int, +) -> DatumSpec: + """Process a datum dictionary (directly loaded from dataset) into a DatumSpec for the Math Environment.""" + problem = datum_dict["problem"] + solution = str(datum_dict["expected_answer"]) + extra_env_info = {"ground_truth": solution} + + template = task_data_spec.custom_template + message_log: LLMMessageLogType = [] + if task_data_spec.system_prompt: + sys_message = {"role": "system", "content": task_data_spec.system_prompt} + message = tokenizer.apply_chat_template( + [sys_message], + chat_template=template, + tokenize=False, + add_generation_prompt=False, + add_special_tokens=False, + ) + sys_message["token_ids"] = tokenizer(message, return_tensors="pt")["input_ids"][ + 0 + ] + message_log.append(sys_message) + user_message = { + "role": "user", + "content": task_data_spec.prompt.format(problem), + } + message = tokenizer.apply_chat_template( + [user_message], + chat_template=template, + tokenize=False, + add_generation_prompt=True, + add_special_tokens=False, + ) + user_message["token_ids"] = tokenizer(message, return_tensors="pt")["input_ids"][0] + user_message["content"] = message + message_log.append(user_message) + + length = sum(len(m["token_ids"]) for m in message_log) + + loss_multiplier = 1.0 + if length > max_seq_length: + # make smaller and mask out + for message in message_log: + message["token_ids"] = message["token_ids"][ + : min(4, max_seq_length // len(message_log)) + ] + loss_multiplier = 0.0 + + output = { + "message_log": message_log, + "length": length, + "extra_env_info": extra_env_info, + "loss_multiplier": loss_multiplier, + "idx": idx, + "task_name": datum_dict["task_name"], + } + return output + + +def setup_data(data_config: DataConfig, policy_config: PolicyConfig, env_configs): + print("\nβ–Ά Setting up data...") + math_task_spec = TaskDataSpec( + task_name="math", + prompt_file=data_config["prompt_file"], + system_prompt_file=data_config["system_prompt_file"], + ) + + base_dataset = load_dataset("json", data_files=data_config["dataset_name"])["train"] + tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"]) + + task_data_processors = defaultdict(lambda: (math_task_spec, math_data_processor)) + task_data_processors["math"] = (math_task_spec, math_data_processor) + + math_env = MathEnvironment.options( + runtime_env={"py_executable": MathEnvironment.DEFAULT_PY_EXECUTABLE} + ).remote(env_configs["math"]) + dataset = AllTaskProcessedDataset( + base_dataset, + tokenizer, + math_task_spec, + task_data_processors, + max_seq_length=data_config["max_input_seq_length"], + ) + + if "val_dataset_name" in data_config and data_config["val_dataset_name"]: + val_dataset = load_dataset("json", data_files=data_config["val_dataset_name"])[ + "train" + ] + val_dataset = AllTaskProcessedDataset( + val_dataset, + tokenizer, + math_task_spec, + task_data_processors, + max_seq_length=data_config["max_input_seq_length"], + ) + else: + val_dataset = None + + task_to_env = defaultdict(lambda: math_env) + task_to_env["math"] = math_env + return dataset, val_dataset, task_to_env, task_to_env, tokenizer + + +def main(): + """Main entry point.""" + # Parse arguments + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join(os.path.dirname(__file__), "configs", "grpo.yaml") + + config = load_config(args.config) + print(f"Loaded configuration from: {args.config}") + + if overrides: + print(f"Overrides: {overrides}") + config = OmegaConf.merge(config, overrides) + + config: MasterConfig = OmegaConf.to_container(config, resolve=True) + print("Applied CLI overrides") + + # Print config + print("Final config:") + pprint.pprint(config) + + # Get the next experiment directory with incremented ID + config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) + print(f"πŸ“Š Using log directory: {config['logger']['log_dir']}") + + init_ray() + + # setup data + dataset, val_dataset, task_to_env, val_task_to_env, tokenizer = setup_data( + config["data"], config["policy"], config["env"] + ) + ( + policy, + policy_generation, + cluster, + dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + grpo_state, + master_config, + ) = setup(config, dataset, val_dataset) + grpo_train( + policy, + policy_generation, + dataloader, + val_dataloader, + tokenizer, + loss_fn, + task_to_env, + val_task_to_env, + logger, + checkpointer, + grpo_state, + master_config, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/run_sft.py b/examples/run_sft.py new file mode 100644 index 0000000000..f8649a9484 --- /dev/null +++ b/examples/run_sft.py @@ -0,0 +1,91 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import pprint + +from omegaconf import OmegaConf + +from nemo_reinforcer.algorithms.sft import MasterConfig, sft_train, setup +from nemo_reinforcer.distributed.virtual_cluster import init_ray +from nemo_reinforcer.utils.logger import get_next_experiment_dir + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Run SFT training with configuration") + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + + # Parse known args for the script + args, remaining = parser.parse_known_args() + + # Convert remaining args to OmegaConf format + overrides = OmegaConf.from_dotlist(remaining) + + return args, overrides + + +def main(): + """Main entry point.""" + # Parse arguments + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join(os.path.dirname(__file__), "configs", "sft.yaml") + + config = OmegaConf.load(args.config) + print(f"Loaded configuration from: {args.config}") + + if overrides: + override_conf = OmegaConf.from_cli() + print(f"Overrides: {override_conf}") + config = OmegaConf.merge(config, override_conf) + + config: MasterConfig = OmegaConf.to_container(config, resolve=True) + print("Applied CLI overrides") + + # Print config + print("Final config:") + pprint.pprint(config) + + config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) + print(f"πŸ“Š Using log directory: {config['logger']['log_dir']}") + + init_ray() + ( + policy, + cluster, + dataloader, + tokenizer, + loss_fn, + master_config, + logger, + sft_task_spec, + ) = setup(config) + sft_train( + policy, + dataloader, + tokenizer, + loss_fn, + master_config, + logger, + sft_task_spec, + ) + + +if __name__ == "__main__": + main() diff --git a/nemo_reinforcer/__init__.py b/nemo_reinforcer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_reinforcer/algorithms/__init__.py b/nemo_reinforcer/algorithms/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py new file mode 100644 index 0000000000..0f120330fa --- /dev/null +++ b/nemo_reinforcer/algorithms/grpo.py @@ -0,0 +1,767 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Tuple, TypedDict, Iterable, Optional, List + +import os +from pathlib import Path +import numpy as np +import ray +import torch +from torchdata.stateful_dataloader import StatefulDataLoader +from transformers import AutoTokenizer + +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.algorithms.utils import calculate_baseline_and_std_per_prompt + +from nemo_reinforcer.environments.interfaces import EnvironmentInterface +from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster +from nemo_reinforcer.data.interfaces import ( + DatumSpec, + LLMMessageLogType, + FlatMessagesType, +) +from nemo_reinforcer.data.datasets import AllTaskProcessedDataset, rl_collate_fn +from nemo_reinforcer.models.policy.hf_policy import HfPolicy +from nemo_reinforcer.models.generation.vllm import VllmGeneration +from nemo_reinforcer.algorithms.loss_functions import ( + ClippedPGLossConfig, + ClippedPGLossDataDict, + ClippedPGLossFn, +) +from nemo_reinforcer.algorithms.interfaces import LossFunction +from nemo_reinforcer.data import DataConfig +from nemo_reinforcer.data.llm_message_utils import ( + get_keys_from_message_log, + batched_message_log_to_flat_message, +) +from nemo_reinforcer.utils.logger import ( + print_message_log_samples, +) +from nemo_reinforcer.distributed.virtual_cluster import ClusterConfig +from nemo_reinforcer.environments.math_environment import MathEnvConfig +from nemo_reinforcer.models.generation.interfaces import ( + GenerationInterface, + GenerationDatumSpec, +) +from nemo_reinforcer.models.interfaces import PolicyInterface +from nemo_reinforcer.models.policy import PolicyConfig +from nemo_reinforcer.utils.logger import Logger, LoggerConfig +from nemo_reinforcer.utils.timer import Timer +from nemo_reinforcer.utils.checkpoint import CheckpointManager, CheckpointingConfig + + +# =============================================================================== +# Configuration +# =============================================================================== + + +class GRPOConfig(TypedDict): + num_prompts_per_step: int + num_generations_per_prompt: int + num_steps: int + normalize_rewards: bool + use_leave_one_out_baseline: bool + val_period: int + val_at_start: bool + checkpoint_dir: str + + +class GRPOSaveState(TypedDict): + step: int + val_reward: float + consumed_samples: int + + +def _default_grpo_save_state() -> GRPOSaveState: + return { + "step": 0, + "val_reward": -99999999.0, + "consumed_samples": 0, + } + + +class MasterConfig(TypedDict): + policy: PolicyConfig + loss_fn: ClippedPGLossConfig + math_env: MathEnvConfig + data: DataConfig + grpo: GRPOConfig + logger: LoggerConfig + cluster: ClusterConfig + checkpointing: CheckpointingConfig + + +# =============================================================================== +# Setup & Initialization +# =============================================================================== + + +def setup( + master_config: MasterConfig, + dataset: AllTaskProcessedDataset, + val_dataset: Optional[AllTaskProcessedDataset], +) -> Tuple[ + PolicyInterface, + GenerationInterface, + RayVirtualCluster, + StatefulDataLoader, + Optional[StatefulDataLoader], + ClippedPGLossFn, + Logger, + CheckpointManager, + GRPOSaveState, + MasterConfig, +]: + """Main entry point for running GRPO algorithm. + + Returns: + Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, logger, master_config, val_dataloader + """ + # Extract individual configs for easier access + policy_config = master_config["policy"] + generation_config = master_config["policy"]["generation"] + loss_config = master_config["loss_fn"] + data_config = master_config["data"] + grpo_config = master_config["grpo"] + logger_config = master_config["logger"] + cluster_config = master_config["cluster"] + + # ========================== + # Checkpointing + # ========================== + checkpointer = CheckpointManager(master_config["checkpointing"]) + last_checkpoint_path = checkpointer.get_latest_checkpoint_path() + grpo_save_state: Optional[GRPOSaveState] = checkpointer.load_training_info( + last_checkpoint_path + ) + if grpo_save_state is None: + grpo_save_state = _default_grpo_save_state() + + # config validation checks + if master_config["checkpointing"]["enabled"]: + assert master_config["checkpointing"]["save_period"] > 0 + assert ( + master_config["checkpointing"]["save_period"] + % master_config["grpo"]["val_period"] + == 0 + ), ( + f"Checkpointing save period {master_config['checkpointing']['save_period']} " + f"must be a multiple of validation period {master_config['grpo']['val_period']}" + f", or we won't know what metric to save!" + ) + + # ========================== + # Data + # ========================== + dataloader = StatefulDataLoader( + dataset, + batch_size=grpo_config["num_prompts_per_step"], + shuffle=False, + collate_fn=rl_collate_fn, + ) + if last_checkpoint_path is not None: + dataloader_state_dict = torch.load( + os.path.join(last_checkpoint_path, "train_dataloader.pt") + ) + dataloader.load_state_dict(dataloader_state_dict) + + print(f" βœ“ Training dataloader loaded with {len(dataset)} samples") + + # Load validation dataset if provided + val_dataloader = None + if "val_dataset_name" in data_config and data_config["val_dataset_name"]: + val_dataloader = StatefulDataLoader( + val_dataset, + batch_size=grpo_config["val_batch_size"], + shuffle=False, + collate_fn=rl_collate_fn, + ) + print(f" βœ“ Validation dataloader loaded with {len(val_dataset)} samples") + + # ========================== + # Cluster + # ========================== + print("\nβ–Ά Setting up compute cluster...") + colocated_inference = generation_config["backend"] != "hf" + cluster = RayVirtualCluster( + name="grpo_policy_cluster", + bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] + * cluster_config["num_nodes"], + use_gpus=True, + num_gpus_per_node=cluster_config["gpus_per_node"], + max_colocated_worker_groups=2 if colocated_inference else 1, + ) + print(f" βœ“ Ray cluster initialized with {cluster_config['num_nodes']} nodes") + + # ========================== + # Training and Inference + # ========================== + print("\nβ–Ά Setting up model and training...") + + # vllm model loading prefers clean environment, initialize policy_generation before policy (#52 will fix this) + backend = generation_config["backend"] + generation_config["model_name"] = policy_config["model_name"] # Needed for vLLM + if backend == "hf": + policy_generation = None + print(f" βœ“ Using HF backend for generation with {policy_config['model_name']}") + elif backend == "vllm": + policy_generation = VllmGeneration(cluster=cluster, config=generation_config) + # Worker groups are not initialized until the first call to run something on workergroups. + # vllm 0.8 fails in initialization if its called in the first training step since it has no clean view of the GPU memory (HF is sharing the same memory). + policy_generation.finish_generation() + print( + f" βœ“ Using vLLM backend for generation with {policy_config['model_name']}" + ) + + policy = HfPolicy( + cluster=cluster, + config=policy_config, + weights_path=Path(last_checkpoint_path) / "policy.pt" + if last_checkpoint_path + else None, + optimizer_path=Path(last_checkpoint_path) / "policy_optimizer.pt" + if last_checkpoint_path + else None, + init_optimizer=True, + ) + + loss_fn = ClippedPGLossFn(loss_config) + logger = Logger(logger_config) + logger.log_hyperparams(master_config) + + print("\n" + "=" * 60) + print(" " * 18 + "SETUP COMPLETE") + print("=" * 60 + "\n") + + return ( + policy, + policy_generation, + cluster, + dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + grpo_save_state, + master_config, + ) + + +# =============================================================================== +# Core Algorithm Functions +# =============================================================================== + + +def refit_policy_generation( + policy: PolicyInterface, + policy_generation: GenerationInterface, +): + """Refit the policy generation interface with the latest policy weights.""" + policy.offload_before_refit() + ipc_handles = policy.get_weights_ipc_handles() + policy_generation.prepare_for_generation() + policy_generation.update_weights(ipc_handles) + policy.offload_after_refit() + + +def generate_responses( + policy_generation: GenerationInterface, + generation_input_data: BatchedDataDict[GenerationDatumSpec], + batch: BatchedDataDict[DatumSpec], + tokenizer, + input_lengths: torch.Tensor, + include_logprobs: bool = True, +) -> Tuple[List[torch.Tensor], List[str], torch.Tensor]: + """Generate responses from policy.""" + # Generate responses + generation_outputs = policy_generation.generate(generation_input_data) + + # Extract generated tokens + generated_ids = [] + unpadded_sequence_lengths = generation_outputs["unpadded_sequence_lengths"] + for output_ids, input_length, total_length in zip( + generation_outputs["output_ids"], input_lengths, unpadded_sequence_lengths + ): + generated_ids.append(output_ids[input_length:total_length]) + + generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + + # Append to message log + for i, (text, input_length, total_length) in enumerate( + zip(generated_texts, input_lengths, unpadded_sequence_lengths) + ): + message = { + "role": "assistant", + "content": text, + "token_ids": generation_outputs["output_ids"][i, input_length:total_length], + } + + if include_logprobs and "logprobs" in generation_outputs: + message["generation_logprobs"] = generation_outputs["logprobs"][ + i, input_length:total_length + ] + + batch["message_log"][i].append(message) + + metrics = { + "mean_generation_length": ( + torch.sum(unpadded_sequence_lengths) - torch.sum(input_lengths) + ).item() + / len(unpadded_sequence_lengths), + "max_seqlen": torch.max(unpadded_sequence_lengths).item(), + } + + return batch, generated_ids, metrics + + +def calculate_rewards( + batch: BatchedDataDict[DatumSpec], + task_to_env: Dict[str, EnvironmentInterface], +) -> Tuple[torch.Tensor, List[LLMMessageLogType]]: + """Calculate rewards for generated responses. + + Args: + batch: Batch containing message_log (LLMMessageLogType) with generated responses + task_to_env: Dictionary mapping task names to their corresponding environments + + Returns: + rewards: Tensor of rewards + to_env: Simplified message logs sent to environment (LLMMessageLogType format) + """ + # Extract message logs for environment + to_env = [ + get_keys_from_message_log(batch["message_log"][i], ["role", "content"]) + for i in range(len(batch["message_log"])) + ] + task_names = [batch["task_name"][i] for i in range(len(batch["task_name"]))] + + # Group messages by task type + task_groups = {} + for i, task_name in enumerate(task_names): + if task_name not in task_groups: + task_groups[task_name] = [] + task_groups[task_name].append((i, to_env[i])) + + # Calculate rewards for each task group concurrently + futures = [] + future_to_indices = {} # Map future to its corresponding indices + for task_name, group in task_groups.items(): + if task_name not in task_to_env: + raise ValueError(f"No environment found for task type: {task_name}") + + # Extract indices and messages for this group + indices = [idx for idx, _ in group] + messages = [msg for _, msg in group] + + # Get corresponding environment info + env_info = [batch["extra_env_info"][i] for i in indices] + + # Submit task to environment and store future + future = task_to_env[task_name].step.remote(messages, env_info) + futures.append(future) + future_to_indices[future] = indices + + results = ray.get(futures) + all_rewards = [] + for future, result in zip(futures, results): + indices = future_to_indices[future] + _, _, task_rewards, _ = result + + # Store results with their original indices + for idx, reward in zip(indices, task_rewards): + all_rewards.append((idx, reward)) + + # Sort results by original index to maintain order + all_rewards.sort(key=lambda x: x[0]) + rewards = torch.tensor([reward for _, reward in all_rewards]) + + return rewards, to_env + + +# =============================================================================== +# Training & Validation +# =============================================================================== + + +def grpo_train( + policy: PolicyInterface, + policy_generation: Optional[GenerationInterface], + dataloader: StatefulDataLoader, + val_dataloader: Optional[StatefulDataLoader], + tokenizer, + loss_fn: LossFunction, + task_to_env: Dict[str, EnvironmentInterface], + val_task_to_env: Optional[Dict[str, EnvironmentInterface]], + logger: Logger, + checkpointer: CheckpointManager, + grpo_save_state: Optional[GRPOSaveState], + master_config: MasterConfig, +): + """Run GRPO training algorithm.""" + timer = Timer() + NEED_REFIT = True + # If policy_generation is None, use the policy as the generation interface (hf framework backend) + if policy_generation is None: + policy_generation = policy + NEED_REFIT = False + POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running + + # common config/state itmes + step = grpo_save_state["step"] + consumed_samples = grpo_save_state["consumed_samples"] + val_period = master_config["grpo"]["val_period"] + val_at_start = master_config["grpo"]["val_at_start"] + + # Run validation at the start if configured + if val_at_start and step == 0: + print("\nπŸ” Running initial validation...") + if NEED_REFIT and POLICY_GENERATION_STALE: + refit_policy_generation(policy, policy_generation) + POLICY_GENERATION_STALE = False + else: + policy_generation.prepare_for_generation() + val_metrics, validation_timings = validate( + policy_generation, + val_dataloader, + tokenizer, + val_task_to_env, + step=0, + master_config=master_config, + ) + policy_generation.finish_generation() + logger.log_metrics(val_metrics, step, prefix="validation") + logger.log_metrics(validation_timings, step, prefix="timing/validation") + + # Run grpo training (single-turn) + for batch in dataloader: + print(f"\n{'=' * 25} Step {step + 1}/{len(dataloader)} {'=' * 25}") + + with timer.time("total_step_time"): + # Prepare batch + print("β–Ά Preparing batch...") + with timer.time("data_processing"): + # Repeat batch items + repeated_batch = batch.repeat_interleave( + master_config["grpo"]["num_generations_per_prompt"] + ) + # Convert LLMMessageLogType to FlatMessagesType for generation + batched_flat, input_lengths = batched_message_log_to_flat_message( + repeated_batch["message_log"], + pad_value_dict={"token_ids": tokenizer.eos_token_id}, + ) + input_ids = batched_flat["token_ids"] + # Create generation-specific input structure + generation_input_data = BatchedDataDict[GenerationDatumSpec]( + { + "input_ids": input_ids, + "input_lengths": input_lengths, + } + ) + + # Generate responses - this updates the LLMMessageLogType in repeated_batch + print(f"β–Ά Generating responses for batch of size {len(input_ids)}...") + with timer.time("prepare_for_generation"): + if NEED_REFIT and POLICY_GENERATION_STALE: + refit_policy_generation(policy, policy_generation) + POLICY_GENERATION_STALE = False + else: + policy_generation.prepare_for_generation() + with timer.time("generation"): + repeated_batch, _, gen_metrics = generate_responses( + policy_generation, + generation_input_data, + repeated_batch, + tokenizer, + input_lengths, + ) + policy_generation.finish_generation() + + # Calculate rewards & advantages based on the updated LLMMessageLogType + print("β–Ά Calculating rewards...") + with timer.time("reward_calculation"): + rewards, _ = calculate_rewards(repeated_batch, task_to_env) + + print("β–Ά Computing advantages...") + baseline, std = calculate_baseline_and_std_per_prompt( + input_ids, + rewards, + torch.ones_like(rewards), + leave_one_out_baseline=master_config["grpo"][ + "use_leave_one_out_baseline" + ], + ) + advantages = (rewards - baseline).unsqueeze(-1) + + if master_config["grpo"]["normalize_rewards"]: + # don't sharpen the ones with no variation + zero_std_mask = std > 0 + advantages[zero_std_mask] = ( + advantages[zero_std_mask] / std.unsqueeze(-1)[zero_std_mask] + ) + + with timer.time("data_processing"): + # Add loss mask and advantages to each message in LLMMessageLogType + for i, message_log in enumerate(repeated_batch["message_log"]): + for j, message in enumerate(message_log): + if message["role"] == "assistant": + message["token_loss_mask"] = torch.ones_like( + message["token_ids"] + ) + else: + message["token_loss_mask"] = torch.zeros_like( + message["token_ids"] + ) + if "generation_logprobs" not in message: + message["generation_logprobs"] = torch.zeros_like( + message["token_ids"], dtype=torch.float32 + ) + message["advantages"] = advantages[i].expand( + message["token_ids"].shape + ) + + # Convert updated LLMMessageLogType to FlatMessagesType for training + flat_messages, input_lengths = batched_message_log_to_flat_message( + repeated_batch["message_log"], + pad_value_dict={"token_ids": tokenizer.eos_token_id}, + ) + + # Create training data from flattened messages + train_data = BatchedDataDict[ClippedPGLossDataDict]( + { + "input_ids": flat_messages["token_ids"], + "input_lengths": input_lengths, + "advantages": flat_messages["advantages"], + "generation_logprobs": flat_messages["generation_logprobs"], + "token_mask": flat_messages["token_loss_mask"], + "sample_mask": repeated_batch["loss_multiplier"], + } + ) + train_data.to("cpu") + + print("β–Ά Preparing for logprob inference...") + with timer.time("logprob_inference_prep"): + policy.prepare_for_lp_inference() + + print("β–Ά Computing logprobs...") + with timer.time("policy_and_reference_logprobs"): + fprop_logprobs = policy.get_logprobs(train_data)["logprobs"] + reference_logprobs = policy.get_reference_policy_logprobs(train_data)[ + "reference_logprobs" + ] + train_data["prev_logprobs"] = fprop_logprobs + train_data["reference_policy_logprobs"] = reference_logprobs + + print("β–Ά Preparing for training...") + with timer.time("training_prep"): + policy.prepare_for_training() # set model train and reload optim to GPU + POLICY_GENERATION_STALE = True + + print("β–Ά Training policy...") + with timer.time("policy_training"): + train_results = policy.train(train_data, loss_fn) + + # Run validation if it's a validation step + if val_period > 0 and (step + 1) % val_period == 0: + if NEED_REFIT and POLICY_GENERATION_STALE: + refit_policy_generation(policy, policy_generation) + POLICY_GENERATION_STALE = False + else: + policy_generation.prepare_for_generation() + val_metrics, validation_timings = validate( + policy_generation, + val_dataloader, + tokenizer, + val_task_to_env, + step=step + 1, + master_config=master_config, + ) + policy_generation.finish_generation() + logger.log_metrics( + validation_timings, step + 1, prefix="timing/validation" + ) + logger.log_metrics(val_metrics, step + 1, prefix="validation") + + ## Checkpointing + consumed_samples += master_config["grpo"]["num_prompts_per_step"] + if ( + master_config["checkpointing"]["enabled"] + and (step + 1) % master_config["checkpointing"]["save_period"] == 0 + ): # +1 because step is 0-indexed + policy.prepare_for_training() + grpo_save_state["step"] = step + 1 + grpo_save_state["val_reward"] = val_metrics["accuracy"] + grpo_save_state["consumed_samples"] = consumed_samples + with timer.time("checkpointing"): + print(f"Saving checkpoint for step {step + 1}...") + checkpoint_path = checkpointer.init_tmp_checkpoint( + step + 1, grpo_save_state, master_config + ) + policy.save_checkpoint( + os.path.join(checkpoint_path, "policy.pt"), + os.path.join(checkpoint_path, "policy_optimizer.pt"), + ) + torch.save( + dataloader.state_dict(), + os.path.join(checkpoint_path, "train_dataloader.pt"), + ) + checkpointer.finalize_checkpoint(checkpoint_path) + policy.offload_after_refit() + + # Logging + print("\nπŸ“Š Training Results:") + metrics = { + "loss": train_results["loss"].numpy(), + "reward": rewards.numpy(), + } + metrics.update(train_results["all_mb_metrics"]) + metrics = {k: np.mean(v).item() for k, v in metrics.items()} + metrics.update(gen_metrics) + + timing_metrics = timer.get_timing_metrics(reduction_op="sum") + + print(f" β€’ Loss: {metrics['loss']:.4f}") + print(f" β€’ Avg Reward: {np.mean(rewards.numpy()):.4f}") + print( + f" β€’ Mean Generation Length: {gen_metrics['mean_generation_length']:.4f}" + ) + + print("\n⏱️ Timing:") + # Display total time first, separately + total_time = timing_metrics.get("total_step_time", 0) + print(f" β€’ Total step time: {total_time:.2f}s") + + # Display all other timing metrics + for k, v in sorted( + timing_metrics.items(), key=lambda item: item[1], reverse=True + ): + if k != "total_step_time": + percent = (v / total_time * 100) if total_time > 0 else 0 + print(f" β€’ {k}: {v:.2f}s ({percent:.1f}%)") + + logger.log_metrics(metrics, step + 1, prefix="train") + logger.log_metrics(timing_metrics, step + 1, prefix="timing/train") + + timer.reset() + step += 1 + if step >= master_config["grpo"]["num_steps"]: + break + + +def validate( + policy_generation: GenerationInterface, + val_dataloader: StatefulDataLoader, + tokenizer, + val_task_to_env: Dict[str, EnvironmentInterface], + step: int, + master_config: MasterConfig, +) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Run validation on the validation dataset.""" + if val_dataloader is None: + print(" ⚠️ No validation dataloader provided, skipping validation") + return + + timer = Timer() + with timer.time("total_validation_time"): + print(f"β–Ά Starting validation at step {step}...") + + total_rewards = [] + total_lengths = [] + all_message_logs = [] # Collect all message logs + + max_batches = ( + master_config["grpo"]["max_val_samples"] + // master_config["grpo"]["val_batch_size"] + ) + for batch_idx, val_batch in enumerate(val_dataloader): + if batch_idx >= max_batches: + break + + # Convert LLMMessageLogType to FlatMessagesType for generation + batched_flat, input_lengths = batched_message_log_to_flat_message( + val_batch["message_log"], + pad_value_dict={"token_ids": tokenizer.eos_token_id}, + ) + # Extract input IDs + input_ids = batched_flat["token_ids"] + # Create generation-specific input structure + generation_input_data = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": input_lengths, + } + ) + + # Generate responses (updates the LLMMessageLogType in batch_with_msg_logs) + val_batch, generated_ids, gen_metrics = generate_responses( + policy_generation, + generation_input_data, + val_batch, + tokenizer, + input_lengths, + include_logprobs=False, + ) + + # Calculate rewards based on the updated LLMMessageLogType + with timer.time("reward_calculation"): + rewards, to_env = calculate_rewards(val_batch, val_task_to_env) + + total_rewards.extend(rewards.tolist()) + total_lengths.extend([len(ids) for ids in generated_ids]) + + # Collect message logs for later display + all_message_logs.extend(to_env) + + # Calculate validation metrics + accuracy = sum(total_rewards) / len(total_rewards) + avg_length = sum(total_lengths) / len(total_lengths) + + val_metrics = { + "accuracy": accuracy, + "avg_length": avg_length, + } + + # Print sample conversations only once at the end of validation + try: + print_message_log_samples( + all_message_logs, + total_rewards, + num_samples=min( + master_config["logger"]["num_val_samples_to_print"], + len(all_message_logs), + ), + step=step, + ) + except Exception as e: + print(f"\n ⚠️ Error displaying message samples: {str(e)}") + print(" ⚠️ Continuing validation without displaying samples...") + + # Get timing metrics + timing_metrics = timer.get_timing_metrics(reduction_op="sum") + validation_time = timing_metrics.get("total_validation_time", 0) + + # Print summary of validation results + print("\nπŸ“Š Validation Results:") + print(f" β€’ Accuracy: {accuracy:.4f}") + print(f" β€’ Average response length: {avg_length:.1f} tokens") + print(f" β€’ Samples processed: {len(total_rewards)}") + + # Print timing information + print("\n ⏱️ Validation Timing:") + validation_time = timing_metrics.get("total_validation_time", 0) + print(f" β€’ Total validation time: {validation_time:.2f}s") + + # Make sure to reset the timer after validation + timer.reset() + + return val_metrics, timing_metrics diff --git a/nemo_reinforcer/algorithms/interfaces.py b/nemo_reinforcer/algorithms/interfaces.py new file mode 100644 index 0000000000..b0290ce17f --- /dev/null +++ b/nemo_reinforcer/algorithms/interfaces.py @@ -0,0 +1,50 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Protocol, Tuple + +import torch + +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict + + +class LossFunction(Protocol): + """Signature for loss functions used in reinforcement learning algorithms. + + Loss functions compute a scalar loss value and associated metrics from + model logprobs and other data contained in a BatchedDataDict. + """ + + def __call__( + self, next_token_logits: torch.Tensor, data: BatchedDataDict + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + """Compute loss and metrics from logprobs and other data. + + Args: + next_token_logits: Logits from the model, typically with shape [batch_size, seq_len, vocab_size]. + For each position (b, i), contains the logit distribution over the entire vocabulary + for predicting the next token (at position i+1). For example, if processing "The cat sat on", + then next_token_logits[b, 3] would contain the logits for predicting the word + that follows "on". + data: Dictionary containing all relevant data for loss computation + such as rewards, values, actions, advantages, masks, and other + algorithm-specific information needed for the particular loss calculation. + + Returns: + tuple: (loss, metrics) + - loss: A scalar tensor representing the loss value to be minimized during training + - metrics: A dictionary of metrics related to the loss computation, which may include + component losses, statistics about gradients/rewards, and other diagnostic information + """ + pass diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py new file mode 100644 index 0000000000..0518e3f0c0 --- /dev/null +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -0,0 +1,165 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Tuple, TypedDict + +import torch + +from nemo_reinforcer.algorithms.interfaces import LossFunction +from nemo_reinforcer.algorithms.utils import ( + calculate_kl_penalty_joschu2020, + masked_mean, +) +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict + + +class ClippedPGLossConfig(TypedDict): + reference_policy_kl_penalty: float + ratio_eps: float + + +class ClippedPGLossDataDict(TypedDict): + """Required keys for the Clipped Policy Gradient loss function.""" + + input_ids: torch.Tensor + advantages: torch.Tensor + prev_logprobs: torch.Tensor + generation_logprobs: torch.Tensor + reference_policy_logprobs: torch.Tensor + token_mask: torch.Tensor + sample_mask: torch.Tensor + __extra__: Any + + +class ClippedPGLossFn(LossFunction): + """Generalized Clipped Policy Gradient loss function w/ KL regularization. + + This implements: + + - PPO (Clipped) - https://arxiv.org/abs/1707.06347 + - GRPO - https://arxiv.org/abs/2402.03300 + - REINFORCE/RLOO (set disable_ppo_ratio = True and ignores ratio_eps) - https://arxiv.org/abs/2402.14740 + + Formula: + L(ΞΈ) = E_t [ min(r_t(ΞΈ) * A_t, clip(r_t(ΞΈ), 1-Ξ΅, 1+Ξ΅) * A_t) ] - Ξ² * KL(Ο€_ΞΈ || Ο€_ref) + + where: + - r_t(ΞΈ) = Ο€_ΞΈ(a_t|s_t) / Ο€_ΞΈ_old(a_t|s_t) is the probability ratio + - A_t is the advantage estimate + - Ξ΅ is the clip parameter (ratio_eps) + - Ξ² is the KL penalty coefficient (reference_policy_kl_penalty) + - KL(Ο€_ΞΈ || Ο€_ref) is the KL divergence between the current policy and reference policy (Schulman Approx.) + + For REINFORCE/RLOO (when disable_ppo_ratio=True), the formula simplifies to: + L(ΞΈ) = E_t [ Ο€_ΞΈ(a_t|s_t) * A_t ] - Ξ² * KL(Ο€_ΞΈ || Ο€_ref) + """ + + def __init__(self, cfg: ClippedPGLossConfig): + self.ratio_eps = cfg["ratio_eps"] + self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"] + self.disable_ppo_ratio = cfg.get("disable_ppo_ratio", False) + + def __call__( + self, + next_token_logits: torch.Tensor, + data: BatchedDataDict[ClippedPGLossDataDict], + ) -> Tuple[torch.Tensor, dict]: + """Clipped Policy Gradient RL loss function.""" + token_mask = data["token_mask"][:, 1:] + sample_mask = data["sample_mask"] + advantages = data["advantages"][:, 1:] + prev_logprobs = data["prev_logprobs"][:, 1:] + generation_logprobs = data["generation_logprobs"][:, 1:] + reference_policy_logprobs = data["reference_policy_logprobs"][:, 1:] + + mask = token_mask * sample_mask.unsqueeze(-1) + + lp_error = torch.abs(generation_logprobs - prev_logprobs) # noqa: F841 (precommit ignore for now) + mult_prob_error = ((torch.exp(lp_error) * mask).sum() / mask.sum()).item() + + next_token_logits = next_token_logits[:, :-1] # Remove last position's logits + next_token_logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1) + + next_tokens = data["input_ids"][:, 1:] # Skip first token + curr_logprobs = next_token_logprobs.gather( + dim=-1, index=next_tokens.unsqueeze(-1) + ).squeeze(-1) + + # Calculate KL regularization. + if self.reference_policy_kl_penalty != 0: + kl = self.reference_policy_kl_penalty * calculate_kl_penalty_joschu2020( + logprobs_policy=curr_logprobs, + logprobs_reference=reference_policy_logprobs, + ) + kl = masked_mean(kl, mask) + else: + kl = 0 + + # Calculate clipped loss function if ppo ratio is enabled. + if not self.disable_ppo_ratio: + ratios = (curr_logprobs - prev_logprobs).exp() + ratios_clamped = ratios.clamp(1.0 - self.ratio_eps, 1.0 + self.ratio_eps) + else: + ratios = curr_logprobs + ratios_clamped = curr_logprobs + + loss1 = -advantages * ratios + loss2 = -advantages * ratios_clamped + + if mask.sum() > 0: + actor_loss = masked_mean(torch.max(loss1, loss2), mask) + loss = actor_loss + kl + else: + # disable this update since there are no valid tokens + loss = loss1.view(-1)[0] * 0 + + with torch.no_grad(): + probs_ratio = masked_mean(ratios.detach(), mask).item() + probs_ratio_clamped = masked_mean(ratios_clamped.detach(), mask).item() + + return ( + loss, + { + "loss": loss.item(), + "probs_ratio": probs_ratio, + "probs_ratio_clamped": probs_ratio_clamped, + "kl_penalty": kl.item() / self.reference_policy_kl_penalty if kl else 0, + "token_mult_prob_error": mult_prob_error, + }, + ) + + +class NLLLoss(LossFunction): + def __call__( + self, next_token_logits: torch.Tensor, data: BatchedDataDict + ) -> Tuple[torch.Tensor, dict]: + # logits shape: [batch_size, seq_len, vocab_size] + # Get the next token logits for each position + token_mask = data["token_mask"][:, 1:] + sample_mask = data["sample_mask"] + mask = token_mask * sample_mask.unsqueeze(-1) + + next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token + next_token_logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1) + logprobs = next_token_logprobs[:, :-1] # Remove last position's logits + + # Gather the logprobs for the actual next tokens + token_logprobs = logprobs.gather( + dim=-1, index=next_tokens.unsqueeze(-1) + ).squeeze(-1) + + # Only compute loss on generated tokens (not input tokens) + # by applying the token_loss_mask (shifted by 1 since we're predicting next tokens) + loss = -torch.sum(token_logprobs * mask) + + return loss, {"loss": loss.item()} diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py new file mode 100644 index 0000000000..1f948b1b5c --- /dev/null +++ b/nemo_reinforcer/algorithms/sft.py @@ -0,0 +1,210 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Tuple, TypedDict + +from torch.utils.data import DataLoader +from transformers import AutoTokenizer + +from nemo_reinforcer.algorithms.loss_functions import ( + NLLLoss, +) +from nemo_reinforcer.data import DataConfig, hf_datasets +from nemo_reinforcer.data.datasets import AllTaskProcessedDataset, rl_collate_fn +from nemo_reinforcer.data.interfaces import TaskDataSpec, DatumSpec +from nemo_reinforcer.data.llm_message_utils import ( + add_loss_mask_to_message_log, + batched_message_log_to_flat_message, + get_formatted_message_log, +) +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster +from nemo_reinforcer.models.policy.hf_policy import HfPolicy +from nemo_reinforcer.models.policy import PolicyConfig +from nemo_reinforcer.utils.logger import Logger, LoggerConfig +from nemo_reinforcer.utils.timer import Timer + + +class SFTConfig(TypedDict): + num_steps: int + + +class MasterConfig(TypedDict): + policy: PolicyConfig + data: DataConfig + sft: SFTConfig + logger: LoggerConfig + cluster: ClusterConfig + + +def sft_preprocessor( + datum_dict: Dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer, + max_seq_length: int, + idx: int, +) -> DatumSpec: + """Process a datum dictionary for SFT training.""" + message_log = get_formatted_message_log( + datum_dict["messages"], tokenizer, task_data_spec + ) + + length = sum(len(m["token_ids"]) for m in message_log) + + loss_multiplier = 1.0 + if length > max_seq_length: + # make smaller and mask out + for message in message_log: + message["token_ids"] = message["token_ids"][ + : min(4, max_seq_length // len(message_log)) + ] + loss_multiplier = 0.0 + + output = { + "message_log": message_log, + "length": length, + "extra_env_info": None, + "loss_multiplier": loss_multiplier, + "idx": idx, + } + return output + + +def setup( + master_config: MasterConfig, +) -> Tuple[ + HfPolicy, + RayVirtualCluster, + DataLoader, + AutoTokenizer, + NLLLoss, + MasterConfig, + Logger, +]: + """Main entry point for running SFT algorithm. + + Returns: + Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, master_config, logger + """ + # Extract individual configs for easier access + policy_config = master_config["policy"] + data_config = master_config["data"] + logger_config = master_config["logger"] + cluster_config = master_config["cluster"] + + ## TODO: unify this with grpo + data_cls = data_config["dataset_name"] + if data_cls == "open_assistant": + data = hf_datasets.OasstDataset(output_dir="/tmp/open_assistant") + elif data_cls == "squad": + data = hf_datasets.SquadDataset() + else: + raise ValueError(f"Unknown dataset class: {data_cls}") + + base_dataset = data.formatted_ds["train"] + sft_task_spec = data.task_spec + + tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"]) + + dataset = AllTaskProcessedDataset( + base_dataset, + tokenizer, + sft_task_spec, + sft_preprocessor, + max_seq_length=data_config["max_input_seq_length"], + ) + + dataloader = DataLoader( + dataset, + batch_size=policy_config["train_global_batch_size"], + shuffle=False, + collate_fn=rl_collate_fn, ## TODO: change this for sft! or make it more general + ) + + cluster = RayVirtualCluster( + name="sft_cluster", + bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] + * cluster_config["num_nodes"], + use_gpus=True, + num_gpus_per_node=cluster_config["gpus_per_node"], + max_colocated_worker_groups=1, + ) + + policy = HfPolicy(cluster=cluster, config=policy_config) + loss_fn = NLLLoss() + + logger = Logger(logger_config) + + return ( + policy, + cluster, + dataloader, + tokenizer, + loss_fn, + master_config, + logger, + sft_task_spec, + ) + + +def sft_train( + policy, dataloader, tokenizer, loss_fn, master_config, logger, sft_task_spec +): + # Run basic sft training + timer = Timer() + + policy.prepare_for_training() + + for step, batch in enumerate(dataloader): + timer.start("sft_train_step") + + timer.start("data_processing") + ## add loss mask based on role to every message + add_loss_mask_to_message_log( + batch["message_log"], + roles_to_train_on=["assistant"], + ) + + cat_and_padded, input_lengths = batched_message_log_to_flat_message( + batch["message_log"], + pad_value_dict={"token_ids": tokenizer.eos_token_id}, + ) + + train_data: BatchedDataDict = BatchedDataDict( + { + "input_ids": cat_and_padded["token_ids"], + "input_lengths": input_lengths, + "token_mask": cat_and_padded["token_loss_mask"], + "sample_mask": batch["loss_multiplier"], + } + ) + timer.stop("data_processing") + + ## train_data.to("cpu") + train_results = policy.train(train_data, loss_fn) + timer.stop("sft_train_step") + losses = train_results["loss"] + timing_metrics = timer.get_timing_metrics(reduction_op="sum") + + print(f"Step {step} completed. Loss: {losses[-1].item()}") + + logger.log_metrics( + {"loss": losses[-1].item()}, + step, + prefix="train", + ) + logger.log_metrics(timing_metrics, step, prefix="timing/train") + timer.reset() + + if step >= master_config["sft"]["num_steps"] - 1: + break diff --git a/nemo_reinforcer/algorithms/utils.py b/nemo_reinforcer/algorithms/utils.py new file mode 100644 index 0000000000..a153bfb53e --- /dev/null +++ b/nemo_reinforcer/algorithms/utils.py @@ -0,0 +1,122 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +from functools import wraps + +import torch +from torch.masked import as_masked_tensor + + +def calculate_kl_penalty_joschu2020( + logprobs_policy: torch.Tensor, logprobs_reference: torch.Tensor +): + """Calculates a per-token estimate of the KL Divergence between two log_probs. + + From Schulman 2020, always positive. + + logprobs_policy: torch.Tensor (b, s) + logprobs_reference: torch.Tensor (b, s) + """ + r = logprobs_reference - logprobs_policy + return torch.exp(r) - r - 1 + + +def calculate_baseline_and_std_per_prompt( + prompts: torch.Tensor, + rewards: torch.Tensor, + valid_mask: torch.Tensor, + leave_one_out_baseline: bool = True, +): + """Function to compute a baseline for each (prompt, response) pair in the batch. + + The same baseline is calculated for each prompt. Samples set to 0 in 'valid_mask' + are not included in the baseline calculation. + + prompts: tensor (b, s) Tensor of prompts the model used. May be on any device + rewards: tensor (b,) Float-valued rewards. May be on any device + valid_mask: tensor (b,) Vector of 0/1, where 0 is to ignore and 1 is to keep + leave_one_out_baseline: bool Compute an unbiased baseline by leaving out the sample that + the baseline is for (from RLOO https://arxiv.org/abs/2402.14740) + + Returns: + tensor (b,) of baselines on the same device as 'rewards' + """ + unique_prompts = torch.unique(prompts, dim=0) + + baseline = torch.zeros_like(rewards) + sq_baseline = torch.zeros_like(rewards) + reward_device = rewards.get_device() + if reward_device == -1: + reward_device = torch.device("cpu") + + for i in range(len(unique_prompts)): + is_matching_prompt = (prompts == unique_prompts[i]).all(1) + prompt_idx = torch.arange(len(prompts), device=reward_device)[ + is_matching_prompt + ] + + if leave_one_out_baseline: + baseline_mask_matrix = (1 - torch.eye(len(prompt_idx))).to(reward_device) + else: + baseline_mask_matrix = torch.ones((len(prompt_idx), len(prompt_idx))).to( + reward_device + ) + + if valid_mask[prompt_idx].sum() <= 1: + # Ignore sample: there are no valid responses, so set baseline equal to reward + # to ignore it in the loss computation + baseline[prompt_idx] = rewards[prompt_idx] + else: + num_valid = valid_mask[prompt_idx].float().sum() - int( + leave_one_out_baseline + ) + prompt_baseline = ( + torch.matmul( + baseline_mask_matrix, rewards[prompt_idx] * valid_mask[prompt_idx] + ) + / num_valid + ) + prompt_baseline_square = ( + torch.matmul( + baseline_mask_matrix, + (rewards[prompt_idx] ** 2) * valid_mask[prompt_idx], + ) + / num_valid + ) + + baseline[prompt_idx] = prompt_baseline + sq_baseline[prompt_idx] = prompt_baseline_square + + std = (sq_baseline - baseline.square()).sqrt().nan_to_num(0) + return baseline, std + + +def surpress_user_warnings(f): + @wraps(f) + def wrapper(*args, **kwargs): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + output = f(*args, **kwargs) + return output + + return wrapper + + +# need to surpress the masked tensor warnings from pytorch +@surpress_user_warnings +def masked_mean(values, mask, dim=None): + """Masks values with mask, and computes the mean of the values using the masked values.""" + if dim is None: + return values[mask.bool()].mean() + return as_masked_tensor(values, mask.bool()).mean(dim=dim).to_tensor(torch.nan) diff --git a/nemo_reinforcer/converters/__init__.py b/nemo_reinforcer/converters/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_reinforcer/converters/huggingface/__init__.py b/nemo_reinforcer/converters/huggingface/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_reinforcer/converters/huggingface/vllm_export.py b/nemo_reinforcer/converters/huggingface/vllm_export.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nemo_reinforcer/converters/huggingface/vllm_export.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo_reinforcer/converters/megatron/__init__.py b/nemo_reinforcer/converters/megatron/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_reinforcer/converters/megatron/vllm_export.py b/nemo_reinforcer/converters/megatron/vllm_export.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nemo_reinforcer/converters/megatron/vllm_export.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo_reinforcer/data/__init__.py b/nemo_reinforcer/data/__init__.py new file mode 100644 index 0000000000..63aad516b2 --- /dev/null +++ b/nemo_reinforcer/data/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, TypedDict + + +class DataConfig(TypedDict): + max_input_seq_length: int + prompt_file: str + system_prompt_file: Optional[str] + dataset_name: str + val_dataset_name: Optional[str] diff --git a/nemo_reinforcer/data/datasets.py b/nemo_reinforcer/data/datasets.py new file mode 100644 index 0000000000..033cee05d7 --- /dev/null +++ b/nemo_reinforcer/data/datasets.py @@ -0,0 +1,132 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Union, Tuple + +import torch +from datasets import Dataset + +from nemo_reinforcer.data.interfaces import ( + TaskDataSpec, + TaskDataProcessFnCallable, + DatumSpec, +) +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict + + +# TODO @sahilj handle too-long prompts and masking them out throughout the whole process and renormalizing on loss +class AllTaskProcessedDataset: + """Dataset for processing single or multi-task data with task-specific tokenization and processing. + + Args: + dataset: Input dataset containing raw data + tokenizer: Tokenizer for text processing + default_task_data_spec: Default task processing specifications. + In the case of single-task, this is the spec used for processing all entries. + In the case of multi-task, any values not specified in the task-specific specs will be taken from the default spec. + task_data_processors: Either a single TaskDataProcessFnCallable for single-task, + or a dict mapping task names to (TaskDataSpec, TaskDataProcessFnCallable) for multi-task + max_seq_length: Maximum sequence length for tokenized outputs + """ + + def __init__( + self, + dataset: Union[Dataset, Any], + tokenizer, + default_task_data_spec: TaskDataSpec, + task_data_processors: Union[ + Dict[str, Tuple[TaskDataSpec, TaskDataProcessFnCallable]], + TaskDataProcessFnCallable, + ], + max_seq_length=None, + ): + self.dataset = dataset + self.tokenizer = tokenizer + self.default_task_data_spec = default_task_data_spec + self.task_data_processors = task_data_processors + self.max_seq_length = max_seq_length + + if isinstance(task_data_processors, dict): + # apply defaults to all task data specs + for task_name, ( + task_data_spec, + task_data_processor, + ) in task_data_processors.items(): + task_data_spec.copy_defaults(self.default_task_data_spec) + + def __len__(self): + return len(self.dataset) + + def encode_single(self, text: Union[str, List[str]]) -> Tuple[List[int], int]: + """Takes either a single string or a list of strings that represent multiple turns for the same conversation. + + Returns a single (concatenated) list of tokenized ids and the length of the tokenized ids. + """ + if isinstance(text, str): + text_ids = self.tokenizer.text_to_ids(text) + return text_ids, len(text_ids) + elif isinstance(text, list): + text_ids = [self.tokenizer.text_to_ids(t) for t in text] + return torch.cat(text_ids), sum(len(t) for t in text_ids) + else: + raise ValueError( + f"text must be a string or a list of strings, got {type(text)}" + ) + + def __getitem__(self, idx: int) -> DatumSpec: + """Return a single prompt.""" + entry = self.dataset[idx] + + if isinstance(self.task_data_processors, dict): + task_name = entry["task_name"] + + assert task_name in self.task_data_processors, ( + f"task processor not provided for {task_name}. Provided processors: {self.task_data_processors.keys()}" + ) + task_data_spec, task_data_processor = self.task_data_processors[task_name] + else: + task_data_spec = self.default_task_data_spec + task_data_processor = self.task_data_processors + + datum_spec = task_data_processor( + entry, task_data_spec, self.tokenizer, self.max_seq_length, idx + ) + return datum_spec + + +def rl_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict: + """Collate function for RL training.""" + message_log = [datum_spec["message_log"] for datum_spec in data_batch] + length = torch.tensor([datum_spec["length"] for datum_spec in data_batch]) + loss_multiplier = torch.tensor( + [datum_spec["loss_multiplier"] for datum_spec in data_batch] + ) + extra_env_info = [datum_spec["extra_env_info"] for datum_spec in data_batch] + + task_names = [] + for datum_spec in data_batch: + task_names.append(datum_spec.get("task_name", None)) + + idx = [datum_spec["idx"] for datum_spec in data_batch] + batch_max_length = torch.ones_like(length) * length.max() + + output = BatchedDataDict( + message_log=message_log, + length=length, + loss_multiplier=loss_multiplier, + extra_env_info=extra_env_info, + task_name=task_names, + idx=idx, + batch_max_length=batch_max_length, + ) + return output diff --git a/nemo_reinforcer/data/hf_datasets/__init__.py b/nemo_reinforcer/data/hf_datasets/__init__.py new file mode 100644 index 0000000000..df1227140d --- /dev/null +++ b/nemo_reinforcer/data/hf_datasets/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_reinforcer.data.hf_datasets.oasst import OasstDataset +from nemo_reinforcer.data.hf_datasets.squad import SquadDataset + +__all__ = ["OasstDataset", "SquadDataset"] diff --git a/nemo_reinforcer/data/hf_datasets/interfaces.py b/nemo_reinforcer/data/hf_datasets/interfaces.py new file mode 100644 index 0000000000..63be96c5a7 --- /dev/null +++ b/nemo_reinforcer/data/hf_datasets/interfaces.py @@ -0,0 +1,39 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Any, Optional +from nemo_reinforcer.data.interfaces import TaskDataSpec + + +class COMMON_CHAT_TEMPLATES: + ### simple template which prepends a role header to the content + simple_role_header = "{% for message in messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" + + +class HfDataset: + """Interface for HuggingFace datasets.""" + + formatted_ds: Dict[str, Any] + + def __init__( + self, + dataset_name: str, + custom_template: Optional[ + str + ] = None, ## "None" means use HuggingFace's tokenizer's template + ): + self.task_spec = TaskDataSpec( + task_name=dataset_name, + custom_template=custom_template, + ) diff --git a/nemo_reinforcer/data/hf_datasets/oasst.py b/nemo_reinforcer/data/hf_datasets/oasst.py new file mode 100644 index 0000000000..4525c0fa92 --- /dev/null +++ b/nemo_reinforcer/data/hf_datasets/oasst.py @@ -0,0 +1,128 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import gzip +import os +import random +import requests +import copy +from dataclasses import dataclass +from typing import Optional +from nemo_reinforcer.data.hf_datasets.interfaces import HfDataset, COMMON_CHAT_TEMPLATES + +SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\n" + + +def parse_conversations(tree_obj, first=False): + """Recusive function that returns all the sub converstaions in a list starting from node tree_obj. + + Args: + tree_obj (obj): current conversation node + + Returns: + a list of sub conversation threads including the current conversation node + """ + turns = [] + if first: + turn = {"content": SYSTEM_PROMPT, "role": "system"} + turns.append(turn) + + if "prompt" in tree_obj: + prompt_obj = tree_obj["prompt"] + elif "text" in tree_obj and "role" in tree_obj: + prompt_obj = tree_obj + else: + return [[]] + if prompt_obj["role"] == "prompter": + role = "user" + elif prompt_obj["role"] == "assistant": + role = "assistant" + else: + raise ValueError(f"unknown role {prompt_obj['role']}") + turn = {"content": prompt_obj["text"], "role": role} + turns.append(turn) + + all_conversations = [] + multiple_sub_threads = [] + for next_obj in prompt_obj["replies"]: + multiple_threads = parse_conversations(next_obj) + multiple_sub_threads.extend(multiple_threads) + if len(multiple_sub_threads) != 0: + for sub_thread in multiple_sub_threads: + all_conversations.append(copy.deepcopy(turns) + sub_thread) + else: + all_conversations.append(copy.deepcopy(turns)) + return all_conversations + + +def get_data_records(objs): + ## TODO: old format was multi-conversation per example, but ours is single conversation + ## is this just because of the input data format? + output = [] + for obj in objs: + multi_conversations = parse_conversations(obj, first=True) + for conversations in multi_conversations: + if len(conversations) <= 2: + # remove single turn conversations + ## system prompt is always first turn + continue + + conversation_obj = { + "messages": conversations, + } + output.append(conversation_obj) + return output + + +def download_and_process_oasst(output_directory=".", seed=42, split_ratio=0.95): + os.makedirs(output_directory, exist_ok=True) + filename = f"{output_directory}/2023-04-12_oasst_all.trees.jsonl.gz" + + # only download if doesn't exist + if not os.path.isfile(filename): + url = "https://huggingface.co/datasets/OpenAssistant/oasst1/resolve/main/2023-04-12_oasst_all.trees.jsonl.gz" + response = requests.get(url) + with open(filename, mode="wb") as fw: + fw.write(response.content) + + with gzip.open(filename) as f: + file_content = f.readlines() + + all_objs = [json.loads(dp.decode("utf-8")) for dp in file_content] + + random.seed(seed) + random.shuffle(all_objs) + train_num = int(len(all_objs) * split_ratio) + train_objs = all_objs[:train_num] + val_objs = all_objs[train_num:] + train_records = get_data_records(train_objs) + val_records = get_data_records(val_objs) + + formatted_ds = { + "train": train_records, + "validation": val_records, + } + + return formatted_ds + + +@dataclass +class OasstDataset(HfDataset): + def __init__(self, output_dir: str = "."): + self.formatted_ds = download_and_process_oasst(output_dir) + super().__init__( + dataset_name="oasst", + custom_template=COMMON_CHAT_TEMPLATES.simple_role_header, + ) diff --git a/nemo_reinforcer/data/hf_datasets/squad.py b/nemo_reinforcer/data/hf_datasets/squad.py new file mode 100644 index 0000000000..a1378761d1 --- /dev/null +++ b/nemo_reinforcer/data/hf_datasets/squad.py @@ -0,0 +1,49 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +from datasets import load_dataset +from nemo_reinforcer.data.hf_datasets.interfaces import HfDataset + + +def format_squad(data): + return { + "messages": [ + { + "role": "system", + "content": data["context"], + }, + { + "role": "user", + "content": data["question"], + }, + { + "role": "assistant", + "content": data["answers"]["text"][0], + }, + ] + } + + +class SquadDataset(HfDataset): + def __init__(self): + original_ds = load_dataset("rajpurkar/squad") + self.formatted_ds = original_ds.map(format_squad) + + custom_template = "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer: '}}{%- elif message['role'] == 'assistant' %}{{message['content'].strip()}}{%- endif %}{% endfor %}" + + super().__init__( + dataset_name="squad", + custom_template=custom_template, + ) diff --git a/nemo_reinforcer/data/interfaces.py b/nemo_reinforcer/data/interfaces.py new file mode 100644 index 0000000000..ade6f145e8 --- /dev/null +++ b/nemo_reinforcer/data/interfaces.py @@ -0,0 +1,88 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, TypedDict, Optional, Union, Protocol +import os + +import torch + +# OpenAI-API-like message log, but every messsage may contain associated tensors (i.e. tokenized strings and logprobs) in addition to the original "content" string +LLMMessageLogType = List[Dict[str, Union[str, torch.Tensor]]] + +# Flattened message log where all tensors and data are concatenated together for a conversation +# Converts a conversation from list-of-turns format to key-value format with concatenated tensors +FlatMessagesType = Dict[str, Union[List[str], torch.Tensor]] + + +class DatumSpec(TypedDict): + message_log: LLMMessageLogType + length: int # total (concatenated) length of the message tensors + extra_env_info: Dict[str, Any] + loss_multiplier: float # multiplier for the loss for this datum. 0 to mask out (say the sample is invalid) + idx: int + task_name: Optional[str] = "default" + __extra__: Any # This allows additional fields of any type + + +@dataclass +class TaskDataSpec: + task_name: Optional[str] = None + # prompt + prompt_file: Optional[os.PathLike] = None + + system_prompt_file: Optional[Union[str, os.PathLike]] = None + custom_template: Optional[Union[str, os.PathLike]] = None + + def __post_init__(self): + def load_prompt_file( + prompt_file: Optional[os.PathLike], + ) -> Optional[str]: + """Load prompt from file if it exists, otherwise return as is.""" + if prompt_file is None: + return None + if os.path.exists(prompt_file): + with open(prompt_file, "r", encoding="utf-8") as f: + return f.read() + else: + raise FileNotFoundError(f"Prompt file {prompt_file} not found") + + # Load prompts from files if they exist + self.system_prompt = load_prompt_file(self.system_prompt_file) + self.prompt = load_prompt_file(self.prompt_file) + + def copy_defaults(self, from_spec: "TaskDataSpec"): + """Apply default values from another Task instance for any None attributes.""" + default_attrs = { + "system_prompt": from_spec.system_prompt, + "prompt": from_spec.prompt, + "custom_template": from_spec.custom_template, + } + + for attr_name, default_value in default_attrs.items(): + if getattr(self, attr_name) is None: + setattr(self, attr_name, default_value) + + +class TaskDataProcessFnCallable(Protocol): + """A callable that processes a loaded datum dictionary into a DatumSpec.""" + + def __call__( + self, + datum_dict: Dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer, + max_seq_length: int, + idx: int, + ) -> DatumSpec: + raise NotImplementedError("Task data process not implemented") diff --git a/nemo_reinforcer/data/llm_message_utils.py b/nemo_reinforcer/data/llm_message_utils.py new file mode 100644 index 0000000000..43e24fc1ce --- /dev/null +++ b/nemo_reinforcer/data/llm_message_utils.py @@ -0,0 +1,392 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Union + + +import torch + +from nemo_reinforcer.data.interfaces import ( + LLMMessageLogType, + FlatMessagesType, + TaskDataSpec, +) +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict + + +def message_log_to_flat_messages( + message_log: LLMMessageLogType, +) -> FlatMessagesType: + """Converts a message log (sequence of message turns) into a flattened representation. + + This function takes a message log (list of dict messages with 'role', 'content', 'token_ids', etc.) + and converts it to a flat dictionary where all tensors of the same key are concatenated and + all strings of the same key are put into lists. + + Args: + message_log: List of message dictionaries with 'role', 'content', and potentially 'token_ids' + + Returns: + FlatMessagesType: Dictionary mapping keys to concatenated tensors and string lists + + Examples: + ```{doctest} + >>> import torch + >>> from nemo_reinforcer.data.llm_message_utils import message_log_to_flat_messages + >>> # Create a simple message log with two messages + >>> message_log = [ + ... {'role': 'user', 'content': 'Hello', 'token_ids': torch.tensor([1, 2, 3])}, + ... {'role': 'assistant', 'content': 'Hi there', 'token_ids': torch.tensor([4, 5, 6, 7])} + ... ] + >>> flat_msgs = message_log_to_flat_messages(message_log) + >>> flat_msgs['role'] + ['user', 'assistant'] + >>> flat_msgs['content'] + ['Hello', 'Hi there'] + >>> flat_msgs['token_ids'] + tensor([1, 2, 3, 4, 5, 6, 7]) + ``` + """ + result = {} + + if len(message_log) == 0: + return result + + # Get all unique keys across all messages + all_keys = set() + for msg in message_log: + all_keys.update(msg.keys()) + + # Initialize result with empty lists for each key + for key in all_keys: + result[key] = [] + + # Collect values for each key + for msg in message_log: + for key in all_keys: + if key in msg: + result[key].append(msg[key]) + + # Concatenate tensors for each key + for key in result: + if result[key] and isinstance(result[key][0], torch.Tensor): + try: + result[key] = torch.cat(result[key]) + except RuntimeError as e: + if "same number of dimensions" in str(e): + raise RuntimeError( + f"tensors for {key=} must have same number of dimensions: {[t.shape for t in result[key]]}" + ) from e + raise + + return result + + +def get_keys_from_message_log( + message_log: LLMMessageLogType, keys: List[str] +) -> LLMMessageLogType: + """Return a new LLMMessageLogType containing only the specified keys from each message. + + Args: + keys: List of keys to keep in each message + + Returns: + LLMMessageLogType: New list with only specified keys + """ + return [{k: msg[k] for k in keys if k in msg} for msg in message_log] + + +def add_loss_mask_to_message_log( + message_log: LLMMessageLogType, + roles_to_train_on: List[str] = ["assistant"], +) -> None: + """Add token-level loss masks to each message in a message log. + + Args: + message_log (LLMMessageLogType): List of message dictionaries containing token IDs and metadata + roles_to_train_on (List[str]): List of strings indicating which speakers to unmask. Default: ["assistant"] + """ + for i, role in enumerate(roles_to_train_on): + roles_to_train_on[i] = role.lower() + + for message in message_log: + for sentence in message: + if sentence["role"] in roles_to_train_on: + sentence["token_loss_mask"] = torch.ones_like(sentence["token_ids"]) + else: + sentence["token_loss_mask"] = torch.zeros_like(sentence["token_ids"]) + + +def _pad_tensor( + tensor: torch.Tensor, + max_len: int, + pad_side: str, + pad_value: int = 0, +) -> torch.Tensor: + """Pad a tensor to the specified length. + + Args: + tensor: Tensor to pad + max_len: Length to pad to + pad_side: Whether to pad on the 'left' or 'right' + pad_value: Value to use for padding + + Returns: + torch.Tensor: Padded tensor + """ + pad_len = max_len - tensor.size(0) + if pad_len <= 0: + return tensor + + padding = torch.full( + (pad_len, *tensor.shape[1:]), + pad_value, + dtype=tensor.dtype, + device=tensor.device, + ) + return torch.cat( + [padding, tensor] if pad_side == "left" else [tensor, padding], dim=0 + ) + + +def _validate_tensor_consistency(tensors: List[torch.Tensor]) -> None: + """Validate that all tensors have consistent dtypes and devices. + + Args: + tensors: List of tensors to validate + + Raises: + RuntimeError: If tensors have different dtypes or devices + """ + if not tensors: + return + + first = tensors[0] + if not all(t is None or t.dtype == first.dtype for t in tensors): + raise RuntimeError( + f"expected consistent types but got: {[t.dtype for t in tensors]}" + ) + if not all(t is None or t.device == first.device for t in tensors): + raise RuntimeError( + f"expected tensors on the same device but got: {[t.device for t in tensors]}" + ) + + +def batched_message_log_to_flat_message( + message_log_batch: List[LLMMessageLogType], + pad_value_dict: Dict[str, int] = None, +) -> tuple[BatchedDataDict[FlatMessagesType], torch.Tensor]: + """Process and pad a batch of message logs for model input. + + For each message log in the batch: + 1. Converts it to a flat representation using message_log_to_flat_messages + 2. Pads all resulting tensors to the same length for batching + 3. Returns a BatchedDataDict and sequence lengths tensor + + Padding is always applied to the right side of sequences. + + Args: + message_log_batch: List of LLMMessageLogType (each a conversation with multiple turns) + pad_value_dict: Dictionary mapping keys to padding values (default is 0) + + Returns: + BatchedDataDict[FlatMessagesType]: Dictionary containing padded stacked tensors + torch.Tensor: Input lengths tensor with shape [batch_size] (pre-padding lengths) + + Raises: + RuntimeError: If tensors have different dtypes or devices + + Examples: + ```{doctest} + >>> import torch + >>> from nemo_reinforcer.data.llm_message_utils import batched_message_log_to_flat_message + >>> from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict + >>> # Create a batch of two message logs with different lengths + >>> message_log_batch = [ + ... # First conversation + ... [ + ... {'role': 'user', 'content': 'What is 2+2?', 'token_ids': torch.tensor([1, 2, 3, 4, 5])}, + ... {'role': 'assistant', 'content': '4', 'token_ids': torch.tensor([6, 7])} + ... ], + ... # Second conversation + ... [ + ... {'role': 'user', 'content': 'Solve x+10=15', 'token_ids': torch.tensor([1, 8, 9, 10, 11, 12])}, + ... {'role': 'assistant', 'content': 'x=5', 'token_ids': torch.tensor([13, 14, 15])} + ... ] + ... ] + >>> pad_value_dict = {'token_ids': 0} + >>> batched_flat, input_lengths = batched_message_log_to_flat_message(message_log_batch, pad_value_dict) + >>> batched_flat['token_ids'][0].tolist() + [1, 2, 3, 4, 5, 6, 7, 0, 0] + >>> batched_flat['token_ids'][1].tolist() + [1, 8, 9, 10, 11, 12, 13, 14, 15] + >>> batched_flat['content'][0] + ['What is 2+2?', '4'] + >>> batched_flat['content'][1] + ['Solve x+10=15', 'x=5'] + >>> batched_flat['role'] + [['user', 'assistant'], ['user', 'assistant']] + >>> input_lengths + tensor([7, 9], dtype=torch.int32) + >>> + ``` + """ + if not message_log_batch: + return BatchedDataDict(), torch.empty(0) + + # Process each message log into a flat representation + sequenced_lists = [message_log_to_flat_messages(ml) for ml in message_log_batch] + all_keys = {k for seq in sequenced_lists for k in seq} + + # Find max length and identify tensor keys + max_len = 0 + tensor_keys = [] + for seq in sequenced_lists: + for key, value in seq.items(): + if isinstance(value, torch.Tensor): + tensor_keys.append(key) + max_len = max(max_len, value.size(0)) + + # Handle non-tensor case + if not tensor_keys: + result = BatchedDataDict( + { + k: [seq[k][0] if k in seq else None for seq in sequenced_lists] + for k in all_keys + } + ) + return result, torch.empty(0) + + # Create input_lengths tensor + input_lengths = [] + for seq in sequenced_lists: + seq_len = next( + (v.size(0) for v in seq.values() if isinstance(v, torch.Tensor)), 0 + ) + input_lengths.append(seq_len) + input_lengths_tensor = torch.tensor(input_lengths, dtype=torch.int32) + + # Process each key + result = BatchedDataDict() + for key in all_keys: + values = [seq.get(key) for seq in sequenced_lists] + if not values or not isinstance(values[0], torch.Tensor): + result[key] = values + continue + + # Filter out None values and validate consistency + tensors = [t for t in values if t is not None] + _validate_tensor_consistency(tensors) + + # Create zero tensors for None values + values = [ + ( + torch.zeros(0, dtype=tensors[0].dtype, device=tensors[0].device) + if v is None + else v + ) + for v in values + ] + + # Pad and stack tensors (always right padding) + pad_value = pad_value_dict.get(key, 0) if pad_value_dict else 0 + padded = [_pad_tensor(t, max_len, "right", pad_value) for t in values] + result[key] = torch.stack(padded) + + return result, input_lengths_tensor + + +def message_log_shape(message_log: LLMMessageLogType) -> List[Dict[str, List[int]]]: + """Get the shape of the tensors in the message log. + + This utility function examines each message in the message log and reports + the shape of tensor values or recursively processes list values. + + Args: + message_log: The message log to analyze + + Returns: + List of dictionaries containing tensor shapes for each key in messages + """ + shapes = [] + for message in message_log: + shape = {} + for k in message.keys(): + if isinstance(message[k], torch.Tensor): + shape[k] = message[k].shape + elif isinstance(message[k], list): + shape[k] = [message_log_shape(v) for v in message[k]] + shapes.append(shape) + return shapes + + +def get_first_index_that_differs(str1, str2): + """Get the first index that differs between two strings.""" + for i, (c1, c2) in enumerate(zip(str1, str2)): + if c1 != c2: + return i + return min(len(str1), len(str2)) + + +def get_formatted_message_log( + message_log: LLMMessageLogType, + tokenizer, + task_data_spec: TaskDataSpec, +) -> LLMMessageLogType: + """Format and tokenize chat messages using the specified template. + + Args: + message_log: List of message dicts with 'role' and 'content' keys + tokenizer: Tokenizer for converting text to token IDs + task_data_spec: Task spec for this dataset. + + Returns: + The message log with updated 'token_ids' and 'content' fields. + """ + cu_message = [] + prev_formatted_message = "" + template = task_data_spec.custom_template + + for i, message in enumerate(message_log): + cu_message.append(message.copy()) + formatted_message = tokenizer.apply_chat_template( + cu_message, + chat_template=template, + add_generation_prompt=False, + tokenize=False, + add_special_tokens=False, + ) + + ## get the length of the previous message, excluding the eos token (if present) + prev_message_len_no_eos = get_first_index_that_differs( + prev_formatted_message, + formatted_message, + ) + + ## pull out the chunk corresponding to the current message + message_chunk = formatted_message[prev_message_len_no_eos:] + + if i == 0 and not message_chunk.startswith(tokenizer.bos_token): + message_chunk = tokenizer.bos_token + message_chunk + + if i == len(message_log) - 1: + message_chunk = message_chunk.rstrip("\n") + if not message_chunk.endswith(tokenizer.eos_token): + message_chunk += tokenizer.eos_token + message["token_ids"] = tokenizer( + message_chunk, return_tensors="pt", add_special_tokens=False + )["input_ids"][0] + message["content"] = message_chunk + prev_formatted_message = formatted_message + + return message_log diff --git a/nemo_reinforcer/distributed/__init__.py b/nemo_reinforcer/distributed/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_reinforcer/distributed/batched_data_dict.py b/nemo_reinforcer/distributed/batched_data_dict.py new file mode 100644 index 0000000000..a1711ae8c2 --- /dev/null +++ b/nemo_reinforcer/distributed/batched_data_dict.py @@ -0,0 +1,297 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from copy import deepcopy +from collections import UserDict +from typing import List, Dict, Optional, Iterator, TypeVar, Any, Generic +from typing_extensions import Self + +import torch + +from nemo_reinforcer.distributed.collectives import ( + rebalance_nd_tensor, + gather_jagged_object_lists, +) + +DictT = TypeVar("DictT", bound=Dict[str, Any]) + + +class BatchedDataDict(UserDict, Generic[DictT]): + @classmethod + def from_batches( + cls: Self, + batches: List[Dict], + pad_value_dict: Optional[Dict[str, int]] = None, + ) -> Self: + """Given a list of batches, stack the tensors/lists within and put them in a single dictionary. + + Pad sequences to the max length in the batch using either 0(default) or a non-default value for a given key provided in pad_value_dict. + + Args: + batches (List[Dict]): A list of dictionaries, each containing a batch of data. + pad_value_dict (Optional[Dict[str, int]]): An optional dict mapping keys to non-default(0) padding values. + + Returns: + BatchedDataDict: A new BatchedDataDict containing the stacked data. + """ + stacked_dict = cls() + pad_value_dict = pad_value_dict or {} + + for k in sorted(batches[0]): + list_of_tensors = [item[k] for item in batches] + + if isinstance(list_of_tensors[0], list): + tensor = [item for sublist in list_of_tensors for item in sublist] + elif all(x.ndim == 1 for x in list_of_tensors): + tensor = torch.cat(list_of_tensors) + elif isinstance(list_of_tensors[0], torch.Tensor): + pad_value = pad_value_dict.get(k, 0) + + list_of_tensors = [ + row.flatten() for tensor in list_of_tensors for row in tensor + ] + # TODO: can we avoid padding locally then padding globally? + tensor = torch.nn.utils.rnn.pad_sequence( + list_of_tensors, batch_first=True, padding_value=pad_value + ) + else: + raise NotImplementedError( + ( + f"Attempted to stack for unsupported type {type(list_of_tensors[0])} with key {k}." + "Please provide either a tensor or a list of picklable objects." + ) + ) + stacked_dict[k] = tensor + + return stacked_dict + + def all_gather(self, group: torch.distributed.ProcessGroup) -> "BatchedDataDict": + """Gathers batches with possibly jagged leading dimensions across the DP ranks. + + If using reshard, it will treat PP as DP ranks. + Works with data that is either tensors or string lists. + """ + global_rollout_batch = type(self)() + + for k, value in self.data.items(): + if isinstance(value, torch.Tensor): + value = rebalance_nd_tensor(value, group=group) + global_rollout_batch[k] = value + elif isinstance(value, list): + value = gather_jagged_object_lists(value, group=group) + global_rollout_batch[k] = value + else: + raise NotImplementedError( + ( + f"Attempted to gather_and_balance_globally for unsupported type {type(value)} with key {k}." + "Please provide either a tensor or a list of picklable objects." + ) + ) + + return global_rollout_batch + + def chunk(self, rank: int, chunks: int) -> "SlicedDataDict": + """Chunks a global batch into 'chunks' splits and returns the 'rank'th split batch=[A A A B B B D D E], rank=2, chunks=3 -> [D D E]. + + Requires all leading dimensions of tensors and lengths of lists to be the same over the batch + and the chunks must divide batch size. + """ + chunked_batch = SlicedDataDict() + + batch_set = set() + for val in self.data.values(): + if isinstance(val, torch.Tensor): + batch_set.add(val.size(0)) + else: + batch_set.add(len(val)) + + assert len(batch_set) == 1, ( + "batch sizes are not the same across the rollout batch" + ) + B = batch_set.pop() + assert B % chunks == 0, ( + f"batch size ({B}) is not a multiple of chunks ({chunks})" + ) + assert B // chunks > rank, ( + f"index OOB: not enough splits for this rank. rollout_batch_size: {B}, chunks ({chunks}), rank_idx ({rank})" + ) + + indices = torch.arange(B).tensor_split(chunks)[rank] + + for k in self.data: + if torch.is_tensor(self.data[k]): + chunked_batch[k] = self.data[k][indices].clone() + else: + chunked_batch[k] = [self.data[k][i] for i in indices] + + return chunked_batch + + def shard_by_batch_size( + self, shards: int, batch_size: Optional[int] = None + ) -> List["SlicedDataDict"]: + """Shards a batch by first dividing it into chunks of size batch_size, then further dividing each chunk into shards equal parts. Finally aggregates the sub-shards by their position. + + If batch_size is None, there will be no chunking beforehand (will default to the total batch size). + + For example, with data [A A B B C C D D], batch_size=2, shards=2: + - Element 0: [A B C D] (first elements from each chunk) + - Element 1: [A B C D] (second elements from each chunk) + + Args: + shards (int): The number of shards to divide each batch_size chunk into. + batch_size (int): The size of each initial chunk. + + Returns: + List[BatchedDataDict]: A list of BatchedDataDicts, length equal to shards. + """ + # Get the total batch size + batch_sizes = set() + for val in self.data.values(): + if isinstance(val, torch.Tensor): + batch_sizes.add(val.size(0)) + else: + batch_sizes.add(len(val)) + + assert len(batch_sizes) == 1, ( + "Batch sizes are not the same across the rollout batch" + ) + total_batch_size = batch_sizes.pop() + if batch_size is None: + batch_size = total_batch_size + + # Validate that our batch size parameters are compatible with the data dimensions + assert total_batch_size % batch_size == 0, ( + f"Total batch size ({total_batch_size}) is not a multiple of batch_size ({batch_size})" + ) + assert batch_size % shards == 0, ( + f"Batch size ({batch_size}) is not a multiple of shards ({shards})" + ) + + num_chunks = total_batch_size // batch_size + shard_size = batch_size // shards + # Create one BatchedDataDict per shard position + aggregated_shards = [SlicedDataDict() for _ in range(shards)] + + # Group data by shard position across all chunks + for shard_idx in range(shards): + for chunk_idx in range(num_chunks): + # Calculate indices for this particular sub-shard within the chunk + chunk_start = chunk_idx * batch_size + shard_start = chunk_start + shard_idx * shard_size + shard_end = chunk_start + (shard_idx + 1) * shard_size + indices = torch.arange(shard_start, shard_end) + + for k in self.data: + if k not in aggregated_shards[shard_idx]: + # First time seeing this key for this shard, initialize it + if torch.is_tensor(self.data[k]): + aggregated_shards[shard_idx][k] = self.data[k][ + indices + ].clone() + else: + aggregated_shards[shard_idx][k] = [ + self.data[k][i] for i in indices + ] + else: + # Append to existing data - concatenate tensors or extend lists + if torch.is_tensor(self.data[k]): + aggregated_shards[shard_idx][k] = torch.cat( + [ + aggregated_shards[shard_idx][k], + self.data[k][indices].clone(), + ] + ) + else: + aggregated_shards[shard_idx][k].extend( + [self.data[k][i] for i in indices] + ) + + return aggregated_shards + + def slice(self, start: int, end: int) -> "SlicedDataDict": + """Slices the batch from start to end. + + Args: + start: Starting index (inclusive) + end: Ending index (exclusive) + + Returns: + BatchedDataDict: A new BatchedDataDict containing the sliced data + """ + sliced_batch = SlicedDataDict() + for k in self.data: + sliced_batch[k] = self.data[k][start:end] + return sliced_batch + + def repeat_interleave(self, num_repeats: int) -> "BatchedDataDict": + """Repeats the batch num_repeats times. + + For each element in the batch, repeat each value num_repeats times. + i.e: + {"key": torch.tensor([1, 2, 3]), "other_key": [1, 2, 3]} -> {"key": torch.tensor([1, 1, 2, 2, 3, 3]), "other_key": [1, 1, 2, 2, 3, 3]} + """ + repeated_batch = BatchedDataDict() + for k, v in self.data.items(): + if torch.is_tensor(v): + # For tensors, use repeat_interleave to repeat each element + repeated_batch[k] = v.repeat_interleave(num_repeats, dim=0) + else: + # For lists or other sequences, use a list comprehension to repeat each element + repeated_batch[k] = [ + deepcopy(item) for item in v for _ in range(num_repeats) + ] + return repeated_batch + + def make_microbatch_iterator( + self, microbatch_size: int + ) -> Iterator["SlicedDataDict"]: + """Make an iterator over the batch that yields microbatches of size microbatch_size.""" + bsize = self.size + assert bsize % microbatch_size == 0, ( + f"Data dict size ({bsize}) is not a multiple of the provided microbatch size ({microbatch_size})" + ) + for i in range(0, bsize, microbatch_size): + yield self.slice(i, i + microbatch_size) + + @property + def size(self) -> int: + """Get the batch size of the batch.""" + # Get the first key and use its size as the batch size + # This assumes all keys have the same batch size + key = next(iter(self.data)) + if not self.data: + return 0 + if not torch.is_tensor(self.data[key]): + return len(self.data[key]) + return self.data[key].shape[0] + + def to(self, device: torch.device) -> "BatchedDataDict": + """Move all tensors in the batch to a specific device.""" + for k in self.data: + if torch.is_tensor(self.data[k]): + self.data[k] = self.data[k].to(device) + return self + + def get_dict(self) -> dict: + """Get the underlying data dictionary.""" + return self.data + + +class SlicedDataDict(BatchedDataDict): + """A specialized subclass of BatchedDataDict that represents a slice or shard of a larger batch. + + This class provides a distinct type to differentiate between full batches and sliced/sharded batches, which can be helpful for + type checking. + """ + + pass diff --git a/nemo_reinforcer/distributed/collectives.py b/nemo_reinforcer/distributed/collectives.py new file mode 100644 index 0000000000..fddf3fe881 --- /dev/null +++ b/nemo_reinforcer/distributed/collectives.py @@ -0,0 +1,96 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch + + +def rebalance_nd_tensor( + tensor: torch.Tensor, group: Optional[torch.distributed.ProcessGroup] = None +): + """Takes tensors with variable leading sizes (at dim=0) and stacks them into a single tensor. + + This function handles the case where different GPUs have tensors with different batch sizes + and combines them into a single balanced tensor across all ranks. + + For example, with 3 GPUs: + GPU0: tensor of shape [3, D] + GPU1: tensor of shape [5, D] + GPU2: tensor of shape [2, D] + + After rebalancing: + All GPUs will have the same tensor of shape [10, D] (3+5+2=10) + + NOTE: assumes all other (i.e., non-zero) dimensions are equal. + """ + num_samples = torch.as_tensor( + tensor.size(0), dtype=torch.int64, device=torch.cuda.current_device() + ) + batch_num_per_rank = torch.zeros( + torch.distributed.get_world_size(group), + dtype=torch.int64, + device=torch.cuda.current_device(), + ) + torch.distributed.all_gather_into_tensor( + batch_num_per_rank, num_samples, group=group + ) + + B = batch_num_per_rank.sum() + other_dims = tensor.shape[1:] + + indices = batch_num_per_rank.cumsum(dim=0) + output_tensor = torch.zeros( + B, *other_dims, dtype=tensor.dtype, device=torch.cuda.current_device() + ) + + # tensor_split is a view we can copy into + output_tensor.tensor_split(indices[0:-1].cpu())[ + torch.distributed.get_rank(group=group) + ].copy_(tensor) + torch.distributed.all_reduce(output_tensor, group=group) + return output_tensor + + +def gather_jagged_object_lists( + local_objects: list, group: Optional[torch.distributed.ProcessGroup] = None +): + """Gathers jagged lists of picklable objects from all ranks and flattens them into a single list. + + This function handles the case where different GPUs have lists of different lengths + and combines them into a single list containing all objects from all ranks. + + For example, with 3 GPUs: + GPU0: [obj0, obj1] + GPU1: [obj2, obj3, obj4] + GPU2: [obj5] + + After gathering: + All GPUs will have: [obj0, obj1, obj2, obj3, obj4, obj5] + + WARNING: synchronous + + Args: + local_objects: List of objects to gather from current rank + group: Optional process group + + Returns: + Flattened list of all objects from all ranks in order [rank0, rank1, ...] + """ + # Gather all lists across ranks + world_size = torch.distributed.get_world_size(group=group) + gathered_lists = [None] * world_size + torch.distributed.all_gather_object(gathered_lists, local_objects, group=group) + + # Flatten into single list while preserving order + return [obj for sublist in gathered_lists for obj in sublist] diff --git a/nemo_reinforcer/distributed/virtual_cluster.py b/nemo_reinforcer/distributed/virtual_cluster.py new file mode 100644 index 0000000000..3c45697345 --- /dev/null +++ b/nemo_reinforcer/distributed/virtual_cluster.py @@ -0,0 +1,499 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, TypedDict, Optional + +from copy import deepcopy +import os +import ray +import logging +from ray.util.placement_group import placement_group, remove_placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class ClusterConfig(TypedDict): + gpus_per_node: int + num_nodes: int + + +# Get the directory path of the current module and the root of the package +dir_path = os.path.dirname(os.path.abspath(__file__)) +git_root = os.path.abspath(os.path.join(dir_path, "../..")) + +UV_CACHE_DIR = os.environ.get("UV_CACHE_DIR", None) +uv_cache_flag = f"--cache-dir {UV_CACHE_DIR}" if UV_CACHE_DIR else "" + +if "VIRTUAL_ENV" not in os.environ: + raise EnvironmentError( + "VIRTUAL_ENV environment variable not found. This variable is required and can be set by:\n" + "1. (HIGHLY RECOMMENDED) Running with 'uv run'\n" + "2. Activating a virtual environment with 'source .venv/bin/activate'\n" + "3. Setting it manually (e.g., export VIRTUAL_ENV=/path/to/venv)\n\n" + "If set manually, we will look for the Python binary at $VIRTUAL_ENV/bin/python" + ) + + +# --with-editable .: speeds up the install slightly since editable installs don't require full copies +# --cache-dir $UV_CACHE_DIR: caching isn't propagated by default. This will set it if the user has set it. +class PY_EXECUTABLES: + # This uses the .venv created by `uv`. This is the fastest option, but provides no isolation between workers. + DEFAULT_VENV = f"{os.environ['VIRTUAL_ENV']}/bin/python" + + # TODO: Debug why run-to-run variance is so high with these options + # Use NeMo-Reinforcer direct dependencies and nothing from system + DEFAULT = f"uv run --isolated --with-editable . {uv_cache_flag}" + # Use none of NeMo-Reinforcer's dependencies or the system. Useful for workers that only need standard python packages. + BARE_BONES = f"uv run --isolated --no-project --with-editable . {uv_cache_flag}" + + +@ray.remote +def _get_node_ip_and_free_port(): + import socket + + # Get the IP address of the current node + # Use socket.gethostbyname(socket.gethostname()) as a fallback + node_ip = socket.gethostbyname(socket.gethostname()) + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) # Bind to port 0 to get a random free port + s.listen(1) + port = s.getsockname()[1] + return node_ip, port + + +def init_ray(log_dir: Optional[str] = None): + """Initialize Ray and connect to an existing Ray cluster or fall back and start a local one. Should be called before any ray API is called. + + This function: + 1. Gathers common environment variables needed for distributed training + 2. Sets up the working directory and Python executable + 3. Connects to an existing Ray cluster + """ + if "UV_CACHE_DIR" not in os.environ: + logging.warning("UV_CACHE_DIR is not set, using default cache dir") + + # Set up runtime environment + runtime_env = { + "env_vars": dict(os.environ), # Pass thru all user environment variables + "working_dir": git_root, + "py_executable": PY_EXECUTABLES.DEFAULT_VENV, + } + + # Initialize Ray connection + try: + # Try to connect to an existing cluster first. + ray.init( + address="auto", + log_to_driver=True, + include_dashboard=False, + runtime_env=runtime_env, + _temp_dir=os.path.abspath(log_dir) if log_dir else None, + ) + logger.info(f"Connected to existing Ray cluster: {ray.cluster_resources()}") + except ConnectionError: + # If no existing cluster, start a new one with local resources + ray.init( + log_to_driver=True, + include_dashboard=False, + runtime_env=runtime_env, + _temp_dir=os.path.abspath(log_dir) if log_dir else None, + ) + logger.info(f"Started local cluster with: {ray.cluster_resources()}") + + +class RayVirtualCluster: + """Creates a virtual distributed cluster using Ray placement groups. + + This class simplifies distributed training setup by: + - Creating placement groups that represent logical compute nodes + - Allocating GPU and CPU resources for distributed workers + - Managing communication between distributed processes + + - Bundle: A resource allocation unit (ex: 4 GPUs on a single node) + - Worker: A process that performs computation (model training/inference) + - Node: A physical or virtual machine containing multiple bundles + """ + + def __init__( + self, + bundle_ct_per_node_list: List[int], + use_gpus: bool = True, + max_colocated_worker_groups: int = 1, + num_gpus_per_node: int = 8, + name: str = "", + placement_group_strategy: str = "STRICT_PACK", + ): + """Initialize a virtual cluster using Ray placement groups. + + Args: + bundle_ct_per_node_list: List specifying GPU bundles per node + (e.g., [2,2] creates 2 nodes with 2 GPU bundles each) + use_gpus: Whether to allocate GPU resources + max_colocated_worker_groups: Maximum number of worker groups that can be colocated + num_gpus_per_node: Number of GPUs per node + name: Name prefix for placement groups + placement_group_strategy: Ray placement group strategy ("STRICT_PACK", "PACK", or "SPREAD") + """ + self._bundle_ct_per_node_list = bundle_ct_per_node_list + self._world_size = sum(self._bundle_ct_per_node_list) + self._node_placement_groups = None + + self.num_gpus_per_node = num_gpus_per_node + self.use_gpus = use_gpus + if use_gpus: + assert num_gpus_per_node > 0, ( + "num_gpus_per_node must be greater than 0 if using GPUs" + ) + self.max_colocated_worker_groups = max_colocated_worker_groups + self.name = name + self._init_placement_groups(placement_group_strategy) + + def _init_placement_groups(self, strategy: str): + """Creates placement groups for each node in the cluster. Has empty groups for nodes that don't have any bundles. + + Args: + strategy: Ray placement group strategy + + Returns: + List of placement groups, one per node + """ + if self._node_placement_groups is not None: + return self._node_placement_groups + + num_cpus_per_bundle = self.max_colocated_worker_groups + # num_gpus_per_bundle == 1 indicates that there is 1 GPU per process + num_gpus_per_bundle = 1 if self.use_gpus else 0 + + resources = [ + [ + {"CPU": num_cpus_per_bundle, "GPU": num_gpus_per_bundle} + for _ in range(bundle_count) + ] + for bundle_count in self._bundle_ct_per_node_list + ] + + self._node_placement_groups = [ + placement_group( + bundles=bundles, strategy=strategy, name=f"{self.name}-node-{i}" + ) + for i, bundles in enumerate(resources) + ] + + ray.get([pg.ready() for pg in self._node_placement_groups]) + return self._node_placement_groups + + def get_placement_groups(self): + """Returns a list of placement groups that have at least one bundle, filtering out empty nodes. + + This represents the "virtual cluster" - only nodes that are actually being used. + + Returns: + List of placement groups that have at least one bundle + """ + return [pg for pg in self._node_placement_groups if pg.bundle_specs] + + def world_size(self): + return self._world_size + + def node_count(self): + return len(self.get_placement_groups()) + + def get_master_address_and_port(self): + """Gets the master address and port for the distributed training setup. + + Returns: + Tuple of (address, port) + """ + # Get placement groups if not already created + if not self._node_placement_groups: + self.get_placement_groups() + + # Find first non-empty placement group + pg = self.get_placement_groups()[0] + if pg.bundle_specs: + # Launch port finder on the first bundle of this placement group + addr, port = ray.get( + _get_node_ip_and_free_port.options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, placement_group_bundle_index=0 + ), + # Need to explicitly set to 0 since it's possible for this to be unschedulable if all CPUs are already in use. + num_cpus=0, + ).remote() + ) + return addr, port + + raise RuntimeError("No valid placement groups found to get master address") + + def shutdown(self): + """Cleans up and releases all resources associated with this virtual cluster. + + This includes removing all placement groups and resetting the internal state. + + This method is idempotent and can be safely called multiple times. + """ + if self._node_placement_groups is not None: + # Remove all placement groups + for pg in self._node_placement_groups: + try: + remove_placement_group(pg) + except Exception as e: + # Log but continue if a placement group can't be removed + print(f"Error removing placement group {pg.id}: {e}") + + # Reset internal state + self._node_placement_groups = None + + return True + + def _create_visualization_grid(self, worker_groups=None, is_global_view=False): + """Create a visualization grid for the cluster with optional worker groups. + + Args: + worker_groups: Single worker group, list of worker groups, or None + is_global_view: Whether this is a global view (multiple worker groups) or single view + + Returns: + dict: A dictionary containing the grid data for display + """ + # Convert single worker group to list for uniform processing + if worker_groups is not None and not isinstance(worker_groups, list): + worker_groups = [worker_groups] + elif worker_groups is None: + worker_groups = [] + + # Find the maximum number of GPUs per node for grid layout + max_gpus_per_node = ( + max(self._bundle_ct_per_node_list) if self._bundle_ct_per_node_list else 0 + ) + if max_gpus_per_node == 0: + return {"empty": True} + + # Number of nodes with GPUs + active_nodes = sum(1 for count in self._bundle_ct_per_node_list if count > 0) + + # Determine cell width based on view type + cell_width = 12 if is_global_view else 7 + + # Create horizontal divider based on max GPUs per node + h_divider = "+" + "+".join(["-" * cell_width] * max_gpus_per_node) + "+" + + # Build the grid data + grid_data = { + "active_nodes": active_nodes, + "total_gpus": self.world_size(), + "worker_groups": worker_groups, + "max_gpus_per_node": max_gpus_per_node, + "cell_width": cell_width, + "h_divider": h_divider, + "is_global_view": is_global_view, + "rows": [], + } + + # For each node, create its row in the grid + for node_idx, bundle_count in enumerate(self._bundle_ct_per_node_list): + if bundle_count == 0: + continue + + # Initialize row data + node_row = { + "node_idx": node_idx, + "bundle_count": bundle_count, + "gpu_cells": [], + "worker_cells": [], + } + + # Initialize worker cells arrays (one per worker group) + for i in range(len(worker_groups)): + node_row["worker_cells"].append([]) + + # Process each GPU position in the row + for gpu_idx in range(max_gpus_per_node): + if gpu_idx < bundle_count: + # This is a real GPU + gpu_cell = f" {node_idx}.{gpu_idx} " + + # Process worker assignments for this GPU + worker_cells = self._get_worker_cells( + node_idx, gpu_idx, worker_groups, cell_width, is_global_view + ) + else: + # Empty cell (no GPU) + gpu_cell = " " * cell_width + worker_cells = [" " * cell_width] * len(worker_groups) + + # Add cells to the row + node_row["gpu_cells"].append(gpu_cell) + for i, cell in enumerate(worker_cells): + if i < len(node_row["worker_cells"]): + node_row["worker_cells"][i].append(cell) + + # Add the completed row to the grid + grid_data["rows"].append(node_row) + + return grid_data + + def _get_worker_cells( + self, node_idx, gpu_idx, worker_groups, cell_width, is_global_view + ): + """Get the worker cell content for each worker group at a specific GPU location. + + Args: + node_idx: The node index + gpu_idx: The GPU index within the node + worker_groups: List of worker groups to check + cell_width: Width of each cell for formatting + is_global_view: Whether this is a global view with multiple worker groups + + Returns: + list: List of formatted worker cells, one per worker group + """ + worker_cells = [] + + for wg_idx, worker_group in enumerate(worker_groups): + # Default empty worker cell + worker_cell = " " * cell_width + + # Find workers from this group assigned to this GPU + for worker_id, metadata in enumerate(worker_group.worker_metadata): + if ( + metadata["node_idx"] == node_idx + and metadata["local_rank"] == gpu_idx + ): + if is_global_view: + # Use group numbering in global view + worker_cell = f" G{wg_idx}:W{worker_id:<2d} " + else: + # Use simple worker IDs in single group view + worker_cell = f" W {worker_id:<2d} " + break + + worker_cells.append(worker_cell) + + return worker_cells + + def _print_visualization(self, grid_data): + """Print the visualization based on the grid data. + + Args: + grid_data: The grid data generated by _create_visualization_grid + """ + if grid_data.get("empty", False): + print("\nEmpty Ray Cluster (no GPUs)") + return + + # Print header + if grid_data["is_global_view"]: + # Global view header + wg_summary = "" + if grid_data["worker_groups"]: + wg_summary = f", Worker Groups: {len(grid_data['worker_groups'])}" + + print( + f"\nRay Cluster Global View: {grid_data['active_nodes']} nodes, {grid_data['total_gpus']} GPUs{wg_summary}" + ) + else: + # Single view header + wg_info = "" + if grid_data["worker_groups"]: + worker_group = grid_data["worker_groups"][0] + wg_name = getattr(worker_group, "name_prefix", "Default") or "Default" + wg_info = ( + f", Worker Group: {wg_name} ({worker_group.world_size} workers)" + ) + + print( + f"\nRay Cluster: {grid_data['active_nodes']} nodes, {grid_data['total_gpus']} GPUs{wg_info}" + ) + + # Print the top border + print(grid_data["h_divider"]) + + # Print each row of the grid + for row in grid_data["rows"]: + # Print GPU row + gpu_row = ["|"] + for cell in row["gpu_cells"]: + gpu_row.append(cell.ljust(grid_data["cell_width"])) + gpu_row.append("|") + print("".join(gpu_row)) + + # Print worker rows + for wg_idx, worker_cells in enumerate(row["worker_cells"]): + worker_row = ["|"] + for cell in worker_cells: + worker_row.append(cell.ljust(grid_data["cell_width"])) + worker_row.append("|") + print("".join(worker_row)) + + # Print divider between nodes + print(grid_data["h_divider"]) + + # Print legend + self._print_legend(grid_data) + + def _print_legend(self, grid_data): + """Print the legend for the visualization.""" + if grid_data["is_global_view"]: + # Legend for global view + if grid_data["worker_groups"]: + print("Legend:") + for wg_idx, wg in enumerate(grid_data["worker_groups"]): + wg_name = getattr(wg, "name_prefix", "unnamed") or "unnamed" + wg_count = wg.world_size + print(f"G{wg_idx}: {wg_name} ({wg_count} workers)") + print("W##: Worker ID within its group") + else: + # Legend for single worker group view + if grid_data["worker_groups"]: + wg_name = ( + getattr(grid_data["worker_groups"][0], "name_prefix", "") or "" + ) + print(f"W## = Worker ID in '{wg_name}' worker group") + + print("#.#: Node.GPU identifier") + + def print_cluster_grid(self, worker_group=None): + """Prints a compact grid visualization of the virtual cluster, similar to JAX's visualize_array_sharding. + + If a worker_group is provided, it will also show worker assignments on each device. + + Args: + worker_group: Optional RayWorkerGroup instance to visualize worker assignments + """ + grid_data = self._create_visualization_grid(worker_group, is_global_view=False) + self._print_visualization(grid_data) + + def print_all_worker_groups(self, worker_groups=None): + """Prints a visualization showing all worker groups in the cluster. + + This provides a global view of all workers across all worker groups. + + Args: + worker_groups: List of RayWorkerGroup instances to visualize. If None, + no worker assignments will be shown. + """ + grid_data = self._create_visualization_grid(worker_groups, is_global_view=True) + self._print_visualization(grid_data) + + def __del__(self): + """Shutsdown the virtual cluster when the object is deleted or is garbage collected. + + This is an extra safety net in case the user forgets to call shutdown and the pointer to + the cluster is lost due to leaving a function scope. It's always recommended that the + user calls shutdown(). + """ + self.shutdown() diff --git a/nemo_reinforcer/distributed/worker_groups.py b/nemo_reinforcer/distributed/worker_groups.py new file mode 100644 index 0000000000..9cc7798631 --- /dev/null +++ b/nemo_reinforcer/distributed/worker_groups.py @@ -0,0 +1,541 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Union, Dict, Any +import warnings +from dataclasses import dataclass + +import os +import ray +from copy import deepcopy +from ray.util.placement_group import PlacementGroup +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster +from nemo_reinforcer.distributed.batched_data_dict import SlicedDataDict + + +@dataclass +class MultiWorkerFuture: + """Container for Ray futures with associated worker information.""" + + futures: List[ray.ObjectRef] + used_workers: List[int] + respect_tied_workers: bool = True + + def get_results(self, worker_group): + """Get results from the futures, optionally respecting tied workers. + + When respect_tied_workers is True, this method deduplicates results by returning + only one result per tied worker group. + + The method uses worker_group.worker_to_tied_group_index to identify which tied + worker group each worker belongs to, then selects only the first result from each group. + + Args: + worker_group: The RayWorkerGroup that created this bundle + + Returns: + List of results, deduplicated by tied workers if respect_tied_workers is True + """ + # Basic case: Get all results + all_results = ray.get(self.futures) + + # If we don't need to deduplicate by tied workers, return all results + if not self.respect_tied_workers: + return all_results + + if not self.used_workers: + return all_results + + # Create tied worker sets based on used workers + active_tied_workers = {} + for i, worker_idx in enumerate(self.used_workers): + tied_worker_idx = worker_group.worker_to_tied_group_index.get(worker_idx) + if tied_worker_idx is None: + continue + + if tied_worker_idx not in active_tied_workers: + active_tied_workers[tied_worker_idx] = [] + active_tied_workers[tied_worker_idx].append(i) + + # Take the first result from each tied worker group + tied_worker_results = [] + for tied_worker_idx in sorted(active_tied_workers.keys()): + if active_tied_workers[tied_worker_idx]: + result_idx = active_tied_workers[tied_worker_idx][0] + tied_worker_results.append(all_results[result_idx]) + + return tied_worker_results + + +class RayWorkerBuilder: + def __init__(self, ray_actor_class: type, *args, **kwargs): + self.ray_actor_class = ray_actor_class + self.args = args + self.kwargs = kwargs + + def __call__( + self, + placement_group: PlacementGroup, + placement_group_bundle_index: int, + num_gpus: int, + bundle_indices: Optional[list] = None, + **extra_options: Dict[str, Any], + ): + """Create a Ray worker with the specified configuration. + + Order of precedence for worker options configuration (from lowest to highest): + 1. Options passed by the user to __call__ (extra_options) + 2. Options required by the worker via configure_worker (may override user options with warning) + 3. Options set by the RayWorkerBuilder.__call__ (specifically scheduling strategy) + + If the worker needs to override user-provided options, it should log a warning + to inform the user about the change and the reason for it. + + Args: + placement_group: Ray placement group for resource allocation + placement_group_bundle_index: Index of the bundle in the placement group + num_gpus: Number of GPUs to allocate to this worker + bundle_indices: List of bundle indices for tensor parallelism (if applicable) + extra_options: Additional options to pass to the Ray actor (may be overridden by actor's configure_worker(...) method) + + Returns: + A Ray actor reference to the created worker + """ + # Set up worker arguments and resources + worker_class = self.ray_actor_class + worker_kwargs = dict(self.kwargs) + options = deepcopy(extra_options) + + # Use the worker's configuration interface if available + if hasattr(worker_class, "configure_worker"): + # Get complete worker configuration from the worker class + resources, env_vars, init_kwargs = worker_class.configure_worker( + num_gpus=num_gpus, + bundle_indices=bundle_indices, + ) + + # Apply resource configuration + if resources and "num_gpus" in resources: + num_gpus = resources["num_gpus"] + + # Apply environment variables if provided + if env_vars: + if "runtime_env" not in options: + options["runtime_env"] = {} + options["runtime_env"]["env_vars"] = env_vars + + # Apply initialization parameters + if init_kwargs: + worker_kwargs.update(init_kwargs) + + # Create options for Ray actor + options["scheduling_strategy"] = PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_bundle_index=placement_group_bundle_index, + placement_group_capture_child_tasks=True, + ) + options["num_gpus"] = num_gpus + # If the user hasn't specified a py_executable, use the worker class's default + if not options.get("runtime_env", {}).get("py_executable", None) and hasattr( + worker_class, "DEFAULT_PY_EXECUTABLE" + ): + if "runtime_env" not in options: + options["runtime_env"] = {} + options["runtime_env"]["py_executable"] = worker_class.DEFAULT_PY_EXECUTABLE + + # Create and return the worker + return worker_class.options(**options).remote(*self.args, **worker_kwargs) + + +class RayWorkerGroup: + """Manages a group of distributed Ray worker/actor processes that execute tasks in parallel. + + This class creates and manages Ray actor instances that run on resources + allocated by a RayVirtualCluster. It handles: + - Worker creation and placement on specific GPU resources + - Setting up distributed training environment variables (rank, world size, etc.) + - Executing methods across all workers in parallel + - Collecting and aggregating results + - Support for tied worker groups where multiple workers process the same data + """ + + def __init__( + self, + cluster: RayVirtualCluster, + remote_worker_builder: RayWorkerBuilder, + workers_per_node: Optional[Union[int, List[int]]] = None, + name_prefix: str = "", + bundle_indices_list: Optional[List[tuple]] = None, + ): + """Initialize a group of distributed Ray workers. + + Args: + cluster: RayVirtualCluster + remote_worker_builder: Callable that launches a ray worker and has updatable options + workers_per_node: Defaults to launch one worker per bundle in the cluster. + Alternatively specify an int or list to launch a different number of workers per node. + name_prefix: Optional prefix for the names of the workers + bundle_indices_list: Explicit list of (node_idx, [local_bundle_indices]) tuples. + Each tuple defines a tied group of workers placed on the same node. + If provided, workers_per_node is ignored. + """ + self._workers = [] + self._worker_metadata = [] + self.cluster = cluster + self.name_prefix = name_prefix + self.tied_workers_groups = [] + + # Maps worker indices to their corresponding tied group index + # For example, if worker with index 3 belongs to tied worker group 1, + # then worker_to_tied_group_index[3] = 1 + self.worker_to_tied_group_index = {} + + # If explicit bundle indices are provided, use those + if bundle_indices_list is None: + # Create bundle_indices_list from workers_per_node specification + # In this case, each worker is its own group (no tied workers) + bundle_indices_list = [] + + # Determine how many workers per node + if workers_per_node is None: + workers_per_node = [ + pg.bundle_count for pg in self.cluster.get_placement_groups() + ] + elif isinstance(workers_per_node, int): + workers_per_node = [workers_per_node] * self.cluster.node_count() + elif not isinstance(workers_per_node, list): + raise ValueError( + "workers_per_node must be None(for default node distribution), an int, or a list" + ) + + # Validate workers_per_node + assert len(workers_per_node) == self.cluster.node_count(), ( + "workers_per_node_list must be the same length as the number of nodes in the virtual cluster" + ) + assert all( + [ + workers_per_node[i] <= pg.bundle_count + for i, pg in enumerate(self.cluster.get_placement_groups()) + ] + ), ( + "workers_per_node must be less than or equal to the number of bundles in the placement groups" + ) + + # Create bundle_indices_list where each worker is its own group + for node_idx, worker_count in enumerate(workers_per_node): + for local_idx in range(worker_count): + # Each worker is its own single-element group + bundle_indices_list.append((node_idx, [local_idx])) + + # Create workers based on the bundle_indices_list + self._create_workers_from_bundle_indices( + remote_worker_builder, bundle_indices_list + ) + + def _create_workers_from_bundle_indices( + self, remote_worker_builder, bundle_indices_list + ): + """Create workers based on explicit bundle indices for tied worker groups. + + Args: + remote_worker_builder: Builder function for Ray actors + bundle_indices_list: List of (node_idx, local_bundle_indices) tuples, where each tuple + specifies a tied group with its node and local bundle indices. + """ + self.master_address, self.master_port = ( + self.cluster.get_master_address_and_port() + ) + + # Count total workers + self.world_size = sum(len(indices) for _, indices in bundle_indices_list) + global_rank = 0 + + for group_idx, (node_idx, local_bundle_indices) in enumerate( + bundle_indices_list + ): + current_group = [] + + # Get the placement group for this node + pg = self.cluster.get_placement_groups()[node_idx] + is_tp_group = len(local_bundle_indices) > 1 + + for local_rank, bundle_idx in enumerate(local_bundle_indices): + # Set up basic distributed environment variables + env_vars = dict( + os.environ + ) # Pass thru all user environment variables (at the lowest precendence) + env_vars.update( + { + "RANK": str(global_rank), + "LOCAL_RANK": str(bundle_idx), + "WORLD_SIZE": str(self.world_size), + "MASTER_ADDR": self.master_address, + "MASTER_PORT": str(self.master_port), + "NODE_RANK": str(node_idx), + } + ) + + # For tensor parallel groups, only the first worker gets bundle_indices + worker_bundle_indices = ( + local_bundle_indices if local_rank == 0 else None + ) + + # Create a descriptive name based on group structure + name = ( + f"{self.name_prefix}-grp{group_idx}-{local_rank}" + if is_tp_group + else f"{self.name_prefix}-{node_idx}-{bundle_idx}" + ) + + # Calculate GPU resources + num_gpus = ( + 1 / self.cluster.max_colocated_worker_groups + if self.cluster.use_gpus + else 0 + ) + + # Pass these options to the remote_worker_builder + runtime_env = {"env_vars": env_vars} + extra_options = {"runtime_env": runtime_env, "name": name} + + # Create the worker + worker = remote_worker_builder( + placement_group=pg, + placement_group_bundle_index=bundle_idx, + num_gpus=num_gpus, + bundle_indices=worker_bundle_indices, + **extra_options, + ) + + # Store worker metadata + worker_idx = len(self._workers) + current_group.append(worker_idx) + self.worker_to_tied_group_index[worker_idx] = group_idx + self._workers.append(worker) + self._worker_metadata.append( + { + "node_idx": node_idx, + "local_rank": local_rank, + "global_rank": global_rank, + "name": name, + "bundle_indices": worker_bundle_indices, + "tied_group_idx": group_idx, + } + ) + + global_rank += 1 + + # Add this tied group to our list + self.tied_workers_groups.append(current_group) + + @property + def workers(self): + return self._workers + + @property + def worker_metadata(self): + return self._worker_metadata + + @property + def group_count(self): + """Number of tied worker groups.""" + return len(self.tied_workers_groups) + + def run_all_workers_multiple_data( + self, + method_name: str, + data: List[SlicedDataDict], + common_kwargs: Optional[Dict[str, Any]] = None, + respect_tied_workers: bool = True, + ): + """Run a method on all workers in parallel with different data. + + Args: + method_name: Name of the method to call on each worker + data: List of data slices to pass to workers/groups + common_kwargs: Additional keyword arguments to pass to all workers + respect_tied_workers: If True, only the leader (first worker) of each tied worker group + receives a data slice. If False, each worker gets its own data slice + regardless of tied worker groups. + + Returns: + MultiWorkerFuture: Object containing futures and their associated worker information + """ + # Verify that the data is a list of SlicedDataDict objects + if not all(isinstance(d, SlicedDataDict) for d in data): + warnings.warn( + f"Expected all elements in 'data' to be of type SlicedDataDict, but got " + f"{[type(d).__name__ for d in data]}. This may cause unexpected behavior. " + f"Please use make sure you're passing in Sharded Data to this function (and not replicated data)", + UserWarning, + ) + + if common_kwargs is None: + common_kwargs = {} + + futures = [] + used_workers = [] + + # Handle tied worker groups if requested + if respect_tied_workers: + # If there are fewer data slices than tied worker groups, use only the first N tied worker groups + active_tied_worker_count = min(len(data), len(self.tied_workers_groups)) + if active_tied_worker_count < len(self.tied_workers_groups): + print( + f"Warning: Using only {active_tied_worker_count} of {len(self.tied_workers_groups)} tied worker groups due to limited data slices" + ) + + # For each tied worker group, all workers in the group get the same data slice + for tied_worker_idx in range(active_tied_worker_count): + tied_worker_group = self.tied_workers_groups[tied_worker_idx] + tied_worker_data = data[tied_worker_idx] + + # Running only on the leader of the tied worker group for vllm case + futures.append( + getattr(self._workers[tied_worker_group[0]], method_name).remote( + tied_worker_data, **common_kwargs + ) + ) + used_workers.append(tied_worker_group[0]) + # for worker_idx in tied_worker_group: + # worker = self._workers[worker_idx] + # method = getattr(worker, method_name) + # futures.append(method.remote(tied_worker_data, **common_kwargs)) + # used_workers.append(worker_idx) + else: + # Regular case - each worker gets its own data slice + for worker_id, worker in enumerate(self.workers): + if worker_id >= len(data): + break + method = getattr(worker, method_name) + futures.append(method.remote(data[worker_id], **common_kwargs)) + used_workers.append(worker_id) + + # Return a MultiWorkerFuture containing both futures and worker information + return MultiWorkerFuture( + futures=futures, + used_workers=used_workers, + respect_tied_workers=respect_tied_workers, + ) + + def run_all_workers_single_data( + self, method_name: str, *args, respect_tied_workers: bool = True, **kwargs + ): + """Run a method on all workers in parallel with the same data. + + Args: + method_name: Name of the method to call on each worker + respect_tied_workers: If True, only the leader (first worker) of each tied worker group + receives the call. If False, all workers receive the call. + *args, **kwargs: Arguments to pass to the method + + Returns: + List[ray.ObjectRef]: A list of ray futures + """ + futures = [] + if respect_tied_workers: + for tied_worker_group in self.tied_workers_groups: + futures.append( + getattr(self._workers[tied_worker_group[0]], method_name).remote( + *args, **kwargs + ) + ) + else: + for worker in self.workers: + method = getattr(worker, method_name) + futures.append(method.remote(*args, **kwargs)) + + return futures + + def get_all_worker_results(self, future_bundle): + """Get results from all workers, optionally filtering to get just one result per tied worker group. + + Args: + future_bundle: MultiWorkerFuture containing futures and worker information. + When future_bundle.respect_tied_workers is True, only results from + the leaders of tied worker groups are returned. + + Returns: + List of results, deduplicated as specified in the future_bundle + """ + return future_bundle.get_results(self) + + def shutdown( + self, + cleanup_method: Optional[str] = None, + timeout: Optional[float] = 30.0, + force: bool = False, + ): + """Shutdown all workers in the worker group. + + Args: + cleanup_method: Optional method name to call on each worker before termination. + If provided, this method will be called on each worker to allow + for graceful cleanup. + timeout: Timeout in seconds for graceful shutdown. Only applicable if cleanup_method is provided. + If None, wait indefinitely for workers to complete their cleanup. + force: If True, forcefully terminate workers with ray.kill() even if cleanup_method is provided. + If cleanup_method is None, workers are always forcefully terminated. + + Returns: + bool: True if all workers were successfully shut down + """ + if not self._workers: + return True + + success = True + + # First attempt graceful shutdown if cleanup method is provided and force=False + if cleanup_method is not None and not force: + try: + # Call cleanup method on all workers + futures = self.run_all_workers_single_data(cleanup_method) + + # Wait for all cleanup operations to complete with timeout + if timeout is not None: + ray.get(futures, timeout=timeout) + else: + ray.get(futures) + + except (ray.exceptions.RayTaskError, ray.exceptions.GetTimeoutError) as e: + success = False + print( + f"Error during graceful shutdown: {e}. Falling back to force termination." + ) + force = True + + # Force kill any remaining workers + if force or cleanup_method is None: + for worker in self._workers: + try: + ray.kill(worker) + except Exception as e: + success = False + print(f"Error killing worker: {e}") + + # Clear worker lists + self._workers = [] + self._worker_metadata = [] + self.tied_workers_groups = [] + self.worker_to_tied_group_index = {} + + return success + + def print_worker_layout(self): + """Prints a visual representation of the worker layout across the virtual cluster. + + This shows which workers are assigned to which nodes and GPUs. + """ + self.cluster.print_cluster_grid(self) diff --git a/nemo_reinforcer/environments/__init__.py b/nemo_reinforcer/environments/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_reinforcer/environments/interfaces.py b/nemo_reinforcer/environments/interfaces.py new file mode 100644 index 0000000000..40986f4f19 --- /dev/null +++ b/nemo_reinforcer/environments/interfaces.py @@ -0,0 +1,64 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +from typing import Dict, List, Tuple + +from torch import Tensor + +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict + +EnvironmentReturn = Tuple[List[List[Dict[str, str]]], List[Dict], Tensor, Tensor] + + +class EnvironmentInterface(abc.ABC): + @abc.abstractmethod + def step( + self, + message_log_batch: List[List[Dict[str, str]]], + metadata: List[Dict], + *args, + **kwargs, + ) -> EnvironmentReturn: + """Runs a step in the environment. Allows for asynchrony with remote servers, but it's not required (this function is a ray remote). + + message_log_batch: batch of OpenAI-API-like message logs that represent interactions with the LLM. + For example, if this were a Math Environment, then the message log + would be + [ + {"role": "user", "content": "problem"}, + {"role": "assistant", "content": "response"}, + ] + but if this were a code environment + with feedback, it would be: + [ + {"role": "user", "content": "problem"}, + {"role": "assistant", "content": "response"}, + {"role": "user", "content": "code result"}, + {"role": "assistant", "content": "model response"}, + ] + metadata: batch of whatever the environment needs to keep track of. I.e. + math solutions, code unit tests, or agent states. + + Returns: + - List[Dict[str, str]]: An observation/response batch in an OpenAI-API-like message format that is the result of the step. + - List[Dict]: An updated batch of metadata. + - Tensor: A tensor of rewards. + - Tensor: A tensor of done flags. + """ + + @abc.abstractmethod + def global_post_process_and_metrics( + self, batch: BatchedDataDict + ) -> Tuple[BatchedDataDict, dict]: + """Post processing function after all rollouts are done for the batch and returns metrics.""" diff --git a/nemo_reinforcer/environments/math_environment.py b/nemo_reinforcer/environments/math_environment.py new file mode 100644 index 0000000000..1619ca576c --- /dev/null +++ b/nemo_reinforcer/environments/math_environment.py @@ -0,0 +1,180 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from itertools import tee +from typing import Dict, List, Tuple, TypedDict + +import ray +import torch +from math_verify import parse, verify + +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.environments.interfaces import EnvironmentInterface +from nemo_reinforcer.environments.metrics import ( + calculate_pass_rate_per_prompt, +) +from nemo_reinforcer.environments.utils import chunk_list_to_workers +from nemo_reinforcer.distributed.virtual_cluster import PY_EXECUTABLES + + +class MathEnvConfig(TypedDict): + num_workers: int + + +@ray.remote +class HFVerifyWorker: + # TODO: Slim down the dependencies to just math_verify + DEFAULT_PY_EXECUTABLE = PY_EXECUTABLES.DEFAULT_VENV + + def verify( + self, pred_responses: List[str], ground_truths: List[str] + ) -> List[float]: + """Verify the correctness of the predicted responses against the ground truth. + + Args: + pred_responses: List[str]. The predicted responses from the LLM. + ground_truths: List[str]. The ground truth responses. + + Returns: + List[float]. The rewards for each predicted response. + """ + results = [] + for response, ground_truth in zip(pred_responses, ground_truths): + try: + gold = parse(ground_truth) + pred = parse(response[-100:]) # avoid looking at the whole string + results.append(float(verify(gold, pred))) + except Exception: + results.append(0) + return results + + +class MathEnvironmentMetadata(TypedDict): + ground_truth: str + + +@ray.remote +class MathEnvironment(EnvironmentInterface): + # TODO: Slim down the dependencies to just math_verify + DEFAULT_PY_EXECUTABLE = PY_EXECUTABLES.DEFAULT_VENV + + def __init__(self, cfg: Dict): + self.num_workers = cfg["num_workers"] + self.workers = [ + HFVerifyWorker.options( + runtime_env={"py_executable": HFVerifyWorker.DEFAULT_PY_EXECUTABLE} + ).remote() + for _ in range(self.num_workers) + ] + + def shutdown(self): + # shutdown all workers + for worker in self.workers: + ray.kill(worker) + + def step( + self, + message_log_batch: List[List[Dict[str, str]]], + metadata: List[MathEnvironmentMetadata], + ): + """Runs a step in the math environment. + + Args: + message_log: List[List[Dict[str, str]]]. A batch of OpenAI-API-like message logs that represent interactions with the LLM. + metadata: List[MathEnvironmentMetadata]. The grader will use the 'ground_truth' key to evaluate correctness. + + Returns: + EnvironmentReturn: A tuple containing: + - List[Dict[str, str]]: Observations/responses batch + - List[Dict]: Updated metadata + - Tensor: Rewards tensor + - Tensor: Done flags tensor + """ + # Extract the assistant's responses from the message history + # Each message list should have at least one assistant response + assistant_response_batch = [] + for conversation in message_log_batch: + assistant_responses = [ + interaction["content"] + for interaction in conversation + if interaction["role"] == "assistant" + ] + assistant_response_batch.append("".join(assistant_responses)) + + ground_truths = [g["ground_truth"] for g in metadata] + + chunked_assistant_response_batch = chunk_list_to_workers( + assistant_response_batch, self.num_workers + ) + chunked_ground_truths = chunk_list_to_workers(ground_truths, self.num_workers) + + # # Process each chunk in parallel + futures = [ + self.workers[i].verify.remote(chunk, ground_truth_chunk) + for i, (chunk, ground_truth_chunk) in enumerate( + zip(chunked_assistant_response_batch, chunked_ground_truths) + ) + ] + + results = ray.get(futures) + + # flatten the results + results = [item for sublist in results for item in sublist] + observations = [ + {"role": "user", "content": "correct" if result else "incorrect"} + for result in results + ] + + # create a tensor of rewards and done flags + rewards = torch.tensor(results).cpu() + done = torch.ones_like(rewards).cpu() + + return observations, metadata, rewards, done + + def global_post_process_and_metrics( + self, batch: BatchedDataDict + ) -> Tuple[BatchedDataDict, dict]: + """Computes metrics for this environment given a global rollout batch. + + Every rank will run this function, so you're free to use distributed + calculations if you'd prefer for heavy metrics. + """ + batch["rewards"] = ( + batch["rewards"] * batch["is_end"] + ) # set a reward of 0 for any incorrectly ended sequences + if (batch["rewards"] == 1).float().sum() > 0: + correct_solution_generation_lengths = ( + (batch["generation_lengths"] - batch["prompt_lengths"])[ + batch["rewards"] == 1 + ] + .float() + .mean() + .item() + ) + else: + correct_solution_generation_lengths = 0 + + metrics = { + # "table": table, TODO @sahilj WIP + "accuracy": batch["rewards"].mean().item(), + "pass@samples_per_prompt": calculate_pass_rate_per_prompt( + batch["text"], batch["rewards"] + ), + "fraction_of_samples_properly_ended": batch["is_end"].float().mean().item(), + "num_problems_in_batch": batch["is_end"].shape[0], + "generation_lengths": batch["generation_lengths"].float().mean().item(), + "prompt_lengths": batch["prompt_lengths"].float().mean().item(), + "correct_solution_generation_lengths": correct_solution_generation_lengths, + } + + return batch, metrics diff --git a/nemo_reinforcer/environments/metrics.py b/nemo_reinforcer/environments/metrics.py new file mode 100644 index 0000000000..eee81d7993 --- /dev/null +++ b/nemo_reinforcer/environments/metrics.py @@ -0,0 +1,34 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +def calculate_pass_rate_per_prompt(prompts, is_correct): + """Function to compute fraction of prompts that have at least one correct answer (reward > 0). + + prompts: tensor (b, s) Tensor of prompts the model used. May be on any device + is_correct: tensor (b,) bool-valued label. May be on any device + + Returns: + pass rate: float + """ + unique_prompts = torch.unique(prompts, dim=0) + + correct_prompt_ct = 0 + for i in range(len(unique_prompts)): + is_matching_prompt = (prompts == unique_prompts[i]).all(1) + if torch.any(is_correct[is_matching_prompt] > 0): + correct_prompt_ct += 1 + + return correct_prompt_ct / len(unique_prompts) diff --git a/nemo_reinforcer/environments/utils.py b/nemo_reinforcer/environments/utils.py new file mode 100644 index 0000000000..75c912cf26 --- /dev/null +++ b/nemo_reinforcer/environments/utils.py @@ -0,0 +1,61 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Any + + +def chunk_list_to_workers(to_chunk: List[Any], num_workers: int) -> List[List[Any]]: + """Chunk a list into a list of lists, where each sublist is assigned to a worker. Keeps ordering of elements. + + If the list is not divisible by the number of workers, the last worker may have fewer elements. + If there are more workers than elements, the first len(list) workers will have a single element each, + and the remaining workers will have empty lists. + + Args: + list: The list to be chunked. + num_workers: The number of workers to distribute the list to. + + Returns: + A list of lists, where each sublist contains elements assigned to a worker. + + Examples: + ```{doctest} + >>> from nemo_reinforcer.environments.utils import chunk_list_to_workers + >>> chunk_list_to_workers([1, 2, 3, 4, 5], 3) + [[1, 2], [3, 4], [5]] + ``` + """ + if not to_chunk: + return [[] for _ in range(num_workers)] + + # Handle case where we have more workers than elements + if len(to_chunk) <= num_workers: + result = [[item] for item in to_chunk] + # Add empty lists for remaining workers + result.extend([[] for _ in range(num_workers - len(to_chunk))]) + return result + + # Calculate chunk size (ceiling division to ensure all elements are covered) + chunk_size = (len(to_chunk) + num_workers - 1) // num_workers + + # Create chunks + chunks = [] + for i in range(0, len(to_chunk), chunk_size): + chunks.append(to_chunk[i : i + chunk_size]) + + # If we somehow ended up with more chunks than workers (shouldn't happen with ceiling division) + # merge the last chunks + if len(chunks) > num_workers: + chunks[num_workers - 1 :] = [sum(chunks[num_workers - 1 :], [])] + + return chunks diff --git a/nemo_reinforcer/evals/__init__.py b/nemo_reinforcer/evals/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_reinforcer/evals/run_env_eval.py b/nemo_reinforcer/evals/run_env_eval.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nemo_reinforcer/evals/run_env_eval.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo_reinforcer/experience/__init__.py b/nemo_reinforcer/experience/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_reinforcer/metrics/__init__.py b/nemo_reinforcer/metrics/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_reinforcer/metrics/metrics_utils.py b/nemo_reinforcer/metrics/metrics_utils.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nemo_reinforcer/metrics/metrics_utils.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo_reinforcer/models/__init__.py b/nemo_reinforcer/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_reinforcer/models/generation/__init__.py b/nemo_reinforcer/models/generation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_reinforcer/models/generation/interfaces.py b/nemo_reinforcer/models/generation/interfaces.py new file mode 100644 index 0000000000..8ffa1d2945 --- /dev/null +++ b/nemo_reinforcer/models/generation/interfaces.py @@ -0,0 +1,200 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from typing import Any, TypedDict, Union, Tuple + +import torch +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict + + +def verify_right_padding( + data: Union[ + BatchedDataDict["GenerationDatumSpec"], BatchedDataDict["GenerationOutputSpec"] + ], + pad_value: int = 0, + raise_error: bool = True, +) -> Tuple[bool, Union[str, None]]: + """Verify that a tensor is right-padded according to the provided lengths. + + Arguments: + data: The BatchedDataDict to check, containing either: + - For GenerationDatumSpec: input_ids and input_lengths + - For GenerationOutputSpec: output_ids and unpadded_sequence_lengths + pad_value: The expected padding value (default: 0) + raise_error: Whether to raise an error if wrong padding is detected + + Returns: + Tuple of (is_right_padded, error_message) + - is_right_padded: True if right padding confirmed, False otherwise + - error_message: None if properly padded, otherwise a description of the issue + """ + # Extract tensors from the BatchedDataDict + assert isinstance(data, BatchedDataDict), ( + f"data must be a BatchedDataDict, got type: {type(data)}" + ) + + # Determine which type of data we're dealing with + if "input_ids" in data and "input_lengths" in data: + # GenerationDatumSpec + tensor = data["input_ids"] + lengths = data["input_lengths"] + elif "output_ids" in data and "unpadded_sequence_lengths" in data: + # GenerationOutputSpec + tensor = data["output_ids"] + lengths = data["unpadded_sequence_lengths"] + else: + msg = f"Could not find the required pairs of fields. Expected either (input_ids, input_lengths) or (output_ids, unpadded_sequence_lengths). Got keys: {data.keys()}" + if raise_error: + raise ValueError(msg) + return False, msg + + if tensor.ndim != 2: + msg = f"Expected 2D tensor for padding check, got shape {tensor.shape}" + if raise_error: + raise ValueError(msg) + return False, msg + + batch_size, seq_len = tensor.shape + if lengths.shape[0] != batch_size: + msg = f"Mismatch between tensor batch size ({batch_size}) and lengths tensor size ({lengths.shape[0]})" + if raise_error: + raise ValueError(msg) + return False, msg + + # Check each sequence to verify zero padding on the right + for i in range(batch_size): + length = lengths[i].item() + if length > seq_len: + msg = f"Length {length} at index {i} exceeds tensor sequence dimension {seq_len}" + if raise_error: + raise ValueError(msg) + return False, msg + + # Check that all positions after length are pad_value + if length < seq_len and not torch.all(tensor[i, length:] == pad_value): + non_pad_indices = torch.where(tensor[i, length:] != pad_value)[0] + length + msg = f"Non-padding values found after specified length at index {i}: positions {non_pad_indices.tolist()}" + if raise_error: + raise ValueError(msg) + return False, msg + + return True, None + + +class GenerationConfig(TypedDict): + """Configuration for generation.""" + + backend: str + max_new_tokens: int + temperature: float + top_p: float + top_k: int + model_name: str + + +class GenerationDatumSpec(TypedDict): + """Specification for input data required by generation models. + + - input_ids: Tensor of token IDs representing the input sequences (right padded) + - input_lengths: Tensor containing the actual length of each sequence (without padding) + - __extra__: Additional model-specific data fields + + Example of a batch with 4 entries with different sequence lengths: + ``` + # Batch of 4 sequences with lengths [3, 5, 2, 4] + + input_ids (padded): + [ + [101, 2054, 2003, 0, 0], # Length 3 + [101, 2054, 2003, 2001, 1996], # Length 5 + [101, 2054, 0, 0, 0], # Length 2 + [101, 2054, 2003, 2001, 0], # Length 4 + ] + + input_lengths: + [3, 5, 2, 4] + ``` + + All functions receiving or returning GenerationDatumSpec should ensure + right padding is maintained. Use verify_right_padding() to check. + """ + + input_ids: torch.Tensor + input_lengths: torch.Tensor + __extra__: Any + + +class GenerationOutputSpec(TypedDict): + """Specification for output data returned by generation models. + + - output_ids: Tensor of token IDs representing the generated sequences (right padded) + - generation_lengths: Tensor containing the actual length of each generated sequence + - unpadded_sequence_lengths: Tensor containing the actual length of each input + generated sequence (without padding) + - logprobs: Tensor of log probabilities for each generated token (right padded with zeros) + - __extra__: Additional model-specific data fields + + Example of a batch with 2 sequences: + ``` + # Sample batch with 2 examples + # - Example 1: Input length 3, generated response length 4 + # - Example 2: Input length 5, generated response length 2 + + output_ids (right-padded): + [ + [101, 2054, 2003, 2023, 2003, 1037, 2200, 0], # 7 valid tokens (3 input + 4 output) + [101, 2054, 2003, 2001, 1996, 3014, 2005, 0], # 7 valid tokens (5 input + 2 output) + ] + + generation_lengths: + [4, 2] # Length of just the generated response part + + unpadded_sequence_lengths: + [7, 7] # Length of full valid sequence (input + generated response) + + logprobs (right-padded with zeros): + [ + [0.0, 0.0, 0.0, -1.2, -0.8, -2.1, -1.5, 0.0], # First 3 are 0 (input tokens), next 4 are actual logprobs + [0.0, 0.0, 0.0, 0.0, 0.0, -0.9, -1.7, 0.0], # First 5 are 0 (input tokens), next 2 are actual logprobs + ] + ``` + + All functions receiving or returning GenerationOutputSpec should ensure + right padding is maintained. Use verify_right_padding() to check. + """ + + output_ids: torch.Tensor + generation_lengths: torch.Tensor # Length of just the generated response part + unpadded_sequence_lengths: ( + torch.Tensor + ) # Length of full valid sequence (input + generated response) + logprobs: torch.Tensor + __extra__: Any + + +class GenerationInterface(ABC): + """Abstract base class defining the interface for RL policies.""" + + @abstractmethod + def generate( + self, data: BatchedDataDict["GenerationDatumSpec"], greedy: bool + ) -> BatchedDataDict["GenerationOutputSpec"]: + pass + + @abstractmethod + def prepare_for_generation(self, *args, **kwargs): + pass + + @abstractmethod + def finish_generation(self, *args, **kwargs): + pass diff --git a/nemo_reinforcer/models/generation/vllm.py b/nemo_reinforcer/models/generation/vllm.py new file mode 100644 index 0000000000..eab407c0a3 --- /dev/null +++ b/nemo_reinforcer/models/generation/vllm.py @@ -0,0 +1,596 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union, List, TypedDict +import gc +import warnings + +import ray +import torch +from transformers import AutoTokenizer + +from nemo_reinforcer.models.generation.interfaces import ( + GenerationInterface, + GenerationDatumSpec, + GenerationOutputSpec, + verify_right_padding, + GenerationConfig, +) +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.distributed.virtual_cluster import ( + RayVirtualCluster, + PY_EXECUTABLES, +) +from nemo_reinforcer.distributed.worker_groups import RayWorkerGroup, RayWorkerBuilder + + +class VllmSpecificArgs(TypedDict): + tensor_parallel_size: int + gpu_memory_utilization: float + max_model_len: int + + +class VllmConfig(GenerationConfig): + vllm_cfg: VllmSpecificArgs + + +@ray.remote +class VllmGenerationWorker: + # This is the default py_executable for vLLM workers + DEFAULT_PY_EXECUTABLE = PY_EXECUTABLES.DEFAULT_VENV + + def __repr__(self): + """Customizes the actor's prefix in the Ray logs. + + This makes it easier to identify which worker is producing specific log messages. + """ + return f"{self.__class__.__name__}" + + @staticmethod + def configure_worker( + num_gpus: int | float, bundle_indices: Optional[list] = None + ) -> tuple[dict, dict, dict]: + """Provides complete worker configuration for vLLM tensor parallelism. + + This method configures the worker based on its role in tensor parallelism, + which is determined directly from the bundle_indices parameter. + + Args: + num_gpus: Original GPU allocation for this worker based on the placement group + bundle_indices: Bundle indices for tensor parallelism (if applicable) + + Returns: + tuple with complete worker configuration: + - 'resources': Resource allocation (e.g., num_gpus) + - 'env_vars': Environment variables for this worker + - 'init_kwargs': Parameters to pass to __init__ of the worker + """ + # Initialize configuration + resources = {"num_gpus": num_gpus} + init_kwargs = {} + env_vars = {} + + init_kwargs["bundle_indices"] = bundle_indices + + is_part_of_tp_workers = ( + bundle_indices is not None and len(bundle_indices) > 1 + ) or bundle_indices is None + if is_part_of_tp_workers: + # Ray + vllm likes to manage GPU assignment internally + resources["num_gpus"] = 0 + env_vars["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = "1" + init_kwargs["fraction_of_gpus"] = num_gpus + + # Force vllm to use v0 runtime (will be enabled by default in #51) + env_vars["VLLM_USE_V1"] = "0" + return resources, env_vars, init_kwargs + + def __init__( + self, + config: VllmConfig, + bundle_indices: Optional[list] = None, + fraction_of_gpus: float = 1.0, + ): + """Initialize a vLLM worker for distributed inference. + + Args: + config: Configuration dictionary for the policy + bundle_indices: List of local bundle indices within a node for tensor parallelism. + Only needed for the first worker in each tied worker group. + """ + self.cfg = config + self.model_name = self.cfg["model_name"] + self.tensor_parallel_size = self.cfg["vllm_cfg"]["tensor_parallel_size"] + self.gpu_memory_utilization = self.cfg["vllm_cfg"]["gpu_memory_utilization"] + self.fraction_of_gpus = fraction_of_gpus + self.is_model_owner = bundle_indices is not None + + # Skip model loading if we're not the model owner + if not self.is_model_owner: + self.llm = None + self.tokenizer = None + self.rank = 0 + self.world_size = 1 + return + + # In Ray+vLLM setup, each worker process considers itself rank 0 + # vLLM handles the tensor parallelism internally through Ray + self.rank = 0 + self.world_size = 1 + + try: + from vllm import LLM, SamplingParams + from nemo_reinforcer.models.generation.vllm_backend import ( + UpdatableVllmInternalWorker, + ) + + self.SamplingParams = SamplingParams + except ImportError: + raise ImportError( + "vLLM is not installed. Please install it with `pip install nemo-reinforcer[vllm]` " + "or `pip install vllm --no-build-isolation` separately." + ) + vllm_kwargs = self.cfg.get("vllm_kwargs", {}).copy() + + # Special handling for tensor parallel case + if self.tensor_parallel_size > 1: + # Configure vLLM for tensor parallelism within Ray + import os + + # Reset CUDA_VISIBLE_DEVICES to allow vLLM to manage GPU assignment + os.environ.pop("CUDA_VISIBLE_DEVICES", None) + + os.environ["VLLM_RAY_PER_WORKER_GPUS"] = str( + self.fraction_of_gpus / self.tensor_parallel_size + ) + + # Set bundle indices for tensor parallelism workers + os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices)) + + # Use Ray for distributed execution in TP mode + vllm_kwargs["distributed_executor_backend"] = "ray" + else: + # For non-TP mode, explicitly set executor to None to avoid Ray issues + vllm_kwargs["distributed_executor_backend"] = None + + self.llm = LLM( + model=self.model_name, + tensor_parallel_size=self.tensor_parallel_size, + gpu_memory_utilization=self.gpu_memory_utilization, + enable_prefix_caching=True, + dtype="auto", + enforce_eager=True, + max_model_len=self.cfg["vllm_cfg"]["max_model_len"], + trust_remote_code=True, + worker_cls=UpdatableVllmInternalWorker, + enable_sleep_mode=True, + **vllm_kwargs, + ) + + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + def llm(self): + return self.llm + + def is_alive(self): + """Check if the worker is alive.""" + return True + + def generate( + self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False + ) -> BatchedDataDict[GenerationOutputSpec]: + """Generate a batch of data using vLLM generation. + + Args: + data: BatchedDataDict containing input_ids and input_lengths tensors + + Returns: + BatchedDataDict conforming to GenerationOutputSpec: + - output_ids: input + generated token IDs with proper padding + - logprobs: Log probabilities for tokens + - generation_lengths: Lengths of each response + - unpadded_sequence_lengths: Lengths of each input + generated sequence + """ + # Verify input is right padded + assert isinstance(data, BatchedDataDict), ( + f"data must be a BatchedDataDict, got type: {type(data)}" + ) + assert "input_ids" in data and "input_lengths" in data, ( + f"input_ids and input_lengths must be present in the BatchedDataDict, got keys: {data.keys()}" + ) + is_right_padded, error_msg = verify_right_padding( + data, pad_value=self.tokenizer.pad_token_id + ) + if not is_right_padded: + warnings.warn( + f"Input to vLLM worker is not properly right-padded: {error_msg}" + ) + + # Convert inputs to vLLM format + input_ids = data["input_ids"] + input_lengths = data["input_lengths"] + batch_size = input_ids.shape[0] + # Original input length with padding + padded_input_length = input_ids.size(1) + + # Prepare prompts for vLLM (removing padding) + prompts = [] + + for i in range(batch_size): + # Use input_lengths to get only valid tokens (not padding) + valid_length = input_lengths[i].item() + valid_ids = ( + input_ids[i, :valid_length] if valid_length > 0 else input_ids[i, :0] + ) + token_ids = valid_ids.tolist() + + prompts.append({"prompt_token_ids": token_ids}) + + # Read generation parameters from config + top_k = self.cfg["top_k"] if self.cfg["top_k"] is not None else -1 + sampling_params = self.SamplingParams( + temperature=self.cfg["temperature"], + top_p=self.cfg["top_p"], + top_k=top_k + if not greedy + else 1, # we use a default of -1 if unset so that 'null'/None is a common disable value + max_tokens=self.cfg["max_new_tokens"], + logprobs=0, # Return logprobs for the generated tokens + stop=None, + ) + + # Generate outputs + outputs = self.llm.generate(prompts, sampling_params) + + # Process the outputs - but preserve the original input padding structure + output_ids_list = [] + logprobs_list = [] + generation_lengths = [] + unpadded_sequence_lengths = [] + max_length = 0 + for output in outputs: + max_length = max(max_length, len(output.outputs[0].token_ids)) + + for i, output in enumerate(outputs): + # Extract generated tokens + sequence_length = input_lengths[i] + generation = output.outputs[0] + generated_tokens = list(generation.token_ids) + + # Calculate total sequence length (original input length + generated tokens) + total_length = padded_input_length + max_length + + # Create a new tensor with the right size and fill with padding token + full_output = torch.full( + (total_length,), self.tokenizer.pad_token_id, dtype=input_ids.dtype + ) + + # Copy original input (with padding) into the beginning + full_output[:sequence_length] = input_ids[i][:sequence_length] + + # Add generated tokens after the original input + full_output[sequence_length : sequence_length + len(generated_tokens)] = ( + torch.tensor(generated_tokens) + ) + + output_ids_list.append(full_output) + full_logprobs = torch.zeros(total_length, dtype=torch.float32) + if hasattr(generation, "logprobs") and generation.logprobs: + try: + for idx, logprob_dict in enumerate(generation.logprobs): + if logprob_dict: + position = sequence_length + idx + full_logprobs[position] = next(iter(logprob_dict.items()))[ + 1 + ].logprob + except Exception: + import traceback + + traceback.print_exc() + + logprobs_list.append(full_logprobs) + + response_length = sequence_length + len(generated_tokens) + generation_lengths.append(len(generated_tokens)) + unpadded_sequence_lengths.append(response_length) + # Create return data conforming to GenerationOutputSpec + output_ids = torch.stack(output_ids_list) + logprobs = torch.stack(logprobs_list) + + return_data = BatchedDataDict[GenerationOutputSpec]( + { + "output_ids": output_ids, + "logprobs": logprobs, + "generation_lengths": torch.tensor( + generation_lengths, dtype=torch.long + ), + "unpadded_sequence_lengths": torch.tensor( + unpadded_sequence_lengths, dtype=torch.long + ), + } + ) + + return return_data + + def shutdown(self): + """Clean up vLLM resources.""" + try: + # Clear caches and free memory + self.llm = None + self.tokenizer = None + + # Force garbage collection + gc.collect() + torch.cuda.empty_cache() + + return True + except Exception as e: + print(f"Error during vLLM shutdown: {e}") + return False + + def report_device_id(self) -> str: + # from vllm.platforms import current_platform + # self.device_uuid = current_platform.get_device_uuid(self.rank) + # return self.device_uuid + return self.llm.collective_rpc("report_device_id", args=tuple())[0] + + def update_weights_from_ipc_handles(self, ipc_handles): + """Update weights from IPC handles by delegating to the vLLM Worker implementation. + + Args: + ipc_handles (dict): Dictionary mapping device UUIDs to parameter IPC handles. + + Returns: + bool: True if weights were successfully updated, False otherwise. + """ + try: + # Use collective_rpc to delegate to the UpdatableVllmInternalWorker implementation + self.llm.collective_rpc( + "update_weights_from_ipc_handles", args=(ipc_handles,) + ) + return True + except Exception as e: + print(f"Error updating weights: {e}") + return False + + def check_weights_changed(self): + """Check if the weights are updated to 0 by delegating to the vLLM Worker implementation. + + Returns: + bool: True if all weights have been zeroed, False otherwise. + """ + try: + result = self.llm.collective_rpc("check_weights_changed", args=tuple()) + # The collective_rpc returns a list of results, one per worker + # Extract the boolean value from the first worker's result + return result[0] if isinstance(result, list) and len(result) > 0 else False + except Exception as e: + print(f"Error checking weights: {e}") + return False + + def sleep(self): + self.llm.sleep(level=1) + gc.collect() + torch.cuda.empty_cache() + + def wake_up(self): + self.llm.wake_up() + + +class VllmGeneration(GenerationInterface): + def __init__( + self, + cluster: RayVirtualCluster, + config: VllmConfig, + name_prefix: str = "vllm_policy", + workers_per_node: Optional[Union[int, List[int]]] = None, + ): + """Initialize a vLLM policy with distributed workers.""" + # Store config + self.cfg = config + self.tensor_parallel_size = self.cfg["vllm_cfg"]["tensor_parallel_size"] + + # Create worker builder for VllmGenerationWorker + worker_builder = RayWorkerBuilder(VllmGenerationWorker, config) + + if self.tensor_parallel_size > 1: + # For tensor parallelism, create node-aware worker groups + node_bundle_indices = self._get_tied_worker_bundle_indices(cluster) + + self.worker_group = RayWorkerGroup( + cluster, + worker_builder, + name_prefix=name_prefix, + bundle_indices_list=node_bundle_indices, + ) + else: + # Use standard worker group creation for non-TP case + self.worker_group = RayWorkerGroup( + cluster, + worker_builder, + name_prefix=name_prefix, + workers_per_node=workers_per_node, + ) + + # Number of data parallel groups is the number of tied worker groups + self.dp_size = self.worker_group.group_count + + def _get_tied_worker_bundle_indices(self, cluster): + """Calculate bundle indices for tensor parallel workers.""" + # Get the placement groups (nodes) from the cluster + placement_groups = cluster.get_placement_groups() + + tied_worker_groups = [] + + # For each node (placement group), create tied worker groups of size tensor_parallel_size + for node_idx, pg in enumerate(placement_groups): + # How many bundles (GPUs) are on this node + bundles_on_node = pg.bundle_count + tied_worker_groups_on_node = bundles_on_node // self.tensor_parallel_size + + if tied_worker_groups_on_node > 0: + for group_idx in range(tied_worker_groups_on_node): + # Local bundle indices for this tied worker group (consecutive GPUs on this node) + start_idx = group_idx * self.tensor_parallel_size + end_idx = start_idx + self.tensor_parallel_size + local_bundle_indices = list(range(start_idx, end_idx)) + tied_worker_groups.append((node_idx, local_bundle_indices)) + + if not tied_worker_groups: + raise ValueError( + f"Cannot create any tensor parallel tied worker groups with size {self.tensor_parallel_size}. " + f"Make sure each node has at least {self.tensor_parallel_size} GPUs." + ) + + return tied_worker_groups + + def _check_all_weights_changed(self): + """Check if weights have been updated across all workers or leaders. + + Returns: + bool: True if all checked weights have been updated, False otherwise. + """ + if not self.worker_group or not self.worker_group.workers: + return False + + try: + # Use run_all_workers_single_data for methods that don't need data + futures = self.worker_group.run_all_workers_single_data( + "check_weights_changed", respect_tied_workers=True + ) + # Wait for all futures to complete + results = ray.get(futures) + return all(result for result in results if result is not None) + except Exception as e: + print(f"Error checking weights: {e}") + return False + + def generate( + self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False + ) -> BatchedDataDict[GenerationOutputSpec]: + """Generate a batch of data using vLLM.""" + assert isinstance(data, BatchedDataDict), ( + f"data must be a BatchedDataDict, got type: {type(data)}" + ) + assert "input_ids" in data and "input_lengths" in data, ( + "input_ids and input_lengths are required in data for vLLM generation" + ) + + batch_size = data["input_ids"].shape[0] + + # Shard the data across the tied worker groups + sharded_data = data.shard_by_batch_size(self.dp_size, batch_size=batch_size) + future_bundle = self.worker_group.run_all_workers_multiple_data( + "generate", + sharded_data, + common_kwargs={"greedy": greedy}, + respect_tied_workers=True, + ) + + # Get results from the workers, respecting tied worker groups (only one result per tied worker group) + results = self.worker_group.get_all_worker_results(future_bundle) + + # Combine results from all tied worker groups + combined = BatchedDataDict.from_batches(results) + + # Verify the output has all required fields + required_keys = [ + "output_ids", + "generation_lengths", + "unpadded_sequence_lengths", + "logprobs", + ] + missing_keys = [key for key in required_keys if key not in combined] + if missing_keys: + raise ValueError( + f"Missing required keys for GenerationOutputSpec: {missing_keys}" + ) + + return combined + + def prepare_for_generation(self, *args, **kwargs): + """Abstract method that must be implemented by subclasses.""" + try: + # Use run_all_workers_single_data for methods that don't need data + futures = self.worker_group.run_all_workers_single_data( + "wake_up", respect_tied_workers=True + ) + # Wait for all futures to complete + results = ray.get(futures) + return all(result for result in results if result is not None) + except Exception as e: + print(f"Error during policy preparation: {e}") + return False + + def finish_generation(self, *args, **kwargs): + """Abstract method that must be implemented by subclasses.""" + try: + # Use run_all_workers_single_data for methods that don't need data + futures = self.worker_group.run_all_workers_single_data( + "sleep", respect_tied_workers=True + ) + # Wait for all futures to complete + results = ray.get(futures) + return all(result for result in results if result is not None) + except Exception as e: + print(f"Error during policy preparation: {e}") + return False + + def shutdown(self) -> bool: + """Shut down all vLLM workers and clean up resources.""" + try: + # Use the worker group's shutdown method with the worker's cleanup method + return self.worker_group.shutdown(cleanup_method="shutdown") + except Exception as e: + print(f"Error during policy shutdown: {e}") + return False + + def update_weights(self, ipc_handles): + """Update weights of the policy using IPC handles, considering tensor parallelism. + + For tp > 1, only the leader in each tensor parallel tied worker group will update weights. + + Args: + ipc_handles (dict): Dictionary mapping device UUIDs to parameter IPC handles. + + Returns: + bool: True if weights were successfully updated, False otherwise. + """ + if not self.worker_group or not self.worker_group.workers: + return False + + try: + # Directly pass ipc_handles to the method + futures = self.worker_group.run_all_workers_single_data( + "update_weights_from_ipc_handles", + respect_tied_workers=True, + ipc_handles=ipc_handles, + ) + # Wait for all futures to complete + results = ray.get(futures) + return all(result for result in results if result is not None) + except Exception as e: + print(f"Error updating weights: {e}") + return False + + def __del__(self): + """Shuts down the worker groups when the object is deleted or is garbage collected. + + This is an extra safety net in case the user forgets to call shutdown() and the pointer to + the object is lost due to leaving a function scope. It's always recommended that the + user calls shutdown(). + """ + self.shutdown() diff --git a/nemo_reinforcer/models/generation/vllm_backend.py b/nemo_reinforcer/models/generation/vllm_backend.py new file mode 100644 index 0000000000..a8ded9b519 --- /dev/null +++ b/nemo_reinforcer/models/generation/vllm_backend.py @@ -0,0 +1,83 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +try: + from vllm.worker.worker import Worker +except ImportError: + raise ImportError( + "vLLM is not installed. Please install it with `pip install nemo-reinforcer[vllm]` " + "or `pip install vllm` separately. This issue may also occur if worker is using incorrect " + "py_executable." + ) + + +class UpdatableVllmInternalWorker(Worker): + def report_device_id(self) -> str: + from vllm.platforms import current_platform + + self.device_uuid = current_platform.get_device_uuid(self.device.index) + return self.device_uuid + + def update_weights_from_ipc_handles(self, ipc_handles): + """Update weights from IPC handles. + + Args: + ipc_handles (dict): Dictionary mapping device UUIDs to parameter IPC handles. + + Returns: + bool: True if weights were successfully updated. + """ + try: + # Get handles for this device + device_uuid = self.report_device_id() + handles = ipc_handles[device_uuid] + device_id = self.device.index + weights = [] + + # Process each handle to get the tensor + for name, handle in handles.items(): + func, args = handle + list_args = list(args) + # Update device ID to match the current device + list_args[6] = device_id + tensor = func(*list_args) + weights.append((name, tensor)) + + # Load weights into the model + self.model_runner.model.load_weights(weights=weights) + torch.cuda.synchronize() + return True + except Exception as e: + print( + f"Error in UpdatableVllmInternalWorker.update_weights_from_ipc_handles: {e}" + ) + return False + + def check_weights_changed(self): + """Check if the weights are updated to 0. + + Returns: + bool: True if all weights have been zeroed, False otherwise. + """ + try: + weights_updated = True + for name, p in self.model_runner.model.named_parameters(): + weights_updated = weights_updated and torch.allclose( + p, torch.zeros_like(p) + ) + return weights_updated + except Exception as e: + print(f"Error in UpdatableVllmInternalWorker.check_weights_changed: {e}") + return False diff --git a/nemo_reinforcer/models/huggingface/__init__.py b/nemo_reinforcer/models/huggingface/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_reinforcer/models/huggingface/common.py b/nemo_reinforcer/models/huggingface/common.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nemo_reinforcer/models/huggingface/common.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo_reinforcer/models/interfaces.py b/nemo_reinforcer/models/interfaces.py new file mode 100644 index 0000000000..cb87d805ee --- /dev/null +++ b/nemo_reinforcer/models/interfaces.py @@ -0,0 +1,70 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from typing import Any, Dict + +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.algorithms.interfaces import LossFunction +from nemo_reinforcer.models.generation.interfaces import GenerationDatumSpec + + +class PolicyInterface(ABC): + """Abstract base class defining the interface for RL policies.""" + + @abstractmethod + def get_logprobs( + self, data: BatchedDataDict[GenerationDatumSpec] + ) -> BatchedDataDict: + """Get logprobs of actions from observations. + + Args: + data: BatchedDataDict containing rollouts (tokens) + + Returns: + BatchedDataDict containing: + - logprobs: Tensor of logprobs of actions + """ + pass + + @abstractmethod + def get_reference_policy_logprobs( + self, data: BatchedDataDict[GenerationDatumSpec] + ) -> BatchedDataDict: + """Get logprobs of actions from observations. + + Args: + data: BatchedDataDict containing rollouts (tokens) + + Returns: + BatchedDataDict containing: + - logprobs: Tensor of logprobs of actions + """ + pass + + @abstractmethod + def train(self, data: BatchedDataDict, loss_fn: LossFunction) -> Dict[str, Any]: + """Train the policy on a global batch of data. + + Args: + data: BatchedDataDict containing rollouts (tokens) + """ + pass + + @abstractmethod + def prepare_for_training(self, *args, **kwargs): + pass + + @abstractmethod + def finish_training(self, *args, **kwargs): + pass diff --git a/nemo_reinforcer/models/megatron/__init__.py b/nemo_reinforcer/models/megatron/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_reinforcer/models/megatron/common.py b/nemo_reinforcer/models/megatron/common.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/nemo_reinforcer/models/megatron/common.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo_reinforcer/models/policy/__init__.py b/nemo_reinforcer/models/policy/__init__.py new file mode 100644 index 0000000000..d28d7a91c0 --- /dev/null +++ b/nemo_reinforcer/models/policy/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TypedDict + +from nemo_reinforcer.models.generation.interfaces import GenerationConfig + + +class PolicyConfig(TypedDict): + model_name: str + train_global_batch_size: int + train_micro_batch_size: int + learning_rate: float + logprob_batch_size: int + generation: GenerationConfig diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py new file mode 100644 index 0000000000..627c221cee --- /dev/null +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -0,0 +1,1016 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import warnings +import os +from collections import defaultdict +from contextlib import contextmanager +from typing import Any, Dict, List, Optional, Union + +import ray +import torch +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import ( + FullyShardedDataParallel, + FullStateDictConfig, + MixedPrecision, + StateDictType, +) +from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy +from transformers import AutoModelForCausalLM, AutoTokenizer + +from nemo_reinforcer.algorithms.interfaces import LossFunction +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster +from nemo_reinforcer.distributed.worker_groups import RayWorkerBuilder, RayWorkerGroup +from nemo_reinforcer.models.generation.interfaces import ( + GenerationInterface, + GenerationDatumSpec, + GenerationOutputSpec, + verify_right_padding, +) +from nemo_reinforcer.models.interfaces import PolicyInterface +from nemo_reinforcer.models.policy import PolicyConfig +from nemo_reinforcer.models.policy.utils import import_class_from_path +from nemo_reinforcer.distributed.virtual_cluster import ( + PY_EXECUTABLES, +) + + +@ray.remote +class HfPolicyWorker: + DEFAULT_PY_EXECUTABLE = PY_EXECUTABLES.DEFAULT_VENV + + def __repr__(self): + """Customizes the actor's prefix in the Ray logs. + + This makes it easier to identify which worker is producing specific log messages. + """ + return f"{self.__class__.__name__}[rank={torch.distributed.get_rank()}]" + + def __init__( + self, + config: PolicyConfig, + weights_path: Optional[str] = None, + optimizer_path: Optional[str] = None, + init_optimizer: bool = True, + ): + self.cfg = config + # torch distributed init. Envars for rank, world_size, and master_addr and master_port are set from the ray remote call + torch.distributed.init_process_group(backend="nccl") + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + model_name = self.cfg["model_name"] + + print(f"[Rank {rank}] Loading model {model_name} on CPU...") + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map="cpu", # load weights onto CPU initially + torch_dtype=torch.bfloat16, # use half precision to save memory + ) + self.reference_model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map="cpu", # load weights onto CPU initially + torch_dtype=torch.bfloat16, # use half precision to save memory + ) + + self.tokenizer = tokenizer = AutoTokenizer.from_pretrained(model_name) + # If no pad token is defined, you might need: + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # ------------------------------------------------ + # 3) Move to GPU + Composable FSDP + # (Initialize device mesh, shard submodules, then shard entire model) + # ------------------------------------------------ + + def do_fsdp(model): + # Create a device mesh with 'world_size' GPUs in a 1D arrangement. + mesh = init_device_mesh("cuda", (world_size,)) + + # Mixed precision training + # https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision + param_dtype = torch.bfloat16 # use lower precision for model parameters + reduce_dtype = torch.float32 # use higher precision for gradient reduction + buffer_dtype = torch.float32 # use higher precision for optimizer states + + mp = MixedPrecision( + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + buffer_dtype=buffer_dtype, + ) + + return FullyShardedDataParallel( + model, + device_mesh=mesh, + auto_wrap_policy=size_based_auto_wrap_policy, + mixed_precision=mp, + ) + + self.model.to("cuda") + self.model = do_fsdp(self.model) + self.model = self.move_to_cpu(self.model) + self.reference_model.to("cuda") + self.reference_model = do_fsdp(self.reference_model) + self.reference_model = self.move_to_cpu(self.reference_model) + self.model.to("cuda") + self._held_reference_model_params = None + # register_fsdp_forward_method(self.model, "generate") + if init_optimizer: + self.optimizer = torch.optim.AdamW( + self.model.parameters(), lr=self.cfg["learning_rate"] + ) + else: + self.optimizer = None + + # restore + if weights_path: + self.load_checkpoint(weights_path, optimizer_path) + else: + print( + "No weights path provided. Starting from scratch (default policy init)" + ) + + if "scheduler" in self.cfg: + if isinstance(self.cfg["scheduler"], dict): + scheduler_cls = import_class_from_path(self.cfg["scheduler"]["name"]) + self.scheduler = scheduler_cls( + self.optimizer, **self.cfg["scheduler"]["kwargs"] + ) + else: + schedulers = [] + for scheduler_cfg in self.cfg["scheduler"]: + if "name" in scheduler_cfg: + schedulers.append( + import_class_from_path(scheduler_cfg["name"])( + self.optimizer, **scheduler_cfg["kwargs"] + ) + ) + else: + assert "milestones" in scheduler_cfg, ( + "unknown scheduler config: ", + scheduler_cfg, + ) + milestones = scheduler_cfg["milestones"] + + self.scheduler = torch.optim.lr_scheduler.SequentialLR( + self.optimizer, schedulers, milestones + ) + + else: + ## default to a passthrough LR schedule + self.scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, lr_lambda=lambda epoch: 1 + ) + + def is_alive(self): + return True + + def get_gpu_info(self): + """Return information about the GPU being used by this worker.""" + import torch + + # Get distributed training info + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + + # Get device info from CUDA + device = torch.cuda.current_device() + device_name = torch.cuda.get_device_name(device) + device_count = torch.cuda.device_count() + memory_allocated = torch.cuda.memory_allocated(device) / (1024**2) # in MB + memory_reserved = torch.cuda.memory_reserved(device) / (1024**2) # in MB + + # Try to get the real global device ID (not the local one) + # In distributed training, each process only sees its assigned GPU as device 0 + local_device_id = device + global_device_id = local_device_id + + if "CUDA_VISIBLE_DEVICES" in os.environ: + cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"].split(",") + if local_rank < len(cuda_visible_devices): + global_device_id = int(cuda_visible_devices[local_rank]) + + # Get a parameter from the model to verify CUDA device placement + # This confirms tensors are actually on the appropriate device + param_info = {} + for module_name, module in self.model.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + if param is not None and param.requires_grad: + full_name = f"{module_name}.{param_name}" + param_info[full_name] = { + "device": str(param.device), + "shape": list(param.shape), + "dtype": str(param.dtype), + } + # Just grab one parameter for verification + break + if param_info: + break + + return { + "rank": rank, + "world_size": world_size, + "local_rank": local_rank, + "local_device_id": local_device_id, + "global_device_id": global_device_id, + "device_count": device_count, + "device_name": device_name, + "memory_allocated_mb": memory_allocated, + "memory_reserved_mb": memory_reserved, + "parameter_sample": param_info, + "env_vars": { + k: v + for k, v in os.environ.items() + if k.startswith("CUDA") or k in ["LOCAL_RANK", "RANK", "WORLD_SIZE"] + }, + } + + def train(self, data: BatchedDataDict, loss_fn: LossFunction) -> Dict[str, Any]: + """Train the policy on a batch of data with a given loss function.""" + mbs = self.cfg["train_micro_batch_size"] + gbs = self.cfg["train_global_batch_size"] + local_gbs = gbs // torch.distributed.get_world_size() + dataset_size = data.get("input_ids").shape[0] + + # Ensure model is in training mode + self.model.train() + + # Get data from batch and move to device + data.to("cuda") + + losses = [] + all_mb_metrics = [] + for gb_start in range(0, dataset_size, local_gbs): + self.optimizer.zero_grad() + mb_losses = [] + for mb in data.slice( + gb_start, gb_start + local_gbs + ).make_microbatch_iterator(mbs): + input_ids = mb.get("input_ids") + + input_lengths = mb.get("input_lengths") + batch_size, seq_len = input_ids.shape + attention_mask = torch.ones( + (batch_size, seq_len), dtype=torch.long, device=input_ids.device + ) + for i, length in enumerate(input_lengths): + # For right-padded sequence, set 1s at the beginning of the sequence + attention_mask[i, :length] = 1 + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + use_cache=False, + ) + # Get logprobs + if not hasattr(outputs, "logits"): + logits = self.model.lm_head(outputs.last_hidden_state) + else: + logits = outputs.logits + + loss, loss_metrics = loss_fn(logits, mb) + + # Backward pass + loss.backward() + mb_losses.append(loss.item()) + all_mb_metrics.append(loss_metrics) + + # Clip gradients + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + + # Update parameters + self.optimizer.step() + self.scheduler.step() + losses.append(torch.tensor(mb_losses).mean().item()) + + # Compute global loss across all ranks + with torch.no_grad(): + local_loss = torch.tensor(losses, device="cuda") + global_loss = torch.zeros_like(local_loss) + torch.distributed.all_reduce(local_loss) + global_loss = local_loss / torch.distributed.get_world_size() + + # Aggregate metrics across all microbatches + mb_metrics = defaultdict(list) + for m in all_mb_metrics: + for k, v in m.items(): + mb_metrics[k].append(v) + + metrics = { + "global_loss": global_loss.cpu(), + "local_loss": local_loss.cpu(), + "rank": torch.distributed.get_rank(), + "all_mb_metrics": dict(mb_metrics), + } + + return metrics + + def get_logprobs(self, data: BatchedDataDict) -> BatchedDataDict: + """Get the logprobs of the model for a batch of data. + + Uses the configured logprob_batch_size to do microbatching. + + Input data is assumed to be right-padded. The method internally converts to + left-padded format for computation, and returns outputs in right-padded format. + + Returns: + a BatchedDataDict with key "logprobs" and shape [batch_size, sequence_length]. + We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. + The logprob of input token i is specified at position i in the output logprobs tensor. + """ + logprob_batch_size = self.cfg["logprob_batch_size"] + all_log_probs = [] + self.model.eval() + + # Process in batches + with torch.no_grad(): + data.to("cuda") + for lp_batch in data.make_microbatch_iterator(logprob_batch_size): + input_ids = lp_batch.get("input_ids") + batch_size, seq_len = input_ids.shape + + # Create attention mask + input_lengths = lp_batch.get("input_lengths") + + # Create attention mask for right-padded data + attention_mask = torch.zeros( + (batch_size, seq_len), dtype=torch.long, device=input_ids.device + ) + for i, length in enumerate(input_lengths): + # For right-padded sequence, set 1s at the beginning of the sequence + attention_mask[i, :length] = 1 + + # Process with the model directly using right-padded inputs + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + use_cache=False, + ) + log_probs = torch.nn.functional.log_softmax( + outputs.logits.to(torch.float32), dim=-1 + ) + + # Extract logprobs for each token in the sequence by gathering the logprob + # corresponding to the next token at each position + # Input shapes: + # log_probs: [batch_size, sequence_length, vocab_size] - logits for each position + # token_ids: [batch_size, sequence_length] - actual tokens + # Output shape: [batch_size, sequence_length] - logprob of each token given previous + # We get logprob of token[t+1] from logits[t], prepending 0 to maintain sequence length + token_ids = input_ids + next_tokens = token_ids[:, 1:] # Skip first token + log_probs = log_probs[:, :-1] # Remove last position's logits + token_logprobs = log_probs.gather( + dim=-1, index=next_tokens.unsqueeze(-1) + ).squeeze(-1) + + # Prepend 0 logprob for first token to maintain same sequence length as input + token_logprobs = torch.cat( + [torch.zeros_like(token_logprobs[:, :1]), token_logprobs], dim=1 + ) + + # Apply mask to zero out padding tokens logprobs + token_logprobs = token_logprobs * attention_mask + all_log_probs.append(token_logprobs) + + # Concatenate all batches + return_data = BatchedDataDict() + return_data["logprobs"] = torch.cat(all_log_probs, dim=0).cpu() + + return return_data + + @contextmanager + def use_reference_model(self): + """Context manager that temporarily swaps the reference model and active model. + + On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references + On exit: Restores original references and re-flips cuda/cpu + + """ + try: + # Save original references + original_model = self.model + original_reference_model = self.reference_model + + self.model = self.move_to_cpu(self.model) + self.reference_model = self.reference_model.to("cuda") + + # Swap the references + self.model, self.reference_model = self.reference_model, self.model + gc.collect() + torch.cuda.empty_cache() + + # - self.model is the original reference_model, now on CUDA + # - self.reference_model is the original model, now on CPU + yield + + finally: + # Restore original references and device placement + self.reference_model = self.move_to_cpu(original_reference_model) + self.model = original_model.to("cuda") + gc.collect() + torch.cuda.empty_cache() + + def get_reference_policy_logprobs(self, data: BatchedDataDict) -> BatchedDataDict: + """Get the logprobs from the reference policy for a batch of data. + + Returns: + a BatchedDataDict with key "reference_logprobs" and shape [batch_size, sequence_length]. + We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. + The logprob of input token i is specified at position i in the output logprobs tensor. + """ + with self.use_reference_model(): + reference_logprobs = self.get_logprobs(data) + + return_data = BatchedDataDict() + return_data["reference_logprobs"] = reference_logprobs["logprobs"].cpu() + return return_data + + def generate( + self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False + ) -> BatchedDataDict[GenerationOutputSpec]: + """Generate a batch of data using huggingface framework generation. + + Args: + data: BatchedDataDict containing input_ids and input_lengths tensors + + Returns: + BatchedDataDict conforming to GenerationOutputSpec: + - output_ids: input + generated token IDs + - logprobs: Log probabilities for each token + - generation_lengths: Lengths of each response + """ + # Verify input is right padded + assert isinstance(data, BatchedDataDict), ( + f"data must be a BatchedDataDict, got type: {type(data)}" + ) + assert "input_ids" in data and "input_lengths" in data, ( + f"input_ids and input_lengths must be present in the BatchedDataDict, got keys: {data.keys()}" + ) + is_right_padded, error_msg = verify_right_padding( + data, pad_value=self.tokenizer.pad_token_id + ) + if not is_right_padded: + warnings.warn( + f"Input to vLLM worker is not properly right-padded: {error_msg}" + ) + + self.model.eval() + + # Right padded tokens are converted to left padded tokens for HF generate (https://huggingface.co/docs/transformers/main/en/llm_tutorial?padding=right+pad#padding-side) + with torch.distributed.fsdp.FullyShardedDataParallel.summon_full_params( + self.model, recurse=False + ): + # Get generation config from self.cfg + generation_batch_size = self.cfg["generation_batch_size"] + gen_cfg = self.cfg["generation"] + + micro_batches = [] + + # Process in batches + max_length = 0 + for gen_batch in data.make_microbatch_iterator(generation_batch_size): + # Create attention mask from input_lengths if needed for the model + input_ids = gen_batch.get("input_ids").cuda() + input_lengths = gen_batch.get("input_lengths").cuda() + batch_size, seq_len = input_ids.shape + + # Convert right padding to left padding + left_padded_input_ids = torch.zeros_like(input_ids) + left_padded_attention_mask = torch.zeros( + (batch_size, seq_len), dtype=torch.long, device=input_ids.device + ) + + for i, length in enumerate(input_lengths): + # Move tokens to the end of the sequence (left padding) + left_padded_input_ids[i, seq_len - length :] = input_ids[i, :length] + # Set attention mask for the actual tokens (at the end for left padding) + left_padded_attention_mask[i, seq_len - length :] = 1 + + outputs = self.model.module.generate( + input_ids=left_padded_input_ids, + attention_mask=left_padded_attention_mask, + max_new_tokens=gen_cfg["max_new_tokens"], + do_sample=not greedy, + temperature=gen_cfg["temperature"], + top_p=gen_cfg["top_p"], + top_k=gen_cfg["top_k"], + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=self.tokenizer.eos_token_id, + return_dict_in_generate=True, + output_scores=True, + synced_gpus=True, + ) + # Get the generated sequences + max_length = max(max_length, outputs.sequences.size(1)) + + # Convert scores to log probabilities and extract the logprob of the chosen token + scores = torch.stack( + outputs.scores, dim=1 + ) # [batch_size, seq_len, vocab_size] + logprobs = torch.nn.functional.log_softmax(scores, dim=-1) + + # Get the logprobs of the actually generated tokens + # outputs.sequences[:, -scores.size(1):] gives us just the newly generated tokens + generated_tokens = outputs.sequences[:, -scores.size(1) :] + token_logprobs = logprobs.gather( + dim=-1, index=generated_tokens.unsqueeze(-1) + ).squeeze(-1) + + # Prepend zeros for input tokens based on original input lengths, not the padded length + mb = {} + mb["orig_input_lengths"] = input_lengths.clone() + mb["generation_logprobs"] = token_logprobs + mb["left_padded_output_ids"] = outputs.sequences + + micro_batches.append(mb) + + # Get lengths, pad, and concatenate all batches + return_data = BatchedDataDict.from_batches(micro_batches) + + # Calculate the lengths of generations for each sequence by finding stop tokens + generation_lengths = [] + unpadded_sequence_lengths = [] + input_length = data.get("input_ids").size(1) + + # Convert left-padded outputs back to right-padded format + batch_size = len(return_data["left_padded_output_ids"]) + max_seq_len = max( + [seq.size(0) for seq in return_data["left_padded_output_ids"]] + ) + right_padded_output_ids = torch.zeros( + (batch_size, max_seq_len), + dtype=return_data["left_padded_output_ids"][0].dtype, + device=return_data["left_padded_output_ids"][0].device, + ) + + for idx, seq in enumerate(return_data["left_padded_output_ids"]): + # Get only the generated part (excluding input) + original_length = return_data["orig_input_lengths"][idx].item() + seq_len = seq.size(0) + + # The generated content starts after the left-padded input + generated_part = seq[-(seq_len - input_length) :] + + eos_positions = (generated_part == self.tokenizer.eos_token_id).nonzero( + as_tuple=True + )[0] + # TODO @sahilj: handle different stopping criteria + # Calculate generation length + if len(eos_positions) > 0: + gen_length = ( + eos_positions[0].item() + 1 + ) # +1 to include the EOS token + else: + gen_length = len(generated_part) + + generation_lengths.append(gen_length) + + valid_length = original_length + gen_length + unpadded_sequence_lengths.append(valid_length) + + # Extract the original input tokens from the left-padded sequence + # For left-padded sequences, tokens are at the end of the input section + valid_input_part = ( + seq[input_length - original_length : input_length] + if original_length > 0 + else torch.tensor([], device=seq.device, dtype=seq.dtype) + ) + + # Combine with generated part + valid_generated_part = generated_part[:gen_length] + valid_tokens = torch.cat([valid_input_part, valid_generated_part]) + + # Place at the beginning of the right-padded sequence + right_padded_output_ids[idx, :valid_length] = valid_tokens + + # Store the right-padded outputs + return_data["output_ids"] = right_padded_output_ids + + # Align generation_logprobs with right-padded output format + batch_size = len(return_data["generation_logprobs"]) + right_padded_logprobs = torch.zeros( + (batch_size, max_seq_len), + dtype=return_data["generation_logprobs"][0].dtype, + device=return_data["generation_logprobs"][0].device, + ) + + for idx, logprob_seq in enumerate(return_data["generation_logprobs"]): + original_length = return_data["orig_input_lengths"][idx].item() + gen_length = generation_lengths[idx] + + # For right-padded format, we need: + # 1. Zeros for the original input tokens (at the beginning) + # 2. Actual logprobs for generated tokens (after the zeros) + # 3. Zeros padding at the end (if needed) + + right_padded_seq = torch.zeros( + max_seq_len, dtype=logprob_seq.dtype, device=logprob_seq.device + ) + right_padded_seq[original_length : original_length + gen_length] = ( + logprob_seq[:gen_length] + ) + right_padded_logprobs[idx] = right_padded_seq + valid_length = original_length + gen_length + + # Remove the temporary data we added + if "generation_logprobs" in return_data: + del return_data["generation_logprobs"] + if "orig_input_lengths" in return_data: + del return_data["orig_input_lengths"] + if "left_padded_output_ids" in return_data: + del return_data["left_padded_output_ids"] + + # Ensure consistent data types and device placement + return_data["output_ids"] = right_padded_output_ids + return_data["logprobs"] = right_padded_logprobs + return_data["generation_lengths"] = torch.tensor( + generation_lengths, dtype=torch.long + ) + return_data["unpadded_sequence_lengths"] = torch.tensor( + unpadded_sequence_lengths, dtype=torch.long + ) + + # Move everything to CPU before returning + return_data.to("cpu") + + return return_data + + def zero_out_weights(self): + """Zero out the weights of the model.""" + # TODO @sahilj: do this without a summon (maybe FSDP2) + with torch.distributed.fsdp.FullyShardedDataParallel.summon_full_params( + self.model, recurse=True + ): + for p in self.model.parameters(): + p.data.zero_() + torch.cuda.synchronize() + + def report_device_id(self) -> str: + from vllm.platforms import current_platform + + self.device_uuid = current_platform.get_device_uuid(torch.cuda.current_device()) + return self.device_uuid + + def get_weight_ipc_handles(self): + from torch.multiprocessing.reductions import reduce_tensor + + # TODO @sahilj: do this without an allgather (maybe FSDP2) + params = self.model.state_dict() + self._held_reference_model_params = params + data = {} + self.device_uuid = self.report_device_id() + for name, p in params.items(): + data[name] = reduce_tensor(p.detach()) + return {self.device_uuid: data} + + def prepare_for_lp_inference(self): + self.model.to("cuda") + self.model.eval() + self.offload_before_refit() + + def prepare_for_training(self, *args, **kwargs): + # onload models and optimizer state to cuda + self.model.to("cuda") + self.model.train() + + # Move optimizer state to CUDA if it exists + if hasattr(self, "optimizer") and self.optimizer is not None: + for state in self.optimizer.state.values(): + for k, v in state.items(): + if torch.is_tensor(v) and not v.is_cuda: + state[k] = v.to("cuda") + + torch.cuda.empty_cache() + + def offload_before_refit(self): + """Offload the optimizer and buffers to the CPU.""" + if hasattr(self, "optimizer") and self.optimizer is not None: + for state in self.optimizer.state.values(): + for k, v in state.items(): + if torch.is_tensor(v): + state[k] = v.to("cpu") + gc.collect() + torch.cuda.empty_cache() + + # Print memory stats after offloading + allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB + reserved = torch.cuda.memory_reserved() / (1024**3) # Convert to GB + print( + f"GPU Memory after optimizer offload: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved" + ) + + def offload_after_refit(self): + # Offload as much as possible on the CPU + self.model = self.move_to_cpu(self.model) + self.model.eval() + self.offload_before_refit() # rerun the old offload function + + if self._held_reference_model_params is not None: + del self._held_reference_model_params + self._held_reference_model_params = None + + gc.collect() + torch.cuda.empty_cache() + + allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB + reserved = torch.cuda.memory_reserved() / (1024**3) # Convert to GB + print( + f"GPU Memory after refit complete: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved" + ) + + def move_to_cpu(self, model): + for param in model.parameters(): + param.data = param.data.to("cpu") + + for buffer in model.buffers(): + buffer.data = buffer.data.to("cpu") + + if hasattr(model, "_fsdp_wrapped_module"): + model._fsdp_wrapped_module.to("cpu") + + return model + + def save_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = None): + # Config to save full state dict on rank 0, offloaded to CPU + state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + + with FullyShardedDataParallel.state_dict_type( + self.model, + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=state_dict_config, + ): + # Save model state dict + model_state_dict = self.model.state_dict() + optim_state_dict = FullyShardedDataParallel.optim_state_dict( + self.model, self.optimizer + ) + + if torch.distributed.get_rank() == 0: + # check if weights_path dir exists + weights_dir = os.path.dirname(weights_path) + if not os.path.exists(weights_dir): + print( + f"Creating weights directory {weights_dir} DOESN'T EXIST SOMEHOW" + ) + os.makedirs(weights_dir) + torch.save(model_state_dict, weights_path) + if optimizer_path is not None: + torch.save(optim_state_dict, optimizer_path) + + def load_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = None): + print(f"Loading Policy from {weights_path} and optimizer from {optimizer_path}") + state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + + state_dict = torch.load(weights_path) + if optimizer_path is not None: + optimizer_state_dict = torch.load(optimizer_path) + else: + optimizer_state_dict = None + + with FullyShardedDataParallel.state_dict_type( + self.model, + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=state_dict_config, + ): + # Load model weights + self.model.load_state_dict(state_dict if state_dict else None) + + # Load optimizer state + if optimizer_state_dict is not None: + optim_state_dict = FullyShardedDataParallel.shard_full_optim_state_dict( + optimizer_state_dict, self.model + ) + if self.optimizer is not None: + self.optimizer.load_state_dict(optim_state_dict) + else: + print("WARNING: initializing without optimizer") + else: + print("WARNING: No optimizer checkpoint provided") + + +class HfPolicy(PolicyInterface, GenerationInterface): + def __init__( + self, + cluster: RayVirtualCluster, + config: PolicyConfig, + name_prefix: str = "hf_policy", + workers_per_node: Optional[Union[int, List[int]]] = None, + init_optimizer: bool = True, + weights_path: Optional[str] = None, + optimizer_path: Optional[str] = None, + ): + worker_builder = RayWorkerBuilder( + HfPolicyWorker, + config, + init_optimizer=init_optimizer, + weights_path=weights_path, + optimizer_path=optimizer_path, + ) + self.worker_group = RayWorkerGroup( + cluster, + worker_builder, + name_prefix=name_prefix, + workers_per_node=workers_per_node, + ) + self.dp_size = self.worker_group.world_size + self.cfg = config + + def get_logprobs( + self, data: BatchedDataDict[GenerationDatumSpec] + ) -> BatchedDataDict: + """Get the logprobs of the model for a data dict. + + Returns: + a BatchedDataDict with key "logprobs" and shape [batch_size, sequence_length]. + We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. + The logprob of input token i is specified at position i in the output logprobs tensor. + """ + sharded_data = data.shard_by_batch_size(self.dp_size, batch_size=None) + futures = self.worker_group.run_all_workers_multiple_data( + "get_logprobs", sharded_data + ) + logprobs = BatchedDataDict.from_batches( + self.worker_group.get_all_worker_results(futures) + ) + return logprobs + + def get_reference_policy_logprobs( + self, data: BatchedDataDict[GenerationDatumSpec] + ) -> BatchedDataDict: + """Get the logprobs of the reference policy for a data dict. + + Returns: Identical to get_logprobs. + """ + sharded_data = data.shard_by_batch_size(self.dp_size, batch_size=None) + futures = self.worker_group.run_all_workers_multiple_data( + "get_reference_policy_logprobs", sharded_data + ) + logprobs = BatchedDataDict.from_batches( + self.worker_group.get_all_worker_results(futures) + ) + return logprobs + + def train(self, data: BatchedDataDict, loss_fn: LossFunction): + """Train the policy on a batch of data with a given loss function.""" + # Shard and replicate the batch + shards = self.dp_size + sharded_data = data.shard_by_batch_size( + shards, batch_size=self.cfg["train_global_batch_size"] + ) + + # Train each shard in parallel + futures = self.worker_group.run_all_workers_multiple_data( + "train", sharded_data, common_kwargs={"loss_fn": loss_fn} + ) + results = self.worker_group.get_all_worker_results(futures) + + # Aggregate the results + aggregated_results = {} + aggregated_results["loss"] = results[0]["global_loss"] + + # Aggregate metrics across all workers + all_mb_metrics = defaultdict(list) + for r in results: + for k, v in r["all_mb_metrics"].items(): + all_mb_metrics[k].extend(v) + aggregated_results["all_mb_metrics"] = dict(all_mb_metrics) + + return aggregated_results + + def generate( + self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False + ) -> BatchedDataDict[GenerationOutputSpec]: + """Generate a batch of data using the policy.""" + # Verify input data is right-padded + assert isinstance(data, BatchedDataDict), ( + f"data must be a BatchedDataDict, got type: {type(data)}" + ) + assert "input_ids" in data and "input_lengths" in data, ( + "Missing required input fields" + ) + + sharded_data = data.shard_by_batch_size(self.dp_size, batch_size=None) + futures = self.worker_group.run_all_workers_multiple_data( + "generate", sharded_data, common_kwargs={"greedy": greedy} + ) + result = BatchedDataDict.from_batches( + self.worker_group.get_all_worker_results(futures) + ) + + # Verify the output has all required fields + required_keys = [ + "output_ids", + "generation_lengths", + "unpadded_sequence_lengths", + "logprobs", + ] + missing_keys = [key for key in required_keys if key not in result] + if missing_keys: + raise ValueError( + f"Missing required keys for GenerationOutputSpec: {missing_keys}" + ) + + return result + + def prepare_for_generation(self, *args, **kwargs): + # We don't need to do anything here + pass + + def prepare_for_training(self, *args, **kwargs): + # onload everything to the GPU + futures = self.worker_group.run_all_workers_single_data( + "prepare_for_training", respect_tied_workers=True + ) + ray.get(futures) + pass + + def prepare_for_lp_inference(self, *args, **kwargs): + futures = self.worker_group.run_all_workers_single_data( + "prepare_for_lp_inference", respect_tied_workers=True + ) + ray.get(futures) + + def finish_generation(self, *args, **kwargs): + # We don't need to do anything here + pass + + def finish_training(self, *args, **kwargs): + # Placeholder implementation + pass + + def get_weights_ipc_handles(self): + """Fetch weight IPC handles from all workers. + + Returns: + dict: A dictionary mapping device UUIDs to parameter IPC handles. + """ + # Collect IPC handles from all workers + worker_handles = ray.get( + [ + worker.get_weight_ipc_handles.remote() + for worker in self.worker_group.workers + ] + ) + + # Combine all worker handles into a single dictionary + all_handles = {} + for handle in worker_handles: + all_handles.update(handle) + + return all_handles + + def offload_before_refit(self): + """Offload the optimizer and buffers to the CPU.""" + futures = self.worker_group.run_all_workers_single_data( + "offload_before_refit", respect_tied_workers=True + ) + ray.get(futures) + + def offload_after_refit(self): + """Offload the optimizer and buffers to the CPU.""" + futures = self.worker_group.run_all_workers_single_data( + "offload_after_refit", respect_tied_workers=True + ) + ray.get(futures) + + def save_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = None): + """Save a checkpoint of the model.""" + futures = self.worker_group.run_all_workers_single_data( + "save_checkpoint", + weights_path, + optimizer_path, + respect_tied_workers=True, + ) + ray.get(futures) + + def shutdown(self) -> bool: + """Shut down all HF workers and clean up resources.""" + try: + # Use the worker group's shutdown method with the worker's cleanup method + return self.worker_group.shutdown(cleanup_method="shutdown") + except Exception as e: + print(f"Error during policy shutdown: {e}") + return False + + def __del__(self): + """Shuts down the worker groups when the object is deleted or is garbage collected. + + This is an extra safety net in case the user forgets to call worker_group.shutdown() and the pointer to + the object is lost due to leaving a function scope. It's always recommended that the + user calls worker_group.shutdown(). + """ + self.worker_group.shutdown() diff --git a/nemo_reinforcer/models/policy/utils.py b/nemo_reinforcer/models/policy/utils.py new file mode 100644 index 0000000000..3035de5360 --- /dev/null +++ b/nemo_reinforcer/models/policy/utils.py @@ -0,0 +1,29 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib + + +def import_class_from_path(name): + """Import a class from a string path (e.g. 'torch.optim.AdamW'). + + Args: + full_path: Full path to class including module path and class name + + Returns: + The imported class object + """ + module_name, cls_name = name.rsplit(".", 1) + cls_instance = getattr(importlib.import_module(module_name), cls_name) + return cls_instance diff --git a/nemo_reinforcer/utils/__init__.py b/nemo_reinforcer/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_reinforcer/utils/checkpoint.py b/nemo_reinforcer/utils/checkpoint.py new file mode 100644 index 0000000000..80344b3cda --- /dev/null +++ b/nemo_reinforcer/utils/checkpoint.py @@ -0,0 +1,273 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Checkpoint management utilities for the rl algorithm loop. + +It handles logic at the algorithm level. Each RL Actor is expected to have its +own checkpoint saving function (called by the algorithm loop). +""" + +import os +import json +import glob +from typing import Dict, Any, Optional, List, Tuple, TypedDict +import shutil +from pathlib import Path +import torch +import numpy as np + + +class CheckpointingConfig(TypedDict): + """Configuration for checkpoint management. + + Attributes: + enabled (bool): Whether checkpointing is enabled. + checkpoint_dir (os.PathLike): Directory where checkpoints will be saved. + metric_name (str): Name of the metric to use for determining best checkpoints. + higher_is_better (bool): Whether higher values of the metric indicate better performance. + keep_top_k (Optional[int]): Number of best checkpoints to keep. If None, all checkpoints are kept. + """ + + enabled: bool + checkpoint_dir: os.PathLike + metric_name: str + higher_is_better: bool + keep_top_k: Optional[int] + + +class CheckpointManager: + """Manages model checkpoints during training. + + This class handles creating checkpoint dirs, saving training info, and + configurations. It also provides utilities for keeping just the top-k checkpoints. + The checkpointing structure looks like this: + ``` + checkpoint_dir/ + step_0/ + training_info.json + config.json + policy.py (up to the algorithm loop to save here) + policy_optimizer.py (up to the algorithm loop to save here) + ... + step_1/ + ... + ``` + + Attributes: Derived from the CheckpointingConfig. + """ + + def __init__(self, config: CheckpointingConfig): + """Initialize the checkpoint manager. + + Args: + config (CheckpointingConfig) + """ + self.checkpoint_dir = Path(config["checkpoint_dir"]) + self.metric_name = config["metric_name"] + self.higher_is_better = config["higher_is_better"] + self.keep_top_k = config["keep_top_k"] + + def init_tmp_checkpoint( + self, + step: int, + training_info: Dict[str, Any], + run_config: Optional[Dict[str, Any]] = None, + ) -> os.PathLike: + """Initialize a temporary checkpoint directory. + + Creates a temporary directory for a new checkpoint and saves training info + and configuration. The directory is named 'tmp_step_{step}' and will be renamed + to 'step_{step}' when the checkpoint is completed. + We do it this way to allow the algorithm loop to save any files it wants to save + in a safe, temporary directory. + + Args: + step (int): The training step number. + training_info (Dict[str, Any]): Dictionary containing training metrics and info. + run_config (Optional[Dict[str, Any]]): Optional configuration for the training run. + + Returns: + os.PathLike: Path to the temporary checkpoint directory. + """ + # create new step_{step} directory + save_dir = self.checkpoint_dir / f"tmp_step_{step}" + save_dir.mkdir(parents=True, exist_ok=True) + + # save training info + with open(save_dir / "training_info.json", "w") as f: + # make any numpy items serializable + for k, v in training_info.items(): + if isinstance(v, torch.Tensor) or isinstance(v, np.ndarray): + training_info[k] = v.item() + json.dump(training_info, f) + + # save config + if run_config is not None: + with open(save_dir / "config.json", "w") as f: + json.dump(run_config, f) + + return save_dir + + def finalize_checkpoint(self, checkpoint_path: os.PathLike) -> None: + """Complete a checkpoint by moving it from temporary to permanent location. + + If a checkpoint at the target location already exists (i.e when resuming training), + we override the old one. + Also triggers cleanup of old checkpoints based on the keep_top_k setting. + + Args: + checkpoint_path (os.PathLike): Path to the temporary checkpoint directory. + """ + # rename tmp_step_{step} to step_{step} + checkpoint_path = Path(checkpoint_path) + to_checkpoint_path = ( + checkpoint_path.parent / f"step_{checkpoint_path.name.split('_')[2]}" + ) + if to_checkpoint_path.exists(): + # if step_{step} exists, rename it to old_step_{step}, move tmp_step_{step} to step_{step}, then delete + # we do this trickery to have a 'pseudo-atomic' checkpoint save + old_checkpoint_path = ( + checkpoint_path.parent + / f"old_step_{checkpoint_path.name.split('_')[2]}" + ) + os.rename(to_checkpoint_path, old_checkpoint_path) + os.rename(checkpoint_path, to_checkpoint_path) + # delete old_step_{step} + if old_checkpoint_path.exists(): + shutil.rmtree(old_checkpoint_path) + else: + os.rename(checkpoint_path, to_checkpoint_path) + self.remove_old_checkpoints() + + def remove_old_checkpoints(self, exclude_latest: bool = True) -> None: + """Remove checkpoints that are not in the top-k or latest based on the metric. + + If keep_top_k is set, this method removes all checkpoints except the top-k + best ones based on the specified metric. The best checkpoints are determined + by the metric value and the higher_is_better setting. When multiple checkpoints + have the same metric value, more recent checkpoints (higher step numbers) are preferred. + + Args: + exclude_latest (bool): Whether to exclude the latest checkpoint from deletion. (may result in K+1 checkpoints) + """ + if self.keep_top_k is None: + return + checkpoint_history = _load_checkpoint_history(self.checkpoint_dir) + latest_step = ( + max([step for step, _, _ in checkpoint_history]) + if checkpoint_history + else None + ) + # sort by metric value first, then by step number (for equal metrics, prefer more recent) + if self.higher_is_better: + # For higher_is_better=True: higher metric values first, then higher step numbers + checkpoint_history.sort( + key=lambda x: (x[2][self.metric_name], x[0]), reverse=True + ) + else: + # For higher_is_better=False: lower metric values first, then higher step numbers for equal values + checkpoint_history.sort(key=lambda x: (x[2][self.metric_name], -x[0])) + + # remove checkpoints that are not in the top-k + for checkpoint in checkpoint_history[self.keep_top_k :]: + if exclude_latest and checkpoint[0] == latest_step: + continue + print( + f"Removing checkpoint {checkpoint[1]} due to being outside top-{self.keep_top_k}, metric: {checkpoint[2][self.metric_name]}" + ) + shutil.rmtree(checkpoint[1]) + + def get_best_checkpoint_path(self) -> Optional[str]: + """Get the path to the best checkpoint based on the metric. + + Returns the path to the checkpoint with the best metric value. If no checkpoints + exist, returns None. If the metric isn't found, we warn and return the latest checkpoint. + + Returns: + Optional[str]: Path to the best checkpoint, or None if no valid checkpoints exist. + """ + checkpoint_history = _load_checkpoint_history(self.checkpoint_dir) + if len(checkpoint_history) == 0: + return None + # sort by metric value + if self.metric_name not in checkpoint_history[0][2]: + print( + f"WARNING:Metric {self.metric_name} not found in checkpoint history. Returning last" + ) + return self.get_latest_checkpoint_path() + + checkpoint_history.sort( + key=lambda x: x[2][self.metric_name], reverse=self.higher_is_better + ) + return str(checkpoint_history[0][1]) + + def get_latest_checkpoint_path(self) -> str: + """Get the path to the latest checkpoint. + + Returns the path to the checkpoint with the highest step number. + + Returns: + str: Path to the latest checkpoint, or None if no checkpoints exist. + """ + # find checkpoint directory with highest step number + step_dirs = glob.glob(str(self.checkpoint_dir / "step_*")) + step_dirs.sort(key=lambda x: int(Path(x).name.split("_")[1])) + if len(step_dirs) == 0: + return None + return str(step_dirs[-1]) + + def load_training_info( + self, checkpoint_path: Optional[os.PathLike] = None + ) -> Dict[str, Any]: + """Load the training info from a checkpoint. + + Args: + checkpoint_path (Optional[os.PathLike]): Path to the checkpoint. If None, + returns None. + + Returns: + Dict[str, Any]: Dictionary containing the training info, or None if + checkpoint_path is None. + """ + if checkpoint_path is None: + return None + with open(Path(checkpoint_path) / "training_info.json", "r") as f: + return json.load(f) + + +def _load_checkpoint_history( + checkpoint_dir: Path, +) -> List[Tuple[int, os.PathLike, Dict[str, Any]]]: + """Load the history of checkpoints and their metrics. + + Args: + checkpoint_dir (Path): Directory containing the checkpoints. + + Returns: + List[Tuple[int, os.PathLike, Dict[str, Any]]]: List of tuples containing + (step_number, checkpoint_path, info) for each checkpoint. + """ + checkpoint_history: List[Tuple[int, os.PathLike, Dict[str, Any]]] = [] + + # Find all step directories + step_dirs = glob.glob(str(checkpoint_dir / "step_*")) + + for step_dir in step_dirs: + info_file = Path(step_dir) / "training_info.json" + if info_file.exists(): + with open(info_file) as f: + info = json.load(f) + step = int(Path(step_dir).name.split("_")[1]) + checkpoint_history.append((step, step_dir, info)) + + return checkpoint_history diff --git a/nemo_reinforcer/utils/config.py b/nemo_reinforcer/utils/config.py new file mode 100644 index 0000000000..51418e3b1b --- /dev/null +++ b/nemo_reinforcer/utils/config.py @@ -0,0 +1,132 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Optional, Union + +from omegaconf import DictConfig, ListConfig, OmegaConf + + +def resolve_path(base_path: Path, path: str) -> Path: + """Resolve a path relative to the base path.""" + if path.startswith("/"): + return Path(path) + return base_path / path + + +def load_config_with_inheritance( + config_path: Union[str, Path], + base_dir: Optional[Union[str, Path]] = None, +) -> DictConfig: + """Load a config file with inheritance support. + + Args: + config_path: Path to the config file + base_dir: Base directory for resolving relative paths. If None, uses config_path's directory + + Returns: + Merged config dictionary + """ + config_path = Path(config_path) + if base_dir is None: + base_dir = config_path.parent + base_dir = Path(base_dir) + + config = OmegaConf.load(config_path) + + # Handle inheritance + if "defaults" in config: + defaults = config.pop("defaults") + if isinstance(defaults, (str, Path)): + defaults = [defaults] + elif isinstance(defaults, ListConfig): + defaults = [str(d) for d in defaults] + + # Load and merge all parent configs + base_config = OmegaConf.create({}) + for default in defaults: + parent_path = resolve_path(base_dir, default) + parent_config = load_config_with_inheritance(parent_path, base_dir) + base_config = OmegaConf.merge(base_config, parent_config) + + # Merge with current config + config = OmegaConf.merge(base_config, config) + + return config + + +def load_config(config_path: Union[str, Path]) -> DictConfig: + """Load a config file with inheritance support and convert it to an OmegaConf object. + + The config inheritance system supports: + + 1. Single inheritance: + ```yaml + # child.yaml + defaults: parent.yaml + common: + value: 43 + ``` + + 2. Multiple inheritance: + ```yaml + # child.yaml + defaults: + - parent1.yaml + - parent2.yaml + common: + value: 44 + ``` + + 3. Nested inheritance: + ```yaml + # parent.yaml + defaults: grandparent.yaml + common: + value: 43 + + # child.yaml + defaults: parent.yaml + common: + value: 44 + ``` + + 4. Variable interpolation: + ```yaml + # parent.yaml + base_value: 42 + derived: + value: ${base_value} + + # child.yaml + defaults: parent.yaml + base_value: 43 # This will update both base_value and derived.value + ``` + + The system handles: + - Relative and absolute paths + - Multiple inheritance + - Nested inheritance + - Variable interpolation + + The inheritance is resolved depth-first, with later configs overriding earlier ones. + This means in multiple inheritance, the last config in the list takes precedence. + + Args: + config_path: Path to the config file + + Returns: + Merged config dictionary + """ + return load_config_with_inheritance(config_path) diff --git a/nemo_reinforcer/utils/logger.py b/nemo_reinforcer/utils/logger.py new file mode 100644 index 0000000000..2286075cc9 --- /dev/null +++ b/nemo_reinforcer/utils/logger.py @@ -0,0 +1,467 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import re +import glob +from abc import ABC, abstractmethod +import logging +from typing import List, Any, Dict, Optional, TypedDict +import wandb +from rich.console import Console +from rich.panel import Panel +from rich.box import ROUNDED +from rich.logging import RichHandler + +from nemo_reinforcer.data.interfaces import LLMMessageLogType +from torch.utils.tensorboard import SummaryWriter + +# Flag to track if rich logging has been configured +_rich_logging_configured = False + + +class WandbConfig(TypedDict): + project: str + name: str + + +class TensorboardConfig(TypedDict): + log_dir: str + + +class LoggerConfig(TypedDict): + log_dir: str + wandb_enabled: bool + tensorboard_enabled: bool + wandb: WandbConfig + tensorboard: TensorboardConfig + + +class LoggerInterface(ABC): + """Abstract base class for logger backends.""" + + @abstractmethod + def log_metrics( + self, metrics: Dict[str, Any], step: int, prefix: Optional[str] = "" + ) -> None: + """Log a dictionary of metrics.""" + pass + + @abstractmethod + def log_hyperparams(self, params: Dict[str, Any]) -> None: + """Log dictionary of hyperparameters.""" + pass + + +class TensorboardLogger(LoggerInterface): + """Tensorboard logger backend.""" + + def __init__(self, cfg: TensorboardConfig, log_dir: Optional[str] = None): + self.writer = SummaryWriter(log_dir=log_dir) + print(f"Initialized TensorboardLogger at {log_dir}") + + def log_metrics( + self, metrics: Dict[str, Any], step: int, prefix: Optional[str] = "" + ) -> None: + """Log metrics to Tensorboard. + + Args: + metrics: Dict of metrics to log + step: Global step value + prefix: Optional prefix for metric names + """ + for name, value in metrics.items(): + if prefix: + name = f"{prefix}/{name}" + self.writer.add_scalar(name, value, step) + + def log_hyperparams(self, params: Dict[str, Any]) -> None: + """Log hyperparameters to Tensorboard. + + Args: + params: Dictionary of hyperparameters to log + """ + # Flatten the params because add_hparams does not support nested dicts + self.writer.add_hparams(flatten_dict(params), {}) + + +class WandbLogger(LoggerInterface): + """Weights & Biases logger backend.""" + + def __init__(self, cfg: WandbConfig, log_dir: Optional[str] = None): + self.run = wandb.init(**cfg, dir=log_dir) + print( + f"Initialized WandbLogger for project {cfg.get('project')}, run {cfg.get('name')} at {log_dir}" + ) + + def log_metrics( + self, metrics: Dict[str, Any], step: int, prefix: Optional[str] = "" + ) -> None: + """Log metrics to wandb. + + Args: + metrics: Dict of metrics to log + step: Global step value + prefix: Optional prefix for metric names + """ + if prefix: + metrics = {f"{prefix}/{k}": v for k, v in metrics.items()} + + self.run.log(metrics, step=step) + + def log_hyperparams(self, params: Dict[str, Any]) -> None: + """Log hyperparameters to wandb. + + Args: + params: Dict of hyperparameters to log + """ + self.run.config.update(params) + + +class Logger(LoggerInterface): + """Main logger class that delegates to multiple backend loggers.""" + + def __init__(self, cfg: LoggerConfig): + """Initialize the logger. + + Args: + cfg: Config dict with the following keys: + - wandb_enabled + - tensorboard_enabled + - wandb + - tensorboard + """ + self.loggers = [] + + self.base_log_dir = cfg["log_dir"] + os.makedirs(self.base_log_dir, exist_ok=True) + + if cfg["wandb_enabled"]: + wandb_log_dir = os.path.join(self.base_log_dir, "wandb") + os.makedirs(wandb_log_dir, exist_ok=True) + wandb_logger = WandbLogger(cfg["wandb"], log_dir=wandb_log_dir) + self.loggers.append(wandb_logger) + + if cfg["tensorboard_enabled"]: + tensorboard_log_dir = os.path.join(self.base_log_dir, "tensorboard") + os.makedirs(tensorboard_log_dir, exist_ok=True) + tensorboard_logger = TensorboardLogger( + cfg["tensorboard"], log_dir=tensorboard_log_dir + ) + self.loggers.append(tensorboard_logger) + + if not self.loggers: + print("No loggers initialized") + + def log_metrics( + self, metrics: Dict[str, Any], step: int, prefix: Optional[str] = "" + ) -> None: + """Log metrics to all enabled backends. + + Args: + metrics: Dict of metrics to log + step: Global step value + prefix: Optional prefix for metric names + """ + for logger in self.loggers: + logger.log_metrics(metrics, step, prefix) + + def log_hyperparams(self, params: Dict[str, Any]) -> None: + """Log hyperparameters to all enabled backends. + + Args: + params: Dict of hyperparameters to log + """ + for logger in self.loggers: + logger.log_hyperparams(params) + + +def flatten_dict(d: Dict[str, Any], sep: str = ".") -> Dict[str, Any]: + """Flatten a nested dictionary.""" + result = {} + + def _flatten(d, parent_key=""): + for key, value in d.items(): + new_key = f"{parent_key}{sep}{key}" if parent_key else key + + if isinstance(value, dict): + _flatten(value, new_key) + else: + result[new_key] = value + + _flatten(d) + return result + + +""" +Rich Console Logging Functionality +--------------------------------- +Functions for setting up rich console logging and visualizing model outputs. +""" + + +def configure_rich_logging( + level: str = "INFO", show_time: bool = True, show_path: bool = True +) -> None: + """Configure rich logging for more visually appealing log output. + + Args: + level: The logging level to use + show_time: Whether to show timestamps in logs + show_path: Whether to show file paths in logs + """ + global _rich_logging_configured + + # Only configure if not already done + if not _rich_logging_configured: + # Configure logging with rich handler + logging.basicConfig( + level=level.upper(), + format="%(message)s", + datefmt="[%X]", + handlers=[ + RichHandler( + rich_tracebacks=True, + show_time=show_time, + show_path=show_path, + markup=True, + ) + ], + ) + _rich_logging_configured = True + + +def print_message_log_samples( + message_logs: List[LLMMessageLogType], + rewards: List[float], + num_samples: int = 5, + step: int = 0, +) -> None: + """Visualization for message logs and rewards using a more visual approach with emoji indicators and horizontal layout. + + Args: + message_logs: List of message logs to sample from + rewards: List of rewards corresponding to each message log + num_samples: Number of samples to display (default: 5) + step: Current training step (for display purposes) + """ + # Make sure rich logging is configured before printing + configure_rich_logging(level="INFO") + + if not message_logs or not rewards: + return + + if num_samples <= 0: + return + + # Sample up to num_samples (or all if less) + num_to_show = min(num_samples, len(message_logs)) + indices = list(range(len(message_logs))) + + # If we have more samples than needed, prioritize showing a mix of high and low rewards + if len(indices) > num_to_show: + # Sort indices by reward + sorted_indices = sorted(indices, key=lambda i: rewards[i], reverse=True) + # Take some from the top and some from the bottom + half = num_to_show // 2 + indices = sorted_indices[:half] + sorted_indices[-half:] + # If num_to_show is odd, add a middle sample + if num_to_show % 2 == 1: + middle_idx = len(sorted_indices) // 2 + indices.append(sorted_indices[middle_idx]) + indices = indices[:num_to_show] + + console = Console() + + # Header with step information + console.rule(f"[bold bright_white on purple4]TRAINING STEP {step}") + + # Count the unique reward values + all_rewards = rewards.copy() + unique_rewards = sorted(set(all_rewards)) + reward_counts = {r: all_rewards.count(r) for r in unique_rewards} + + # Create a bar chart for discrete reward levels + max_count = max(reward_counts.values()) if reward_counts else 1 + + # Create discrete reward level visualization + discrete_lines = [] + discrete_lines.append("[bold bright_white]Discrete Reward Levels:[/]") + + # Get emoji for each reward level + def get_reward_emoji(reward): + if reward >= 0.7: + return "πŸ”₯" # Excellent + elif reward >= 0.3: + return "✨" # Good + elif reward >= -0.5: + return "🟠" # Poor + else: + return "πŸ”΄" # Very poor + + # Create a bar for each discrete reward level + for reward in unique_rewards: + count = reward_counts[reward] + emoji = get_reward_emoji(reward) + bar_len = int((count / max_count) * 20) + + # Choose different bar characters and colors + if reward > 0.5: + bar_char = "β–ˆ" + color = "bright_green" + elif reward > 0: + bar_char = "β–ˆ" + color = "green" + elif reward == 0: + bar_char = "β–’" + color = "bright_white" + elif reward > -0.5: + bar_char = "β–“" + color = "orange3" + else: + bar_char = "β–ˆ" + color = "red" + + bar = f"[{color}]{bar_char * bar_len}[/]" + # Format with color based on reward value + discrete_lines.append( + f"{emoji} Reward [bold {color}]{reward:.4f}[/]: {bar} ({count} samples)" + ) + + # Create a summary panel + avg_reward = sum(all_rewards) / len(all_rewards) if all_rewards else 0 + stats_text = ( + f"[bold]Batch Summary[/]\n" + f"Total Samples: [bright_yellow]{len(all_rewards)}[/]\n" + f"Avg Reward: [bright_blue]{avg_reward:.4f}[/]\n" + f"Min: [orange3]{min(all_rewards):.4f}[/] | Max: [bright_green]{max(all_rewards):.4f}[/]\n\n" + + "\n".join(discrete_lines) + ) + + stats_panel = Panel( + stats_text, + title="[bold purple4]Reward Statistics", + border_style="purple4", + box=ROUNDED, + ) + + # Display the stats panel + console.print(stats_panel) + + # Display the samples with horizontal layout + console.print("\n[bold bright_white]Sample Conversations[/]") + + # Helper function to safely render content that might have problematic markups + def safe_render(content, role_color): + # Fix common problematic patterns that might break Rich markup + # Replace any standalone [/ without matching closing bracket + content = content.replace("[/", "\\[/") + # Replace any standalone [ that isn't followed by a valid tag with escaped version + import re + + content = re.sub(r"\[(?![a-z_]+\s|/[a-z_]+\])", "\\[", content) + return f"[{role_color}]{content}[/]" + + for i, idx in enumerate(indices): + message_log = message_logs[idx] + reward = rewards[idx] + + # Format each message in the conversation + message_parts = [] + for msg in message_log: + role = msg.get("role", "unknown").upper() + content = msg.get("content", "") + + # Choose color based on role - using muted, elegant colors + if role == "SYSTEM": + message_parts.append( + f"[bold #8A2BE2]{role}:[/] {safe_render(content, '#8A2BE2')}" + ) + elif role == "USER": + message_parts.append( + f"[bold #4682B4]{role}:[/] {safe_render(content, '#4682B4')}" + ) + elif role == "ASSISTANT": + message_parts.append( + f"[bold #2E8B57]{role}:[/] {safe_render(content, '#2E8B57')}" + ) + else: + message_parts.append(f"[bold]{role}:[/] {content}") + + # Get reward emoji + emoji = get_reward_emoji(reward) + + # Choose color based on reward + if reward > 0.5: + color = "bright_green" + elif reward > 0: + color = "green" + elif reward == 0: + color = "bright_white" + elif reward > -0.5: + color = "orange3" + else: + color = "red" + + content = "\n\n".join(message_parts) + + panel = Panel( + content, + title=f"[bold]{emoji} Sample {i + 1} | Reward: {reward:.4f}", + border_style=color, + box=ROUNDED, + ) + + console.print(panel) + console.print("") # Add some spacing + + console.rule("[bold bright_white on purple4]End of Samples") + + +def get_next_experiment_dir(base_log_dir): + """Create a new experiment directory with an incremented ID. + + Args: + base_log_dir (str): The base log directory path + + Returns: + str: Path to the new experiment directory with incremented ID + """ + # Check if the log directory already contains an experiment ID pattern (e.g., /exp_001/) + pattern = re.compile(r"exp_(\d+)") + next_exp_id = 1 + + # Check for existing experiment directories + existing_dirs = glob.glob(os.path.join(base_log_dir, "exp_*")) + + if existing_dirs: + # Extract experiment IDs and find the maximum + exp_ids = [] + for dir_path in existing_dirs: + match = pattern.search(dir_path) + if match: + exp_ids.append(int(match.group(1))) + + if exp_ids: + # Increment the highest experiment ID + next_exp_id = max(exp_ids) + 1 + + # Format the new log directory with the incremented experiment ID + new_log_dir = os.path.join(base_log_dir, f"exp_{next_exp_id:03d}") + + # Create the new log directory + os.makedirs(new_log_dir, exist_ok=True) + + return new_log_dir diff --git a/nemo_reinforcer/utils/timer.py b/nemo_reinforcer/utils/timer.py new file mode 100644 index 0000000000..5796b1da39 --- /dev/null +++ b/nemo_reinforcer/utils/timer.py @@ -0,0 +1,246 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from contextlib import contextmanager +from typing import Dict, List, Optional, Union +import numpy as np + + +class Timer: + """A utility for timing code execution. + + Supports two usage patterns: + 1. Explicit start/stop: timer.start("label"), timer.stop("label") + 2. Context manager: with timer.time("label"): ... + + The timer keeps track of multiple timing measurements for each label, + and supports different reductions on these measurements (mean, median, + min, max, std dev). + + Example usage: + ``` + timer = Timer() + + # Method 1: start/stop + timer.start("load_data") + data = load_data() + timer.stop("load_data") + + # Method 2: context manager + with timer.time("model_forward"): + model_outputs = model(inputs) + + # Multiple timing measurements for the same operation + for batch in dataloader: + with timer.time("model_forward_multiple"): + outputs = model(batch) + + # Get all times for one label + model_forward_times = timer.get_elapsed("model_forward_multiple") + + # Get reductions for one label + mean_forward_time = timer.reduce("model_forward_multiple") + max_forward_time = timer.reduce("model_forward_multiple", "max") + ``` + """ + + # Define valid reduction types and their corresponding NumPy functions + _REDUCTION_FUNCTIONS = { + "mean": np.mean, + "median": np.median, + "min": np.min, + "max": np.max, + "std": np.std, + "sum": np.sum, + "count": len, + } + + def __init__(self): + # Dictionary mapping labels to lists of elapsed times + # We store a list of times for each label rather than a single value + # to support multiple timing runs with the same label (e.g., in loops) + # This allows calculating reductions like mean, min, max, and std dev + self._timers: Dict[str, List[float]] = {} + self._start_times: Dict[str, float] = {} + + def start(self, label: str) -> None: + """Start timing for the given label.""" + if label in self._start_times: + raise ValueError(f"Timer '{label}' is already running") + self._start_times[label] = time.perf_counter() + + def stop(self, label: str) -> float: + """Stop timing for the given label and return the elapsed time. + + Args: + label: The label to stop timing for + + Returns: + The elapsed time in seconds + + Raises: + ValueError: If the timer for the given label is not running + """ + if label not in self._start_times: + raise ValueError( + f"Timer '{label}' is not running. Running times: {self._start_times.keys()}" + ) + + elapsed = time.perf_counter() - self._start_times[label] + if label not in self._timers: + self._timers[label] = [] + self._timers[label].append(elapsed) + del self._start_times[label] + return elapsed + + @contextmanager + def time(self, label: str): + """Context manager for timing a block of code. + + Args: + label: The label to use for this timing + + Yields: + None + """ + self.start(label) + try: + yield + finally: + self.stop(label) + + def get_elapsed(self, label: str) -> List[float]: + """Get all elapsed time measurements for a specific label. + + Args: + label: The timing label to get elapsed times for + + Returns: + A list of all elapsed time measurements in seconds + + Raises: + KeyError: If the label doesn't exist + """ + if label not in self._timers: + raise KeyError(f"No timings recorded for '{label}'") + + return self._timers[label] + + def get_latest_elapsed(self, label: str) -> float: + """Get the most recent elapsed time measurement for a specific label. + + Args: + label: The timing label to get the latest elapsed time for + + Returns: + The most recent elapsed time measurement in seconds + + Raises: + KeyError: If the label doesn't exist + IndexError: If the label exists but has no measurements + """ + if label not in self._timers: + raise KeyError(f"No timings recorded for '{label}'") + + if not self._timers[label]: + raise IndexError(f"No measurements recorded for '{label}'") + + return self._timers[label][-1] + + def reduce(self, label: str, operation: str = "mean") -> float: + """Apply a reduction function to timing measurements for the specified label. + + Args: + label: The timing label to get reduction for + operation: The type of reduction to apply. Valid options are: + - "mean": Average time (default) + - "median": Median time + - "min": Minimum time + - "max": Maximum time + - "std": Standard deviation + - "sum": Total time + - "count": Number of measurements + + Returns: + A single float with the reduction result + + Raises: + KeyError: If the label doesn't exist + ValueError: If an invalid operation is provided + """ + if operation not in self._REDUCTION_FUNCTIONS: + valid_reductions = ", ".join(self._REDUCTION_FUNCTIONS.keys()) + raise ValueError( + f"Invalid operation '{operation}'. Valid options are: {valid_reductions}" + ) + + if label not in self._timers: + raise KeyError(f"No timings recorded for '{label}'") + + reduction_func = self._REDUCTION_FUNCTIONS[operation] + return reduction_func(self._timers[label]) + + def get_timing_metrics( + self, reduction_op: Union[str, Dict[str, str]] = "mean" + ) -> Dict[str, List[float]]: + """Get all timing measurements with optional reduction. + + Args: + reduction_op: Either a string specifying a reduction operation to apply to all labels, + or a dictionary mapping specific labels to reduction operations. + Valid reduction operations are: "mean", "median", "min", "max", "std", "sum", "count". + If a label is not in the dictionary, no reduction is applied and all measurements are returned. + + Returns: + A dictionary mapping labels to either: + - A list of all timing measurements for that label (if no reduction specified) + - A single float with the reduction result (if reduction specified) + + Raises: + ValueError: If an invalid reduction operation is provided + """ + if isinstance(reduction_op, str): + reduction_op = {label: reduction_op for label in self._timers} + + results = {} + for label, op in reduction_op.items(): + if label not in self._timers: + continue + + if op in self._REDUCTION_FUNCTIONS: + results[label] = self.reduce(label, op) + else: + results[label] = self._timers[label] + + # Add any labels not in the reduction_op dictionary + for label in self._timers: + if label not in reduction_op: + results[label] = self._timers[label] + + return results + + def reset(self, label: Optional[str] = None) -> None: + """Reset timings for the specified label or all labels. + + Args: + label: Optional label to reset. If None, resets all timers. + """ + if label: + if label in self._timers: + del self._timers[label] + if label in self._start_times: + del self._start_times[label] + else: + self._timers = {} + self._start_times = {} diff --git a/pyproject.toml b/pyproject.toml index 8793fb50ed..eddc57a15b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,31 +1,33 @@ [build-system] -requires = ["setuptools>=42", "wheel", "torch==2.6.0"] +requires = ["setuptools>=42", "wheel"] build-backend = "setuptools.build_meta" [project] -name = "nemo-_placeholder" +name = "nemo-reinforcer" version = "0.0.1" -description = "_Placeholder" +description = "Nemo-Reinforcer: A Scalable and Efficient Post-Training Library for Models Ranging from 1 GPU to 1000s, and from Tiny to >100B Parameters" readme = "README.md" requires-python = ">=3.10" license = {text = "Apache 2.0"} dependencies = [ - "torch==2.6.0" + "torch==2.6.0", + "colored==2.2.3", + "ray[default]==2.43.0", + "transformers", + "wandb", + "numpy", + "datasets", + "rich", + "math-verify", + "accelerate>=0.26", + "tensorboard", + "omegaconf", + "torchdata", + "vllm==0.8.0", ] [tool.setuptools] -packages = ["nemo__placeholder"] - -[dependency-groups] -dev = [ - "pre-commit==3.6.0", - "ruff==0.9.9", -] -test = [ - "pytest>=7.0.0", - "pytest-timeout", - "pytest-cov", -] +packages = ["nemo_reinforcer"] [project.optional-dependencies] build = [ @@ -42,6 +44,16 @@ docs = [ "myst_parser", # For our markdown docs "nvidia-sphinx-theme", # Our NVIDIA theme ] +dev = [ + "pre-commit==3.6.0", + "ruff==0.9.9", +] +test = [ + "pytest>=7.0.0", + "pytest-timeout", + "pytest-cov", +] + [tool.black] line-length = 120 diff --git a/ray.sub b/ray.sub new file mode 100644 index 0000000000..1d5e00db4b --- /dev/null +++ b/ray.sub @@ -0,0 +1,174 @@ +#!/bin/bash +#SBATCH --nodes=2 +#SBATCH --exclusive +#SBATCH --account=ACCOUNT +#SBATCH --job-name=JOB_NAME +#SBATCH --partition=PARTITION +#SBATCH --time=1:0:0 +#SBATCH --dependency=singleton +#SBATCH --gres=gpu:8 + + +set -eou pipefail + +######################################################## +# User defined variables +######################################################## +CONTAINER=$CONTAINER +MOUNTS=$MOUNTS +COMMAND=${COMMAND:-} # This is a script relative to the SLURM_SUBMIT_DIR. If left empty, it will leave the cluster idle after it's brought up. +######################################################## + +COMMON_SRUN_ARGS="" +COMMON_SRUN_ARGS+=" --export=ALL" +COMMON_SRUN_ARGS+=" --no-container-mount-home" +COMMON_SRUN_ARGS+=" --mpi=pmix" +COMMON_SRUN_ARGS+=" --container-mounts=$MOUNTS" +COMMON_SRUN_ARGS+=" --container-image=$CONTAINER" +COMMON_SRUN_ARGS+=" --container-workdir=$SLURM_SUBMIT_DIR" +# TODO: delete these (just for debugging) +COMMON_SRUN_ARGS+=" -p $SLURM_JOB_PARTITION" +COMMON_SRUN_ARGS+=" -A $SLURM_JOB_ACCOUNT" +COMMON_SRUN_ARGS+=" --gres=gpu:8" + +# Create logs directory +LOG_DIR="$SLURM_JOB_ID-logs" +mkdir -p $LOG_DIR + +# Number of GPUs per node +gpus_per_node=8 + +num_retries=5 + +# Getting the node names and IP addresses in the SLURM allocation +nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") +nodes_array=($nodes) +ip_addresses_array=() + +for node in $nodes; do + ip_address=$(host $node | awk '/has address/ { print $4 }') + # Add the IP address to the array + ip_addresses_array+=("$ip_address") +done + +head_node=${nodes_array[0]} +head_node_ip=${ip_addresses_array[0]} + +port=41993 +ip_head=$head_node_ip:$port + +# First we start the head of the ray cluster on one of the physical nodes +# In this case we are giving an entire physical node to the ray head node +# The ray head node is marked by including --head to the ray start command +head_cmd=$(cat <$SLURM_SUBMIT_DIR/${SLURM_JOB_ID}-attach.sh +# No args launches on the head node +WORKER_NUM=\${1:-0} +if [[ \$WORKER_NUM -eq 0 ]]; then + srun --gres=gpu:8 -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" --jobid $SLURM_JOB_ID --pty bash +else + nodes_array=($nodes) + srun --gres=gpu:8 -A $SLURM_JOB_ACCOUNT -p $SLURM_JOB_PARTITION --overlap --container-name=ray-worker-\$WORKER_NUM --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "\${nodes_array[\$WORKER_NUM]}" --jobid $SLURM_JOB_ID --pty bash +fi +EOF + chmod +x $SLURM_SUBMIT_DIR/${SLURM_JOB_ID}-attach.sh + echo " bash $SLURM_SUBMIT_DIR/${SLURM_JOB_ID}-attach.sh" + sleep infinity +fi diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/functional/check_metrics.py b/tests/functional/check_metrics.py new file mode 100644 index 0000000000..b1da6bc924 --- /dev/null +++ b/tests/functional/check_metrics.py @@ -0,0 +1,128 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import sys +import argparse +import statistics +from typing import Dict, Tuple, Any, Union, List +from rich.console import Console +from rich.table import Table + + +# Custom functions for working with dictionary values +def min(value): + """Return the minimum value in a dictionary.""" + return __builtins__.min(float(v) for v in value.values()) + + +def max(value): + """Return the maximum value in a dictionary.""" + return __builtins__.max(float(v) for v in value.values()) + + +def mean(value): + """Return the mean of values in a dictionary.""" + return statistics.mean(float(v) for v in value.values()) + + +def evaluate_check(data: Dict, check: str) -> Tuple[bool, str, object]: + """Evaluate a check against the data. + + Returns: + Tuple of (passed, message, value) + """ + # Create a local context with our custom functions and the data + local_context = {"data": data, "min": min, "max": max, "mean": mean} + + # Extract the value expression from the check + value_expr = check.split(">")[0].split("<")[0].split("==")[0].strip() + + try: + # Try to get the value first + value = eval(value_expr, {"__builtins__": __builtins__}, local_context) + + # Then evaluate the check + result = eval(check, {"__builtins__": __builtins__}, local_context) + if result: + return True, f"PASS: {check}", value + else: + return False, f"FAIL: {check} (condition evaluated to False)", value + except KeyError as e: + return False, f"FAIL: {check} (key not found: {e})", None + except IndexError as e: + return False, f"FAIL: {check} (index error: {e})", None + except Exception as e: + return False, f"FAIL: {check} (error: {e})", None + + +def main(): + parser = argparse.ArgumentParser(description="Check conditions against a JSON file") + parser.add_argument("json_file", help="Path to the JSON file") + parser.add_argument( + "checks", nargs="+", help="Conditions to check, will be eval()'d" + ) + + # Add helpful usage examples + parser.epilog = """ + Examples: + # Check if a specific metric is above a threshold + python check_metrics.py results.json "data['accuracy'] > 0.9" + + # Check multiple conditions + python check_metrics.py results.json "data['precision'] > 0.8" "data['recall'] > 0.7" + + # Use helper functions + python check_metrics.py results.json "min(data['class_f1']) > 0.6" + python check_metrics.py results.json "mean(data['accuracies']) > 0.85" + """ + parser.formatter_class = argparse.RawDescriptionHelpFormatter + args = parser.parse_args() + + # Load the JSON data - simplified + with open(args.json_file, "r") as f: + data = json.load(f) + + # Initialize rich console + console = Console() + + # Create a table + table = Table(title="Metric Checks") + table.add_column("Status", style="bold") + table.add_column("Check", style="dim") + table.add_column("Value", style="cyan") + table.add_column("Message", style="italic") + + # Evaluate all checks + success = True + for check in args.checks: + passed, message, value = evaluate_check(data, check) + + status = "[green]PASS[/green]" if passed else "[red]FAIL[/red]" + value_str = str(value) if value is not None else "N/A" + detail = "" if passed else message.split(": ", 1)[1] + + table.add_row(status, check, value_str, detail) + + if not passed: + success = False + + # Display the table + console.print(table) + + # Exit with appropriate status code + sys.exit(0 if success else 1) + + +if __name__ == "__main__": + main() diff --git a/tests/functional/grpo.sh b/tests/functional/grpo.sh new file mode 100755 index 0000000000..56ea95026d --- /dev/null +++ b/tests/functional/grpo.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) + +set -eou pipefail + +LOG_DIR=$SCRIPT_DIR/$(basename $0 .sh)-logs +JSON_METRICS=$LOG_DIR/$(basename $0 .sh).json +RUN_LOG=$LOG_DIR/$(basename $0 .sh).log +export RAY_DEDUP_LOGS=0 +export UV_CACHE_DIR=$PROJECT_ROOT/uv_cache + +mkdir -p $LOG_DIR + +cd $PROJECT_ROOT +uv run $PROJECT_ROOT/examples/run_grpo_math.py \ + cluster.gpus_per_node=2 \ + grpo.num_steps=10 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + checkpointing.enabled=false \ + $@ \ + 2>&1 | tee $RUN_LOG + +cd $SCRIPT_DIR +uv run json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run check_metrics.py $JSON_METRICS \ + 'data["timing/train/policy_refit"]["10"] < 3.0' \ + 'data["timing/train/total_step_time"]["10"] < 20.0' \ + 'data["timing/validation/generation"]["10"] < 3.0' \ + 'max(data["train/token_mult_prob_error"]) < 1.05' \ + 'data["validation/avg_length"]["10"] < 1024' \ + diff --git a/tests/functional/hello_world_fsdp_llama/__init__.py b/tests/functional/hello_world_fsdp_llama/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/tests/functional/hello_world_fsdp_llama/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/functional/hello_world_fsdp_llama/example.ipynb b/tests/functional/hello_world_fsdp_llama/example.ipynb new file mode 100644 index 0000000000..4529efa34f --- /dev/null +++ b/tests/functional/hello_world_fsdp_llama/example.ipynb @@ -0,0 +1,350 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n", + "Requirement already satisfied: ipywidgets in /usr/local/lib/python3.10/dist-packages (8.1.5)\n", + "Requirement already satisfied: comm>=0.1.3 in /usr/local/lib/python3.10/dist-packages (from ipywidgets) (0.2.2)\n", + "Requirement already satisfied: ipython>=6.1.0 in /usr/local/lib/python3.10/dist-packages (from ipywidgets) (8.28.0)\n", + "Requirement already satisfied: traitlets>=4.3.1 in /usr/local/lib/python3.10/dist-packages (from ipywidgets) (5.14.3)\n", + "Requirement already satisfied: widgetsnbextension~=4.0.12 in /usr/local/lib/python3.10/dist-packages (from ipywidgets) (4.0.13)\n", + "Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /usr/local/lib/python3.10/dist-packages (from ipywidgets) (3.0.13)\n", + "Requirement already satisfied: decorator in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets) (5.1.1)\n", + "Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets) (0.19.1)\n", + "Requirement already satisfied: matplotlib-inline in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets) (0.1.7)\n", + "Requirement already satisfied: prompt-toolkit<3.1.0,>=3.0.41 in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets) (3.0.48)\n", + "Requirement already satisfied: pygments>=2.4.0 in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets) (2.18.0)\n", + "Requirement already satisfied: stack-data in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets) (0.6.3)\n", + "Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets) (1.2.2)\n", + "Requirement already satisfied: typing-extensions>=4.6 in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets) (4.12.2)\n", + "Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets) (4.9.0)\n", + "Requirement already satisfied: parso<0.9.0,>=0.8.3 in /usr/local/lib/python3.10/dist-packages (from jedi>=0.16->ipython>=6.1.0->ipywidgets) (0.8.4)\n", + "Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.10/dist-packages (from pexpect>4.3->ipython>=6.1.0->ipywidgets) (0.7.0)\n", + "Requirement already satisfied: wcwidth in /usr/local/lib/python3.10/dist-packages (from prompt-toolkit<3.1.0,>=3.0.41->ipython>=6.1.0->ipywidgets) (0.2.13)\n", + "Requirement already satisfied: executing>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from stack-data->ipython>=6.1.0->ipywidgets) (2.1.0)\n", + "Requirement already satisfied: asttokens>=2.1.0 in /usr/local/lib/python3.10/dist-packages (from stack-data->ipython>=6.1.0->ipywidgets) (2.4.1)\n", + "Requirement already satisfied: pure-eval in /usr/local/lib/python3.10/dist-packages (from stack-data->ipython>=6.1.0->ipywidgets) (0.2.3)\n", + "Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.10/dist-packages (from asttokens>=2.1.0->stack-data->ipython>=6.1.0->ipywidgets) (1.16.0)\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", + "\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.0.1\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" + ] + } + ], + "source": [ + "!pip install -U ipywidgets" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import ray" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-03-02 23:13:34,875\tINFO worker.py:1654 -- Connecting to existing Ray cluster at address: 10.65.26.15:41993...\n", + "2025-03-02 23:13:34,883\tINFO worker.py:1832 -- Connected to Ray cluster. View the dashboard at \u001b[1m\u001b[32m127.0.0.1:8265 \u001b[39m\u001b[22m\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d16239ac6cf44259a74cf9c02d97c0a6", + "version_major": 2, + "version_minor": 0 + }, + "text/html": [ + "
\n", + "
\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "\n", + "
Python version:3.10.12
Ray version:2.43.0
Dashboard:http://127.0.0.1:8265
\n", + "\n", + "
\n", + "
\n" + ], + "text/plain": [ + "RayContext(dashboard_url='127.0.0.1:8265', python_version='3.10.12', ray_version='2.43.0', ray_commit='ecdcdc6a6e63dc4bcd6ea16aae256ce4d32a7e2c')" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ray.init()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "16.0" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "avail_workers = ray.cluster_resources()[\"worker_units\"]\n", + "avail_workers" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from examples.hello_world_fsdp_llama import train" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['AutoModelForCausalLM',\n", + " 'AutoTokenizer',\n", + " 'LlamaDecoderLayer',\n", + " 'ModelTrainer',\n", + " 'NASS',\n", + " 'NodeInfo',\n", + " 'RayClusterCoordinator',\n", + " 'Worker',\n", + " 'WorkerGroupResources',\n", + " '__builtins__',\n", + " '__cached__',\n", + " '__doc__',\n", + " '__file__',\n", + " '__loader__',\n", + " '__name__',\n", + " '__package__',\n", + " '__spec__',\n", + " 'dataclass',\n", + " 'dist',\n", + " 'fully_shard',\n", + " 'init_device_mesh',\n", + " 'json',\n", + " 'os',\n", + " 'random',\n", + " 'ray',\n", + " 'register_fsdp_forward_method',\n", + " 'time',\n", + " 'torch']" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dir(train)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Worker node info: worker_node_info=[NodeInfo(node_id='9faf44e0e79cf43aef2eec7b5003ba0ec197a5bf7c07c7e4706bde7f', node_rank=0, node_ip='10.65.26.23'), NodeInfo(node_id='9faf44e0e79cf43aef2eec7b5003ba0ec197a5bf7c07c7e4706bde7f', node_rank=0, node_ip='10.65.26.23'), NodeInfo(node_id='9faf44e0e79cf43aef2eec7b5003ba0ec197a5bf7c07c7e4706bde7f', node_rank=0, node_ip='10.65.26.23'), NodeInfo(node_id='9faf44e0e79cf43aef2eec7b5003ba0ec197a5bf7c07c7e4706bde7f', node_rank=0, node_ip='10.65.26.23'), NodeInfo(node_id='9faf44e0e79cf43aef2eec7b5003ba0ec197a5bf7c07c7e4706bde7f', node_rank=0, node_ip='10.65.26.23'), NodeInfo(node_id='9faf44e0e79cf43aef2eec7b5003ba0ec197a5bf7c07c7e4706bde7f', node_rank=0, node_ip='10.65.26.23'), NodeInfo(node_id='9faf44e0e79cf43aef2eec7b5003ba0ec197a5bf7c07c7e4706bde7f', node_rank=0, node_ip='10.65.26.23'), NodeInfo(node_id='9faf44e0e79cf43aef2eec7b5003ba0ec197a5bf7c07c7e4706bde7f', node_rank=0, node_ip='10.65.26.23')]\n", + "Num physical nodes: self.num_physical_nodes=1\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[36m(ModelTrainer pid=52042, ip=10.65.26.23)\u001b[0m DEBUG: self.logical_gpu_id=2 (process_id, world_size, physical_node_id, physical_node_ip, master_addr, num_workers_per_node)=(1, 8, 0, '10.65.26.23', '10.65.26.23', 8)\n", + "\u001b[36m(ModelTrainer pid=52043, ip=10.65.26.23)\u001b[0m DEBUG: self.logical_gpu_id=4 (process_id, world_size, physical_node_id, physical_node_ip, master_addr, num_workers_per_node)=(6, 8, 0, '10.65.26.23', '10.65.26.23', 8)\n", + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m [Rank 0] Loading model meta-llama/Llama-3.2-1B on CPU...\n", + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m [Rank 0] Starting synthetic training...\n" + ] + } + ], + "source": [ + "worker_resources = train.WorkerGroupResources(\n", + " num_nodes=1, num_gpus_per_node=8, num_cpus_per_worker=16\n", + ")\n", + "coordinator = train.RayClusterCoordinator(train.ModelTrainer, worker_resources)\n", + "coordinator.initialize_workers()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "16.0" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Workers after consuming 1 node should be 16 -> 8\n", + "ray.cluster_resources()[\"worker_units\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m /usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py:818: UserWarning: cuDNN SDPA backward got grad_output.strides() != output.strides(), attempting to materialize a grad_output with matching strides... (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/native/cudnn/MHA.cpp:670.)\n", + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m Epoch: 0, Step: 0, Loss: 30.59472656\n", + "\u001b[36m(ModelTrainer pid=52041, ip=10.65.26.23)\u001b[0m [Rank 2] Loading model meta-llama/Llama-3.2-1B on CPU...\u001b[32m [repeated 7x across cluster]\u001b[0m\n", + "\u001b[36m(ModelTrainer pid=52041, ip=10.65.26.23)\u001b[0m [Rank 2] Starting synthetic training...\u001b[32m [repeated 7x across cluster]\u001b[0m\n", + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m Epoch: 0, Step: 1, Loss: 0.55084229\n", + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m Epoch: 0, Step: 2, Loss: 3.00805664\n", + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m Epoch: 0, Step: 3, Loss: 2.95410156\n", + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m Epoch: 0, Step: 4, Loss: 2.15722656\n", + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m Epoch: 0, Step: 5, Loss: 1.41015625\n", + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m Epoch: 0, Step: 6, Loss: 0.75201416\n", + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m Epoch: 0, Step: 7, Loss: 0.16345596\n", + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m Epoch: 0, Step: 8, Loss: 0.07264709\n", + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m Epoch: 0, Step: 9, Loss: 0.07691956\n", + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m Epoch: 1, Step: 0, Loss: 0.03055668\n", + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m Epoch: 1, Step: 1, Loss: 0.00250489\n", + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m Epoch: 1, Step: 2, Loss: 0.00558114\n", + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m Epoch: 1, Step: 3, Loss: 0.03055668\n", + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m Epoch: 1, Step: 4, Loss: 0.05177402\n", + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m Epoch: 1, Step: 5, Loss: 0.06747818\n", + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m Epoch: 1, Step: 6, Loss: 0.06152725\n", + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m Epoch: 1, Step: 7, Loss: 0.04326725\n", + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m Epoch: 1, Step: 8, Loss: 0.02659702\n", + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m Epoch: 1, Step: 9, Loss: 0.01250291\n", + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m [Rank 0] Testing loss is close to expected loss: 0.012502908706665039\n", + "\u001b[36m(ModelTrainer pid=52039, ip=10.65.26.23)\u001b[0m Yay! Loss was close :)\n" + ] + }, + { + "data": { + "text/plain": [ + "[None, None, None, None, None, None, None, None]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "coordinator.run(hf_model_name=\"meta-llama/Llama-3.2-1B\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "ray.shutdown()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/functional/hello_world_fsdp_llama/main.py b/tests/functional/hello_world_fsdp_llama/main.py new file mode 100644 index 0000000000..df5b9bf1c9 --- /dev/null +++ b/tests/functional/hello_world_fsdp_llama/main.py @@ -0,0 +1,77 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from ray.job_submission import JobSubmissionClient, JobStatus + + +def main() -> None: + client = JobSubmissionClient("http://127.0.0.1:8265") + print("Connected to head!", flush=True) + + # HACK: for now just + def num_ray_nodes_available() -> int: + import ray + + ray.init() + num_gpus_per_node = 8 # hard coded + num_nodes_avail = ( + int(ray.cluster_resources()["worker_units"]) // num_gpus_per_node + ) + ray.shutdown() + return num_nodes_avail + + job_id = client.submit_job( + entrypoint="RAY_DEDUP_LOGS=0 python3 tests/functional/hello_world_fsdp_llama/train.py", + runtime_env={ + # TODO: disabling for now since it causes issues if my hf_home is in my working dir and ray + # wants to upload it to all workers. you get an error like this: + # 2025-03-02 11:16:48,187 WARNING packaging.py:417 -- File /workspace/hf_home/hub/models--meta-llama--Meta-Llama-3-8b/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920/model-00001-of-00004.safetensors is very large (4746.15MiB). Consider adding this file to the 'excludes' list to skip uploading it: `ray.init(..., runtime_env={'excludes': ['/workspace/hf_home/hub/models--meta-llama--Meta-Llama-3-8b/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920/model-00001-of-00004.safetensors']})` + # "working_dir": "./", + "driver_args": { + # Scope each "workergroup" + "trainer": { + "resources": { + # TODO: read this in from cli args eventually, but for now just use all available + "num_nodes": num_ray_nodes_available(), + "num_gpus_per_node": 8, + "num_cpus_per_worker": 16, + }, + "hf_model_name": "meta-llama/Llama-3.2-1B", + } + }, + "env_vars": { + # TODO: hardcoded, parametrize + "HF_HOME": "/workspace/hf_home", + }, + }, + ) + + print(f"Launched job: {job_id}", flush=True) + prev_logs = "" + while True: + status = client.get_job_status(job_id) + if status in {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED}: + if status in {JobStatus.STOPPED, JobStatus.FAILED}: + logs = client.get_job_logs(job_id) + print(logs, flush=True) + break + time.sleep(5) + if status == JobStatus.RUNNING: + logs = client.get_job_logs(job_id) + print(logs[len(prev_logs) :], flush=True) + prev_logs = logs + + +if __name__ == "__main__": + main() diff --git a/tests/functional/hello_world_fsdp_llama/train.py b/tests/functional/hello_world_fsdp_llama/train.py new file mode 100644 index 0000000000..61c823f37f --- /dev/null +++ b/tests/functional/hello_world_fsdp_llama/train.py @@ -0,0 +1,329 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import ray +import os +import random +import torch +import torch.distributed as dist +from torch.distributed.device_mesh import init_device_mesh +from transformers import AutoModelForCausalLM, AutoTokenizer +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy as NASS +from dataclasses import dataclass +from torch.distributed.fsdp import FullyShardedDataParallel, MixedPrecision +from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy +from torch.distributed.fsdp import register_fsdp_forward_method + + +@dataclass +class WorkerGroupResources: + num_nodes: int + num_cpus_per_worker: int = 16 # 128 hyperthread / 8 gpu = 16 cpu/gpu + num_gpus_per_node: int = 8 # will always be true on slurm + + +@dataclass +class NodeInfo: + node_id: str + node_rank: int + node_ip: str + + +# Define the coordinator and worker +class RayClusterCoordinator: + def __init__( + self, worker_cls: type["ModelTrainer"], worker_resources: WorkerGroupResources + ) -> None: + self.worker_cls = worker_cls + self.worker_resources = worker_resources + self.num_workers = ( + worker_resources.num_nodes * worker_resources.num_gpus_per_node + ) + self.num_workers_per_node = worker_resources.num_gpus_per_node + + ray_available_workers = int(ray.cluster_resources()["worker_units"]) + assert self.num_workers // self.num_workers_per_node <= ray_available_workers, ( + f"Only {ray_available_workers} workers available, which is not enough to schedule {self.num_workers} workers with {self.num_workers_per_node} workers per node" + ) + + self.workers_initialized = False + + worker_node_info, self.num_physical_nodes = self._get_schedulable_worker_info() + print(f"Worker node info: {worker_node_info=}") + print(f"Num physical nodes: {self.num_physical_nodes=}") + # Assume there's one worker per GPU + self.workers = [ + worker_cls.options( + num_gpus=1, + num_cpus=worker_resources.num_cpus_per_worker, + resources={"worker_units": 1}, + # Use NodeAffinitySchedulingStrategy to ensure each worker is placed on a specific node + # node_id: Unique ID of the target node for this worker + # soft=False: Strictly enforce placement on the specified node (no fallback to other nodes) + scheduling_strategy=NASS( + node_id=worker_node_info[i].node_id, soft=False + ), + ).remote( + i, + self.num_workers, + worker_node_info[i].node_rank, + worker_node_info[i].node_ip, # TODO: probably can delete this + worker_node_info[ + 0 + ].node_ip, # Arbitrarily make the first worker's hots the master + self.num_workers_per_node, + ) + for i in range(self.num_workers) + ] + + def _get_schedulable_worker_info(self) -> tuple[list[NodeInfo], int]: + """Collects information about available worker nodes in the Ray cluster and prepares + scheduling information for worker actors. + + This method: + 1. Identifies all alive worker nodes with 'worker_units' resources + 2. Sorts them by NodeID for consistent allocation + 3. Calculates how many physical nodes are needed based on workers per node + 4. Verifies that enough nodes are available + 5. Creates a list of NodeInfo objects for each worker + + Returns: + tuple: (worker_node_info, num_nodes_required) + - worker_node_info: List of NodeInfo objects containing node_id, node_rank, and node_ip for each worker + - num_nodes_required: Number of physical nodes needed for all workers + + Raises: + AssertionError: If there aren't enough nodes available to schedule all workers + """ + # Get list of alive worker nodes sorted by NodeID for deterministic allocation + worker_node_info = [] + worker_nodes = sorted( + [ + node + for node in ray.nodes() + if (node["Alive"] and "worker_units" in node["Resources"]) + ], + key=lambda x: x["NodeID"], + ) + + # Calculate required nodes and verify availability + num_nodes_required = self.num_workers // self.num_workers_per_node + num_nodes_available = len(worker_nodes) + assert num_nodes_required <= num_nodes_available + + # Create worker info entries - one per GPU across all needed nodes + worker_nodes = worker_nodes[:num_nodes_required] + for worker_node_id, worker_node in enumerate(worker_nodes): + for _ in range(self.num_workers_per_node): + worker_node_info.append( + NodeInfo( + worker_node["NodeID"], worker_node_id, worker_node["NodeName"] + ) + ) + + return worker_node_info, num_nodes_required + + def initialize_workers(self, **kwargs): + self.worker_init_kwargs = kwargs + ray.get([w.initialize.remote(**kwargs) for i, w in enumerate(self.workers)]) + self.workers_initialized = True + + def run(self, *args, **kwargs): + if not self.workers_initialized: + raise ValueError("""Cannot run workers without initializing them first. + Please call the initialize_workers method of your cluster coordinator first.""") + + worker_results = ray.get([w.run.remote(*args, **kwargs) for w in self.workers]) + return worker_results + + +class Worker: + def __init__( + self, + process_id, + world_size, + physical_node_id, + physical_node_ip, + master_addr: str, + num_workers_per_node: int, + ): + self.process_id = process_id + self.world_size = world_size + self.physical_node_id = physical_node_id + self.host_ip = physical_node_ip + self.master_addr = master_addr + self.logical_gpu_id = int(os.environ["CUDA_VISIBLE_DEVICES"]) + print( + f"DEBUG: {self.logical_gpu_id=} {(process_id, world_size, physical_node_id, physical_node_ip, master_addr, num_workers_per_node)=}" + ) + self.num_workers_per_node = num_workers_per_node + + def get_process_id(self): + return self.process_id + + def get_host_ip(self): + return self.host_ip + + def get_logical_gpu_id(self): + return self.logical_gpu_id + + def get_physical_node_id(self): + return self.physical_node_id + + def initialize(self): + # Set distributed training environment variables + os.environ["RANK"] = str(self.process_id) + os.environ["WORLD_SIZE"] = str(self.world_size) + os.environ["LOCAL_RANK"] = str(self.logical_gpu_id) + os.environ["LOCAL_WORLD_SIZE"] = str(self.num_workers_per_node) + os.environ["MASTER_ADDR"] = self.master_addr + os.environ["MASTER_PORT"] = "29500" + + dist.init_process_group("nccl") + + def run(self, *args, **kwargs): + raise NotImplementedError + + +@ray.remote +class ModelTrainer(Worker): + def __init__( + self, + process_id, + world_size, + physical_node_id, + physical_node_ip, + master_addr, + num_workers_per_node, + ): + super().__init__( + process_id, + world_size, + physical_node_id, + physical_node_ip, + master_addr, + num_workers_per_node, + ) + + def run(self, hf_model_name): + rank = dist.get_rank() + world_size = dist.get_world_size() + ####local_device = torch.device(f"cuda:{rank}") + ####torch.cuda.set_device(local_device) + + print(f"[Rank {rank}] Loading model {hf_model_name} on CPU...") + model = AutoModelForCausalLM.from_pretrained( + hf_model_name, + device_map="cpu", + torch_dtype=torch.bfloat16, + ) + + tokenizer = AutoTokenizer.from_pretrained(hf_model_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # TODO: could oom? + # ------------------------------------------------ + # 3) Move to GPU + Composable FSDP + # (Initialize device mesh, shard submodules, then shard entire model) + # ------------------------------------------------ + model.cuda() + + # Create a device mesh with 'world_size' GPUs in a 1D arrangement. + mesh = init_device_mesh("cuda", (world_size,)) + + param_dtype = torch.bfloat16 + reduce_dtype = torch.float32 + buffer_dtype = torch.float32 + + mp = MixedPrecision( + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + buffer_dtype=buffer_dtype, + ) + + model = FullyShardedDataParallel( + model, + device_mesh=mesh, + auto_wrap_policy=size_based_auto_wrap_policy, + mixed_precision=mp, + ) + + # Optionally register "generate" as the forward method so FSDP can handle it properly. + register_fsdp_forward_method(model, "generate") + + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) + model.train() + + num_epochs = 2 + num_batches = 10 + batch_size = 2 + seq_length = 16 + vocab_size = tokenizer.vocab_size or 32000 + + print(f"[Rank {rank}] Starting synthetic training...") + + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + random.seed(42) + for epoch in range(num_epochs): + for step in range(num_batches): + input_ids = torch.ones( + (batch_size, seq_length), device="cuda", dtype=torch.long + ) * (vocab_size - 1) + attention_mask = torch.ones_like(input_ids) + labels = input_ids.clone() + + optimizer.zero_grad() + outputs = model( + input_ids=input_ids, attention_mask=attention_mask, labels=labels + ) + loss = torch.square(outputs.logits.view(-1)[0]) + loss.backward() + optimizer.step() + + if rank == 0: + print( + f"Epoch: {epoch}, Step: {step}, Loss: {loss.item():.8f}", + flush=True, + ) + + if rank == 0: + expected_loss = 0.012502908706665039 + print( + f"[Rank {rank}] Testing loss is close to expected loss: {expected_loss}" + ) + torch.testing.assert_close(loss.item(), expected_loss) + print("Yay! Loss was close :)") + + +if __name__ == "__main__": + ray.init(address="auto", logging_level=0) + print(json.dumps(json.loads(os.environ["RAY_JOB_CONFIG_JSON_ENV_VAR"]), indent=4)) + driver_args = json.loads(os.environ["RAY_JOB_CONFIG_JSON_ENV_VAR"])["runtime_env"][ + "driver_args" + ] + + # TODO: very simple, need to think thru CLI + trainer_args = driver_args["trainer"] + trainer_resources = trainer_args.pop("resources") + worker_resources = WorkerGroupResources(**trainer_resources) + + coordinator = RayClusterCoordinator(ModelTrainer, worker_resources) + coordinator.initialize_workers() + print("Initialized workers") + # Get the job configuration set during launch. + # This is automatically set by Ray + coordinator.run(**trainer_args) + print("Finished") diff --git a/tests/functional/json_dump_tb_logs.py b/tests/functional/json_dump_tb_logs.py new file mode 100644 index 0000000000..58554eadf9 --- /dev/null +++ b/tests/functional/json_dump_tb_logs.py @@ -0,0 +1,232 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import glob +import json +import os +import datetime +import statistics +from collections import defaultdict +from tensorboard.backend.event_processing import event_accumulator +import sys +from rich.console import Console +from rich.table import Table +from rich.box import SIMPLE +from rich.panel import Panel +from rich.text import Text + +# By default TB tries to be smart about what to load in memory to avoid OOM +# Since we expect every step to be there when we do our comparisons, we explicitly +# set the size guidance to 0 so that we load everything. +SIZE_GUIDANCE = { + event_accumulator.TENSORS: 0, + event_accumulator.SCALARS: 0, +} + +console = Console() +error_console = Console(stderr=True) + + +def merge_tb_logs_to_json(log_dir, output_path, allow_conflicts=False): + """Merge multiple TensorBoard event files into a single JSON file. + + Arguments: + log_dir: Path to directory containing TensorBoard event files (searched recursively) + output_path: Path to save the output JSON file + allow_conflicts: If True, allow multiple values for the same step (last one wins) + + Raises: + ValueError: If conflicting values are found for the same step and allow_conflicts is False + """ + # Find all event files recursively + files = glob.glob(f"{log_dir}/**/events*tfevents*", recursive=True) + files.sort(key=lambda x: os.path.getmtime(x)) + + if not files: + raise FileNotFoundError(f"No TensorBoard event files found under {log_dir}") + + # Display found files in a table + file_table = Table(title="Found TensorBoard Event Files", show_header=True) + file_table.add_column("Index", style="dim") + file_table.add_column("Path", style="green") + file_table.add_column("Last Modified", style="cyan") + + # Keep a map of file index to path for conflict reporting + file_index_map = {} + for i, f in enumerate(files, 1): + file_index_map[f] = i + modified_time = os.path.getmtime(f) + formatted_time = datetime.datetime.fromtimestamp(modified_time).strftime( + "%Y/%m/%d %H:%M:%S" + ) + file_table.add_row(str(i), f, formatted_time) + + console.print(file_table) + + # {metric_name: {step: (value, source_file)}} + merged_data = defaultdict(dict) + + console.print("[bold green]Processing event files...[/bold green]") + + for event_file in files: + console.print(f"Processing {os.path.basename(event_file)}") + + ea = event_accumulator.EventAccumulator(event_file, size_guidance=SIZE_GUIDANCE) + ea.Reload() + + for metric_name in ea.scalars.Keys(): + for scalar in ea.Scalars(metric_name): + step, value = scalar.step, scalar.value + + # Check for conflicts - immediately raise error if not allowing conflicts + if step in merged_data[metric_name]: + existing_value, existing_file = merged_data[metric_name][step] + + # Only consider it a conflict if the values are different + if existing_value != value: + if not allow_conflicts: + # Immediate error if not allowing conflicts + raise ValueError( + f"Conflict detected for metric '{metric_name}' at step {step}:\n" + f" File #{file_index_map[existing_file]}: {existing_file} has value {existing_value}\n" + f" File #{file_index_map[event_file]}: {event_file} has value {value}\n" + f"Use --allow-conflicts to force merging with latest value." + ) + + # Add or override the value + merged_data[metric_name][step] = (value, event_file) + + # Convert defaultdict to regular dict and sort the steps + output_data = {} + for metric_name in sorted(merged_data.keys()): + output_data[metric_name] = { + str(step): merged_data[metric_name][step][ + 0 + ] # Just keep the value, not the source file + for step in sorted(merged_data[metric_name].keys()) + } + + # Create summary table header + console.print("\n[bold cyan]Metrics Summary:[/bold cyan]") + + # Display summary for each metric using tables for better alignment + for metric, steps_data in sorted(output_data.items()): + if not steps_data: + console.print(f"[bold magenta]{metric}[/bold magenta] - No data") + continue + + # Get steps and values as sorted lists + steps = sorted([int(step) for step in steps_data.keys()]) + values = [steps_data[str(step)] for step in steps] + + # Calculate statistics + min_val = min(values) + max_val = max(values) + avg_val = statistics.mean(values) + + # Create metric header with better highlighting + metric_text = Text() + metric_text.append(f"πŸ”Ή ", style="bold blue") + metric_text.append(f"{metric}", style="bold magenta") + metric_text.append(f" - {len(steps)} steps", style="green") + console.print(metric_text) + + # Create statistics panel + stats_text = Text() + stats_text.append("Min: ", style="dim") + stats_text.append(f"{min_val:.6g}", style="red") + stats_text.append(" Max: ", style="dim") + stats_text.append(f"{max_val:.6g}", style="green") + stats_text.append(" Avg: ", style="dim") + stats_text.append(f"{avg_val:.6g}", style="yellow") + console.print(stats_text) + + # Create value table + value_table = Table(show_header=True, header_style="bold", box=SIMPLE) + + # Determine what to display + if len(steps) <= 6: + # Show all steps + display_indices = list(range(len(steps))) + for i in display_indices: + value_table.add_column(f"Step {steps[i]}") + else: + # Show first 3 and last 3 + display_indices = [0, 1, 2, len(steps) - 3, len(steps) - 2, len(steps) - 1] + value_table.add_column(f"Step {steps[0]}") + value_table.add_column(f"Step {steps[1]}") + value_table.add_column(f"Step {steps[2]}") + value_table.add_column("...") + value_table.add_column(f"Step {steps[-3]}") + value_table.add_column(f"Step {steps[-2]}") + value_table.add_column(f"Step {steps[-1]}") + + # Add value row + if len(steps) <= 6: + value_table.add_row(*[f"{values[i]:.6g}" for i in display_indices]) + else: + value_table.add_row( + f"{values[0]:.6g}", + f"{values[1]:.6g}", + f"{values[2]:.6g}", + "...", + f"{values[-3]:.6g}", + f"{values[-2]:.6g}", + f"{values[-1]:.6g}", + ) + + console.print(value_table) + console.print() + + # Write the merged data to JSON file + if output_path: + with open(output_path, "w") as f: + json.dump(output_data, f, indent=2) + + console.print( + f"[bold green]βœ“ Merged data written to {output_path}[/bold green]" + ) + else: + console.print( + f"[bold red]βœ“ To save the merged data, use --output_path[/bold red]" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Merge TensorBoard event files into a single JSON file" + ) + parser.add_argument( + "log_dir", type=str, help="Directory containing TensorBoard event files" + ) + parser.add_argument( + "--output_path", + required=False, + default=None, + type=str, + help="Path to save the output JSON file", + ) + parser.add_argument( + "--allow-conflicts", + action="store_true", + help="Allow conflicting values for the same step (last one wins)", + ) + + args = parser.parse_args() + + try: + merge_tb_logs_to_json(args.log_dir, args.output_path, args.allow_conflicts) + except Exception as e: + error_console.print(f"[bold red]Error: {e}[/bold red]") + sys.exit(1) diff --git a/tests/functional/sft.sh b/tests/functional/sft.sh new file mode 100755 index 0000000000..94d565d700 --- /dev/null +++ b/tests/functional/sft.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) + +set -eou pipefail + +LOG_DIR=$SCRIPT_DIR/$(basename $0 .sh)-logs +JSON_METRICS=$LOG_DIR/$(basename $0 .sh).json +RUN_LOG=$LOG_DIR/$(basename $0 .sh).log +export RAY_DEDUP_LOGS=0 +export UV_CACHE_DIR=$PROJECT_ROOT/uv_cache + +mkdir -p $LOG_DIR + +cd $PROJECT_ROOT +uv run $PROJECT_ROOT/examples/run_sft.py \ + cluster.gpus_per_node=2 \ + sft.num_steps=10 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + $@ \ + 2>&1 | tee $RUN_LOG + +cd $SCRIPT_DIR +uv run json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run check_metrics.py $JSON_METRICS \ + 'data["train/loss"]["9"] < 600' \ + 'data["timing/train/sft_train_step"]["9"] < 0.25' + diff --git a/tests/run_unit.sh b/tests/run_unit.sh new file mode 100644 index 0000000000..efc614868b --- /dev/null +++ b/tests/run_unit.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +PROJECT_ROOT=$(realpath ${SCRIPT_DIR}/..) + +set -eou pipefail + +cd $SCRIPT_DIR +GPUS_PER_NODE=$(nvidia-smi -L | grep -c '^GPU') + +if ! command -v pytest >/dev/null 2>&1; then + echo "[ERROR] pytest not found. Make sure it's installed." + exit 1 +elif ! command -v ray >/dev/null 2>&1; then + echo "[ERROR] ray binary not installed, which suggests this package is not installed." + exit 1 +elif [[ $GPUS_PER_NODE -lt 2 ]]; then + echo "[ERROR]: Unit tests need at least 2 GPUs, but found $GPUS_PER_NODE" + exit 1 +fi + +export CUDA_DEVICE_ORDER=PCI_BUS_ID +nvidia-smi +export CUDA_VISIBLE_DEVICES=0,1 +export PYTHONPATH=$(realpath ${SCRIPT_DIR}/..):${PYTHONPATH:-} +export RAY_DEDUP_LOGS=0 + +# Run unit tests +echo "Running unit tests..." +if ! pytest unit/ -s -rA "$@"; then + echo "[ERROR]: Unit tests failed." + exit 1 +fi +echo "Unit tests passed!" diff --git a/tests/run_unit_in_docker.sh b/tests/run_unit_in_docker.sh new file mode 100644 index 0000000000..1190e3d8ec --- /dev/null +++ b/tests/run_unit_in_docker.sh @@ -0,0 +1,39 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +set -eou pipefail + +# Ensure Docker is installed +if ! command -v docker &> /dev/null; then + echo "Error: Docker is not installed or not in PATH." + exit 1 +fi + +# CONTAINER is expected to be set as an environment variable +if [[ -z "${CONTAINER:-}" ]]; then + echo "Error: CONTAINER environment variable is not set." + echo "Usage: CONTAINER= $0 [optional pytest-args...]" + exit 1 +fi + +CONTAINER=${CONTAINER} + +export HF_HOME=${HF_HOME:-$(realpath $SCRIPT_DIR/../hf_home)} +mkdir -p $HF_HOME + +# Check if running in GitLab CI +INTERACTIVE_FLAG="" +if [[ "${CI:-false}" != "true" ]]; then + # Setting this interactively lets us issue a keyboard interrupt. + INTERACTIVE_FLAG="-it" +fi + +# Note: we run as root because: +# 1. running as ray prevents us from writing into the current working directory +# 2. running as ourselves (-u $(id -u):$(id -g)) causes torch compile to fail +# +# The workaround is we launch the job but set umask 000 so all files created as root are rwxrwxrwx. +# We have found that 111 does not always work and can leave the filesystem permissions in a bad state. + +# Run the script inside the Docker container with GPU support +docker run -u root $INTERACTIVE_FLAG --ulimit memlock=-1 --ulimit stack=67108864 --rm --gpus '"device=0,1"' -v "$(realpath $SCRIPT_DIR/..):/workspace" -v $HF_HOME:/hf_home -e HF_TOKEN -e HF_HOME=/hf_home -e HOME=/tmp/ -w /workspace/tests "$CONTAINER" -- bash -x -c "umask 000 && uv run --extra test bash -x ./run_unit.sh $@" diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py new file mode 100644 index 0000000000..c6491e02d6 --- /dev/null +++ b/tests/unit/algorithms/test_grpo.py @@ -0,0 +1,192 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch +import ray +from typing import Dict, List, Tuple + +from nemo_reinforcer.algorithms.grpo import calculate_rewards +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.data.interfaces import DatumSpec, LLMMessageLogType +from nemo_reinforcer.environments.interfaces import EnvironmentInterface + + +@ray.remote(num_cpus=0) +class MockEnvironment(EnvironmentInterface): + def __init__(self, rewards: List[float]): + self.rewards = rewards + self._calls = 0 + + def step( + self, messages: List[LLMMessageLogType], env_info: List[dict] + ) -> Tuple[None, None, List[float], None]: + self._calls += 1 + return None, None, self.rewards, None + + def get_calls(self): + return self._calls + + def reset_calls(self): + self._calls = 0 + return True + + def global_post_process_and_metrics( + self, batch: BatchedDataDict + ) -> Tuple[BatchedDataDict, dict]: + return batch, {} + + +def create_mock_batch( + num_samples: int, + task_names: List[str], + message_logs: List[LLMMessageLogType], + extra_env_info: List[dict] = None, +) -> BatchedDataDict[DatumSpec]: + """Helper function to create a mock batch for testing.""" + if extra_env_info is None: + extra_env_info = [{} for _ in range(num_samples)] + + return BatchedDataDict[DatumSpec]( + { + "task_name": task_names, + "message_log": message_logs, + "extra_env_info": extra_env_info, + "loss_multiplier": torch.ones(num_samples), + } + ) + + +@pytest.fixture(scope="module") +def ray_init(): + """Initialize Ray for testing.""" + if not ray.is_initialized(): + ray.init(ignore_reinit_error=True) + yield + if ray.is_initialized(): + ray.shutdown() + + +@pytest.fixture(scope="module") +def mock_env(ray_init): + """Create a mock environment for single task tests.""" + env = MockEnvironment.remote(rewards=[1.0, 2.0]) + yield env + ray.kill(env) + + +@pytest.fixture(scope="module") +def mock_envs(ray_init): + """Create mock environments for multiple task tests.""" + math_env = MockEnvironment.remote(rewards=[1.0, 2.0]) + code_env = MockEnvironment.remote(rewards=[3.0, 4.0]) + yield {"math": math_env, "code": code_env} + ray.kill(math_env) + ray.kill(code_env) + + +@pytest.fixture(autouse=True) +def reset_env_calls(mock_env, mock_envs): + """Reset call counters before each test.""" + ray.get(mock_env.reset_calls.remote()) + ray.get(mock_envs["math"].reset_calls.remote()) + ray.get(mock_envs["code"].reset_calls.remote()) + yield + + +def test_calculate_rewards_single_task(mock_env): + """Test reward calculation with a single task type.""" + task_to_env = {"math": mock_env} + + # Create test data + task_names = ["math", "math"] + message_logs = [ + [{"role": "user", "content": "1+1"}, {"role": "assistant", "content": "2"}], + [{"role": "user", "content": "2+2"}, {"role": "assistant", "content": "4"}], + ] + batch = create_mock_batch(2, task_names, message_logs) + + # Calculate rewards + rewards, to_env = calculate_rewards(batch, task_to_env) + + # Verify results + assert torch.allclose(rewards, torch.tensor([1.0, 2.0])) + assert len(to_env) == 2 + assert ( + ray.get(mock_env.get_calls.remote()) == 1 + ) # Should only call once for all samples of same task + + +def test_calculate_rewards_multiple_tasks(mock_envs): + """Test reward calculation with multiple task types.""" + # Create test data + task_names = ["math", "math", "code", "code"] + message_logs = [ + [{"role": "user", "content": "1+1"}, {"role": "assistant", "content": "2"}], + [{"role": "user", "content": "2+2"}, {"role": "assistant", "content": "4"}], + [ + {"role": "user", "content": "print('hello')"}, + {"role": "assistant", "content": "hello"}, + ], + [ + {"role": "user", "content": "print('world')"}, + {"role": "assistant", "content": "world"}, + ], + ] + batch = create_mock_batch(4, task_names, message_logs) + + # Calculate rewards + rewards, to_env = calculate_rewards(batch, mock_envs) + + # Verify results + assert torch.allclose(rewards, torch.tensor([1.0, 2.0, 3.0, 4.0])) + assert len(to_env) == 4 + assert ( + ray.get(mock_envs["math"].get_calls.remote()) == 1 + ) # One call for all math samples + assert ( + ray.get(mock_envs["code"].get_calls.remote()) == 1 + ) # One call for all code samples + + +def test_calculate_rewards_empty_batch(mock_env): + """Test reward calculation with an empty batch.""" + task_to_env = {"math": mock_env} + + # Create empty test data + batch = create_mock_batch(0, [], []) + + # Calculate rewards + rewards, to_env = calculate_rewards(batch, task_to_env) + + # Verify results + assert len(rewards) == 0 + assert len(to_env) == 0 + assert ( + ray.get(mock_env.get_calls.remote()) == 0 + ) # Should not call environment for empty batch + + +def test_calculate_rewards_missing_environment(): + """Test reward calculation with a missing environment.""" + # Create test data with unknown task + task_names = ["unknown_task"] + message_logs = [[{"role": "user", "content": "test"}]] + batch = create_mock_batch(1, task_names, message_logs) + + # Try to calculate rewards with missing environment + task_to_env = {} # Empty dict means no environments available + with pytest.raises( + ValueError, match="No environment found for task type: unknown_task" + ): + calculate_rewards(batch, task_to_env) diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py new file mode 100644 index 0000000000..02e7a072b6 --- /dev/null +++ b/tests/unit/algorithms/test_loss_functions.py @@ -0,0 +1,62 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from nemo_reinforcer.algorithms.loss_functions import NLLLoss + + +def test_nll_loss(): + loss_fn = NLLLoss() + + vocab_size = 8 + data = { + "input_ids": torch.arange(vocab_size / 2) + .unsqueeze(0) + .to(torch.int64) + .to("cuda"), + "token_mask": torch.tensor([[0, 0, 1, 1]]).to("cuda"), + "sample_mask": torch.tensor([[1]]).to("cuda"), + } + + ### assume we predict the correct token with high probability + next_token_logits = ( + torch.tensor( + [ + [0, 999.0, 0, 0, 0, 0, 0, 0], + [0, 0, 999.0, 0, 0, 0, 0, 0], + [0, 0, 0, 999.0, 0, 0, 0, 0], + [0, 0, 0, 0, 0.0, 0, 0, 0], ## unused because we don't have a label + ] + ) + .unsqueeze(0) + .to("cuda") + ) + loss, metrics_dict = loss_fn(next_token_logits, data) + torch.testing.assert_allclose(loss.cpu(), torch.tensor(0.0)) + + ## now assume we predict the incorrect token with high probability + next_token_logits = ( + torch.tensor( + [ + [999.0, 0, 0, 0, 0, 0, 0, 0], + [0, 999.0, 0, 0, 0, 0, 0, 0], + [0, 0, 999.0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ] + ) + .unsqueeze(0) + .to("cuda") + ) + loss, metrics_dict = loss_fn(next_token_logits, data) + ## loss per token is 999, and we have two unmasked tokens + torch.testing.assert_allclose(loss.cpu(), torch.tensor(1998.0)) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 0000000000..67de3a36af --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,128 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import os +import random +from typing import Callable +import ray + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import os +import random +from typing import Callable +import ray +from nemo_reinforcer.distributed.virtual_cluster import init_ray + + +@pytest.fixture(scope="session", autouse=True) +def init_ray_cluster(): + """Initialize Ray for the test module and clean up afterward. + + This fixture doesn't need to be called directly. + """ + init_ray() + yield + ray.shutdown() + + +def _setup_distributed(rank, world_size, port, backend="nccl"): + """Initialize the distributed environment for a test (internal use only)""" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) # Use the same port for all processes + + # Initialize the process group + dist.init_process_group(backend=backend, rank=rank, world_size=world_size) + # Set the device for this process + torch.cuda.set_device(rank) + + +def _cleanup_distributed(): + """Clean up the distributed environment after a test (internal use only)""" + dist.destroy_process_group() + + +@pytest.fixture +def distributed_test_runner(): + """Fixture that returns a function to run distributed tests. + + This fixture provides a reusable way to run a test function across multiple processes + with PyTorch distributed communication set up. + """ + + def run_distributed_test( + test_fn: Callable, world_size: int, backend: str = "nccl" + ) -> None: + """Run a test function in a distributed environment. + + Args: + test_fn: The test function to run on each process + world_size: Number of processes to spawn + backend: PyTorch distributed backend to use + """ + # Skip if CUDA is not available and using NCCL backend + if backend == "nccl" and not torch.cuda.is_available(): + pytest.skip("CUDA not available, skipping CUDA-based test") + + # Skip if we don't have enough GPUs for NCCL backend + if backend == "nccl" and torch.cuda.device_count() < world_size: + pytest.skip( + f"Not enough GPUs available. Need {world_size}, got {torch.cuda.device_count()}" + ) + + # Generate a single random port in the main process + port = random.randint(10000, 20000) + + # Run the test on multiple processes + mp.spawn( + _distributed_test_wrapper, + args=(test_fn, world_size, port, backend), + nprocs=world_size, + join=True, + ) + + return run_distributed_test + + +def _distributed_test_wrapper( + rank: int, test_fn: Callable, world_size: int, port: int, backend: str +) -> None: + """Wrapper function that sets up the distributed environment before running the test function. + Internal use only - use distributed_test_runner fixture instead. + + Args: + rank: Process rank + test_fn: The test function to run + world_size: Total number of processes + port: Port to use for distributed communication + backend: PyTorch distributed backend to use + """ + try: + # Setup the distributed environment + _setup_distributed(rank, world_size, port, backend=backend) + + # Run the actual test function + test_fn(rank, world_size) + + # Clean up + _cleanup_distributed() + except Exception as e: + print(f"Error in rank {rank}: {e}") + _cleanup_distributed() + raise diff --git a/tests/unit/data/test_hf_datasets.py b/tests/unit/data/test_hf_datasets.py new file mode 100644 index 0000000000..5fe634a2a7 --- /dev/null +++ b/tests/unit/data/test_hf_datasets.py @@ -0,0 +1,50 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from transformers import AutoTokenizer +from nemo_reinforcer.data.hf_datasets.squad import SquadDataset + + +@pytest.mark.skip(reason="dataset download is flaky") +def test_squad_dataset(): + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") + squad_dataset = SquadDataset() + + # check that the dataset is formatted correctly + for example in squad_dataset.formatted_ds["train"].take(5): + assert "messages" in example + assert len(example["messages"]) == 3 + + assert example["messages"][0]["role"] == "system" + assert example["messages"][1]["role"] == "user" + assert example["messages"][2]["role"] == "assistant" + + ## check that applying chat template works as expected + default_templated = tokenizer.apply_chat_template( + example["messages"], + chat_template=squad_dataset.task_spec.custom_template, + tokenize=False, + add_generation_prompt=False, + add_special_tokens=False, + ) + + assert default_templated == ( + "Context: " + + example["messages"][0]["content"] + + " Question: " + + example["messages"][1]["content"] + + " Answer: " + + example["messages"][2]["content"] + ) diff --git a/tests/unit/data/test_llm_message_utils.py b/tests/unit/data/test_llm_message_utils.py new file mode 100644 index 0000000000..37f92cdc45 --- /dev/null +++ b/tests/unit/data/test_llm_message_utils.py @@ -0,0 +1,396 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from typing import Dict, List +from transformers import AutoTokenizer + +from nemo_reinforcer.data.llm_message_utils import ( + message_log_to_flat_messages, + get_keys_from_message_log, + batched_message_log_to_flat_message, + get_formatted_message_log, + add_loss_mask_to_message_log, + get_first_index_that_differs, +) +from nemo_reinforcer.data.interfaces import LLMMessageLogType, TaskDataSpec + + +@pytest.fixture +def simple_message_log() -> LLMMessageLogType: + """Fixture for a single message with tensor and text data.""" + return [ + { + "input_ids": torch.tensor([1, 2, 3]), + "attention_mask": torch.tensor([1, 1, 1]), + "text": "test", + } + ] + + +@pytest.fixture +def multiple_messages_log() -> LLMMessageLogType: + """Fixture for multiple messages with tensor and text data.""" + return [ + { + "input_ids": torch.tensor([1, 2]), + "attention_mask": torch.tensor([1, 1]), + "text": "first", + }, + { + "input_ids": torch.tensor([3, 4]), + "attention_mask": torch.tensor([1, 1]), + "text": "second", + }, + ] + + +@pytest.fixture +def uneven_message_logs() -> List[LLMMessageLogType]: + """Fixture for message logs of different lengths.""" + return [ + [ # First sequence (shorter) + { + "input_ids": torch.tensor([1, 2]), + "role": "user", + } + ], + [ # Second sequence (longer) + { + "input_ids": torch.tensor([3, 4, 5]), + "role": "assistant", + } + ], + ] + + +@pytest.fixture +def raw_chat_message_log() -> List[LLMMessageLogType]: + """Fixture for chat message logs.""" + return [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + +@pytest.fixture +def tokenized_non_chat_message_log() -> List[LLMMessageLogType]: + return [ + [ + { + "text": "some input text", + "token_ids": torch.tensor([0, 1, 2, 3, 4, 5, 6]), + "context_length": 3, + "answer_length": 4, + } + ] + ] + + +@pytest.fixture +def tokenized_chat_message_log() -> List[LLMMessageLogType]: + return [ + [ + { + "role": "system", + "content": "system message", + "token_ids": torch.tensor([0, 1, 2, 3, 4, 5]), + }, + { + "role": "user", + "content": "user message", + "token_ids": torch.tensor([6, 7, 8]), + }, + { + "role": "assistant", + "content": "assistant message", + "token_ids": torch.tensor([9, 10]), + }, + ] + ] + + +def test_message_log_to_flat_messages_empty() -> None: + """Test message_log_to_flat_messages with empty input.""" + result = message_log_to_flat_messages([]) + assert result == {}, "Empty input should return empty dictionary" + + +def test_message_log_to_flat_messages_missing_keys() -> None: + """Test message_log_to_flat_messages with messages having different keys.""" + message_log: LLMMessageLogType = [ + {"input_ids": torch.tensor([1, 2]), "text": "first"}, + {"input_ids": torch.tensor([3, 4]), "attention_mask": torch.tensor([1, 1])}, + ] + result = message_log_to_flat_messages(message_log) + assert torch.equal(result["input_ids"], torch.tensor([1, 2, 3, 4])) + assert result["text"] == ["first"] + assert torch.equal(result["attention_mask"], torch.tensor([1, 1])) + + +def test_concatenate_messages_different_shapes() -> None: + """Test message_log_to_flat_messages with tensors of different shapes.""" + message_log: LLMMessageLogType = [ + {"input_ids": torch.tensor([[1, 2], [3, 4]])}, # 2D tensor + {"input_ids": torch.tensor([5, 6])}, # 1D tensor + ] + with pytest.raises( + RuntimeError, + match=r"tensors for key='input_ids' must have same number of dimensions", + ): + message_log_to_flat_messages(message_log) + + +def test_get_keys_from_messages_empty() -> None: + """Test get_keys_from_message_log with empty input.""" + assert get_keys_from_message_log([], ["key1"]) == [] + + +def test_get_keys_from_messages_empty_keys() -> None: + """Test get_keys_from_message_log with empty keys list.""" + message_log: LLMMessageLogType = [{"key1": "val1"}] + assert get_keys_from_message_log(message_log, []) == [{}] + + +def test_get_keys_from_messages_all_missing() -> None: + """Test get_keys_from_message_log when all requested keys are missing.""" + message_log: LLMMessageLogType = [{"key1": "val1"}] + assert get_keys_from_message_log(message_log, ["nonexistent"]) == [{}] + + +def test_batch_pad_message_log_single_item() -> None: + """Test batch_pad_message_log with single-item batch.""" + message_log_batch = [ + [{"input_ids": torch.tensor([1, 2, 3])}], + ] + result, input_lengths = batched_message_log_to_flat_message(message_log_batch) + assert result["input_ids"].shape == (1, 3) + assert input_lengths.shape == (1,) + assert torch.equal(input_lengths, torch.tensor([3], dtype=torch.int32)) + + +def test_batch_pad_message_log_empty_batch() -> None: + """Test batch_pad_message_log with empty batch.""" + result, input_lengths = batched_message_log_to_flat_message([]) + assert len(result) == 0 + assert input_lengths.numel() == 0 + + +def test_batch_pad_message_log_no_tensors() -> None: + """Test batch_pad_message_log with messages containing no tensors.""" + message_log_batch = [ + [{"text": "first"}], + [{"text": "second"}], + ] + result, input_lengths = batched_message_log_to_flat_message(message_log_batch) + assert "text" in result + assert isinstance(result["text"], list) + assert result["text"] == ["first", "second"] + assert input_lengths.numel() == 0 + + +def test_batch_pad_messages_mixed_dtypes() -> None: + """Test batch_pad_message_log with tensors of different dtypes.""" + message_log_batch = [ + [{"input_ids": torch.tensor([1, 2], dtype=torch.long)}], + [{"input_ids": torch.tensor([3.0, 4.0, 5.0], dtype=torch.float)}], + ] + with pytest.raises(RuntimeError, match="expected consistent types"): + batched_message_log_to_flat_message(message_log_batch) + + +@pytest.mark.parametrize("device", ["cuda", "meta"]) +def test_batch_pad_message_log_different_devices(device: str) -> None: + """Test batch_pad_message_log with tensors on different devices.""" + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if device == "meta" and not hasattr(torch.device(device), "type"): + pytest.skip(f"Device {device} not available") + + message_log_batch = [ + [{"input_ids": torch.tensor([1, 2], device="cpu")}], + [{"input_ids": torch.tensor([3, 4, 5], device=device)}], + ] + with pytest.raises(RuntimeError, match="expected tensors on the same device"): + batched_message_log_to_flat_message(message_log_batch) + + +def test_message_log_to_flat_messages_single( + simple_message_log: LLMMessageLogType, +) -> None: + """Test message_log_to_flat_messages with a single message.""" + result = message_log_to_flat_messages(simple_message_log) + assert torch.equal(result["input_ids"], simple_message_log[0]["input_ids"]) + assert torch.equal( + result["attention_mask"], simple_message_log[0]["attention_mask"] + ) + assert result["text"] == [simple_message_log[0]["text"]] + + +def test_message_log_to_flat_messages_multiple( + multiple_messages_log: LLMMessageLogType, +) -> None: + """Test message_log_to_flat_messages with multiple messages.""" + result = message_log_to_flat_messages(multiple_messages_log) + assert torch.equal(result["input_ids"], torch.tensor([1, 2, 3, 4])) + assert torch.equal(result["attention_mask"], torch.tensor([1, 1, 1, 1])) + assert result["text"] == ["first", "second"] + + +def test_get_keys_from_messages() -> None: + """Test get_keys_from_message_log with various key combinations.""" + message_log: LLMMessageLogType = [ + {"key1": "val1", "key2": "val2", "key3": "val3"}, + {"key1": "val4", "key2": "val5", "key3": "val6"}, + ] + + # Test getting all keys + result = get_keys_from_message_log(message_log, ["key1", "key2", "key3"]) + assert result == message_log + + # Test getting subset of keys + result = get_keys_from_message_log(message_log, ["key1", "key2"]) + assert result == [ + {"key1": "val1", "key2": "val2"}, + {"key1": "val4", "key2": "val5"}, + ] + + # Test with non-existent key + result = get_keys_from_message_log(message_log, ["key1", "nonexistent"]) + assert result == [{"key1": "val1"}, {"key1": "val4"}] + + +def test_batch_pad_message_log_basic( + uneven_message_logs: List[LLMMessageLogType], +) -> None: + """Test batch_pad_message_log with right padding.""" + result, input_lengths = batched_message_log_to_flat_message(uneven_message_logs) + + # Check shapes + assert result["input_ids"].shape == (2, 3) + assert input_lengths.shape == (2,) + + # Expected tensors for right padding + expected_ids = torch.tensor([[1, 2, 0], [3, 4, 5]]) + expected_lengths = torch.tensor([2, 3], dtype=torch.int32) + + assert torch.equal(result["input_ids"], expected_ids) + assert torch.equal(input_lengths, expected_lengths) + + +def test_batch_pad_message_log_custom_pad_value( + uneven_message_logs: List[LLMMessageLogType], +) -> None: + """Test batch_pad_message_log with custom padding values.""" + pad_value_dict: Dict[str, int] = {"input_ids": -100} + result, input_lengths = batched_message_log_to_flat_message( + uneven_message_logs, pad_value_dict=pad_value_dict + ) + + assert torch.equal( + result["input_ids"], + torch.tensor([[1, 2, -100], [3, 4, 5]]), + ) + assert torch.equal( + input_lengths, + torch.tensor([2, 3], dtype=torch.int32), + ) + + +def test_get_formatted_message_log( + raw_chat_message_log: LLMMessageLogType, +) -> None: + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") + + ## get expected result + formatted_system_message = tokenizer.apply_chat_template( + [raw_chat_message_log[0]], + tokenize=False, + add_generation_prompt=False, + add_special_tokens=False, + ) + formatted_user_message = tokenizer.apply_chat_template( + [raw_chat_message_log[1]], + tokenize=False, + add_generation_prompt=False, + add_special_tokens=False, + ) + formatted_assistant_message = tokenizer.apply_chat_template( + [raw_chat_message_log[2]], + tokenize=False, + add_generation_prompt=False, + add_special_tokens=False, + ) + + ## text should be equivalent to if we apply chat template + ## to each turn separately and manually remove the bot string + ## from the intermediate turns + bot_str = "<|begin_of_text|>" + expected_text = [ + formatted_system_message, + formatted_user_message[len(bot_str) :], + formatted_assistant_message[len(bot_str) :], + ] + + task_data_spec = TaskDataSpec( + task_name="test", + ) + result = get_formatted_message_log(raw_chat_message_log, tokenizer, task_data_spec) + actual_text = [m["content"] for m in result] + + assert actual_text == expected_text + + +def test_add_loss_mask_to_chat_message_log( + tokenized_chat_message_log: LLMMessageLogType, +): + add_loss_mask_to_message_log( + tokenized_chat_message_log, roles_to_train_on=["assistant"] + ) + assert torch.equal( + tokenized_chat_message_log[0][0]["token_loss_mask"], + torch.tensor([0, 0, 0, 0, 0, 0]), + ) + assert torch.equal( + tokenized_chat_message_log[0][1]["token_loss_mask"], torch.tensor([0, 0, 0]) + ) + assert torch.equal( + tokenized_chat_message_log[0][2]["token_loss_mask"], torch.tensor([1, 1]) + ) + + ## test training on multiple roles + add_loss_mask_to_message_log( + tokenized_chat_message_log, + roles_to_train_on=["assistant", "system"], + ) + assert torch.equal( + tokenized_chat_message_log[0][0]["token_loss_mask"], + torch.tensor([1, 1, 1, 1, 1, 1]), + ) + assert torch.equal( + tokenized_chat_message_log[0][1]["token_loss_mask"], torch.tensor([0, 0, 0]) + ) + assert torch.equal( + tokenized_chat_message_log[0][2]["token_loss_mask"], torch.tensor([1, 1]) + ) + + +def test_get_first_index_that_differs(): + assert get_first_index_that_differs("hello", "hello") == 5 + assert get_first_index_that_differs("hello", "hello world") == 5 + assert get_first_index_that_differs("hello world", "hello") == 5 + assert get_first_index_that_differs("hi1", "hello2") == 1 + assert get_first_index_that_differs("hello2", "hi1") == 1 diff --git a/tests/unit/distributed/test_batched_data_dict.py b/tests/unit/distributed/test_batched_data_dict.py new file mode 100644 index 0000000000..acda98b6c2 --- /dev/null +++ b/tests/unit/distributed/test_batched_data_dict.py @@ -0,0 +1,203 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict + + +def test_shard_by_batch_size_basic(): + """Test basic functionality of shard_by_batch_size with tensor data.""" + # Create a sample batch with tensor data + batch = BatchedDataDict( + { + "tensor_data": torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + "other_tensor": torch.tensor([10, 11, 12, 13, 14, 15, 16, 17]), + } + ) + + # Shard with batch_size=4, shards=2 + sharded = batch.shard_by_batch_size(shards=2, batch_size=4) + + # Verify output structure + assert len(sharded) == 2, f"Expected 2 shards, got {len(sharded)}" + + # Verify first shard content (first elements of each chunk) + assert torch.equal(sharded[0]["tensor_data"], torch.tensor([0, 1, 4, 5])) + assert torch.equal(sharded[0]["other_tensor"], torch.tensor([10, 11, 14, 15])) + + # Verify second shard content (second elements of each chunk) + assert torch.equal(sharded[1]["tensor_data"], torch.tensor([2, 3, 6, 7])) + assert torch.equal(sharded[1]["other_tensor"], torch.tensor([12, 13, 16, 17])) + + +def test_shard_by_batch_size_list_data(): + """Test shard_by_batch_size with list data.""" + # Create a sample batch with list data + batch = BatchedDataDict( + { + "list_data": ["A", "B", "C", "D", "E", "F", "G", "H"], + "tensor_data": torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), + } + ) + + # Shard with batch_size=4, shards=2 + sharded = batch.shard_by_batch_size(shards=2, batch_size=4) + + # Verify output structure + assert len(sharded) == 2 + + # Verify first shard content + assert sharded[0]["list_data"] == ["A", "B", "E", "F"] + assert torch.equal(sharded[0]["tensor_data"], torch.tensor([0, 1, 4, 5])) + + # Verify second shard content + assert sharded[1]["list_data"] == ["C", "D", "G", "H"] + assert torch.equal(sharded[1]["tensor_data"], torch.tensor([2, 3, 6, 7])) + + +def test_shard_by_batch_size_larger_example(): + """Test shard_by_batch_size with a larger example with multiple chunks and shards.""" + # Create a batch with 12 elements + batch = BatchedDataDict( + {"tensor_data": torch.arange(12), "list_data": [f"item_{i}" for i in range(12)]} + ) + + # Shard with batch_size=3, shards=3 + sharded = batch.shard_by_batch_size(shards=3, batch_size=3) + + # Verify we get 3 shards + assert len(sharded) == 3 + + # Expected results: + # Chunk 1: [0, 1, 2], Chunk 2: [3, 4, 5], Chunk 3: [6, 7, 8], Chunk 4: [9, 10, 11] + # Shard 1: [0, 3, 6, 9] + # Shard 2: [1, 4, 7, 10] + # Shard 3: [2, 5, 8, 11] + + # Verify tensor content + assert torch.equal(sharded[0]["tensor_data"], torch.tensor([0, 3, 6, 9])) + assert torch.equal(sharded[1]["tensor_data"], torch.tensor([1, 4, 7, 10])) + assert torch.equal(sharded[2]["tensor_data"], torch.tensor([2, 5, 8, 11])) + + # Verify list content + assert sharded[0]["list_data"] == ["item_0", "item_3", "item_6", "item_9"] + assert sharded[1]["list_data"] == ["item_1", "item_4", "item_7", "item_10"] + assert sharded[2]["list_data"] == ["item_2", "item_5", "item_8", "item_11"] + + +def test_shard_by_batch_size_2d_tensor(): + """Test shard_by_batch_size with 2D tensor data.""" + # Create a batch with 2D tensors + batch = BatchedDataDict( + { + "features": torch.tensor( + [ + [1, 2, 3], # 0 + [4, 5, 6], # 1 + [7, 8, 9], # 2 + [10, 11, 12], # 3 + [13, 14, 15], # 4 + [16, 17, 18], # 5 + ] + ) + } + ) + + # Shard with batch_size=3, shards=3 + sharded = batch.shard_by_batch_size(shards=3, batch_size=3) + + # Verify we get 3 shards + assert len(sharded) == 3 + + # Expected results by index: + # Chunk 1: [0, 1, 2], Chunk 2: [3, 4, 5] + # Shard 1: [0, 3] + # Shard 2: [1, 4] + # Shard 3: [2, 5] + + # Verify tensor content + expected_0 = torch.tensor([[1, 2, 3], [10, 11, 12]]) + expected_1 = torch.tensor([[4, 5, 6], [13, 14, 15]]) + expected_2 = torch.tensor([[7, 8, 9], [16, 17, 18]]) + + assert torch.equal(sharded[0]["features"], expected_0) + assert torch.equal(sharded[1]["features"], expected_1) + assert torch.equal(sharded[2]["features"], expected_2) + + +def test_shard_by_batch_size_edge_cases(): + """Test edge cases for shard_by_batch_size.""" + # Case 1: Single batch, multiple shards + batch = BatchedDataDict({"data": torch.tensor([0, 1, 2, 3])}) + + sharded = batch.shard_by_batch_size(shards=2, batch_size=4) + assert len(sharded) == 2 + assert torch.equal(sharded[0]["data"], torch.tensor([0, 1])) + assert torch.equal(sharded[1]["data"], torch.tensor([2, 3])) + + # Case 2: Multiple batches, single shard + batch = BatchedDataDict({"data": torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])}) + + sharded = batch.shard_by_batch_size(shards=1, batch_size=2) + assert len(sharded) == 1 + assert torch.equal(sharded[0]["data"], torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])) + + +def test_shard_by_batch_size_validation(): + """Test validation checks in shard_by_batch_size.""" + # Create a batch + batch = BatchedDataDict({"data": torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])}) + + # Case 1: batch_size not a divisor of total_batch_size + with pytest.raises( + AssertionError, match="Total batch size.*is not a multiple of batch_size" + ): + batch.shard_by_batch_size(shards=2, batch_size=3) + + # Case 2: shards not a divisor of batch_size + # First make a batch that's divisible by batch_size to reach the second assertion + batch_for_case2 = BatchedDataDict({"data": torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])}) + with pytest.raises(AssertionError, match="Batch size.*is not a multiple of shards"): + batch_for_case2.shard_by_batch_size(shards=3, batch_size=4) + + # Case 3: Different batch sizes across keys + inconsistent_batch = BatchedDataDict( + { + "data1": torch.tensor([0, 1, 2, 3]), + "data2": torch.tensor([0, 1, 2]), + } # Different length + ) + + with pytest.raises( + AssertionError, match="Batch sizes are not the same across the rollout batch" + ): + inconsistent_batch.shard_by_batch_size(shards=2, batch_size=2) + + +def test_shard_by_batch_size_matches_example(): + """Test that shard_by_batch_size behaves as described in the docstring example.""" + # Create the example data: [A A B B C C D D] + batch = BatchedDataDict({"data": ["A", "A", "B", "B", "C", "C", "D", "D"]}) + + # Shard with batch_size=2, shards=2 + sharded = batch.shard_by_batch_size(shards=2, batch_size=2) + + # Verify output structure + assert len(sharded) == 2 + + # Expected output: + # Element 0: [A B C D] (first elements from each chunk) + # Element 1: [A B C D] (second elements from each chunk) + assert sharded[0]["data"] == ["A", "B", "C", "D"] + assert sharded[1]["data"] == ["A", "B", "C", "D"] diff --git a/tests/unit/distributed/test_cluster_visualization.py b/tests/unit/distributed/test_cluster_visualization.py new file mode 100644 index 0000000000..9253579cea --- /dev/null +++ b/tests/unit/distributed/test_cluster_visualization.py @@ -0,0 +1,131 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch, MagicMock +import pytest + +from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster + + +@pytest.fixture(autouse=True) +def mock_virtual_cluster_pg(): + # Mock the _init_placement_groups and get_placement_groups methods to avoid actually initializing placement groups + with ( + patch( + "nemo_reinforcer.distributed.virtual_cluster.RayVirtualCluster.get_placement_groups" + ) as mock_get_pg, + patch( + "nemo_reinforcer.distributed.virtual_cluster.RayVirtualCluster._init_placement_groups" + ) as mock_init_pg, + ): + mock_get_pg.return_value = [] + mock_init_pg.return_value = [] + yield + + +def test_empty_cluster_visualization(capsys): + """Test visualization of an empty cluster.""" + # Create a empty cluster + cluster = RayVirtualCluster( + bundle_ct_per_node_list=[], + use_gpus=False, + name="test-empty", + ) + + # Test visualization + cluster.print_cluster_grid() + + # Capture the output + out, _ = capsys.readouterr() + assert "Empty Ray Cluster" in out + + +def test_cluster_grid(capsys): + """Test visualization of a cluster grid.""" + # Create a cluster with a configuration but don't actually allocate resources + cluster = RayVirtualCluster( + bundle_ct_per_node_list=[2, 3], + use_gpus=False, + name="test-visual", + max_colocated_worker_groups=1, + ) + + cluster.print_cluster_grid() + + # Capture the output + out, _ = capsys.readouterr() + print(out) + assert "Ray Cluster: 2 nodes, 5 GPUs" in out + assert "0.0" in out # First node, first GPU + assert "0.1" in out # First node, second GPU + assert "1.0" in out # Second node, first GPU + assert "1.2" in out # Second node, third GPU + + +def test_global_visualization_formatting(capsys): + """Test global visualization formatting without actual worker groups.""" + cluster = RayVirtualCluster( + bundle_ct_per_node_list=[2, 2], + use_gpus=False, + name="test-global", + max_colocated_worker_groups=1, + ) + + cluster.print_all_worker_groups([]) + + # Capture the output + out, _ = capsys.readouterr() + print(out) + assert "Ray Cluster Global View: 2 nodes, 4 GPUs" in out + + +def test_with_mock_worker_groups(capsys): + """Test visualization with mock worker groups.""" + # Create a cluster with a configuration + cluster = RayVirtualCluster( + bundle_ct_per_node_list=[2, 3], + use_gpus=False, + name="test-workers", + max_colocated_worker_groups=1, + ) + + worker_group1 = MagicMock() + worker_group1.name_prefix = "policy" + worker_group1.world_size = 2 + worker_group1.worker_metadata = [ + {"node_idx": 0, "local_rank": 0}, # First worker on node 0, GPU 0 + {"node_idx": 1, "local_rank": 0}, # Second worker on node 1, GPU 0 + ] + + worker_group2 = MagicMock() + worker_group2.name_prefix = "policy_generate" + worker_group2.world_size = 3 + worker_group2.worker_metadata = [ + {"node_idx": 0, "local_rank": 1}, # First worker on node 0, GPU 1 + {"node_idx": 1, "local_rank": 1}, # Second worker on node 1, GPU 1 + {"node_idx": 1, "local_rank": 2}, # Third worker on node 1, GPU 2 + ] + + cluster.print_all_worker_groups([worker_group1, worker_group2]) + + # Capture the output + out, _ = capsys.readouterr() + print(out) + + # Check for key elements in the output + assert "Ray Cluster Global View: 2 nodes, 5 GPUs" in out + assert "G0" in out # First worker group + assert "G1" in out # Second worker group + assert "policy" in out # First worker group name + assert "policy_generate" in out # Second worker group name diff --git a/tests/unit/distributed/test_collectives.py b/tests/unit/distributed/test_collectives.py new file mode 100644 index 0000000000..e8cd5c4f88 --- /dev/null +++ b/tests/unit/distributed/test_collectives.py @@ -0,0 +1,67 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from nemo_reinforcer.distributed.collectives import ( + rebalance_nd_tensor, + gather_jagged_object_lists, +) + + +def run_rebalance_test(rank, world_size): + """Test function for rebalance_nd_tensor""" + # Create different sized tensors on each GPU + # Rank 0: batch size 3, Rank 1: batch size 5, Rank 2: batch size 2 + batch_sizes = [3, 5, 2] + my_batch_size = batch_sizes[rank] + + tensor = torch.ones( + (my_batch_size, 4), dtype=torch.float32, device=f"cuda:{rank}" + ) * (rank + 1) + result = rebalance_nd_tensor(tensor) + + # Verify the shape is correct (sum of all batch sizes) + total_batch_size = sum(batch_sizes) + assert result.shape[0] == total_batch_size, ( + f"Expected shape {total_batch_size}, got {result.shape[0]}" + ) + assert result.shape[1:] == tensor.shape[1:], "Feature dimensions should match" + + +def run_gather_test(rank, world_size): + """Test function for gather_jagged_object_lists""" + object_lists = [ + ["obj0", "obj1"], # rank 0: 2 objects + ["obj2", "obj3", "obj4"], # rank 1: 3 objects + ["obj5"], # rank 2: 1 object + ] + my_objects = object_lists[rank] + + result = gather_jagged_object_lists(my_objects) + + expected = ["obj0", "obj1", "obj2", "obj3", "obj4", "obj5"] + assert len(result) == len(expected), ( + f"Expected {len(expected)} objects, got {len(result)}" + ) + assert set(result) == set(expected), "All objects should be gathered" + + +def test_rebalance_nd_tensor(distributed_test_runner): + """Test rebalance_nd_tensor by spawning multiple processes""" + distributed_test_runner(run_rebalance_test, world_size=3) + + +def test_gather_jagged_object_lists(distributed_test_runner): + """Test gather_jagged_object_lists by spawning multiple processes""" + distributed_test_runner(run_gather_test, world_size=3) diff --git a/tests/unit/environments/test_math_environment.py b/tests/unit/environments/test_math_environment.py new file mode 100644 index 0000000000..7f6d8784e5 --- /dev/null +++ b/tests/unit/environments/test_math_environment.py @@ -0,0 +1,213 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import ray +from nemo_reinforcer.environments.math_environment import MathEnvironment +import time + + +@pytest.fixture(scope="module") +def math_env(): + """Create a MathEnvironment actor for testing.""" + env = MathEnvironment.options( + runtime_env={"py_executable": MathEnvironment.DEFAULT_PY_EXECUTABLE} + ).remote({"num_workers": 2}) + yield env + # Clean up the actor and wait for it to be killed + env.shutdown.remote() + ray.kill(env) + # Give some time for cleanup + time.sleep(0.1) + + +@pytest.fixture +def basic_test_data(): + """Common test data for basic math problems.""" + return { + "message_log_batch": [ + [ + {"role": "user", "content": "What is 2 + 2?"}, + {"role": "assistant", "content": "2 + 2 = \\boxed{4}"}, + ], + [ + {"role": "user", "content": "What is 3 * 4?"}, + {"role": "assistant", "content": "3 * 4 = \\boxed{12}"}, + ], + [ + {"role": "user", "content": "What is 10 - 5?"}, + {"role": "assistant", "content": "10 - 5 = \\boxed{5}"}, + ], + ], + "metadata": [ + {"ground_truth": "4"}, + {"ground_truth": "\\boxed{12}"}, + {"ground_truth": "\\boxed{5}"}, + ], + } + + +@pytest.fixture +def mixed_test_data(): + """Test data with mix of correct and incorrect responses.""" + return { + "message_log_batch": [ + [ + {"role": "user", "content": "What is 2 + 2?"}, + {"role": "assistant", "content": "2 + 2 = \\boxed{\\frac{8}{2}}"}, + ], + [ + {"role": "user", "content": "What is 3 * 4?"}, + {"role": "assistant", "content": "3 * 4 = 13"}, + ], + [ + {"role": "user", "content": "What is 10 - 5?"}, + {"role": "assistant", "content": "10 - 5 = \\boxed{5}"}, + ], + ], + "metadata": [ + {"ground_truth": "4.0"}, + {"ground_truth": "\\boxed{12}"}, + {"ground_truth": "\\boxed{5}"}, + ], + } + + +@pytest.fixture +def multiple_assistant_test_data(): + """Test data with multiple assistant messages in conversations.""" + return { + "message_log_batch": [ + [ + {"role": "user", "content": "What is 2 + 2?"}, + {"role": "assistant", "content": "Let me think..."}, + {"role": "assistant", "content": "2 + 2 = \\boxed{4}"}, + ], + [ + {"role": "user", "content": "What is 3 * 4?"}, + {"role": "assistant", "content": "I'll calculate that..."}, + {"role": "assistant", "content": "3 * 4 = \\boxed{12}"}, + ], + ], + "metadata": [{"ground_truth": "4"}, {"ground_truth": "\\boxed{12}"}], + } + + +def test_math_env_step_basic(math_env, basic_test_data): + """Test basic functionality of MathEnvironment step with simple messages.""" + observations, updated_metadata, rewards, done = ray.get( + math_env.step.remote( + basic_test_data["message_log_batch"], basic_test_data["metadata"] + ) + ) + + # Check observations + assert len(observations) == 3, "Should return observations for all 3 messages" + assert all(obs["role"] == "user" for obs in observations), ( + "All observations should be from user" + ) + assert all(obs["content"] == "correct" for obs in observations), ( + "All responses should be correct" + ) + + # Check metadata + assert len(updated_metadata) == 3, "Should return metadata for all 3 messages" + assert updated_metadata == basic_test_data["metadata"], ( + "Metadata should be unchanged" + ) + + # Check rewards and done flags + assert rewards.shape == (3,), "Rewards should be a tensor of shape (3,)" + assert all(rewards == 1.0), "All rewards should be 1.0 for correct answers" + assert done.shape == (3,), "Done flags should be a tensor of shape (3,)" + assert all(done == 1.0), "All done flags should be 1.0" + + +def test_math_env_step_mixed(math_env, mixed_test_data): + """Test MathEnvironment step with a mix of correct and incorrect responses.""" + observations, updated_metadata, rewards, done = ray.get( + math_env.step.remote( + mixed_test_data["message_log_batch"], mixed_test_data["metadata"] + ) + ) + + # Check observations and rewards + assert len(observations) == 3, "Should return observations for all 3 messages" + assert observations[0]["content"] == "correct", "First response should be correct" + assert observations[1]["content"] == "incorrect", ( + "Second response should be incorrect" + ) + assert observations[2]["content"] == "correct", "Third response should be correct" + + assert rewards.shape == (3,), "Rewards should be a tensor of shape (3,)" + assert rewards[0] == 1.0, "First reward should be 1.0" + assert rewards[1] == 0.0, "Second reward should be 0.0" + assert rewards[2] == 1.0, "Third reward should be 1.0" + + +def test_math_env_step_empty(math_env): + """Test MathEnvironment step with empty input.""" + observations, updated_metadata, rewards, done = ray.get( + math_env.step.remote([], []) + ) + + # Check all outputs are empty + assert len(observations) == 0, "Should return empty observations list" + assert len(updated_metadata) == 0, "Should return empty metadata list" + assert rewards.shape == (0,), "Should return empty rewards tensor" + assert done.shape == (0,), "Should return empty done tensor" + + +def test_math_env_step_multiple_assistant_messages( + math_env, multiple_assistant_test_data +): + """Test MathEnvironment step with multiple assistant messages in a conversation.""" + observations, updated_metadata, rewards, done = ray.get( + math_env.step.remote( + multiple_assistant_test_data["message_log_batch"], + multiple_assistant_test_data["metadata"], + ) + ) + + # Check that only the last assistant message is used + assert len(observations) == 2, "Should return observations for both conversations" + assert all(obs["content"] == "correct" for obs in observations), ( + "All responses should be correct" + ) + assert all(rewards == 1.0), "All rewards should be 1.0" + + +@pytest.mark.parametrize("batch_size", [1, 2, 10, 25, 101]) +def test_math_env_various_batches(math_env, batch_size): + """Test MathEnvironment step with different batch sizes.""" + message_log_batch = [ + [ + {"role": "user", "content": "What is 2 + 2?"}, + {"role": "assistant", "content": "2 + 1.333 = \\boxed{\\frac{10}{3}}"}, + ] + ] * batch_size + metadata = [{"ground_truth": "3.33333333"}] * batch_size + + observations, updated_metadata, rewards, done = ray.get( + math_env.step.remote(message_log_batch, metadata) + ) + + # Check outputs + assert len(observations) == batch_size, ( + f"Should return observations for all {batch_size} messages" + ) + assert all(obs["content"] == "correct" for obs in observations), ( + "All responses should be correct" + ) + assert all(rewards == 1.0), "All rewards should be 1.0" + assert all(done == 1.0), "All done flags should be 1.0" diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py new file mode 100644 index 0000000000..adf1aad824 --- /dev/null +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -0,0 +1,485 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +import ray + +from transformers import AutoTokenizer + +from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.models.generation.vllm import VllmGeneration, VllmConfig + + +# Skip all tests if no CUDA or vLLM +pytestmark = [ + pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 1, + reason="CUDA not available or insufficient GPUs", + ) +] + + +# Define basic vLLM test config +basic_vllm_test_config: VllmConfig = { + "model_name": "meta-llama/Llama-3.2-1B", # Small model for testing + "dtype": "bfloat16", + "max_new_tokens": 10, + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, + "vllm_cfg": { + "tensor_parallel_size": 1, + "gpu_memory_utilization": 0.3, + "max_model_len": 1024, + }, +} + + +@pytest.fixture(scope="module") +def check_vllm_available(): + """Skip tests if vLLM is not installed.""" + try: + import vllm # noqa: F401 + except ImportError: + pytest.skip("vLLM not installed") + + +@pytest.fixture(scope="module") +def cluster(): + """Create a virtual cluster for testing.""" + # Create a cluster with 1 node that has 2 GPU bundles + virtual_cluster = RayVirtualCluster( + bundle_ct_per_node_list=[2], # 1 node with 2 GPU bundle + use_gpus=True, + max_colocated_worker_groups=2, + num_gpus_per_node=torch.cuda.device_count(), # Use available GPUs + name="vllm-test-cluster", + ) + yield virtual_cluster + virtual_cluster.shutdown() + + +@pytest.fixture(scope="function") +def tokenizer(): + """Initialize tokenizer for the test model.""" + model_name = basic_vllm_test_config["model_name"] + tokenizer = AutoTokenizer.from_pretrained(model_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + +@pytest.fixture(scope="function") +def policy(cluster, check_vllm_available): + """Initialize the vLLM policy.""" + policy = VllmGeneration(cluster, basic_vllm_test_config) + yield policy + + # Ensure policy is properly shutdown + try: + policy.shutdown() + # Force garbage collection to help release resources + import gc + + gc.collect() + torch.cuda.empty_cache() + except Exception as e: + print(f"Error during policy cleanup: {e}") + + +@pytest.fixture(scope="function") +def test_input_data(tokenizer): + """Create test input data for inference.""" + test_prompts = [ + "Hello, my name is", + "The capital of France is", + ] + + # Tokenize prompts + encodings = tokenizer( + test_prompts, + padding="max_length", + max_length=20, + truncation=True, + return_tensors="pt", + padding_side="right", + ) + + # Calculate input lengths from attention mask + input_lengths = encodings["attention_mask"].sum(dim=1).to(torch.int32) + + # Create input data dictionary + return BatchedDataDict( + { + "input_ids": encodings["input_ids"], + "input_lengths": input_lengths, + } + ) + + +def test_vllm_policy_generation(policy, test_input_data, tokenizer): + """Test vLLM policy generation capabilities.""" + # Test generation + print("Testing generation...") + outputs = policy.generate(test_input_data) + + # Validate outputs format + assert "output_ids" in outputs, "output_ids not found in generation output" + assert "logprobs" in outputs, "logprobs not found in generation output" + assert "generation_lengths" in outputs, ( + "generation_lengths not found in generation output" + ) + assert "unpadded_sequence_lengths" in outputs, ( + "unpadded_sequence_lengths not found in generation output" + ) + + # Validate outputs shape and content + assert outputs["output_ids"].shape[0] == len(test_input_data["input_ids"]), ( + "Wrong batch size in output" + ) + assert outputs["generation_lengths"].shape[0] == len( + test_input_data["input_ids"] + ), "Wrong batch size in generation_lengths" + + # Decode and check outputs + generated_sequences = outputs["output_ids"] + generated_texts = tokenizer.batch_decode( + generated_sequences, skip_special_tokens=True + ) + + print(f"Generated texts: {generated_texts}") + + # All texts should have a non-zero length and be longer than inputs + assert all(len(text) > 0 for text in generated_texts), ( + "Some generated texts are empty" + ) + + +@pytest.mark.timeout(140) +def test_vllm_generation_with_hf_training(cluster, tokenizer): + """1. Use vLLM for generation + 2. Use HF policy for training and logprob computation + + This test validates that the two policies can work together. + """ + from nemo_reinforcer.models.policy.hf_policy import HfPolicy + from tests.unit.test_utils import nll_loss + + # Create separate configs for each policy + vllm_config = basic_vllm_test_config.copy() + + # Create HF-specific config with required parameters + hf_config = { + "model_name": basic_vllm_test_config["model_name"], + # Required training parameters + "train_global_batch_size": 4, + "train_micro_batch_size": 1, + "learning_rate": 5e-6, + "logprob_batch_size": 1, + "max_new_tokens": 16, + "do_sample": False, + } + + vllm_policy = None + hf_policy = None + + try: + prompts = [ + "Write a story about a magical forest", + "Explain how photosynthesis works", + "What are the benefits of exercise?", + "Describe the water cycle", + "What is the capital of France?", + "Who is the president of the USA?", + "What is the capital of the moon?", + "Where is the sun?", + ] + + # Tokenize the prompts the same way as in test_hf_ray_policy + tokenized = tokenizer( + prompts, + padding=True, + truncation=True, + max_length=64, + return_tensors="pt", + padding_side="right", + ) + # Calculate input lengths from attention mask + input_lengths = tokenized["attention_mask"].sum(dim=1).to(torch.int32) + + test_input_data = BatchedDataDict( + { + "input_ids": tokenized["input_ids"], + "input_lengths": input_lengths, + } + ) + + # Create both policies + print("Creating vLLM policy...") + vllm_policy = VllmGeneration(cluster, vllm_config) + + print("Creating HF policy...") + hf_policy = HfPolicy(cluster, hf_config) + + # Step 1: Use vLLM for generation + print("Using vLLM policy for fast generation...") + generation_results = vllm_policy.generate(test_input_data) + vllm_policy.finish_generation() + # Validate generation outputs + assert "output_ids" in generation_results, ( + "output_ids not found in vLLM generation output" + ) + assert "logprobs" in generation_results, ( + "logprobs not found in vLLM generation output" + ) + + # Decode generations + generated_texts = tokenizer.batch_decode( + generation_results["output_ids"], skip_special_tokens=True + ) + print(f"vLLM generated texts: {generated_texts}") + + # Run logprob calculation with HF policy to verify + + fprop_logprob_data = BatchedDataDict( + { + "input_ids": generation_results["output_ids"], + "input_lengths": generation_results["unpadded_sequence_lengths"], + } + ) + # Get logprobs from HF policy + fprop_results = hf_policy.get_logprobs(fprop_logprob_data) + # Zero out logprobs for input tokens + + print(f"HF logprobs: {fprop_results['logprobs']}") + print(f"vLLM logprobs: {generation_results['logprobs']}") + + # Validate that the logprobs are correct (comparing vLLM generation logprobs with HF computed logprobs) + + # Create a mask for padding tokens to only include tokens up to generation_lengths + padding_mask = torch.zeros_like( + generation_results["logprobs"], dtype=torch.bool + ) + for i, (input_len, total_valid_len) in enumerate( + zip( + test_input_data.get("input_lengths"), + generation_results["unpadded_sequence_lengths"], + ) + ): + padding_mask[i, input_len:total_valid_len] = True + + abs_diff = torch.abs(generation_results["logprobs"] - fprop_results["logprobs"]) + masked_abs_diff = abs_diff.masked_select(padding_mask) + avg_prob_mult_error = ( + torch.mean(torch.exp(masked_abs_diff)) + if masked_abs_diff.numel() > 0 + else torch.tensor(0.0) + ) + + print(f"Average probability multiplicative error: {avg_prob_mult_error}") + assert avg_prob_mult_error <= 1.043, "vLLM and HF logprobs should closely match" + + # Step 2: Prepare simplified training data (smaller and with padding removed to prevent OOM) + # Use a very small sequence for training to ensure it works + max_seq_len = min(40, generation_results["output_ids"].shape[1]) + # cap generation lengths to max_seq_len + generation_results["unpadded_sequence_lengths"] = torch.clamp( + generation_results["unpadded_sequence_lengths"], max=max_seq_len + ) + + train_input_ids = generation_results["output_ids"][:, :max_seq_len] + token_loss_mask = torch.ones_like(train_input_ids) + # Only compute loss on generated tokens, not input + input_len = test_input_data.get("input_ids").size(1) + token_loss_mask[:, :input_len] = 0 + + for idx, length in enumerate(generation_results["unpadded_sequence_lengths"]): + token_loss_mask[idx, length:] = 0 + + train_data = BatchedDataDict( + { + "input_ids": train_input_ids, + "input_lengths": generation_results["unpadded_sequence_lengths"], + "token_loss_mask": token_loss_mask, + } + ) + + # Step 3: Try a minimal training step with HF policy + print("Training with HF policy (single step)...") + hf_policy.prepare_for_training() + + # Just do one training step to verify it works + results = hf_policy.train(train_data, nll_loss) + print(f"Training loss: {results['loss']}") + + hf_policy.finish_training() + + # Step 4: Use vLLM for generation again to complete the workflow + print("Using vLLM for generation again...") + vllm_policy.prepare_for_generation() + final_generation = vllm_policy.generate(test_input_data) + assert "output_ids" in final_generation, ( + "Final generation should contain output_ids" + ) + + print("Successfully demonstrated vLLM generation + HF training workflow!") + + finally: + # Clean up resources + print("Cleaning up resources...") + if vllm_policy: + vllm_policy.shutdown() + if hf_policy and hasattr(hf_policy, "shutdown"): + hf_policy.shutdown() + + +def test_vllm_policy_tensor_parallel(cluster, tokenizer): + """Test vLLM policy with tensor parallelism > 1.""" + # Skip if less than 2 GPUs are available + if torch.cuda.device_count() < 2: + pytest.skip("Tensor parallelism test requires at least 2 GPUs") + + # Configure with tensor_parallel_size=2 + tp_config = basic_vllm_test_config.copy() + tp_config["tensor_parallel_size"] = 2 + + # Ensure we specify the distributed executor backend + tp_config["vllm_kwargs"] = {"distributed_executor_backend": "ray"} + + vllm_policy = None + try: + vllm_policy = VllmGeneration(cluster, tp_config) + + # Create simple test input + test_prompts = ["Hello, my name is", "The capital of France is"] + encodings = tokenizer( + test_prompts, + padding="max_length", + max_length=10, + truncation=True, + return_tensors="pt", + padding_side="right", + ) + + test_input_data = BatchedDataDict( + { + "input_ids": encodings["input_ids"], + "input_lengths": encodings["attention_mask"].sum(dim=1).to(torch.int32), + } + ) + + # Test generation with tensor parallelism + outputs = vllm_policy.generate(test_input_data) + + vllm_policy.finish_generation() + vllm_policy.prepare_for_generation() + # Validate outputs + # Test generation with tensor parallelism + outputs = vllm_policy.generate(test_input_data) + + assert "output_ids" in outputs, "output_ids not found in generation output" + assert outputs["output_ids"].shape[0] == 2, "Wrong batch size in output" + + # Decode and check output + generated_text = tokenizer.decode( + outputs["output_ids"][0], skip_special_tokens=True + ) + print(f"Generated text with TP=2: {generated_text}") + assert len(generated_text) > 0, "Generated text is empty" + + finally: + # Clean up resources + if vllm_policy: + vllm_policy.shutdown() + + +@pytest.mark.timeout(60) +@pytest.mark.parametrize("tensor_parallel_size", [1, 2]) +def test_vllm_policy_weight_update(cluster, tokenizer, tensor_parallel_size): + """Test that weights can be updated from HF to vLLM policy.""" + # Skip if requesting tensor_parallel_size=2 but less than 2 GPUs are available + if tensor_parallel_size > 1 and torch.cuda.device_count() < 2: + pytest.skip( + f"Tensor parallelism test with tp={tensor_parallel_size} requires at least {tensor_parallel_size} GPUs" + ) + + # Create HF policy + from nemo_reinforcer.models.policy.hf_policy import HfPolicy + + # Create separate configs for each policy + vllm_config = basic_vllm_test_config.copy() + vllm_config["tensor_parallel_size"] = tensor_parallel_size + + # Add vllm_kwargs only if using tensor parallelism + if tensor_parallel_size > 1: + vllm_config["vllm_kwargs"] = {"distributed_executor_backend": "ray"} + + # Create HF-specific config with required parameters + hf_config = { + "model_name": basic_vllm_test_config["model_name"], + # Required training parameters + "train_global_batch_size": 4, + "train_micro_batch_size": 1, + "learning_rate": 5e-6, + "logprob_batch_size": 1, + "max_new_tokens": 16, + "do_sample": False, + } + + hf_policy = HfPolicy(cluster, hf_config) + print(f"hf_policy created: {hf_policy}", flush=True) + # hf_policy.finish_training() + vllm_policy = VllmGeneration(cluster, vllm_config) + print( + f"vllm_policy created with tensor_parallel_size={tensor_parallel_size}: {vllm_policy}", + flush=True, + ) + + # Test generation with tensor parallelism + vllm_policy.finish_generation() + # hf_policy.prepare_for_training() + + # Zero out the weights in the HF model via workers + ray.get( + [worker.zero_out_weights.remote() for worker in hf_policy.worker_group.workers] + ) + print("Zeroed out weights in HF policy") + # Get device IDs + training_device_id = ray.get( + hf_policy.worker_group.workers[0].report_device_id.remote() + ) + worker_device_id = ray.get( + vllm_policy.worker_group.workers[0].report_device_id.remote() + ) + + # Ensure they are on the same device + assert training_device_id == worker_device_id, ( + "Training actor and worker should be on the same device" + ) + + # Use our new utility methods for weight update + # Get IPC handles from the HF policy + ipc_handles = hf_policy.get_weights_ipc_handles() + print("Got IPC handles from HF policy") + vllm_policy.prepare_for_generation() + # Update weights in the VllmGeneration + assert vllm_policy.update_weights(ipc_handles), "Weight update should succeed" + + # Check if weights have been updated + assert vllm_policy._check_all_weights_changed(), "Weights should be updated to zero" + + # Clean up + vllm_policy.shutdown() diff --git a/tests/unit/models/policy/test_hf_ray_policy.py b/tests/unit/models/policy/test_hf_ray_policy.py new file mode 100644 index 0000000000..0f51fbf792 --- /dev/null +++ b/tests/unit/models/policy/test_hf_ray_policy.py @@ -0,0 +1,514 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import ray +import pytest +import pprint +import torch + +from nemo_reinforcer.models.policy import PolicyConfig +from nemo_reinforcer.models.policy.hf_policy import HfPolicy +from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.algorithms.interfaces import LossFunction +from tests.unit.test_utils import simple_loss, nll_loss +from transformers import AutoTokenizer + + +basic_llama_test_config: PolicyConfig = { + "model_name": "meta-llama/Llama-3.2-1B", + "generation_batch_size": 1, # Small batch size for testing + "train_global_batch_size": 4, + "train_micro_batch_size": 1, + "learning_rate": 5e-6, + "logprob_batch_size": 1, + "generation": { + "backend": "hf", + "temperature": 1.0, + "max_new_tokens": 16, # Small number of tokens for testing + "top_p": 1.0, + "top_k": None, + }, + "scheduler": { + "name": "torch.optim.lr_scheduler.CosineAnnealingLR", + "kwargs": { + "T_max": 100, + }, + }, +} + + +@pytest.fixture(scope="function") +def gc_collect(): + """Helper function to force garbage collection after a test""" + import gc + + gc.collect() + + +@pytest.fixture +def policy_setup(): + """Setup and teardown for policy tests - creates a virtual cluster and policy.""" + policy = None + cluster = None + + cluster_name = "test" + print(f"Creating virtual cluster '{cluster_name}'...") + + cluster = RayVirtualCluster( + name=cluster_name, + bundle_ct_per_node_list=[2], # Single node, 2 gpus + use_gpus=True, + num_gpus_per_node=2, # Using both GPUs + max_colocated_worker_groups=1, # Only one worker group + ) + + config = basic_llama_test_config + + print("Creating HfPolicy...") + policy = HfPolicy(cluster=cluster, config=config) + + yield policy, cluster + + # Clean up after the test + print("Cleaning up resources for test") + cluster.shutdown() + policy.worker_group.shutdown() + + +@pytest.mark.timeout(180) +def test_hf_policy_init(policy_setup): + policy, cluster = policy_setup + + # Verify cluster and policy were properly created + assert policy is not None, "Policy was not created properly" + assert cluster is not None, "Cluster was not created properly" + + # Verify we have two workers, one per GPU + assert len(policy.worker_group.workers) == 2, "Should have 2 workers, one per GPU" + + # Check workers are alive + worker_alive = ray.get([w.is_alive.remote() for w in policy.worker_group.workers]) + assert all(worker_alive), f"Not all workers are alive: {worker_alive}" + + # Get GPU info from both workers to verify GPU usage + print("\nGetting GPU information from workers...") + gpu_infos = ray.get([w.get_gpu_info.remote() for w in policy.worker_group.workers]) + print("\nGPU Information:") + for i, info in enumerate(gpu_infos): + print(f"\nWorker {i} GPU Info:") + pprint.pprint(info) + + # Check 1: Verify workers have different ranks + gpu_ranks = [info["rank"] for info in gpu_infos] + assert len(set(gpu_ranks)) == 2, f"Expected 2 different ranks, got {gpu_ranks}" + assert set(gpu_ranks) == {0, 1}, f"Expected ranks 0 and 1, got {gpu_ranks}" + + # Check 2: Verify workers have different local_ranks + local_ranks = [info["local_rank"] for info in gpu_infos] + assert len(set(local_ranks)) == 2, ( + f"Expected 2 different local_ranks, got {local_ranks}" + ) + assert set(local_ranks) == {0, 1}, ( + f"Expected local_ranks 0 and 1, got {local_ranks}" + ) + + # Check 3: Verify workers have different CUDA_VISIBLE_DEVICES + cuda_visible_devices = [ + info["env_vars"].get("CUDA_VISIBLE_DEVICES") for info in gpu_infos + ] + assert len(set(cuda_visible_devices)) == 2, ( + f"Expected different CUDA_VISIBLE_DEVICES, got {cuda_visible_devices}" + ) + + # Check 4: Verify all workers report correct world_size + for info in gpu_infos: + assert info["world_size"] == 2, ( + f"Expected world_size=2, got {info['world_size']}" + ) + assert info["env_vars"]["WORLD_SIZE"] == "2", ( + f"Expected WORLD_SIZE=2, got {info['env_vars']['WORLD_SIZE']}" + ) + + # Check 5: Verify significant GPU memory is allocated (at least 1GB) on both GPUs + for info in gpu_infos: + assert info["memory_allocated_mb"] > 1000, ( + f"Not enough memory allocated on GPU for rank {info['rank']}: {info['memory_allocated_mb']:.2f} MB" + ) + + # Check 6: Verify model parameters are on CUDA devices for both workers + for info in gpu_infos: + param_sample = list(info["parameter_sample"].values())[0] + assert "cuda" in param_sample["device"], ( + f"Parameter not on CUDA device: {param_sample['device']}" + ) + + # Check 8: Verify same model parameters are being tracked across workers + param_names = [list(info["parameter_sample"].keys())[0] for info in gpu_infos] + assert len(set(param_names)) == 1, ( + f"Workers are not tracking the same parameter: {param_names}" + ) + + # Check 9: Both workers should see their device as cuda:0 (correct distributed behavior) + for info in gpu_infos: + param_device = list(info["parameter_sample"].values())[0]["device"] + assert param_device == "cuda:0", ( + f"Expected parameter device to be cuda:0, got {param_device}" + ) + + +@pytest.fixture +def training_setup(): + """Setup and teardown specifically for training tests.""" + policy = None + cluster = None + data = None + loss_fn = None + + try: + # Create resources with unique name + cluster_name = "test-train" + print(f"Creating training virtual cluster '{cluster_name}'...") + + cluster = RayVirtualCluster( + name=cluster_name, + bundle_ct_per_node_list=[2], # Single node, 2 gpus + use_gpus=True, + num_gpus_per_node=2, # Using both GPUs + max_colocated_worker_groups=1, # Only one worker group + ) + + config = basic_llama_test_config + + print("Creating training HfPolicy...") + policy = HfPolicy(cluster=cluster, config=config) + + # Create a test batch + print("Creating test batch...") + # set random seed + torch.manual_seed(42) + + # Create test input_ids and attention_mask + input_ids = torch.randint(0, 32000, (8, 128)) # 8 sequences, each of length 128 + attention_mask = torch.ones(8, 128) + + # Calculate input_lengths (all sequences are full length in this test) + input_lengths = attention_mask.sum(dim=1).to(torch.int32) + + data = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": input_lengths, + "attention_mask": attention_mask, # Keep for compatibility with loss functions + "labels": torch.randint(0, 32000, (8, 128)), + } + ) + + # Create loss function + loss_fn: LossFunction = simple_loss + + # Provide the resources to the test + yield policy, cluster, data, loss_fn + + except Exception as e: + print(f"Error during training setup: {e}") + pytest.skip(f"Training setup failed: {e}") + finally: + # Clean up after the test + print("Cleaning up resources for test") + cluster.shutdown() + policy.worker_group.shutdown() + + +@pytest.mark.timeout(180) +def test_hf_policy_training(training_setup): + def verify_loss_tensor(loss_tensor): + assert not torch.isnan(loss_tensor).any(), "Loss should not be NaN" + assert not torch.isinf(loss_tensor).any(), "Loss should not be Inf" + return loss_tensor + + policy, cluster, data, loss_fn = training_setup + + # Verify resources were created properly + assert policy is not None, "Training policy was not created properly" + assert cluster is not None, "Training cluster was not created properly" + assert data is not None, "Test data was not created properly" + assert loss_fn is not None, "Loss function was not created properly" + + # Call prepare_for_training if available + print("\nPreparing for training...") + policy.prepare_for_training() + + losses = [] + for steps in range(4): + results = policy.train(data, loss_fn) + + # Verify results + assert "loss" in results, "Training results should contain 'loss'" + loss_tensor = results["loss"] + verify_loss_tensor(loss_tensor) + losses.append(loss_tensor[-1].item()) + + print(f"Training loss: {results['loss']}") + + policy.finish_training() + + # Verify loss changed between iterations (model parameters were updated) + assert losses[0] > losses[-1], "Loss should decrease over training iterations" + + +@pytest.fixture +def generation_setup(): + """Setup and teardown specifically for generation tests.""" + policy = None + cluster = None + data = None + + try: + # Create resources with unique name + cluster_name = "test-generate" + print(f"Creating generation virtual cluster '{cluster_name}'...") + + cluster = RayVirtualCluster( + name=cluster_name, + bundle_ct_per_node_list=[2], # Single node, 2 gpus + use_gpus=True, + num_gpus_per_node=2, # Using both GPUs + max_colocated_worker_groups=1, # Only one worker group + ) + + config = basic_llama_test_config + + print("Creating generation HfPolicy...") + policy = HfPolicy(cluster=cluster, config=config) + + # Create a test batch + print("Creating test batch...") + torch.manual_seed(42) # For reproducibility + + prompts = [ + "Write a story about a magical forest", + "Explain how photosynthesis works", + "What are the benefits of exercise?", + "Describe the water cycle", + "What is the capital of France?", + "Who is the president of the USA?", + "What is the capital of the moon?", + "Where is the sun?", + ] + + expected_generations = [ + "Write a story about a magical forest. The forest is magical because it is full of magical creatures. The creatures are", + "Explain how photosynthesis works\nExplain how photosynthesis works\nPhotosynthesis is the process by which plants", + "What are the benefits of exercise? The benefits of exercise are many and varied. It is a great way to improve", + "Describe the water cycle in your own words.\nDescribe the water cycle in your own words.\nDescribe the", + "What is the capital of France? A. Paris B. New York C. Washington D. Baton Rouge\nA", + "Who is the president of the USA? Who is the president of the USA? Who is the president of the USA?", + "What is the capital of the moon? A. Houston B. New York C. Washington D. Denver\nA.", + "Where is the sun? Where is the moon? Where is the earth? Where is the sky? Where", + ] + + # Tokenize the prompts + tokenizer = AutoTokenizer.from_pretrained(config["model_name"]) + tokenizer.pad_token = tokenizer.eos_token + tokenized = tokenizer( + prompts, + padding=True, + truncation=True, + max_length=64, + return_tensors="pt", + padding_side="right", + ) + + # Calculate input lengths from attention mask + input_lengths = tokenized["attention_mask"].sum(dim=1).to(torch.int32) + + data = BatchedDataDict( + { + "input_ids": tokenized["input_ids"], + "input_lengths": input_lengths, + } + ) + + # Provide the resources to the test + yield policy, cluster, data, tokenizer, prompts, expected_generations + + except Exception as e: + print(f"Error during generation setup: {e}") + pytest.skip(f"Generation setup failed: {e}") + finally: + # Clean up after the test + print("Cleaning up resources for test") + cluster.shutdown() + policy.worker_group.shutdown() + + +@pytest.mark.timeout(180) +def test_hf_policy_generation(generation_setup): + policy, cluster, data, tokenizer, prompts, expected_generations = generation_setup + + # Verify resources were created properly + assert policy is not None, "Generation policy was not created properly" + assert cluster is not None, "Generation cluster was not created properly" + assert data is not None, "Test data was not created properly" + + # Call prepare_for_generation if available + print("Preparing for generation...") + policy.prepare_for_generation() + + # Generate text + print("Generating text...") + results = policy.generate(data, greedy=True) + + # Verify results + assert "output_ids" in results, "Generation results should contain 'output_ids'" + output_ids = results["output_ids"] + + # run logprob calculation manually to verify + fprop_logprob_data = BatchedDataDict( + { + "input_ids": results.get("output_ids"), + "input_lengths": results.get("unpadded_sequence_lengths"), + } + ) + fprop_results = policy.get_logprobs(fprop_logprob_data) + for i, length in enumerate(data["input_lengths"]): + fprop_results["logprobs"][i, :length] = 0 + + for i, valid_seq_len in enumerate(results["unpadded_sequence_lengths"]): + fprop_results["logprobs"][i, valid_seq_len:] = 0 + + # Basic validation of output shape and content + assert isinstance(output_ids, torch.Tensor), "Output should be a tensor" + assert output_ids.dim() == 2, ( + "Output should be 2-dimensional [batch_size, seq_length]" + ) + assert output_ids.size(0) == data.get("input_ids").size(0), ( + "Output batch size should match input" + ) + assert output_ids.size(1) > data.get("input_ids").size(1), ( + "Output should be longer than input" + ) + + # validate that the logprobs are correct + avg_prob_mult_error = torch.mean( + torch.exp(torch.abs(results["logprobs"] - fprop_results["logprobs"])) + ) + print(f"avg prob mult error: {avg_prob_mult_error}") + assert avg_prob_mult_error <= 1.025 + + # get logprobs for the expected generations + expected_tokenized = tokenizer( + expected_generations, + padding=True, + truncation=True, + max_length=64, + return_tensors="pt", + padding_side="right", + ) + + # Calculate input_lengths for expected generations + expected_lengths = expected_tokenized["attention_mask"].sum(dim=1).to(torch.int32) + + expected_data = BatchedDataDict( + { + "input_ids": expected_tokenized["input_ids"], + "input_lengths": expected_lengths, + } + ) + + expected_logprobs = policy.get_logprobs(expected_data)["logprobs"] + mean_lps = torch.mean(expected_logprobs * expected_tokenized["attention_mask"]) + assert mean_lps > -1.7, "Expected logprobs should be greater than -1.7" + assert mean_lps < -1.4, "Expected logprobs should be less than -1.4" + + # Call finish_generation if available + print("Finishing generation...") + policy.finish_generation() + + +@pytest.mark.timeout(180) +def test_all_hf_policy_generation_lps_ref_training(generation_setup): + policy, cluster, data, tokenizer, prompts, expected_generations = generation_setup + + # Verify resources were created properly + assert policy is not None, "Generation policy was not created properly" + assert cluster is not None, "Generation cluster was not created properly" + assert data is not None, "Test data was not created properly" + + # Create reference data by generating with the model + print("creating some data") + ref_results = policy.generate(data, greedy=True) + + # Create training data with reference outputs + token_loss_mask = torch.ones_like(ref_results["output_ids"]) + token_loss_mask[:, : data.get("input_ids").size(1)] = 0 + + for idx, length in enumerate(ref_results["unpadded_sequence_lengths"]): + token_loss_mask[idx, length:] = 0 + + train_data = BatchedDataDict( + { + "input_ids": ref_results["output_ids"], + "input_lengths": ref_results["unpadded_sequence_lengths"], + "token_loss_mask": token_loss_mask, + } + ) + + fprop_logprobs = policy.get_logprobs(train_data)["logprobs"] + + loss_fn: LossFunction = nll_loss + + # Train for a few steps + policy.prepare_for_training() + losses = [] + for step in range(8): + results = policy.train(train_data, loss_fn) + + # Verify results + assert "loss" in results, "Training results should contain 'loss'" + loss_tensor = results["loss"] + assert not torch.isnan(loss_tensor).any(), "Loss should not be NaN" + assert not torch.isinf(loss_tensor).any(), "Loss should not be Inf" + losses.append(loss_tensor[-1].item()) + + print(f"Training loss at step {step}: {results['loss']}") + + policy.finish_training() + + post_train_reference_logprobs = policy.get_reference_policy_logprobs(train_data)[ + "reference_logprobs" + ] + post_train_fprop_logprobs = policy.get_logprobs(train_data)["logprobs"] + + # Verify that the reference policy logprobs match the original policy logprobs + assert torch.allclose(fprop_logprobs, post_train_reference_logprobs), ( + "Logprobs from policy before training and reference policy after training should match" + ) + + # Calculate NLL before and after training + pre_train_nll = -torch.sum(fprop_logprobs * token_loss_mask, dim=-1) + post_train_nll = -torch.sum(post_train_fprop_logprobs * token_loss_mask, dim=-1) + print(f"Pre-training NLL: {pre_train_nll.mean().item()}") + print(f"Post-training NLL: {post_train_nll.mean().item()}") + + # Verify that training improved the model's predictions on every sample + assert torch.all(post_train_nll < pre_train_nll), ( + "Model should improve at predicting its own generations after training" + ) + assert torch.all(post_train_nll < 5), ( + "Model should improve at predicting its own generations after training" + ) + + # Verify loss decreased during training + assert losses[0] > losses[-1], "Loss should decrease over training iterations" diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py new file mode 100644 index 0000000000..2773fd20f2 --- /dev/null +++ b/tests/unit/test_utils.py @@ -0,0 +1,49 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Tuple +import torch + +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict + + +def simple_loss( + next_token_logits: torch.Tensor, data: BatchedDataDict +) -> Tuple[torch.Tensor, Dict[str, Any]]: + # Just return mean of logprobs as the loss for testing + loss = next_token_logits.mean() + metrics = {"test_metric": loss.item() * 0.5} + return loss, metrics + + +# Create a simple masked NLL loss function +def nll_loss( + next_token_logits: torch.Tensor, data: BatchedDataDict +) -> Tuple[torch.Tensor, Dict[str, Any]]: + # logits shape: [batch_size, seq_len, vocab_size] + # Get the next token logits for each position + next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token + next_token_logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1) + logprobs = next_token_logprobs[:, :-1] # Remove last position's logits + + # Gather the logprobs for the actual next tokens + token_logprobs = logprobs.gather(dim=-1, index=next_tokens.unsqueeze(-1)).squeeze( + -1 + ) + + # Only compute loss on generated tokens (not input tokens) + # by applying the token_loss_mask (shifted by 1 since we're predicting next tokens) + token_loss_mask = data.get("token_loss_mask")[:, 1:].cuda() + loss = -torch.sum(token_logprobs * token_loss_mask) + + return loss, {"loss": loss.item()} diff --git a/tests/unit/utils/test_checkpoint.py b/tests/unit/utils/test_checkpoint.py new file mode 100644 index 0000000000..fe8f2aac67 --- /dev/null +++ b/tests/unit/utils/test_checkpoint.py @@ -0,0 +1,249 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import pytest +import torch +import numpy as np +from pathlib import Path +from nemo_reinforcer.utils.checkpoint import CheckpointManager + + +@pytest.fixture +def checkpoint_dir(tmp_path): + return tmp_path.resolve() / "checkpoints" + + +@pytest.fixture +def checkpoint_config(checkpoint_dir): + return { + "enabled": True, + "checkpoint_dir": checkpoint_dir, + "metric_name": "loss", + "higher_is_better": False, + "keep_top_k": 3, + } + + +@pytest.fixture +def checkpoint_manager(checkpoint_config): + return CheckpointManager(checkpoint_config) + + +def test_init_tmp_checkpoint(checkpoint_manager, checkpoint_dir): + # Test creating a new checkpoint + step = 1 + training_info = {"loss": 0.5, "tensor": torch.tensor(0.5), "numpy": np.array(0.5)} + run_config = {"model": "test"} + + save_dir = checkpoint_manager.init_tmp_checkpoint(step, training_info, run_config) + + # Check if directory was created + assert save_dir.exists() + assert save_dir.name.startswith("tmp_step_") + + # Check if training metadata was saved correctly + with open(save_dir / "training_info.json", "r") as f: + saved_metadata = json.load(f) + assert saved_metadata["loss"] == 0.5 + assert isinstance(saved_metadata["tensor"], (int, float)) + assert isinstance(saved_metadata["numpy"], (int, float)) + + # Check if config was saved + with open(save_dir / "config.json", "r") as f: + saved_config = json.load(f) + assert saved_config == run_config + + +def test_finalize_checkpoint(checkpoint_manager, checkpoint_dir): + # Create a temporary checkpoint + step = 1 + training_info = {"loss": 0.5} + tmp_dir = checkpoint_manager.init_tmp_checkpoint(step, training_info) + + # Complete the checkpoint + checkpoint_manager.finalize_checkpoint(tmp_dir) + + # Check if temporary directory was renamed correctly + assert not tmp_dir.exists() + assert (checkpoint_dir / f"step_{step}").exists() + + +def test_remove_old_checkpoints(checkpoint_manager, checkpoint_dir): + # Create multiple checkpoints with different loss values + steps = [1, 2, 3, 4, 5, 6] + losses = [0.5, 0.3, 0.7, 0.2, 0.4, 0.8] + + for step, loss in zip(steps, losses): + training_info = {"loss": loss} + tmp_dir = checkpoint_manager.init_tmp_checkpoint(step, training_info) + checkpoint_manager.finalize_checkpoint(tmp_dir) + + # Check if only top-k checkpoints are kept + remaining_dirs = list(checkpoint_dir.glob("step_*")) + assert ( + len(remaining_dirs) == checkpoint_manager.keep_top_k + 1 + ) # +1 because we exclude the latest + + # Verify the remaining checkpoints are the ones with lowest loss + remaining_losses = [] + for dir_path in remaining_dirs: + with open(dir_path / "training_info.json", "r") as f: + metadata = json.load(f) + remaining_losses.append(metadata["loss"]) + + assert sorted(remaining_losses) == sorted(losses)[ + : checkpoint_manager.keep_top_k + ] + [0.8] # exclude latest + + +def test_remove_old_checkpoints_topk_bias_recent_if_equal( + checkpoint_manager, checkpoint_dir +): + # Create multiple checkpoints with the same loss value + # Create multiple checkpoints with the same loss value + steps = [1, 2, 3, 4, 10, 12] + losses = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5] # All checkpoints have the same loss + + for step, loss in zip(steps, losses): + training_info = {"loss": loss} + tmp_dir = checkpoint_manager.init_tmp_checkpoint(step, training_info) + checkpoint_manager.finalize_checkpoint(tmp_dir) + + # Check if only top-k checkpoints are kept + remaining_dirs = list(checkpoint_dir.glob("step_*")) + assert ( + len(remaining_dirs) == checkpoint_manager.keep_top_k + ) # +1 because we exclude the latest + + # When all losses are equal, the most recent checkpoints should be kept + # (excluding the latest which is always kept) + remaining_steps = [] + for dir_path in remaining_dirs: + step_num = int(dir_path.name.split("_")[1]) + remaining_steps.append(step_num) + + # Should keep the most recent checkpoints (highest step numbers) + expected_steps = sorted(steps)[-checkpoint_manager.keep_top_k :] + assert sorted(remaining_steps) == sorted(expected_steps) + + +def test_get_best_checkpoint_path(checkpoint_manager, checkpoint_dir): + # Create multiple checkpoints with different loss values + steps = [1, 2, 3] + losses = [0.5, 0.3, 0.7] + + for step, loss in zip(steps, losses): + training_info = {"loss": loss} + tmp_dir = checkpoint_manager.init_tmp_checkpoint(step, training_info) + checkpoint_manager.finalize_checkpoint(tmp_dir) + + # Get best checkpoint path + best_path = checkpoint_manager.get_best_checkpoint_path() + + # Verify it's the checkpoint with lowest loss + with open(Path(best_path) / "training_info.json", "r") as f: + metadata = json.load(f) + assert metadata["loss"] == min(losses) + + +def test_get_latest_checkpoint_path(checkpoint_manager, checkpoint_dir): + # Create multiple checkpoints + steps = [1, 2, 3] + + for step in steps: + training_info = {"loss": 0.5} + tmp_dir = checkpoint_manager.init_tmp_checkpoint(step, training_info) + checkpoint_manager.finalize_checkpoint(tmp_dir) + + # Get latest checkpoint path + latest_path = checkpoint_manager.get_latest_checkpoint_path() + + # Verify it's the checkpoint with highest step number + assert Path(latest_path).name == f"step_{max(steps)}" + + +def test_load_training_metadata(checkpoint_manager, checkpoint_dir): + # Create a checkpoint + step = 1 + training_info = {"loss": 0.5} + tmp_dir = checkpoint_manager.init_tmp_checkpoint(step, training_info) + checkpoint_manager.finalize_checkpoint(tmp_dir) + + # Load training metadata + metadata = checkpoint_manager.load_training_info(checkpoint_dir / f"step_{step}") + + # Verify metadata was loaded correctly + assert metadata == training_info + + +def test_checkpoint_without_keep_top_k(tmp_path): + # Test checkpoint manager without keep_top_k + config = { + "enabled": True, + "checkpoint_dir": str((tmp_path.resolve() / "checkpoints")), + "metric_name": "loss", + "higher_is_better": False, + "keep_top_k": None, + } + manager = CheckpointManager(config) + + # Create multiple checkpoints + steps = [1, 2, 3] + for step in steps: + training_info = {"loss": 0.5} + tmp_dir = manager.init_tmp_checkpoint(step, training_info) + manager.finalize_checkpoint(tmp_dir) + + # Verify all checkpoints are kept + remaining_dirs = list(Path(tmp_path.resolve() / "checkpoints").glob("step_*")) + assert len(remaining_dirs) == len(steps) + + +def test_load_checkpoint_empty_dir(checkpoint_manager, checkpoint_dir): + """Test that loading from an empty checkpoint directory returns None.""" + # Get latest checkpoint path from empty directory + latest_path = checkpoint_manager.get_latest_checkpoint_path() + assert latest_path is None + + # Load training metadata from None path + metadata = checkpoint_manager.load_training_info(None) + assert metadata is None + + +def test_get_latest_checkpoint_path_across_digits(checkpoint_manager, checkpoint_dir): + """Test that getting latest checkpoint works correctly when crossing digit boundaries. + This ensures we're doing numerical comparison rather than string comparison, + as string comparison would incorrectly order step_9 > step_10. + """ + # Create checkpoints with steps that cross digit boundary + steps = [8, 9, 10, 11] + + for step in steps: + training_info = {"loss": 0.5} + tmp_dir = checkpoint_manager.init_tmp_checkpoint(step, training_info) + checkpoint_manager.finalize_checkpoint(tmp_dir) + + # Get latest checkpoint path + latest_path = checkpoint_manager.get_latest_checkpoint_path() + + # Verify it's the checkpoint with highest numerical step (11) + assert Path(latest_path).name == f"step_{max(steps)}" + + # Double check that all checkpoints exist and are properly ordered + all_checkpoints = sorted( + [d for d in Path(checkpoint_dir).glob("step_*")], + key=lambda x: int(x.name.split("_")[1]), + ) + assert len(all_checkpoints) == checkpoint_manager.keep_top_k + assert all_checkpoints[-1].name == f"step_{max(steps)}" diff --git a/tests/unit/utils/test_config.py b/tests/unit/utils/test_config.py new file mode 100644 index 0000000000..245d9e0053 --- /dev/null +++ b/tests/unit/utils/test_config.py @@ -0,0 +1,198 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +from pathlib import Path + +import pytest + +from nemo_reinforcer.utils.config import load_config + + +@pytest.fixture +def temp_config_dir(): + """Create a temporary directory for test configs.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +def create_test_config(config_dir: Path, name: str, content: str): + """Create a test config file.""" + config_path = config_dir / name + config_path.write_text(content) + return config_path + + +def test_single_inheritance(temp_config_dir): + """Test basic inheritance from a single parent.""" + # Create parent config + parent_content = """ + common: + value: 42 + parent_only: + value: 100 + """ + create_test_config(temp_config_dir, "parent.yaml", parent_content) + + # Create child config + child_content = """ + defaults: parent.yaml + common: + value: 43 + child_only: + value: 200 + """ + child_path = create_test_config(temp_config_dir, "child.yaml", child_content) + + # Load and verify + config = load_config(child_path) + assert config.common.value == 43 # Child overrides parent + assert config.parent_only.value == 100 # Parent value preserved + assert config.child_only.value == 200 # Child-only value exists + + +def test_multiple_inheritance(temp_config_dir): + """Test inheritance from multiple parents.""" + # Create first parent + parent1_content = """ + common: + value: 42 + parent1_only: + value: 100 + """ + create_test_config(temp_config_dir, "parent1.yaml", parent1_content) + + # Create second parent + parent2_content = """ + common: + value: 43 + parent2_only: + value: 200 + """ + create_test_config(temp_config_dir, "parent2.yaml", parent2_content) + + # Create child config + child_content = """ + defaults: + - parent1.yaml + - parent2.yaml + common: + value: 44 + child_only: + value: 300 + """ + child_path = create_test_config(temp_config_dir, "child.yaml", child_content) + + # Load and verify + config = load_config(child_path) + assert config.common.value == 44 # Child overrides both parents + assert config.parent1_only.value == 100 # First parent value preserved + assert config.parent2_only.value == 200 # Second parent value preserved + assert config.child_only.value == 300 # Child-only value exists + + +def test_absolute_path_inheritance(temp_config_dir): + """Test inheritance using absolute paths.""" + # Create parent config + parent_content = """ + common: + value: 42 + """ + parent_path = create_test_config(temp_config_dir, "parent.yaml", parent_content) + + # Create child config with absolute path + child_content = f""" + defaults: {parent_path} + common: + value: 43 + """ + child_path = create_test_config(temp_config_dir, "child.yaml", child_content) + + # Load and verify + config = load_config(child_path) + assert config.common.value == 43 # Child overrides parent + + +def test_no_inheritance(temp_config_dir): + """Test config without inheritance.""" + content = """ + common: + value: 42 + """ + config_path = create_test_config(temp_config_dir, "config.yaml", content) + + # Load and verify + config = load_config(config_path) + assert config.common.value == 42 + + +def test_nested_inheritance(temp_config_dir): + """Test nested inheritance (parent inherits from grandparent).""" + # Create grandparent config + grandparent_content = """ + common: + value: 42 + grandparent_only: + value: 100 + """ + create_test_config(temp_config_dir, "grandparent.yaml", grandparent_content) + + # Create parent config + parent_content = """ + defaults: grandparent.yaml + common: + value: 43 + parent_only: + value: 200 + """ + create_test_config(temp_config_dir, "parent.yaml", parent_content) + + # Create child config + child_content = """ + defaults: parent.yaml + common: + value: 44 + child_only: + value: 300 + """ + child_path = create_test_config(temp_config_dir, "child.yaml", child_content) + + # Load and verify + config = load_config(child_path) + assert config.common.value == 44 # Child overrides all + assert config.grandparent_only.value == 100 # Grandparent value preserved + assert config.parent_only.value == 200 # Parent value preserved + assert config.child_only.value == 300 # Child-only value exists + + +def test_interpolation(temp_config_dir): + """Test that interpolation works with inherited configs.""" + # Create parent config + parent_content = """ + base_value: 42 + derived: + value: ${base_value} + """ + create_test_config(temp_config_dir, "parent.yaml", parent_content) + + # Create child config + child_content = """ + defaults: parent.yaml + base_value: 43 + """ + child_path = create_test_config(temp_config_dir, "child.yaml", child_content) + + # Load and verify + config = load_config(child_path) + assert config.base_value == 43 + assert config.derived.value == 43 # Interpolation uses child's base_value diff --git a/tests/unit/utils/test_logger.py b/tests/unit/utils/test_logger.py new file mode 100644 index 0000000000..d54c8748f5 --- /dev/null +++ b/tests/unit/utils/test_logger.py @@ -0,0 +1,332 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import tempfile +from unittest.mock import patch + +import pytest + +from nemo_reinforcer.utils.logger import ( + Logger, + TensorboardLogger, + WandbLogger, + flatten_dict, +) + + +class TestFlattenDict: + """Test the flatten_dict utility function.""" + + def test_empty_dict(self): + """Test flattening an empty dictionary.""" + assert flatten_dict({}) == {} + + def test_flat_dict(self): + """Test flattening a dictionary that is already flat.""" + d = {"a": 1, "b": 2, "c": 3} + assert flatten_dict(d) == d + + def test_nested_dict(self): + """Test flattening a nested dictionary.""" + d = {"a": 1, "b": {"c": 2, "d": 3}, "e": {"f": {"g": 4}}} + expected = {"a": 1, "b.c": 2, "b.d": 3, "e.f.g": 4} + assert flatten_dict(d) == expected + + def test_custom_separator(self): + """Test flattening with a custom separator.""" + d = {"a": 1, "b": {"c": 2, "d": 3}} + expected = {"a": 1, "b_c": 2, "b_d": 3} + assert flatten_dict(d, sep="_") == expected + + +class TestTensorboardLogger: + """Test the TensorboardLogger class.""" + + @pytest.fixture + def temp_dir(self): + """Create a temporary directory for logs.""" + temp_dir = tempfile.mkdtemp() + yield temp_dir + shutil.rmtree(temp_dir) + + @patch("nemo_reinforcer.utils.logger.SummaryWriter") + def test_init(self, mock_summary_writer, temp_dir): + """Test initialization of TensorboardLogger.""" + cfg = {"log_dir": temp_dir} + logger = TensorboardLogger(cfg, log_dir=temp_dir) + + # The log_dir is passed to SummaryWriter but not stored as an attribute + mock_summary_writer.assert_called_once_with(log_dir=temp_dir) + + @patch("nemo_reinforcer.utils.logger.SummaryWriter") + def test_log_metrics(self, mock_summary_writer, temp_dir): + """Test logging metrics to TensorboardLogger.""" + cfg = {"log_dir": temp_dir} + logger = TensorboardLogger(cfg, log_dir=temp_dir) + + metrics = {"loss": 0.5, "accuracy": 0.8} + step = 10 + logger.log_metrics(metrics, step) + + # Check that add_scalar was called for each metric + mock_writer = mock_summary_writer.return_value + assert mock_writer.add_scalar.call_count == 2 + mock_writer.add_scalar.assert_any_call("loss", 0.5, 10) + mock_writer.add_scalar.assert_any_call("accuracy", 0.8, 10) + + @patch("nemo_reinforcer.utils.logger.SummaryWriter") + def test_log_metrics_with_prefix(self, mock_summary_writer, temp_dir): + """Test logging metrics with a prefix to TensorboardLogger.""" + cfg = {"log_dir": temp_dir} + logger = TensorboardLogger(cfg, log_dir=temp_dir) + + metrics = {"loss": 0.5, "accuracy": 0.8} + step = 10 + prefix = "train" + logger.log_metrics(metrics, step, prefix) + + # Check that add_scalar was called for each metric with prefix + mock_writer = mock_summary_writer.return_value + assert mock_writer.add_scalar.call_count == 2 + mock_writer.add_scalar.assert_any_call("train/loss", 0.5, 10) + mock_writer.add_scalar.assert_any_call("train/accuracy", 0.8, 10) + + @patch("nemo_reinforcer.utils.logger.SummaryWriter") + def test_log_hyperparams(self, mock_summary_writer, temp_dir): + """Test logging hyperparameters to TensorboardLogger.""" + cfg = {"log_dir": temp_dir} + logger = TensorboardLogger(cfg, log_dir=temp_dir) + + params = {"lr": 0.001, "batch_size": 32, "model": {"hidden_size": 128}} + logger.log_hyperparams(params) + + # Check that add_hparams was called with flattened params + mock_writer = mock_summary_writer.return_value + mock_writer.add_hparams.assert_called_once() + # First argument should be flattened dict + called_params = mock_writer.add_hparams.call_args[0][0] + assert called_params == { + "lr": 0.001, + "batch_size": 32, + "model.hidden_size": 128, + } + + +class TestWandbLogger: + """Test the WandbLogger class.""" + + @pytest.fixture + def temp_dir(self): + """Create a temporary directory for logs.""" + temp_dir = tempfile.mkdtemp() + yield temp_dir + shutil.rmtree(temp_dir) + + @patch("nemo_reinforcer.utils.logger.wandb") + def test_init_custom_config(self, mock_wandb, temp_dir): + """Test initialization of WandbLogger with custom config.""" + cfg = { + "project": "custom-project", + "name": "custom-run", + "entity": "custom-entity", + "group": "custom-group", + "tags": ["tag1", "tag2"], + } + WandbLogger(cfg, log_dir=temp_dir) + + mock_wandb.init.assert_called_once_with( + project="custom-project", + name="custom-run", + entity="custom-entity", + group="custom-group", + tags=["tag1", "tag2"], + dir=temp_dir, + ) + + @patch("nemo_reinforcer.utils.logger.wandb") + def test_log_metrics(self, mock_wandb): + """Test logging metrics to WandbLogger.""" + cfg = {} + logger = WandbLogger(cfg) + + metrics = {"loss": 0.5, "accuracy": 0.8} + step = 10 + logger.log_metrics(metrics, step) + + # Check that log was called with metrics and step + mock_run = mock_wandb.init.return_value + mock_run.log.assert_called_once_with(metrics, step=step) + + @patch("nemo_reinforcer.utils.logger.wandb") + def test_log_metrics_with_prefix(self, mock_wandb): + """Test logging metrics with a prefix to WandbLogger.""" + cfg = {} + logger = WandbLogger(cfg) + + metrics = {"loss": 0.5, "accuracy": 0.8} + step = 10 + prefix = "train" + logger.log_metrics(metrics, step, prefix) + + # Check that log was called with prefixed metrics and step + mock_run = mock_wandb.init.return_value + expected_metrics = {"train/loss": 0.5, "train/accuracy": 0.8} + mock_run.log.assert_called_once_with(expected_metrics, step=step) + + @patch("nemo_reinforcer.utils.logger.wandb") + def test_log_hyperparams(self, mock_wandb): + """Test logging hyperparameters to WandbLogger.""" + cfg = {} + logger = WandbLogger(cfg) + + params = {"lr": 0.001, "batch_size": 32, "model": {"hidden_size": 128}} + logger.log_hyperparams(params) + + # Check that config.update was called with params + mock_run = mock_wandb.init.return_value + mock_run.config.update.assert_called_once_with(params) + + +class TestLogger: + """Test the main Logger class.""" + + @pytest.fixture + def temp_dir(self): + """Create a temporary directory for logs.""" + temp_dir = tempfile.mkdtemp() + yield temp_dir + shutil.rmtree(temp_dir) + + @patch("nemo_reinforcer.utils.logger.WandbLogger") + @patch("nemo_reinforcer.utils.logger.TensorboardLogger") + def test_init_no_loggers(self, mock_tb_logger, mock_wandb_logger, temp_dir): + """Test initialization with no loggers enabled.""" + cfg = { + "wandb_enabled": False, + "tensorboard_enabled": False, + "log_dir": temp_dir, + } + logger = Logger(cfg) + + assert len(logger.loggers) == 0 + mock_tb_logger.assert_not_called() + mock_wandb_logger.assert_not_called() + + @patch("nemo_reinforcer.utils.logger.WandbLogger") + @patch("nemo_reinforcer.utils.logger.TensorboardLogger") + def test_init_wandb_only(self, mock_tb_logger, mock_wandb_logger, temp_dir): + """Test initialization with only WandbLogger enabled.""" + cfg = { + "wandb_enabled": True, + "tensorboard_enabled": False, + "wandb": {"project": "test-project"}, + "log_dir": temp_dir, + } + logger = Logger(cfg) + + assert len(logger.loggers) == 1 + mock_wandb_logger.assert_called_once() + wandb_cfg = mock_wandb_logger.call_args[0][0] + assert wandb_cfg == {"project": "test-project"} + mock_tb_logger.assert_not_called() + + @patch("nemo_reinforcer.utils.logger.WandbLogger") + @patch("nemo_reinforcer.utils.logger.TensorboardLogger") + def test_init_tensorboard_only(self, mock_tb_logger, mock_wandb_logger, temp_dir): + """Test initialization with only TensorboardLogger enabled.""" + cfg = { + "wandb_enabled": False, + "tensorboard_enabled": True, + "tensorboard": {"log_dir": "test_logs"}, + "log_dir": temp_dir, + } + logger = Logger(cfg) + + assert len(logger.loggers) == 1 + mock_tb_logger.assert_called_once() + tb_cfg = mock_tb_logger.call_args[0][0] + assert tb_cfg == {"log_dir": "test_logs"} + mock_wandb_logger.assert_not_called() + + @patch("nemo_reinforcer.utils.logger.WandbLogger") + @patch("nemo_reinforcer.utils.logger.TensorboardLogger") + def test_init_both_loggers(self, mock_tb_logger, mock_wandb_logger, temp_dir): + """Test initialization with both loggers enabled.""" + cfg = { + "wandb_enabled": True, + "tensorboard_enabled": True, + "wandb": {"project": "test-project"}, + "tensorboard": {"log_dir": "test_logs"}, + "log_dir": temp_dir, + } + logger = Logger(cfg) + + assert len(logger.loggers) == 2 + mock_wandb_logger.assert_called_once() + wandb_cfg = mock_wandb_logger.call_args[0][0] + assert wandb_cfg == {"project": "test-project"} + + mock_tb_logger.assert_called_once() + tb_cfg = mock_tb_logger.call_args[0][0] + assert tb_cfg == {"log_dir": "test_logs"} + + @patch("nemo_reinforcer.utils.logger.WandbLogger") + @patch("nemo_reinforcer.utils.logger.TensorboardLogger") + def test_log_metrics(self, mock_tb_logger, mock_wandb_logger, temp_dir): + """Test logging metrics to all enabled loggers.""" + cfg = { + "wandb_enabled": True, + "tensorboard_enabled": True, + "wandb": {"project": "test-project"}, + "tensorboard": {"log_dir": "test_logs"}, + "log_dir": temp_dir, + } + logger = Logger(cfg) + + # Create mock logger instances + mock_wandb_instance = mock_wandb_logger.return_value + mock_tb_instance = mock_tb_logger.return_value + + metrics = {"loss": 0.5, "accuracy": 0.8} + step = 10 + logger.log_metrics(metrics, step) + + # Check that log_metrics was called on both loggers + mock_wandb_instance.log_metrics.assert_called_once_with(metrics, step, "") + mock_tb_instance.log_metrics.assert_called_once_with(metrics, step, "") + + @patch("nemo_reinforcer.utils.logger.WandbLogger") + @patch("nemo_reinforcer.utils.logger.TensorboardLogger") + def test_log_hyperparams(self, mock_tb_logger, mock_wandb_logger, temp_dir): + """Test logging hyperparameters to all enabled loggers.""" + cfg = { + "wandb_enabled": True, + "tensorboard_enabled": True, + "wandb": {"project": "test-project"}, + "tensorboard": {"log_dir": "test_logs"}, + "log_dir": temp_dir, + } + logger = Logger(cfg) + + # Create mock logger instances + mock_wandb_instance = mock_wandb_logger.return_value + mock_tb_instance = mock_tb_logger.return_value + + params = {"lr": 0.001, "batch_size": 32} + logger.log_hyperparams(params) + + # Check that log_hyperparams was called on both loggers + mock_wandb_instance.log_hyperparams.assert_called_once_with(params) + mock_tb_instance.log_hyperparams.assert_called_once_with(params) diff --git a/tests/unit/utils/test_timer.py b/tests/unit/utils/test_timer.py new file mode 100644 index 0000000000..28eebb5b42 --- /dev/null +++ b/tests/unit/utils/test_timer.py @@ -0,0 +1,189 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import pytest +import numpy as np +from unittest.mock import patch + +from nemo_reinforcer.utils.timer import Timer + + +class TestTimer: + @pytest.fixture + def timer(self): + return Timer() + + def test_start_stop(self, timer): + """Test basic start/stop functionality.""" + timer.start("test_label") + time.sleep(0.01) # Small sleep to ensure measurable time + elapsed = timer.stop("test_label") + + # Check that elapsed time is positive + assert elapsed > 0 + + # Check that the timer recorded the measurement + assert "test_label" in timer._timers + assert len(timer._timers["test_label"]) == 1 + + # Check that the start time was removed + assert "test_label" not in timer._start_times + + def test_start_already_running(self, timer): + """Test that starting an already running timer raises an error.""" + timer.start("test_label") + with pytest.raises(ValueError): + timer.start("test_label") + + def test_stop_not_running(self, timer): + """Test that stopping a timer that isn't running raises an error.""" + with pytest.raises(ValueError): + timer.stop("nonexistent_label") + + def test_context_manager(self, timer): + """Test the context manager functionality.""" + with timer.time("test_context"): + time.sleep(0.01) # Small sleep to ensure measurable time + + # Check that the timer recorded the measurement + assert "test_context" in timer._timers + assert len(timer._timers["test_context"]) == 1 + + def test_multiple_measurements(self, timer): + """Test recording multiple measurements for the same label.""" + for _ in range(3): + timer.start("multiple") + time.sleep(0.01) # Small sleep to ensure measurable time + timer.stop("multiple") + + # Check that all measurements were recorded + assert len(timer._timers["multiple"]) == 3 + + def test_get_elapsed(self, timer): + """Test retrieving elapsed times.""" + # Record some measurements + for _ in range(3): + timer.start("get_elapsed_test") + time.sleep(0.01) # Small sleep to ensure measurable time + timer.stop("get_elapsed_test") + + # Get the elapsed times + elapsed_times = timer.get_elapsed("get_elapsed_test") + + # Check that we got the right number of measurements + assert len(elapsed_times) == 3 + + # Check that all times are positive + for t in elapsed_times: + assert t > 0 + + def test_get_elapsed_nonexistent(self, timer): + """Test that getting elapsed times for a nonexistent label raises an error.""" + with pytest.raises(KeyError): + timer.get_elapsed("nonexistent_label") + + def test_reduce_mean(self, timer): + """Test the mean reduction.""" + # Create known measurements + timer._timers["reduction_test"] = [1.0, 2.0, 3.0] + + # Get the mean + mean = timer.reduce("reduction_test", "mean") + + # Check the result + assert mean == 2.0 + + def test_reduce_default(self, timer): + """Test that the default reduction is mean.""" + # Create known measurements + timer._timers["reduction_default"] = [1.0, 2.0, 3.0] + + # Get the reduction without specifying type + result = timer.reduce("reduction_default") + + # Check that it's the mean + assert result == 2.0 + + def test_reduce_all_types(self, timer): + """Test all reduction types.""" + # Create known measurements + timer._timers["all_reductions"] = [1.0, 2.0, 3.0, 4.0, 5.0] + + # Test each reduction type + assert timer.reduce("all_reductions", "mean") == 3.0 + assert timer.reduce("all_reductions", "median") == 3.0 + assert timer.reduce("all_reductions", "min") == 1.0 + assert timer.reduce("all_reductions", "max") == 5.0 + assert timer.reduce("all_reductions", "sum") == 15.0 + + # For std, just check it's a reasonable value (avoid floating point comparison issues) + std = timer.reduce("all_reductions", "std") + np_std = np.std([1.0, 2.0, 3.0, 4.0, 5.0]) + assert abs(std - np_std) < 1e-6 + + def test_reduce_invalid_type(self, timer): + """Test that an invalid reduction type raises an error.""" + timer._timers["invalid_reduction"] = [1.0, 2.0, 3.0] + + with pytest.raises(ValueError): + timer.reduce("invalid_reduction", "invalid_type") + + def test_reduce_nonexistent_label(self, timer): + """Test that getting a reduction for a nonexistent label raises an error.""" + with pytest.raises(KeyError): + timer.reduce("nonexistent_label") + + def test_reset_specific_label(self, timer): + """Test resetting a specific label.""" + # Create some measurements + timer._timers["reset_test1"] = [1.0, 2.0] + timer._timers["reset_test2"] = [3.0, 4.0] + + # Reset one label + timer.reset("reset_test1") + + # Check that only that label was reset + assert "reset_test1" not in timer._timers + assert "reset_test2" in timer._timers + + def test_reset_all(self, timer): + """Test resetting all labels.""" + # Create some measurements + timer._timers["reset_all1"] = [1.0, 2.0] + timer._timers["reset_all2"] = [3.0, 4.0] + + # Start a timer + timer.start("running_timer") + + # Reset all + timer.reset() + + # Check that everything was reset + assert len(timer._timers) == 0 + assert len(timer._start_times) == 0 + + @patch("time.perf_counter") + def test_precise_timing(self, mock_perf_counter, timer): + """Test that timing is accurate using mocked time.""" + # Set up mock time to return specific values + mock_perf_counter.side_effect = [10.0, 15.0] # Start time, stop time + + # Time something + timer.start("precise_test") + elapsed = timer.stop("precise_test") + + # Check the elapsed time + assert elapsed == 5.0 + assert timer._timers["precise_test"][0] == 5.0 diff --git a/tools/autoformat.sh b/tools/autoformat.sh new file mode 100644 index 0000000000..c18d31e4b0 --- /dev/null +++ b/tools/autoformat.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -euo pipefail + +GIT_VERSION=$(git version | awk '{print $3}') +GIT_MAJOR=$(echo $GIT_VERSION | awk -F. '{print $1}') +GIT_MINOR=$(echo $GIT_VERSION | awk -F. '{print $2}') + +if [[ $GIT_MAJOR -eq 2 && $GIT_MINOR -lt 31 ]]; then + echo "Git version must be at least 2.31.0. Found $GIT_VERSION" + exit 1 +fi + +SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd) +PACKAGE_ROOT=$(realpath $SCRIPT_DIR/../nemo_reinforcer) + +ruff check $PACKAGE_ROOT --fix +ruff format $PACKAGE_ROOT \ No newline at end of file diff --git a/tools/copyright.sh b/tools/copyright.sh new file mode 100644 index 0000000000..c08f410b84 --- /dev/null +++ b/tools/copyright.sh @@ -0,0 +1,65 @@ +#!/bin/bash +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Files ending with .py should have Copyright notice in the first line. +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +# Move to the project root +cd $SCRIPT_DIR/.. +find_files_with_missing_copyright() { +find ./nemo_reinforcer/ ./docs/*.py ./examples/ ./tests/ -type f -name '*.py' | while read path; do + echo -en $path"\t" + head -2 $path | grep -iv 'coding=' | head -1 +done \ + | egrep -iv 'Copyright.*NVIDIA CORPORATION.*All rights reserved.' \ + | grep -iv 'BSD 3-Clause License' \ + | grep -iv 'Copyright.*Microsoft' \ + | grep -iv 'Copyright.*The Open AI Team' \ + | grep -iv 'Copyright.*The Google AI' \ + | grep -iv 'Copyright.*Facebook' | while read line; do + echo $line | cut -d' ' -f1 + done +} + + +declare RESULT=($(find_files_with_missing_copyright)) # (..) = array + +if [ "${#RESULT[@]}" -gt 0 ]; then + echo "Error: Found files with missing copyright:" + for (( i=0; i<"${#RESULT[@]}"; i++ )); do + echo "path= ${RESULT[$i]}" + done + cat <