Skip to content

Commit

Permalink
add trt_llm support
Browse files Browse the repository at this point in the history
Signed-off-by: Gerald Shen <[email protected]>
  • Loading branch information
gshennvm committed May 3, 2024
1 parent e4776bb commit e61121b
Show file tree
Hide file tree
Showing 26 changed files with 1,133 additions and 384 deletions.
37 changes: 37 additions & 0 deletions Accelerated-RLHF.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Accelerated Reinforcement Learning From Human Feedback

For more details beyond the usage guide please see the NeMo-Aligner [paper](https://arxiv.org/abs/2405.01481).

## Description
Response generation during the RLHF PPO rollout phase constitutes a majority of the RLHF step time, taking up as much as 90% of total train time if not optimized. To address these bottlenecks, we use [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) and their fast inference kernels to accelerate the generation stage. In our ablation experiments we observed a 6.96x speedup with our TRT-LLM integration, and we are working on making this speedup even better.

## Environment

We're working on adding all our dependencies into the [NeMo-FW-Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo), in the meantime we provide a [Dockerfile](Dockerfile) that can be built with all our dependencies.

## How it works
At the start of RLHF training, we compile the engine with TRT-LLM. This first compilation will take more time than other steps, in other steps we simply take the existing compiled engine and push updated model weights to it. Training is still done using the [NeMo-FW](https://github.com/NVIDIA/NeMo) which contain efficient training kernels.

## Usage Guide

To begin please follow the usage guide in the [Tutorials](https://docs.nvidia.com/nemo-framework/user-guide/latest/ModelAlignment/index.html) page for RLHF. All the other configurations work just as before, but with TRT-LLM we have now added the [trainer.ppo.trt_llm](examples/nlp/gpt/conf/gpt_ppo_actor.yaml#L39) subconfig in the PPO actor.

## Performance tuning guide
There are a few configurations to consider when using TRT-LLM with RLHF.

* `trainer.ppo.trt_llm.enable`: Turns on and off TRT-LLM
* `trainer.ppo.trt_llm.reshard`: If this flag is on and TRT-LLM is enabled, we will reshard the model to go from pipeline parallelism to tensor parallelism only during inference. NeMo training will still be with pipeline parallelism. When this option is activated, distributed groups within the TRT-LLM inference context treat pipeline parallel groups as data parallel groups. Caution must be used to handle data sharding.
* `trainer.ppo.trt_llm.unload_engine_train`: If this flag is enabled, then we will unload the engine when training. The benefit of unloading the engine when training is that it frees up more memory but comes at a cost of taking time doing this onloading. For the most optimal configuration, we reduce the rollout microbatch size but keep the engine while training(i.e set this boolean to false).

During the TRT-LLM optimization phase, we also noticed that data parallel workers can have significantly different generation times. To balance it out we have a flask server hosted on rank 0 that acts as a distributed queue and distributes work to the other workers. This can be set with `trainer.flask_server.enable=True`.

## Performance
We are working on improving the performance of our TRT-LLM and will post the most up to date numbers in this README as we keep improving. The current performance numbers are as follows:

| Actor + Critic Node Count | Time per PPO Step in seconds | Estimated Time to train [hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) | Scaling from Base |
|---------------------------|----------------------------- |------------------------------------------------------------------------------------|-------------------|
| 8 + 8 | 253.8 | 11.1 hours | 1 |
| 16 + 16 | 143.4 | 6.3 hours | **1.77x** |
| 32 + 32 | 81.2 | 3.5 hours | **3.13x** |

Time per PPO Step on LLaMa2 70B Actor and Critic. Number of rollout samples is 1024, and the training global batch size is 128.
31 changes: 23 additions & 8 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# CUDA 12.3
FROM nvcr.io/nvidia/pytorch:24.01-py3
FROM nvcr.io/nvidia/pytorch:24.02-py3

### config tags
ARG APEX_TAG=master
ARG TE_TAG=release_v1.4
ARG MLM_TAG=43792028f003ed25a3ee8c5a0d4cad82317d81b5
ARG NEMO_TAG=9d86acd5ebf3cec020f84dfe7e25c109506803b1
ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c
ARG TE_TAG=bfe21c3d68b0a9951e5716fb520045db53419c5e
ARG MLM_TAG=fbb375d4b5e88ce52f5f7125053068caff47f93f
ARG NEMO_TAG=10274c941841c9cc30d1db50699d7523851d9fea
ARG PYTRITON_VERSION=0.4.1
ARG PROTOBUF_VERSION=4.24.4
ARG ALIGNER_COMMIT=main
ARG ALIGNER_COMMIT=v0.3.0.trtllm

# if you get errors building TE or Apex, decrease this to 4
ARG MAX_JOBS=8
Expand Down Expand Up @@ -37,7 +37,7 @@ RUN pip uninstall -y apex && \
git fetch origin $APEX_TAG && \
git checkout FETCH_HEAD; \
fi && \
pip install install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./
pip install -e . -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam --group_norm"

# place any util pkgs here
RUN pip install --upgrade-strategy only-if-needed nvidia-pytriton==$PYTRITON_VERSION
Expand Down Expand Up @@ -77,4 +77,19 @@ RUN git clone https://github.com/NVIDIA/NeMo-Aligner.git && \
fi && \
pip install --no-deps -e .

WORKDIR /workspace
# Git LFS
RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash && \
apt-get install git-lfs && \
git lfs install

# TRTLLM-0.9
RUN git clone https://github.com/NVIDIA/TensorRT-LLM.git && \
cd TensorRT-LLM && \
git checkout v0.9.0 && \
git apply ../NeMo-Aligner/trtllm.patch && \
. docker/common/install_tensorrt.sh && \
python3 ./scripts/build_wheel.py --trt_root /usr/local/tensorrt

RUN cd TensorRT-LLM && \
pip install ./build/tensorrt_llm*.whl
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.3/compat/lib.real/
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# NVIDIA NeMo-Aligner

# Latest News
- We released a beta version of accelerated generation support in the RLHF pipeline. This is still very much work in process but adds significant speedup to the RLHF training. For more details see [Accelerated-RLHF](Accelerated-RLHF.md) and the special [Accelerated-RLHF-Release]().
- [NeMo-Aligner Paper](https://arxiv.org/abs/2405.01481) is now out on arxiv!

## Introduction

NeMo-Aligner is a scalable toolkit for efficient model alignment. The toolkit has support for state of the art model alignment algorithms such as SteerLM, DPO and Reinforcement Learning from Human Feedback (RLHF). These algorithms enable users to align language models to be more safe, harmless and helpful. Users can do end-to-end model alignment on a wide range of model sizes and take advantage of all the parallelism techniques to ensure their model alignment is done in a performant and resource efficient manner.
Expand Down Expand Up @@ -55,7 +59,6 @@ Alternatively, you can build the NeMo Dockerfile here [NeMo Dockerfile](https://
## Future work
- Add Rejection Sampling support
- We will continue improving the stability of the PPO learning phase.
- Improve the performance of RLHF

## Contributing
We welcome community contributions! Please refer to [CONTRIBUTING.md](https://github.com/NVIDIA/NeMo-Aligner/blob/main/CONTRIBUTING.md) for guidelines.
Expand Down
32 changes: 30 additions & 2 deletions examples/nlp/gpt/conf/gpt_ppo_actor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,38 @@ trainer:
model_gbs: ${model.global_batch_size}
model_mbs: ${model.micro_batch_size}

# the sequence length to pad the rollout batch to
# this reduces fragmentation at the cost of using more
# the sequence length to pad the rollout batch for training to
# reduce fragmentation at the cost of using more
# memory, set to null if we don't want to pad it
# to a constant size
# if actual seq length is higher than this a warning will be raised
# but will not crash and training will still proceed
rollout_batch_seq_length: null

# Accelerate training times by accelerating inference stage using TRTLLM
trt_llm:
enable: False
reshard: True # if True then reshard the model into TP only for inference

# TRTLLM preallocates activation memory according to the number of input tokens
# By default, assume the max input length is half of the model sequence length
max_input_len: ${int_div:${model.encoder_seq_length}, 2}
max_input_tokens: ${multiply:${trainer.ppo.trt_llm.max_input_len}, ${model.ppo.rollout_micro_batch_size}}

# The model type TRTLLM will build, supported models are listed at:
# https://github.com/NVIDIA/TensorRT-LLM/blob/v0.9.0/tensorrt_llm/models/__init__.py#L75
model_type: LLaMAForCausalLM

# Unload and reload the TRTLLM engine before and after the training stage
# Reloading the engine incurs a constant time overhead
unload_engine_train: False

flask_server:
# flask server that acts as a worker pool to balance out generation time
enable: False
port: 12345
host: null # if not provided, it will be automatically set

# no need to change these
logger: False # logger provided by exp_manager
enable_checkpointing: False
Expand Down Expand Up @@ -92,6 +118,8 @@ pretrained_checkpoint:
model:

ppo:
trt_llm: ${trainer.ppo.trt_llm}

# training generation mbs
rollout_micro_batch_size: 8
num_rollout_samples: 512
Expand Down
4 changes: 4 additions & 0 deletions examples/nlp/gpt/conf/gpt_ppo_critic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ model:
mean: null
std: null

# if the inference microbatch size is big, it's possible
# to split it using forward mbs and run inference iteratively
forward_mbs: ${trainer.ppo.inference_micro_batch_size}

# RM args
use_avg_pool: False # Whether to use avg pool to sum across the sequence dim in reward model
force_head_dtype: float32 # enforce specific dtype for the final projection in the model head
Expand Down
6 changes: 5 additions & 1 deletion examples/nlp/gpt/conf/training_rm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ model:
global_batch_size: 64
megatron_amp_O2: True

# if the inference microbatch size is big, it's possible
# to split it using forward mbs and run inference iteratively
forward_mbs: 8

encoder_seq_length: 4096
max_position_embeddings: ${model.encoder_seq_length}

Expand Down Expand Up @@ -110,4 +114,4 @@ model:
# define fields from the base model's config that should be ignored when merging with this config.
overwrite_base_config:
data:
data_prefix: True
data_prefix: True
57 changes: 47 additions & 10 deletions examples/nlp/gpt/train_gpt_ppo_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import socket
import threading
from functools import partial

import torch
import torch.multiprocessing as mp
from flask import Flask, request
from megatron.core import parallel_state
from megatron.core.utils import divide
from omegaconf.omegaconf import OmegaConf
Expand All @@ -21,7 +26,7 @@
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager
from nemo_aligner.algorithms.ppo import PPOTrainer
from nemo_aligner.algorithms.ppo import DefaultBatchIterator, HTTPBatchIterator, PPOTrainer, SharedSet
from nemo_aligner.data.nlp.builders import (
build_dataloader,
build_train_valid_test_rlhf_datasets,
Expand All @@ -30,6 +35,7 @@
from nemo_aligner.models.nlp.gpt.megatron_gpt_ppo_actor import MegatronGPTActorModel
from nemo_aligner.models.nlp.gpt.reward_critic_clients import RemoteGPTRMCriticClient
from nemo_aligner.utils.distributed import Timer
from nemo_aligner.utils.server_utils import FutureResult
from nemo_aligner.utils.train_script_utils import (
CustomLoggerWrapper,
add_custom_checkpoint_callback,
Expand Down Expand Up @@ -89,10 +95,8 @@ def main(cfg) -> None:
# TODO: log this restore path
if trainer_restore_path is not None:
custom_trainer_state_dict = retrieve_custom_trainer_state_dict(trainer)
consumed_samples = custom_trainer_state_dict["consumed_samples"]
else:
custom_trainer_state_dict = None
consumed_samples = 0

init_distributed(trainer, ptl_model, cfg.model.get("transformer_engine", False))

Expand All @@ -113,22 +117,22 @@ def main(cfg) -> None:
eos_id = ptl_model.tokenizer.eos_id

# collate fn to pad to the max seq length in the batch
collate_fn = collate_with_pad_to_max_batch(max_seqlen, eos_id, cfg)
collate_fn = collate_with_pad_to_max_batch(max_seqlen, eos_id, cfg, generate_masks_and_position_ids=False)

train_dataloader = build_dataloader(
train_dataloader_builder = partial(
build_dataloader,
cfg=cfg,
dataset=train_ds,
consumed_samples=consumed_samples,
mbs=cfg.model.ppo.rollout_micro_batch_size,
gbs=cfg.model.ppo.num_rollout_samples,
collate_fn=collate_fn,
load_gbs=False,
)

val_dataloader = build_dataloader(
val_dataloader_builder = partial(
build_dataloader,
cfg=cfg,
dataset=validation_ds,
consumed_samples=0,
mbs=cfg.model.ppo.val_rollout_micro_batch_size,
gbs=cfg.model.ppo.num_val_samples,
collate_fn=collate_fn,
Expand Down Expand Up @@ -160,14 +164,47 @@ def main(cfg) -> None:
rm_critic = RemoteGPTRMCriticClient(cfg.remote_critic_rm)
timer = Timer(cfg.exp_manager.get("max_time_per_run"))

batch_iterator_cls = DefaultBatchIterator
flask_cfg = cfg.trainer.ppo.flask_server
if flask_cfg.enable:
# only rank 0 has a not None shared set
shared_set = None

# TODO: we might be able to just broadcast the hostname
# so the user don't have to specify it
flask_host = flask_cfg.host
flask_port = flask_cfg.port
if flask_host is None:
# automatically get rank 0's host and broadcast it if not specified
ip_address = [socket.gethostbyname(socket.gethostname())]
torch.distributed.broadcast_object_list(ip_address, src=0, group=None, device=torch.cuda.current_device())
flask_host = ip_address[0]

if torch.distributed.get_rank() == 0:
lock = threading.Lock()
shared_set = SharedSet(lock)
app = Flask(__name__)

# TODO: add batch size
@app.route("/get_idx", methods=["PUT"])
def get_http_idx():
batch_size = request.get_json()["batch_size"]
return shared_set.get_idx(batch_size)

threading.Thread(target=lambda: app.run(host=flask_host, port=flask_port, use_reloader=False)).start()

batch_iterator_cls = partial(HTTPBatchIterator, shared_set, flask_host, flask_port)

ppo_trainer = PPOTrainer(
cfg=cfg.trainer.ppo,
model=ptl_model,
optimizer=optimizer,
scheduler=scheduler,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
train_dataloader_builder=train_dataloader_builder,
val_dataloader_builder=val_dataloader_builder,
collate_fn=collate_fn,
rm_critic=rm_critic,
batch_iterator_cls=batch_iterator_cls,
logger=logger,
ckpt_callback=ckpt_callback,
run_timer=timer,
Expand Down
6 changes: 4 additions & 2 deletions nemo_aligner/algorithms/critic_server_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import numpy as np
import torch
from megatron.core import parallel_state
from megatron.core.utils import divide
from pytriton.decorators import batch, sample
from pytriton.model_config import ModelConfig, Tensor
Expand All @@ -29,6 +28,7 @@
from nemo.utils import logging
from nemo_aligner.servers.constants import ServerSignal
from nemo_aligner.servers.server_callables import run_rm_or_critic_inference
from nemo_aligner.utils import parallel_state
from nemo_aligner.utils.distributed import SyncTimer, broadcast_2d_tensor
from nemo_aligner.utils.server_utils import lock_method, pad_input
from nemo_aligner.utils.train_utils import clip_gradients
Expand Down Expand Up @@ -102,6 +102,7 @@ def server_infer(self, **inputs: np.ndarray) -> Dict[str, np.ndarray]:
torch.distributed.broadcast(choice, 0)

rewards, values, exceeded = self.run_inference(inputs=inputs)

output = {
"values": values,
"exceeded": exceeded,
Expand Down Expand Up @@ -245,6 +246,7 @@ def run_training(self, tokens=None, returns=None, prev_values=None, mask=None):
"prev_values": prev_values,
"mask": mask,
}

batch["tokens"] = broadcast_2d_tensor(batch["tokens"], src=0, group=None, dtype=torch.int64)
batch["returns"] = broadcast_2d_tensor(batch["returns"], src=0, group=None, dtype=torch.float32)
batch["prev_values"] = broadcast_2d_tensor(batch["prev_values"], src=0, group=None, dtype=torch.float32)
Expand Down Expand Up @@ -297,9 +299,9 @@ def run_training(self, tokens=None, returns=None, prev_values=None, mask=None):
self.step += 1

self.model.finish_training()

torch.cuda.synchronize()
torch.distributed.barrier()

return loss_mean

def save(self, extra_candidates=None, is_train_end=False, save_top_only=False):
Expand Down
Loading

0 comments on commit e61121b

Please sign in to comment.