Skip to content

Commit

Permalink
Add sharding for speechlm and vlm (#11876)
Browse files Browse the repository at this point in the history
* Add sharding for speechlm and vlm

Signed-off-by: Boxiang Wang <[email protected]>

* Add ci test for VLM

Signed-off-by: Boxiang Wang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: BoxiangW <[email protected]>

---------

Signed-off-by: Boxiang Wang <[email protected]>
Signed-off-by: BoxiangW <[email protected]>
Co-authored-by: BoxiangW <[email protected]>
  • Loading branch information
BoxiangW and BoxiangW authored Jan 23, 2025
1 parent 7e24313 commit cc365b6
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 15 deletions.
13 changes: 13 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3637,6 +3637,7 @@ jobs:
TRANSFORMERS_OFFLINE=1 python tests/collections/vlm/hf/peft_hf.py --model /home/TestData/vlm/qwen2-2b/ --max-steps 3 --strategy fsdp --devices 2
AFTER_SCRIPT: |
rm -rf nemo_experiments
L2_VLM_HF_Transformer_PEFT_4bit:
needs: [ cicd-test-container-setup ]
uses: ./.github/workflows/_test_template.yml
Expand All @@ -3648,6 +3649,17 @@ jobs:
AFTER_SCRIPT: |
rm -rf nemo_experiments
L2_VLM_HF_Transformer_SFT_FSDP2:
needs: [ cicd-test-container-setup ]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_VLM_HF_Transformer_SFT_FSDP2') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure-gpus-1
SCRIPT: |
TRANSFORMERS_OFFLINE=1 python tests/collections/vlm/hf/sft_fsdp2.py --model /home/TestData/vlm/qwen2-2b/ --max-steps 3
AFTER_SCRIPT: |
rm -rf nemo_experiments
L2_HF_Transformer_PEFT:
needs: [ cicd-test-container-setup ]
uses: ./.github/workflows/_test_template.yml
Expand Down Expand Up @@ -5092,6 +5104,7 @@ jobs:
- L2_VLM_HF_Transformer_PEFT
- L2_VLM_HF_Transformer_PEFT_FSDP
- L2_VLM_HF_Transformer_PEFT_4bit
- L2_VLM_HF_Transformer_SFT_FSDP2
- L2_HF_Transformer_SFT_2gpu_nemorun
- L2_HF_Transformer_SFT_TE_Acceleration
- L2_HF_Transformer_PT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.llm import fn
from nemo.lightning import io
from nemo.lightning.pytorch.strategies.utils import fsdp2_strategy_parallelize
from nemo.utils import logging


Expand Down Expand Up @@ -94,6 +95,10 @@ def configure_model(self, train=True):
config = AutoConfig.from_pretrained(self.model_name, trust_remote_code=self.trust_remote_code)
self.model = AutoModelForSpeechSeq2Seq.from_config(config, trust_remote_code=self.trust_remote_code)

# Apply FSDP2 and TP to the model
if self.device_mesh is not None:
fsdp2_strategy_parallelize(self.model, device_mesh=self.device_mesh, model_type="speech_seq2seq")

if train:
self.model.train()

Expand All @@ -104,7 +109,7 @@ def forward(self, input_features, decoder_input_ids, attention_mask=None):
decoder_input_ids=decoder_input_ids,
)

def training_step(self, batch):
def training_step(self, batch, batch_idx=None):
outputs = self.forward(input_features=batch["input_features"], decoder_input_ids=batch["decoder_input_ids"])
loss_mask = batch.get('loss_mask', None)
if loss_mask is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from nemo.collections.llm import fn
from nemo.lightning import io
from nemo.lightning.pytorch.strategies.utils import fsdp2_strategy_parallelize
from nemo.utils import logging


Expand Down Expand Up @@ -95,13 +96,18 @@ def configure_model(self):
self.model = AutoModelForImageTextToText.from_config(
config, torch_dtype=dtype, trust_remote_code=self.trust_remote_code
)

# Apply FSDP2 and TP to the model
if self.device_mesh is not None:
fsdp2_strategy_parallelize(self.model, device_mesh=self.device_mesh)

self.model.train()

def forward(self, batch):
"""Runs forward with the model"""
return self.model(**batch)

