Skip to content

Commit

Permalink
Add vlm generation
Browse files Browse the repository at this point in the history
  • Loading branch information
meatybobby committed Oct 21, 2024
1 parent dbedae0 commit 4711c75
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 0 deletions.
15 changes: 15 additions & 0 deletions nemo/collections/vlm/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.vlm.inference.base import setup_model_and_tokenizer, generate
88 changes: 88 additions & 0 deletions nemo/collections/vlm/inference/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import pytorch_lightning as pl
import torch
import torch.distributed
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig

import nemo.lightning as nl
from nemo.collections import vlm
from nemo.collections.vlm.inference.vlm_inference_wrapper import VLMInferenceWrapper
from nemo.collections.vlm.inference.vlm_inference_controller import VLMTextGenerationController
from nemo.collections.vlm.inference.vlm_engine import VLMEngine

def _setup_trainer_and_restore_model(path: str, trainer: nl.Trainer, model: pl.LightningModule):
fabric = trainer.to_fabric()
model = fabric.load_model(path, model)
return model

def setup_model_and_tokenizer(
path: str,
trainer: Optional[nl.Trainer] = None,
params_dtype: torch.dtype = torch.bfloat16,
inference_batch_times_seqlen_threshold: int = 1000,
):
# model: io.TrainerContext = io.load_context(path=path, subpath="model")
# trainer = trainer or io.load_context(path=path, subpath="trainer")
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained(model_id)
tokenizer = processor.tokenizer
config = vlm.MLlamaConfig11BInstruct()
model = vlm.MLlamaModel(config, tokenizer=tokenizer)
_setup_trainer_and_restore_model(path=path, trainer=trainer, model=model)

mcore_model = model.module.cuda()
mcore_model = mcore_model.to(params_dtype)
inference_wrapped_model = VLMInferenceWrapper(
mcore_model,
InferenceWrapperConfig(
hidden_size=config.language_model_config.hidden_size,
params_dtype=params_dtype,
inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold,
padded_vocab_size=tokenizer.vocab_size,
),
)

return inference_wrapped_model, processor


def generate(
model: VLMInferenceWrapper,
processor,
prompts: list[str],
images,
max_batch_size: int = 4,
random_seed: Optional[int] = None,
inference_params: Optional[CommonInferenceParams] = None,
) -> dict:
text_generation_controller = VLMTextGenerationController(inference_wrapped_model=model, processor=processor)
mcore_engine = VLMEngine(
text_generation_controller=text_generation_controller, max_batch_size=max_batch_size, random_seed=random_seed
)

common_inference_params = inference_params or CommonInferenceParams(num_tokens_to_generate=50)

results = mcore_engine.generate(
prompts=prompts,
images=images,
common_inference_params=common_inference_params,
)

return results
50 changes: 50 additions & 0 deletions nemo/collections/vlm/inference/vlm_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from PIL.Image import Image

import torch

from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.inference_request import InferenceRequest
from megatron.core.inference.engines.mcore_engine import MCoreEngine

class VLMEngine(MCoreEngine):
def generate(
self,
prompts: List[str],
images: List[Image] = None,
common_inference_params: CommonInferenceParams = None,
) -> dict:
# TODO :M core- get rng state tracker
if self.random_seed:
torch.random.manual_seed(self.random_seed)

for i in range(len(prompts)):
prompt = prompts[i]
image = images[i] if images is not None else None
prompt_tokens, image_dict = self.text_generation_controller.tokenize_prompt(prompt, image)

self.scheduler.add_request(
prompt=prompt,
prompt_tokens=prompt_tokens,
encoder_prompt=image_dict,
inference_parameters=common_inference_params,
)

self.run_engine()

result: List[InferenceRequest] = self.scheduler.completed_request_pool.values()
return result
64 changes: 64 additions & 0 deletions nemo/collections/vlm/inference/vlm_inference_controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import OrderedDict

import torch

