Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add distopt support for FP8 params and BF16 optimizer state #7909

Merged
merged 21 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
b031db6
Add distopt support for FP8 params and BF16 optimizer state
timmoon10 Dec 11, 2023
f3b6c82
Merge branch 'main' into distopt-fp8-bf16-state
timmoon10 Dec 14, 2023
67d34a2
Merge branch 'main' into distopt-fp8-bf16-state
ericharper Dec 15, 2023
bd3c42a
Merge branch 'main' into distopt-fp8-bf16-state
timmoon10 Dec 18, 2023
12fd07c
Removed unused import
timmoon10 Dec 18, 2023
dd22a18
Update PyTorch container in Jenkins pipeline
timmoon10 Dec 18, 2023
c00ae09
Merge branch 'main' into distopt-fp8-bf16-state
timmoon10 Dec 19, 2023
2f11cb2
Merge branch 'main' into distopt-fp8-bf16-state
timmoon10 Dec 19, 2023
9f951fc
Merge branch 'main' into distopt-fp8-bf16-state
timmoon10 Dec 21, 2023
f40a38a
Use custom container with Apex bugfixes
timmoon10 Dec 21, 2023
b74a2ed
Merge branch 'main' into distopt-fp8-bf16-state
timmoon10 Dec 27, 2023
2f8fc4e
Upgrade to PyTorch 23.11 container
timmoon10 Dec 27, 2023
1bdd209
Merge branch 'main' into distopt-fp8-bf16-state
timmoon10 Dec 28, 2023
8177404
Merge branch 'main' into distopt-fp8-bf16-state
timmoon10 Jan 2, 2024
b6890f4
Update Apex commit
timmoon10 Jan 2, 2024
fab79e4
Merge branch 'main' into distopt-fp8-bf16-state
timmoon10 Jan 8, 2024
3af2ed5
Merge branch 'main' into distopt-fp8-bf16-state
timmoon10 Jan 9, 2024
d0b93e3
Merge branch 'main' into distopt-fp8-bf16-state
timmoon10 Jan 10, 2024
d2a7b67
Merge branch 'main' into distopt-fp8-bf16-state
timmoon10 Jan 11, 2024
84ef6b7
Merge branch 'main' into distopt-fp8-bf16-state
timmoon10 Jan 11, 2024
e586072
Merge branch 'main' into distopt-fp8-bf16-state
ericharper Jan 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ RUN git clone https://github.com/NVIDIA/Megatron-LM.git && \
# Distributed Adam support for multiple dtypes
RUN git clone https://github.com/NVIDIA/apex.git && \
cd apex && \
git checkout 52e18c894223800cb611682dce27d88050edf1de && \
pip install install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./
git checkout a2f6683b10fb4c29ab57c9e3d16957db76a8a5ba && \
pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./

