-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
dbedae0
commit 4711c75
Showing
5 changed files
with
305 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from nemo.collections.vlm.inference.base import setup_model_and_tokenizer, generate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Optional | ||
|
||
import pytorch_lightning as pl | ||
import torch | ||
import torch.distributed | ||
from megatron.core.inference.common_inference_params import CommonInferenceParams | ||
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig | ||
|
||
import nemo.lightning as nl | ||
from nemo.collections import vlm | ||
from nemo.collections.vlm.inference.vlm_inference_wrapper import VLMInferenceWrapper | ||
from nemo.collections.vlm.inference.vlm_inference_controller import VLMTextGenerationController | ||
from nemo.collections.vlm.inference.vlm_engine import VLMEngine | ||
|
||
def _setup_trainer_and_restore_model(path: str, trainer: nl.Trainer, model: pl.LightningModule): | ||
fabric = trainer.to_fabric() | ||
model = fabric.load_model(path, model) | ||
return model | ||
|
||
def setup_model_and_tokenizer( | ||
path: str, | ||
trainer: Optional[nl.Trainer] = None, | ||
params_dtype: torch.dtype = torch.bfloat16, | ||
inference_batch_times_seqlen_threshold: int = 1000, | ||
): | ||
# model: io.TrainerContext = io.load_context(path=path, subpath="model") | ||
# trainer = trainer or io.load_context(path=path, subpath="trainer") | ||
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" | ||
from transformers import AutoProcessor | ||
|
||
processor = AutoProcessor.from_pretrained(model_id) | ||
tokenizer = processor.tokenizer | ||
config = vlm.MLlamaConfig11BInstruct() | ||
model = vlm.MLlamaModel(config, tokenizer=tokenizer) | ||
_setup_trainer_and_restore_model(path=path, trainer=trainer, model=model) | ||
|
||
mcore_model = model.module.cuda() | ||
mcore_model = mcore_model.to(params_dtype) | ||
inference_wrapped_model = VLMInferenceWrapper( | ||
mcore_model, | ||
InferenceWrapperConfig( | ||
hidden_size=config.language_model_config.hidden_size, | ||
params_dtype=params_dtype, | ||
inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold, | ||
padded_vocab_size=tokenizer.vocab_size, | ||
), | ||
) | ||
|
||
return inference_wrapped_model, processor | ||
|
||
|
||
def generate( | ||
model: VLMInferenceWrapper, | ||
processor, | ||
prompts: list[str], | ||
images, | ||
max_batch_size: int = 4, | ||
random_seed: Optional[int] = None, | ||
inference_params: Optional[CommonInferenceParams] = None, | ||
) -> dict: | ||
text_generation_controller = VLMTextGenerationController(inference_wrapped_model=model, processor=processor) | ||
mcore_engine = VLMEngine( | ||
text_generation_controller=text_generation_controller, max_batch_size=max_batch_size, random_seed=random_seed | ||
) | ||
|
||
common_inference_params = inference_params or CommonInferenceParams(num_tokens_to_generate=50) | ||
|
||
results = mcore_engine.generate( | ||
prompts=prompts, | ||
images=images, | ||
common_inference_params=common_inference_params, | ||
) | ||
|
||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import List | ||
from PIL.Image import Image | ||
|
||
import torch | ||
|
||
from megatron.core.inference.common_inference_params import CommonInferenceParams | ||
from megatron.core.inference.inference_request import InferenceRequest | ||
from megatron.core.inference.engines.mcore_engine import MCoreEngine | ||
|
||
class VLMEngine(MCoreEngine): | ||
def generate( | ||
self, | ||
prompts: List[str], | ||
images: List[Image] = None, | ||
common_inference_params: CommonInferenceParams = None, | ||
) -> dict: | ||
# TODO :M core- get rng state tracker | ||
if self.random_seed: | ||
torch.random.manual_seed(self.random_seed) | ||
|
||
for i in range(len(prompts)): | ||
prompt = prompts[i] | ||
image = images[i] if images is not None else None | ||
prompt_tokens, image_dict = self.text_generation_controller.tokenize_prompt(prompt, image) | ||
|
||
self.scheduler.add_request( | ||
prompt=prompt, | ||
prompt_tokens=prompt_tokens, | ||
encoder_prompt=image_dict, | ||
inference_parameters=common_inference_params, | ||
) | ||
|
||
self.run_engine() | ||
|
||
result: List[InferenceRequest] = self.scheduler.completed_request_pool.values() | ||
return result |
64 changes: 64 additions & 0 deletions
64
nemo/collections/vlm/inference/vlm_inference_controller.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import OrderedDict | ||
|
||
import torch | ||
|
||
from megatron.core.inference.inference_request import InferenceRequest | ||
from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import ( | ||
SimpleTextGenerationController, | ||
) | ||
|
||
class TokenizerWrapper: | ||
def __init__(self, tokenizer): | ||
self.eod = tokenizer.eos_token_id | ||
self.vocab_size = None | ||
self._tokenizer = tokenizer | ||
|
||
def detokenize(self, tokens): | ||
return self._tokenizer.decode(tokens, skip_special_tokens=True) | ||
|
||
class VLMTextGenerationController(SimpleTextGenerationController): | ||
def __init__(self, inference_wrapped_model, processor): | ||
super().__init__(inference_wrapped_model, TokenizerWrapper(processor.tokenizer)) | ||
self.processor = processor | ||
|
||
def tokenize_prompt(self, prompt: str, image): | ||
num_tiles = None if image is None \ | ||
else self.processor.image_processor.preprocess(image, return_tensors='pt', do_rescale=False)["num_tiles"] | ||
batch = self.processor(image, prompt, add_special_tokens=False, return_tensors="pt") | ||
image_dict = dict( | ||
pixel_values=batch["pixel_values"].cuda(non_blocking=True) if "pixel_values" in batch else None, | ||
aspect_ratio_ids=batch["aspect_ratio_ids"].cuda(non_blocking=True) if "aspect_ratio_ids" in batch else None, | ||
num_tiles=num_tiles, | ||
) | ||
return batch["input_ids"].tolist()[0], image_dict | ||
|
||
def prep_model_for_inference( | ||
self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[int, InferenceRequest] | ||
): | ||
"""Preparing batch for inference, using respective wrapper's prep_model_for_inference method | ||
Args: | ||
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length] | ||
active_requests (OrderedDict[int, InferenceRequest]): The input active requests | ||
""" | ||
images = list( | ||
map(lambda request: request.encoder_prompt, active_requests.values()) | ||
) | ||
|
||
self.inference_wrapped_model.prep_model_for_inference( | ||
prompts_tokens=prompts_tokens, image_dict=images | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from argparse import Namespace | ||
from typing import List, Dict | ||
|
||
import torch | ||
|
||
from megatron.core import tensor_parallel | ||
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( | ||
AbstractModelInferenceWrapper, | ||
) | ||
|
||
|
||
class VLMInferenceWrapper(AbstractModelInferenceWrapper): | ||
"""Constructor for the model inference wrapper | ||
The wrapper prepares the model for inference, provides the required input | ||
data, and runs the forward pass | ||
Args: | ||
model (T5Model): The T5 model (MCore or legacy) | ||
args (Namespace): The command line arguments that were passed | ||
""" | ||
|
||
def __init__(self, model, args: Namespace): | ||
super().__init__(model, args) | ||
|
||
def prep_model_for_inference( | ||
self, prompts_tokens: torch.Tensor, image_dict: List[Dict] = None | ||
): | ||
|
||
super().prep_model_for_inference(prompts_tokens=prompts_tokens) | ||
self.pixel_values = image_dict[0]["pixel_values"] | ||
self.aspect_ratio_ids = image_dict[0]["aspect_ratio_ids"] | ||
self.num_tiles = image_dict[0]["num_tiles"] | ||
seq_length = prompts_tokens.size(1) | ||
self.position_ids = ( | ||
torch.arange(seq_length, dtype=torch.long, device=prompts_tokens.device) | ||
.unsqueeze(0) | ||
.expand_as(prompts_tokens) | ||
) | ||
|
||
def get_batch_for_context_window( | ||
self, context_start_position: int, context_end_position: int | ||
) -> List: | ||
tokens2use = self.prompts_tokens[:, :context_end_position] | ||
positions2use = self.position_ids[:, :context_end_position] | ||
data_at_step_idx = [tokens2use, positions2use] | ||
|
||
return data_at_step_idx | ||
|
||
def forward_pass_without_pipeline_parallel(self, inference_input: List) -> torch.Tensor: | ||
"""Utility to carry out simple forward pass for TP or no model parallel models | ||
Runs a very simple forward pass for model. Used in the case of models without | ||
any parallelism or only tensor parallelism. | ||
Args: | ||
inference_input (List): A list containg the inputs for the gpt | ||
model [tokens, position ids, attention mask] | ||
Returns: | ||
torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size] | ||
""" | ||
tokens2use, positions2use = inference_input | ||
logits = self.model( | ||
batch_images=self.pixel_values, | ||
batch_masks=[[[5, 512]]] if self.num_tiles is not None else None, | ||
num_chunks=self.num_tiles, | ||
aspect_ratio_ids=self.aspect_ratio_ids, | ||
tokens=tokens2use, | ||
position_ids=positions2use, | ||
) | ||
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) | ||
|
||
return logits |