diff --git a/examples/inference/vlm/vlm_inference.py b/examples/inference/vlm/vlm_inference.py index 64cd56188c..9878cd8fde 100644 --- a/examples/inference/vlm/vlm_inference.py +++ b/examples/inference/vlm/vlm_inference.py @@ -22,7 +22,7 @@ from typing import Optional import torch -from megatron.core.inference.common_inference_params import CommonInferenceParams +from megatron.core.inference.sampling_params import SamplingParams from qwen_vl_utils import process_vision_info from megatron.bridge.inference.vlm.base import generate, setup_model_and_tokenizer @@ -94,7 +94,7 @@ def main(args) -> None: text, image_inputs, video_inputs = process_image_inputs(processor, args.image_path, prompt) # Setup inference parameters - inference_params = CommonInferenceParams( + inference_params = SamplingParams( temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, @@ -109,9 +109,8 @@ def main(args) -> None: prompts=[text], images=[image_inputs] if image_inputs is not None else None, processor=processor, - max_batch_size=1, random_seed=0, - inference_params=inference_params, + sampling_params=inference_params, ) # Print results diff --git a/src/megatron/bridge/inference/vlm/base.py b/src/megatron/bridge/inference/vlm/base.py index 5efb24d746..f99489d3f7 100644 --- a/src/megatron/bridge/inference/vlm/base.py +++ b/src/megatron/bridge/inference/vlm/base.py @@ -16,6 +16,7 @@ import torch import torch.distributed +from megatron.core.inference.contexts import StaticInferenceContext from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( AbstractModelInferenceWrapper, ) @@ -40,6 +41,7 @@ def setup_model_and_tokenizer( params_dtype: torch.dtype = torch.bfloat16, inference_batch_times_seqlen_threshold: int = 1000, inference_max_seq_length: int = 8192, + inference_max_batch_size: int = 4, ): """Set up model and tokenizer from a Megatron checkpoint. @@ -50,7 +52,7 @@ def setup_model_and_tokenizer( params_dtype: Data type for model parameters. inference_batch_times_seqlen_threshold: Threshold for inference batching. inference_max_seq_length: Maximum sequence length for inference (prompt + generated tokens). - + inference_max_batch_size: Maximum batch size for inference. Returns: A tuple of (inference_wrapped_model, processor). """ @@ -106,6 +108,7 @@ def setup_model_and_tokenizer( params_dtype=torch.bfloat16, inference_batch_times_seqlen_threshold=1000, inference_max_seq_length=inference_max_seq_length, + inference_max_batch_size=inference_max_batch_size, ) return inference_wrapped_model, processor @@ -120,6 +123,7 @@ def _expose_decoder_from_language_model(model): if hasattr(current, "language_model"): language_model = current.language_model current.decoder = language_model.decoder + current.vocab_size = language_model.vocab_size def setup_inference_wrapper( @@ -128,6 +132,7 @@ def setup_inference_wrapper( params_dtype: torch.dtype = torch.bfloat16, inference_batch_times_seqlen_threshold: int = 1000, inference_max_seq_length: int = 8192, + inference_max_batch_size: int = 4, ): """Set up inference wrapper for the model""" config = model.config @@ -148,7 +153,13 @@ def setup_inference_wrapper( else: raise ValueError(f"Unknown model config: {config}") - inference_wrapped_model = wrapper_cls(mcore_model) + inference_wrapped_model = wrapper_cls( + mcore_model, + inference_context=StaticInferenceContext( + max_batch_size=inference_max_batch_size, + max_sequence_length=inference_max_seq_length, + ), + ) return inference_wrapped_model @@ -160,7 +171,6 @@ def generate( prompts: List[str], images: List[Union[Image, List[Image]]], processor=None, - max_batch_size: int = 4, random_seed: Optional[int] = None, sampling_params: Optional[SamplingParams] = None, ) -> dict: @@ -172,7 +182,6 @@ def generate( image_processor: image processor for the input image, prompts (list[str]): The list of prompts to generate text for. images (list): The list of images to generate text for. - max_batch_size (int, optional): The maximum batch size. Defaults to 4. random_seed (Optional[int], optional): The random seed. Defaults to None. sampling_params (Optional["SamplingParams"], optional): The sampling parameters defined in Mcore's SamplingParams. Defaults to None. @@ -195,9 +204,7 @@ def generate( tokenizer=tokenizer, image_processor=image_processor, ) - mcore_engine = VLMEngine( - text_generation_controller=text_generation_controller, max_batch_size=max_batch_size, random_seed=random_seed - ) + mcore_engine = VLMEngine(text_generation_controller=text_generation_controller, random_seed=random_seed) if sampling_params is None: sampling_params = SamplingParams(num_tokens_to_generate=50) diff --git a/src/megatron/bridge/inference/vlm/qwenvl_inference_wrapper.py b/src/megatron/bridge/inference/vlm/qwenvl_inference_wrapper.py index 7b1e128fbc..277881ea3b 100644 --- a/src/megatron/bridge/inference/vlm/qwenvl_inference_wrapper.py +++ b/src/megatron/bridge/inference/vlm/qwenvl_inference_wrapper.py @@ -31,8 +31,8 @@ class QwenVLInferenceWrapper(AbstractModelInferenceWrapper): model (Qwen2VLModel): The Qwen2VL model """ - def __init__(self, model): - super().__init__(model) + def __init__(self, model, inference_context=None): + super().__init__(model, inference_context=inference_context) def prep_inference_input( self, diff --git a/src/megatron/bridge/inference/vlm/vlm_engine.py b/src/megatron/bridge/inference/vlm/vlm_engine.py index 0cd172322a..12f7862333 100644 --- a/src/megatron/bridge/inference/vlm/vlm_engine.py +++ b/src/megatron/bridge/inference/vlm/vlm_engine.py @@ -18,26 +18,12 @@ from megatron.core.inference.engines.mcore_engine import MCoreEngine from megatron.core.inference.inference_request import InferenceRequest from megatron.core.inference.sampling_params import SamplingParams -from megatron.core.inference.scheduler import Scheduler -from megatron.core.inference.text_generation_controllers.text_generation_controller import TextGenerationController from PIL.Image import Image class VLMEngine(MCoreEngine): """VLM inference engine extending MCoreEngine with image support.""" - def __init__( - self, - text_generation_controller: TextGenerationController, - max_batch_size: Optional[int] = None, - random_seed: Optional[int] = None, - ): - self.controller = text_generation_controller - self.inference_wrapped_model = self.controller.inference_wrapped_model - self.config = self.inference_wrapped_model.config - self.random_seed = random_seed or 1234 - self.scheduler = Scheduler(max_batch_size=max_batch_size) - # pylint: disable=C0115,C0116 def generate( self, diff --git a/tests/unit_tests/inference/vlm/test_base.py b/tests/unit_tests/inference/vlm/test_base.py index ace72e715a..99fc4b7d3a 100644 --- a/tests/unit_tests/inference/vlm/test_base.py +++ b/tests/unit_tests/inference/vlm/test_base.py @@ -247,6 +247,7 @@ class MockObject: # Build the nested structure: model.module.language_model.decoder mock_language_model = MockObject() mock_language_model.decoder = mock_decoder + mock_language_model.vocab_size = 151936 mock_module = MockObject() mock_module.language_model = mock_language_model diff --git a/tests/unit_tests/inference/vlm/test_vlm_engine.py b/tests/unit_tests/inference/vlm/test_vlm_engine.py index 0203bbff4b..d2c96d3c39 100644 --- a/tests/unit_tests/inference/vlm/test_vlm_engine.py +++ b/tests/unit_tests/inference/vlm/test_vlm_engine.py @@ -14,6 +14,8 @@ from unittest.mock import MagicMock +from megatron.core.inference.contexts import StaticInferenceContext + from megatron.bridge.inference.vlm.vlm_engine import VLMEngine @@ -21,8 +23,12 @@ class TestVLMEngine: def test_generate(self): mock_controller = MagicMock() mock_controller.tokenize_prompt.return_value = ([1, 2, 3], "image_dict") - # Fix for TypeError: '>' not supported between instances of 'int' and 'MagicMock' - mock_controller.inference_wrapped_model.context.max_batch_size = 128 + # MCoreEngine/StaticInferenceEngine expects inference_context to be a StaticInferenceContext + # (and uses inference_wrapper_config.inference_max_requests for scheduler batch size). + mock_controller.inference_wrapped_model.inference_context = StaticInferenceContext( + max_batch_size=128, max_sequence_length=8192 + ) + mock_controller.inference_wrapped_model.inference_wrapper_config = MagicMock(inference_max_requests=128) engine = VLMEngine(mock_controller, max_batch_size=4) engine.scheduler = MagicMock()