diff --git a/nemo/collections/vlm/inference/__init__.py b/nemo/collections/vlm/inference/__init__.py new file mode 100644 index 000000000000..ebe251f41d77 --- /dev/null +++ b/nemo/collections/vlm/inference/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.vlm.inference.base import setup_model_and_tokenizer, generate diff --git a/nemo/collections/vlm/inference/base.py b/nemo/collections/vlm/inference/base.py new file mode 100644 index 000000000000..46830601a5b6 --- /dev/null +++ b/nemo/collections/vlm/inference/base.py @@ -0,0 +1,88 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import pytorch_lightning as pl +import torch +import torch.distributed +from megatron.core.inference.common_inference_params import CommonInferenceParams +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig + +import nemo.lightning as nl +from nemo.collections import vlm +from nemo.collections.vlm.inference.vlm_inference_wrapper import VLMInferenceWrapper +from nemo.collections.vlm.inference.vlm_inference_controller import VLMTextGenerationController +from nemo.collections.vlm.inference.vlm_engine import VLMEngine + +def _setup_trainer_and_restore_model(path: str, trainer: nl.Trainer, model: pl.LightningModule): + fabric = trainer.to_fabric() + model = fabric.load_model(path, model) + return model + +def setup_model_and_tokenizer( + path: str, + trainer: Optional[nl.Trainer] = None, + params_dtype: torch.dtype = torch.bfloat16, + inference_batch_times_seqlen_threshold: int = 1000, +): + # model: io.TrainerContext = io.load_context(path=path, subpath="model") + # trainer = trainer or io.load_context(path=path, subpath="trainer") + model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained(model_id) + tokenizer = processor.tokenizer + config = vlm.MLlamaConfig11BInstruct() + model = vlm.MLlamaModel(config, tokenizer=tokenizer) + _setup_trainer_and_restore_model(path=path, trainer=trainer, model=model) + + mcore_model = model.module.cuda() + mcore_model = mcore_model.to(params_dtype) + inference_wrapped_model = VLMInferenceWrapper( + mcore_model, + InferenceWrapperConfig( + hidden_size=config.language_model_config.hidden_size, + params_dtype=params_dtype, + inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold, + padded_vocab_size=tokenizer.vocab_size, + ), + ) + + return inference_wrapped_model, processor + + +def generate( + model: VLMInferenceWrapper, + processor, + prompts: list[str], + images, + max_batch_size: int = 4, + random_seed: Optional[int] = None, + inference_params: Optional[CommonInferenceParams] = None, +) -> dict: + text_generation_controller = VLMTextGenerationController(inference_wrapped_model=model, processor=processor) + mcore_engine = VLMEngine( + text_generation_controller=text_generation_controller, max_batch_size=max_batch_size, random_seed=random_seed + ) + + common_inference_params = inference_params or CommonInferenceParams(num_tokens_to_generate=50) + + results = mcore_engine.generate( + prompts=prompts, + images=images, + common_inference_params=common_inference_params, + ) + + return results diff --git a/nemo/collections/vlm/inference/vlm_engine.py b/nemo/collections/vlm/inference/vlm_engine.py new file mode 100644 index 000000000000..0b298b6cc234 --- /dev/null +++ b/nemo/collections/vlm/inference/vlm_engine.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List +from PIL.Image import Image + +import torch + +from megatron.core.inference.common_inference_params import CommonInferenceParams +from megatron.core.inference.inference_request import InferenceRequest +from megatron.core.inference.engines.mcore_engine import MCoreEngine + +class VLMEngine(MCoreEngine): + def generate( + self, + prompts: List[str], + images: List[Image] = None, + common_inference_params: CommonInferenceParams = None, + ) -> dict: + # TODO :M core- get rng state tracker + if self.random_seed: + torch.random.manual_seed(self.random_seed) + + for i in range(len(prompts)): + prompt = prompts[i] + image = images[i] if images is not None else None + prompt_tokens, image_dict = self.text_generation_controller.tokenize_prompt(prompt, image) + + self.scheduler.add_request( + prompt=prompt, + prompt_tokens=prompt_tokens, + encoder_prompt=image_dict, + inference_parameters=common_inference_params, + ) + + self.run_engine() + + result: List[InferenceRequest] = self.scheduler.completed_request_pool.values() + return result diff --git a/nemo/collections/vlm/inference/vlm_inference_controller.py b/nemo/collections/vlm/inference/vlm_inference_controller.py new file mode 100644 index 000000000000..a851e501d03b --- /dev/null +++ b/nemo/collections/vlm/inference/vlm_inference_controller.py @@ -0,0 +1,64 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import OrderedDict + +import torch + +from megatron.core.inference.inference_request import InferenceRequest +from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import ( + SimpleTextGenerationController, +) + +class TokenizerWrapper: + def __init__(self, tokenizer): + self.eod = tokenizer.eos_token_id + self.vocab_size = None + self._tokenizer = tokenizer + + def detokenize(self, tokens): + return self._tokenizer.decode(tokens, skip_special_tokens=True) + +class VLMTextGenerationController(SimpleTextGenerationController): + def __init__(self, inference_wrapped_model, processor): + super().__init__(inference_wrapped_model, TokenizerWrapper(processor.tokenizer)) + self.processor = processor + + def tokenize_prompt(self, prompt: str, image): + num_tiles = None if image is None \ + else self.processor.image_processor.preprocess(image, return_tensors='pt', do_rescale=False)["num_tiles"] + batch = self.processor(image, prompt, add_special_tokens=False, return_tensors="pt") + image_dict = dict( + pixel_values=batch["pixel_values"].cuda(non_blocking=True) if "pixel_values" in batch else None, + aspect_ratio_ids=batch["aspect_ratio_ids"].cuda(non_blocking=True) if "aspect_ratio_ids" in batch else None, + num_tiles=num_tiles, + ) + return batch["input_ids"].tolist()[0], image_dict + + def prep_model_for_inference( + self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[int, InferenceRequest] + ): + """Preparing batch for inference, using respective wrapper's prep_model_for_inference method + + Args: + prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length] + active_requests (OrderedDict[int, InferenceRequest]): The input active requests + """ + images = list( + map(lambda request: request.encoder_prompt, active_requests.values()) + ) + + self.inference_wrapped_model.prep_model_for_inference( + prompts_tokens=prompts_tokens, image_dict=images + ) diff --git a/nemo/collections/vlm/inference/vlm_inference_wrapper.py b/nemo/collections/vlm/inference/vlm_inference_wrapper.py new file mode 100644 index 000000000000..b9d4cb49f1fb --- /dev/null +++ b/nemo/collections/vlm/inference/vlm_inference_wrapper.py @@ -0,0 +1,88 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import Namespace +from typing import List, Dict + +import torch + +from megatron.core import tensor_parallel +from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( + AbstractModelInferenceWrapper, +) + + +class VLMInferenceWrapper(AbstractModelInferenceWrapper): + """Constructor for the model inference wrapper + + The wrapper prepares the model for inference, provides the required input + data, and runs the forward pass + + Args: + model (T5Model): The T5 model (MCore or legacy) + args (Namespace): The command line arguments that were passed + """ + + def __init__(self, model, args: Namespace): + super().__init__(model, args) + + def prep_model_for_inference( + self, prompts_tokens: torch.Tensor, image_dict: List[Dict] = None + ): + + super().prep_model_for_inference(prompts_tokens=prompts_tokens) + self.pixel_values = image_dict[0]["pixel_values"] + self.aspect_ratio_ids = image_dict[0]["aspect_ratio_ids"] + self.num_tiles = image_dict[0]["num_tiles"] + seq_length = prompts_tokens.size(1) + self.position_ids = ( + torch.arange(seq_length, dtype=torch.long, device=prompts_tokens.device) + .unsqueeze(0) + .expand_as(prompts_tokens) + ) + + def get_batch_for_context_window( + self, context_start_position: int, context_end_position: int + ) -> List: + tokens2use = self.prompts_tokens[:, :context_end_position] + positions2use = self.position_ids[:, :context_end_position] + data_at_step_idx = [tokens2use, positions2use] + + return data_at_step_idx + + def forward_pass_without_pipeline_parallel(self, inference_input: List) -> torch.Tensor: + """Utility to carry out simple forward pass for TP or no model parallel models + + Runs a very simple forward pass for model. Used in the case of models without + any parallelism or only tensor parallelism. + + Args: + inference_input (List): A list containg the inputs for the gpt + model [tokens, position ids, attention mask] + + Returns: + torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size] + """ + tokens2use, positions2use = inference_input + logits = self.model( + batch_images=self.pixel_values, + batch_masks=[[[5, 512]]] if self.num_tiles is not None else None, + num_chunks=self.num_tiles, + aspect_ratio_ids=self.aspect_ratio_ids, + tokens=tokens2use, + position_ids=positions2use, + ) + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + + return logits