Skip to content

Commit 4711c75

Browse files
committed
Add vlm generation
1 parent dbedae0 commit 4711c75

File tree

5 files changed

+305
-0
lines changed

5 files changed

+305
-0
lines changed
+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from nemo.collections.vlm.inference.base import setup_model_and_tokenizer, generate
+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional
16+
17+
import pytorch_lightning as pl
18+
import torch
19+
import torch.distributed
20+
from megatron.core.inference.common_inference_params import CommonInferenceParams
21+
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig
22+
23+
import nemo.lightning as nl
24+
from nemo.collections import vlm
25+
from nemo.collections.vlm.inference.vlm_inference_wrapper import VLMInferenceWrapper
26+
from nemo.collections.vlm.inference.vlm_inference_controller import VLMTextGenerationController
27+
from nemo.collections.vlm.inference.vlm_engine import VLMEngine
28+
29+
def _setup_trainer_and_restore_model(path: str, trainer: nl.Trainer, model: pl.LightningModule):
30+
fabric = trainer.to_fabric()
31+
model = fabric.load_model(path, model)
32+
return model
33+
34+
def setup_model_and_tokenizer(
35+
path: str,
36+
trainer: Optional[nl.Trainer] = None,
37+
params_dtype: torch.dtype = torch.bfloat16,
38+
inference_batch_times_seqlen_threshold: int = 1000,
39+
):
40+
# model: io.TrainerContext = io.load_context(path=path, subpath="model")
41+
# trainer = trainer or io.load_context(path=path, subpath="trainer")
42+
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
43+
from transformers import AutoProcessor
44+
45+
processor = AutoProcessor.from_pretrained(model_id)
46+
tokenizer = processor.tokenizer
47+
config = vlm.MLlamaConfig11BInstruct()
48+
model = vlm.MLlamaModel(config, tokenizer=tokenizer)
49+
_setup_trainer_and_restore_model(path=path, trainer=trainer, model=model)
50+
51+
mcore_model = model.module.cuda()
52+
mcore_model = mcore_model.to(params_dtype)
53+
inference_wrapped_model = VLMInferenceWrapper(
54+
mcore_model,
55+
InferenceWrapperConfig(
56+
hidden_size=config.language_model_config.hidden_size,
57+
params_dtype=params_dtype,
58+
inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold,
59+
padded_vocab_size=tokenizer.vocab_size,
60+
),
61+
)
62+
63+
return inference_wrapped_model, processor
64+
65+
66+
def generate(
67+
model: VLMInferenceWrapper,
68+
processor,
69+
prompts: list[str],
70+
images,
71+
max_batch_size: int = 4,
72+
random_seed: Optional[int] = None,
73+
inference_params: Optional[CommonInferenceParams] = None,
74+
) -> dict:
75+
text_generation_controller = VLMTextGenerationController(inference_wrapped_model=model, processor=processor)
76+
mcore_engine = VLMEngine(
77+
text_generation_controller=text_generation_controller, max_batch_size=max_batch_size, random_seed=random_seed
78+
)
79+
80+
common_inference_params = inference_params or CommonInferenceParams(num_tokens_to_generate=50)
81+
82+
results = mcore_engine.generate(
83+
prompts=prompts,
84+
images=images,
85+
common_inference_params=common_inference_params,
86+
)
87+
88+
return results
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List
16+
from PIL.Image import Image
17+
18+
import torch
19+
20+
from megatron.core.inference.common_inference_params import CommonInferenceParams
21+
from megatron.core.inference.inference_request import InferenceRequest
22+
from megatron.core.inference.engines.mcore_engine import MCoreEngine
23+
24+
class VLMEngine(MCoreEngine):
25+
def generate(
26+
self,
27+
prompts: List[str],
28+
images: List[Image] = None,
29+
common_inference_params: CommonInferenceParams = None,
30+
) -> dict:
31+
# TODO :M core- get rng state tracker
32+
if self.random_seed:
33+
torch.random.manual_seed(self.random_seed)
34+
35+
for i in range(len(prompts)):
36+
prompt = prompts[i]
37+
image = images[i] if images is not None else None
38+
prompt_tokens, image_dict = self.text_generation_controller.tokenize_prompt(prompt, image)
39+
40+
self.scheduler.add_request(
41+
prompt=prompt,
42+
prompt_tokens=prompt_tokens,
43+
encoder_prompt=image_dict,
44+
inference_parameters=common_inference_params,
45+
)
46+
47+
self.run_engine()
48+
49+
result: List[InferenceRequest] = self.scheduler.completed_request_pool.values()
50+
return result
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import OrderedDict
16+
17+
import torch
18+
19+
from megatron.core.inference.inference_request import InferenceRequest
20+
from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import (
21+
SimpleTextGenerationController,
22+
)
23+
24+
class TokenizerWrapper:
25+
def __init__(self, tokenizer):
26+
self.eod = tokenizer.eos_token_id
27+
self.vocab_size = None
28+
self._tokenizer = tokenizer
29+
30+
def detokenize(self, tokens):
31+
return self._tokenizer.decode(tokens, skip_special_tokens=True)
32+
33+
class VLMTextGenerationController(SimpleTextGenerationController):
34+
def __init__(self, inference_wrapped_model, processor):
35+
super().__init__(inference_wrapped_model, TokenizerWrapper(processor.tokenizer))
36+
self.processor = processor
37+
38+
def tokenize_prompt(self, prompt: str, image):
39+
num_tiles = None if image is None \
40+
else self.processor.image_processor.preprocess(image, return_tensors='pt', do_rescale=False)["num_tiles"]
41+
batch = self.processor(image, prompt, add_special_tokens=False, return_tensors="pt")
42+
image_dict = dict(
43+
pixel_values=batch["pixel_values"].cuda(non_blocking=True) if "pixel_values" in batch else None,
44+
aspect_ratio_ids=batch["aspect_ratio_ids"].cuda(non_blocking=True) if "aspect_ratio_ids" in batch else None,
45+
num_tiles=num_tiles,
46+
)
47+
return batch["input_ids"].tolist()[0], image_dict
48+
49+
def prep_model_for_inference(
50+
self, prompts_tokens: torch.Tensor, active_requests: OrderedDict[int, InferenceRequest]
51+
):
52+
"""Preparing batch for inference, using respective wrapper's prep_model_for_inference method
53+
54+
Args:
55+
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
56+
active_requests (OrderedDict[int, InferenceRequest]): The input active requests
57+
"""
58+
images = list(
59+
map(lambda request: request.encoder_prompt, active_requests.values())
60+
)
61+
62+
self.inference_wrapped_model.prep_model_for_inference(
63+
prompts_tokens=prompts_tokens, image_dict=images
64+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from argparse import Namespace
16+
from typing import List, Dict
17+
18+
import torch
19+
20+
from megatron.core import tensor_parallel
21+
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
22+
AbstractModelInferenceWrapper,
23+
)
24+
25+
26+
class VLMInferenceWrapper(AbstractModelInferenceWrapper):
27+
"""Constructor for the model inference wrapper
28+
29+
The wrapper prepares the model for inference, provides the required input
30+
data, and runs the forward pass
31+
32+
Args:
33+
model (T5Model): The T5 model (MCore or legacy)
34+
args (Namespace): The command line arguments that were passed
35+
"""
36+
37+
def __init__(self, model, args: Namespace):
38+
super().__init__(model, args)
39+
40+
def prep_model_for_inference(
41+
self, prompts_tokens: torch.Tensor, image_dict: List[Dict] = None
42+
):
43+
44+
super().prep_model_for_inference(prompts_tokens=prompts_tokens)
45+
self.pixel_values = image_dict[0]["pixel_values"]
46+
self.aspect_ratio_ids = image_dict[0]["aspect_ratio_ids"]
47+
self.num_tiles = image_dict[0]["num_tiles"]
48+
seq_length = prompts_tokens.size(1)
49+
self.position_ids = (
50+
torch.arange(seq_length, dtype=torch.long, device=prompts_tokens.device)
51+
.unsqueeze(0)
52+
.expand_as(prompts_tokens)
53+
)
54+
55+
def get_batch_for_context_window(
56+
self, context_start_position: int, context_end_position: int
57+
) -> List:
58+
tokens2use = self.prompts_tokens[:, :context_end_position]
59+
positions2use = self.position_ids[:, :context_end_position]
60+
data_at_step_idx = [tokens2use, positions2use]
61+
62+
return data_at_step_idx
63+
64+
def forward_pass_without_pipeline_parallel(self, inference_input: List) -> torch.Tensor:
65+
"""Utility to carry out simple forward pass for TP or no model parallel models
66+
67+
Runs a very simple forward pass for model. Used in the case of models without
68+
any parallelism or only tensor parallelism.
69+
70+
Args:
71+
inference_input (List): A list containg the inputs for the gpt
72+
model [tokens, position ids, attention mask]
73+
74+
Returns:
75+
torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]
76+
"""
77+
tokens2use, positions2use = inference_input
78+
logits = self.model(
79+
batch_images=self.pixel_values,
80+
batch_masks=[[[5, 512]]] if self.num_tiles is not None else None,
81+
num_chunks=self.num_tiles,
82+
aspect_ratio_ids=self.aspect_ratio_ids,
83+
tokens=tokens2use,
84+
position_ids=positions2use,
85+
)
86+
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)
87+
88+
return logits

0 commit comments

Comments
 (0)