from megatron.core.inference.inference_request import InferenceRequest
from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import (
SimpleTextGenerationController,
)

class TokenizerWrapper:
def __init__(self, tokenizer):
self.eod = tokenizer.eos_token_id
self.vocab_size = None
self._tokenizer = tokenizer

def detokenize(self, tokens):
return self._tokenizer.decode(tokens, skip_special_tokens=True)

class VLMTextGenerationController(SimpleTextGenerationController):
def __init__(self, inference_wrapped_model, processor):
super().__init__(inference_wrapped_model, TokenizerWrapper(processor.tokenizer))
self.processor = processor

def tokenize_prompt(self, prompt: str, image):
num_tiles = None if image is None \
else self.processor.image_processor.preprocess(image, return_tensors='pt', do_rescale=False)["num_tiles"]
batch = self.processor(image, prompt, add_special_tokens=False, return_tensors="pt")
image_dict = dict(
pixel_values=batch["pixel_values"].cuda(non_blocking=True) if "pixel_values" in batch else None,
aspect_ratio_ids=batch["aspect_ratio_ids"].cuda(non_blocking=True) if "aspect_ratio_ids" in batch else None,
num_tiles=num_tiles,
)
return batch["input_ids"].tolist()[0], image_dict

def prep_model_for_inference(
self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[int, InferenceRequest]
):
"""Preparing batch for inference, using respective wrapper's prep_model_for_inference method
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[int, InferenceRequest]): The input active requests
"""
images = list(
map(lambda request: request.encoder_prompt, active_requests.values())
)

self.inference_wrapped_model.prep_model_for_inference(
prompts_tokens=prompts_tokens, image_dict=images
)
88 changes: 88 additions & 0 deletions nemo/collections/vlm/inference/vlm_inference_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from argparse import Namespace
from typing import List, Dict

import torch

from megatron.core import tensor_parallel
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
AbstractModelInferenceWrapper,
)


class VLMInferenceWrapper(AbstractModelInferenceWrapper):
"""Constructor for the model inference wrapper
The wrapper prepares the model for inference, provides the required input
data, and runs the forward pass
Args:
model (T5Model): The T5 model (MCore or legacy)
args (Namespace): The command line arguments that were passed
"""

def __init__(self, model, args: Namespace):
super().__init__(model, args)

def prep_model_for_inference(
self, prompts_tokens: torch.Tensor, image_dict: List[Dict] = None
):

super().prep_model_for_inference(prompts_tokens=prompts_tokens)
self.pixel_values = image_dict[0]["pixel_values"]
self.aspect_ratio_ids = image_dict[0]["aspect_ratio_ids"]
self.num_tiles = image_dict[0]["num_tiles"]
seq_length = prompts_tokens.size(1)
self.position_ids = (
torch.arange(seq_length, dtype=torch.long, device=prompts_tokens.device)
.unsqueeze(0)
.expand_as(prompts_tokens)
)

def get_batch_for_context_window(
self, context_start_position: int, context_end_position: int
) -> List:
tokens2use = self.prompts_tokens[:, :context_end_position]
positions2use = self.position_ids[:, :context_end_position]
data_at_step_idx = [tokens2use, positions2use]

return data_at_step_idx

def forward_pass_without_pipeline_parallel(self, inference_input: List) -> torch.Tensor:
"""Utility to carry out simple forward pass for TP or no model parallel models
Runs a very simple forward pass for model. Used in the case of models without
any parallelism or only tensor parallelism.
Args:
inference_input (List): A list containg the inputs for the gpt
model [tokens, position ids, attention mask]
Returns:
torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]
"""
tokens2use, positions2use = inference_input
logits = self.model(
batch_images=self.pixel_values,
batch_masks=[[[5, 512]]] if self.num_tiles is not None else None,
num_chunks=self.num_tiles,
aspect_ratio_ids=self.aspect_ratio_ids,
tokens=tokens2use,
position_ids=positions2use,
)
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)

return logits

0 comments on commit 4711c75

Please sign in to comment.