def training_step(self, batch):
def training_step(self, batch, batch_idx=None):
"""Run one training step"""
labels = batch.pop('labels').to(self.model.device)
loss_mask = batch.pop('loss_mask', None)
Expand Down
51 changes: 38 additions & 13 deletions nemo/lightning/pytorch/strategies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def _convert(state_dict, k, sh_key, v, prepend_offsets, prefix="", allow_shape_m

# Taken and modified from torchtitan
# https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py
def fsdp2_strategy_parallelize(model, device_mesh: DeviceMesh = None):
def fsdp2_strategy_parallelize(model, device_mesh: DeviceMesh = None, model_type: str = None):
"""Apply parallelisms and activation checkpointing to the model.
NOTE: The passed-in model preferably should be on meta device. Otherwise,
the model must fit on GPU or CPU memory.
Expand All @@ -364,18 +364,43 @@ def fsdp2_strategy_parallelize(model, device_mesh: DeviceMesh = None):
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)

fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
for layer_id, transformer_block in enumerate(model.model.layers):
# Apply activation checkpointing
# transformer_block = checkpoint_wrapper(transformer_block)
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = int(layer_id) < len(model.model.layers) - 1
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
model.model.layers[layer_id] = transformer_block
if model_type == "speech_seq2seq":
for layer_id, transformer_block in enumerate(model.model.encoder.layers):
# Apply activation checkpointing
# transformer_block = checkpoint_wrapper(transformer_block)
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = int(layer_id) < len(model.model.encoder.layers) - 1
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
model.model.encoder.layers[layer_id] = transformer_block

for layer_id, transformer_block in enumerate(model.model.decoder.layers):
# transformer_block = checkpoint_wrapper(transformer_block)
reshard_after_forward = int(layer_id) < len(model.model.decoder.layers) - 1
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
model.model.decoder.layers[layer_id] = transformer_block
else:
for layer_id, transformer_block in enumerate(model.model.layers):
# Apply activation checkpointing
# transformer_block = checkpoint_wrapper(transformer_block)
# As an optimization, do not reshard after forward for the last
# transformer block since FSDP would prefetch it immediately
reshard_after_forward = int(layer_id) < len(model.model.layers) - 1
fully_shard(
transformer_block,
**fsdp_config,
reshard_after_forward=reshard_after_forward,
)
model.model.layers[layer_id] = transformer_block

model = fully_shard(model, **fsdp_config)

return model
134 changes: 134 additions & 0 deletions tests/collections/vlm/hf/sft_fsdp2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from importlib.metadata import version

import fiddle as fdl
import torch
from lightning.pytorch.loggers import WandbLogger
from packaging.version import Version as PkgVersion

from nemo import lightning as nl
from nemo.collections import llm, vlm

DATA_PATH = "/home/TestData/vlm/rdr-items"


def get_torch_version_str():
import torch

if hasattr(torch, '__version__'):
return str(torch.__version__)
else:
return version("torch")


def mk_hf_vlm_dataset(processor, mbs, gbs):
skipped_tokens = vlm.HFAutoModelForImageTextToText.extract_skipped_token_ids(processor)

def collate_fn(examples, processor):
def fmt(sample):
instruction = "Describe accurately the given image."
conversation = [
{
"role": "user",
"content": [{"type": "text", "text": instruction}, {"type": "image", "image": sample["image"]}],
},
{"role": "assistant", "content": [{"type": "text", "text": sample["text"]}]},
]
return {"conversation": conversation, "images": [sample['image']]}

text = []
images = []
for example in map(fmt, examples):
text.append(
processor.apply_chat_template(
example["conversation"],
tokenize=False,
add_generation_prompt=False,
)
)
images += example['images']

# Tokenize the text and process the images
batch = processor(
text=text,
images=images,
padding=True,
return_tensors="pt",
)

batch["pixel_values"] = batch["pixel_values"].to(torch.bfloat16)

labels = batch["input_ids"].clone()
labels[torch.isin(labels, skipped_tokens)] = -100
batch["labels"] = labels
return batch

return vlm.HFDatasetDataModule(
DATA_PATH,
split="train[:10]",
micro_batch_size=mbs,
global_batch_size=gbs,
collate_fn=lambda x: collate_fn(x, processor=processor),
)


if __name__ == '__main__':
if PkgVersion(get_torch_version_str()) >= PkgVersion("2.4"):
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--model', default='Qwen/Qwen2-VL-2B-Instruct')
parser.add_argument('--devices', default=2)
parser.add_argument('--mbs', default=1)
parser.add_argument('--gbs', default=1)
parser.add_argument('--accelerator', default='gpu', choices=['gpu'])
parser.add_argument('--max-steps', type=int, default=100)
parser.add_argument('--wandb-project', type=str, default=None)
parser.add_argument('--disable-ckpt', action='store_false')
parser.add_argument('--use-4bit', help="Load model in 4bit", action="store_true")
args = parser.parse_args()

wandb = None
if args.wandb_project is not None:
model = '_'.join(args.model.split('/')[-2:])
wandb = WandbLogger(
project=args.wandb_project,
name=f'{model}_dev{args.devices}_strat_fsdp2',
)
grad_clip = None
use_dist_samp = False
processor = vlm.HFAutoModelForImageTextToText.configure_processor(args.model)

llm.api.finetune(
model=vlm.HFAutoModelForImageTextToText(args.model, load_in_4bit=args.use_4bit),
data=mk_hf_vlm_dataset(processor, args.mbs, args.gbs),
trainer=nl.Trainer(
devices=args.devices,
max_steps=args.max_steps,
accelerator=args.accelerator,
strategy=nl.FSDP2Strategy(data_parallel_size=2, tensor_parallel_size=1),
log_every_n_steps=1,
limit_val_batches=0.0,
num_sanity_val_steps=0,
accumulate_grad_batches=10,
gradient_clip_val=grad_clip,
use_distributed_sampler=use_dist_samp,
logger=wandb,
enable_checkpointing=args.disable_ckpt,
),
optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)),
log=None,
)

0 comments on commit cc365b6

Please sign in to comment.