RUN git clone https://github.com/NVIDIA/TransformerEngine.git && \
cd TransformerEngine && \
git fetch origin 8eae4ce2b8fdfbbe525fc8bfecb0df5498cc9687 && \
git fetch origin ff760a9d838ae4617600cccb22131d0359ce0296 && \
git checkout FETCH_HEAD && \
git submodule init && git submodule update && \
NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi pip install .
Expand Down
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pipeline {
steps {
sh 'git clone https://github.com/NVIDIA/TransformerEngine.git && \
cd TransformerEngine && \
git fetch origin e6676c53f26f6ef072943c909d136cf2a39c1d90 && \
git fetch origin ff760a9d838ae4617600cccb22131d0359ce0296 && \
git checkout FETCH_HEAD && \
git submodule init && git submodule update && \
NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi pip install .'
Expand Down
6 changes: 3 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,8 @@ To install Apex, run

git clone https://github.com/NVIDIA/apex.git
cd apex
git checkout 52e18c894223800cb611682dce27d88050edf1de
pip install install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./
git checkout a2f6683b10fb4c29ab57c9e3d16957db76a8a5ba
pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./

It is highly recommended to use the NVIDIA PyTorch or NeMo container if having issues installing Apex or any other dependencies.

Expand Down Expand Up @@ -335,7 +335,7 @@ Transformer Engine requires PyTorch to be built with CUDA 11.8.

Flash Attention
~~~~~~~~~~~~~~~~~~~~
Transformer Engine already supports Flash Attention for GPT models. If you want to use Flash Attention for non-causal models, please install `flash-attn <https://github.com/HazyResearch/flash-attention>`_. If you want to use Flash Attention with attention bias (introduced from position encoding, e.g. Alibi), please also install triton pinned version following the `implementation <https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L3>`_.
Transformer Engine already supports Flash Attention for GPT models. If you want to use Flash Attention for non-causal models, please install `flash-attn <https://github.com/HazyResearch/flash-attention>`_. If you want to use Flash Attention with attention bias (introduced from position encoding, e.g. Alibi), please also install triton pinned version following the `implementation <https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L3>`_.

.. code-block:: bash

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from nemo.collections.nlp.parts import utils_funcs
from nemo.collections.nlp.parts.nlp_overrides import NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE, GradScaler
from nemo.core.optim import MainParamsOptimizerWrapper, prepare_lr_scheduler
from nemo.utils import AppState, logging
from nemo.utils import AppState, logging, str_to_dtype
from nemo.utils.get_rank import is_global_rank_zero

try:
Expand Down Expand Up @@ -457,19 +457,39 @@ def setup_optimization(
self, optim_config: Optional[Union[DictConfig, Dict]] = None, optim_kwargs: Optional[Dict[str, Any]] = None,
):
optim_kwargs = {} if optim_kwargs is None else optim_kwargs.copy()

def get_config_arg(key: str, default_value: Optional[Any] = None) -> Any:
"""Get keyword argument from config"""
val = None
if val is None and optim_kwargs:
val = optim_kwargs.get(key, None)
if val is None and optim_config:
val = optim_config.get(key, None)
if val is None and self._cfg.optim:
val = self._cfg.optim.get(key, None)
if val is None:
val = default_value
return val

if self.with_distributed_adam:

# Allocate contiguous buffer to avoid extra copies
optim_kwargs['contiguous_grad_buffer'] = True
# Allocate contiguous grad buffer to avoid extra copies
optim_kwargs['contiguous_grad_buffer'] = get_config_arg('contiguous_grad_buffer', True)
if self.megatron_amp_O2 and not optim_kwargs['contiguous_grad_buffer']:
raise ValueError(
"Distributed Adam optimizer requires contiguous param buffer for O2. "
"Either enable contiguous_grad_buffer or disable megatron_amp_O2."
)

# Make sure optimizer state is in FP32
optim_dtype = torch.float32
# Optimizer dtype
optim_dtype = str_to_dtype(get_config_arg('dtype', torch.float32))
optim_kwargs['dtype'] = optim_dtype

# Make sure embedding grad reductions are in FP32
for name, param in self.named_parameters():
if 'word_embedding' in name or 'position_embedding' in name or 'output_layer' in name:
param._with_fp32_optimizer = True
if optim_dtype == torch.float32:
for name, param in self.named_parameters():
if 'word_embedding' in name or 'position_embedding' in name or 'output_layer' in name:
param._with_fp32_optimizer = True

# Match param allgather with model dtype
model_dtype = torch.float32
Expand All @@ -478,7 +498,9 @@ def setup_optimization(
optim_kwargs['param_sync_dtype'] = model_dtype

# Determine whether to store master params in optimizer
if optim_dtype == model_dtype:
if self.cfg.get('fp8_params', False):
optim_kwargs['store_params'] = True
elif optim_dtype == model_dtype:
optim_kwargs['store_params'] = False
elif optim_dtype == torch.float32 and model_dtype == torch.bfloat16:
optim_kwargs['store_params'] = False
Expand Down Expand Up @@ -545,9 +567,11 @@ def configure_optimizers(self):
if self.with_distributed_adam:

# Initialize param buckets if explicitly provided
if hasattr(self, 'distributed_adam_buckets'):
if getattr(self, 'distributed_adam_buckets', None):
for bucket in self.distributed_adam_buckets:
self._optimizer.init_params_bucket(bucket)
self._optimizer.init_params_bucket(self.parameters())
if hasattr(self, 'distributed_adam_buckets'):
del self.distributed_adam_buckets

# Make sure all params are initialized so main grads are
Expand Down
21 changes: 10 additions & 11 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
import queue
import warnings
from contextlib import nullcontext
from dataclasses import fields
from functools import partial
from typing import Any, Dict, Iterator, List, Optional, Union
Expand Down Expand Up @@ -234,11 +235,15 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None),
)
else:
self.model = build_model(
model_provider_func=self.model_provider_func,
wrap_with_ddp=False,
virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None),
)
build_model_context = nullcontext
if HAVE_TE and self.cfg.get('fp8', False) and self.cfg.get('fp8_params', False):
build_model_context = transformer_engine.pytorch.fp8_model_init
with build_model_context():
self.model = build_model(
model_provider_func=self.model_provider_func,
wrap_with_ddp=False,
virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None),
)

# if we're not using interleaved, then self.model is a module.
if self.cfg.get('virtual_pipeline_model_parallel_size', None) is None:
Expand Down Expand Up @@ -472,12 +477,6 @@ def configure_optimizers(self):
[p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)]
)
buckets.reverse()
used_params = set()
for bucket in buckets:
used_params.update(bucket)
remaining_params = [p for p in self.parameters() if p not in used_params]
if remaining_params:
buckets.append(remaining_params)
self.distributed_adam_buckets = buckets

return super().configure_optimizers()
Expand Down
Loading
Loading