Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions examples/inference/vlm/vlm_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import Optional

import torch
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.sampling_params import SamplingParams
from qwen_vl_utils import process_vision_info

from megatron.bridge.inference.vlm.base import generate, setup_model_and_tokenizer
Expand Down Expand Up @@ -94,7 +94,7 @@ def main(args) -> None:
text, image_inputs, video_inputs = process_image_inputs(processor, args.image_path, prompt)

# Setup inference parameters
inference_params = CommonInferenceParams(
inference_params = SamplingParams(
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
Expand All @@ -109,9 +109,8 @@ def main(args) -> None:
prompts=[text],
images=[image_inputs] if image_inputs is not None else None,
processor=processor,
max_batch_size=1,
random_seed=0,
inference_params=inference_params,
sampling_params=inference_params,
)

# Print results
Expand Down
21 changes: 14 additions & 7 deletions src/megatron/bridge/inference/vlm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch
import torch.distributed
from megatron.core.inference.contexts import StaticInferenceContext
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
AbstractModelInferenceWrapper,
)
Expand All @@ -40,6 +41,7 @@ def setup_model_and_tokenizer(
params_dtype: torch.dtype = torch.bfloat16,
inference_batch_times_seqlen_threshold: int = 1000,
inference_max_seq_length: int = 8192,
inference_max_batch_size: int = 4,
):
"""Set up model and tokenizer from a Megatron checkpoint.

Expand All @@ -50,7 +52,7 @@ def setup_model_and_tokenizer(
params_dtype: Data type for model parameters.
inference_batch_times_seqlen_threshold: Threshold for inference batching.
inference_max_seq_length: Maximum sequence length for inference (prompt + generated tokens).

inference_max_batch_size: Maximum batch size for inference.
Returns:
A tuple of (inference_wrapped_model, processor).
"""
Expand Down Expand Up @@ -106,6 +108,7 @@ def setup_model_and_tokenizer(
params_dtype=torch.bfloat16,
inference_batch_times_seqlen_threshold=1000,
inference_max_seq_length=inference_max_seq_length,
inference_max_batch_size=inference_max_batch_size,
)

return inference_wrapped_model, processor
Expand All @@ -120,6 +123,7 @@ def _expose_decoder_from_language_model(model):
if hasattr(current, "language_model"):
language_model = current.language_model
current.decoder = language_model.decoder
current.vocab_size = language_model.vocab_size


def setup_inference_wrapper(
Expand All @@ -128,6 +132,7 @@ def setup_inference_wrapper(
params_dtype: torch.dtype = torch.bfloat16,
inference_batch_times_seqlen_threshold: int = 1000,
inference_max_seq_length: int = 8192,
inference_max_batch_size: int = 4,
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

inference_max_batch_size is not surfaced through setup_model_and_tokenizer.

setup_inference_wrapper gains the new inference_max_batch_size parameter (default 4), but setup_model_and_tokenizer — the public entry point — never accepts or forwards it (lines 104–110). Any caller who needs generate(max_batch_size=N) with N > 4 will hit a StaticInferenceContext KV-cache overflow at runtime, with no way to prevent it through the public API.

🔧 Suggested fix: surface the parameter in the public function
 def setup_model_and_tokenizer(
     megatron_model_path: str,
     tp: int = 1,
     pp: int = 1,
     params_dtype: torch.dtype = torch.bfloat16,
     inference_batch_times_seqlen_threshold: int = 1000,
     inference_max_seq_length: int = 8192,
+    inference_max_batch_size: int = 4,
 ):
     ...
     inference_wrapped_model = setup_inference_wrapper(
         model[0],
         processor.tokenizer,
         params_dtype=torch.bfloat16,
         inference_batch_times_seqlen_threshold=1000,
         inference_max_seq_length=inference_max_seq_length,
+        inference_max_batch_size=inference_max_batch_size,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/inference/vlm/base.py` at line 133, The public function
setup_model_and_tokenizer currently does not accept or forward the new
inference_max_batch_size parameter, causing callers to be unable to set
generate(max_batch_size=N) and risking StaticInferenceContext KV-cache
overflows; update setup_model_and_tokenizer to add an inference_max_batch_size:
int = 4 parameter and pass it through to setup_inference_wrapper (which already
accepts inference_max_batch_size) so downstream calls to generate() can use the
configured max_batch_size and avoid KV-cache overflow.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@meatybobby could you please check if this is relevant ?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MCore still has backward compatibility for this, but I just removed it to avoid confusing.

):
"""Set up inference wrapper for the model"""
config = model.config
Expand All @@ -148,7 +153,13 @@ def setup_inference_wrapper(
else:
raise ValueError(f"Unknown model config: {config}")

inference_wrapped_model = wrapper_cls(mcore_model)
inference_wrapped_model = wrapper_cls(
mcore_model,
inference_context=StaticInferenceContext(
max_batch_size=inference_max_batch_size,
max_sequence_length=inference_max_seq_length,
),
)

return inference_wrapped_model

Expand All @@ -160,7 +171,6 @@ def generate(
prompts: List[str],
images: List[Union[Image, List[Image]]],
processor=None,
max_batch_size: int = 4,
random_seed: Optional[int] = None,
sampling_params: Optional[SamplingParams] = None,
) -> dict:
Expand All @@ -172,7 +182,6 @@ def generate(
image_processor: image processor for the input image,
prompts (list[str]): The list of prompts to generate text for.
images (list): The list of images to generate text for.
max_batch_size (int, optional): The maximum batch size. Defaults to 4.
random_seed (Optional[int], optional): The random seed. Defaults to None.
sampling_params (Optional["SamplingParams"], optional): The sampling parameters defined in
Mcore's SamplingParams. Defaults to None.
Expand All @@ -195,9 +204,7 @@ def generate(
tokenizer=tokenizer,
image_processor=image_processor,
)
mcore_engine = VLMEngine(
text_generation_controller=text_generation_controller, max_batch_size=max_batch_size, random_seed=random_seed
)
mcore_engine = VLMEngine(text_generation_controller=text_generation_controller, random_seed=random_seed)

if sampling_params is None:
sampling_params = SamplingParams(num_tokens_to_generate=50)
Expand Down
4 changes: 2 additions & 2 deletions src/megatron/bridge/inference/vlm/qwenvl_inference_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class QwenVLInferenceWrapper(AbstractModelInferenceWrapper):
model (Qwen2VLModel): The Qwen2VL model
"""

def __init__(self, model):
super().__init__(model)
def __init__(self, model, inference_context=None):
super().__init__(model, inference_context=inference_context)

def prep_inference_input(
self,
Expand Down
14 changes: 0 additions & 14 deletions src/megatron/bridge/inference/vlm/vlm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,12 @@
from megatron.core.inference.engines.mcore_engine import MCoreEngine
from megatron.core.inference.inference_request import InferenceRequest
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.scheduler import Scheduler
from megatron.core.inference.text_generation_controllers.text_generation_controller import TextGenerationController
from PIL.Image import Image


class VLMEngine(MCoreEngine):
"""VLM inference engine extending MCoreEngine with image support."""

def __init__(
self,
text_generation_controller: TextGenerationController,
max_batch_size: Optional[int] = None,
random_seed: Optional[int] = None,
):
self.controller = text_generation_controller
self.inference_wrapped_model = self.controller.inference_wrapped_model
self.config = self.inference_wrapped_model.config
self.random_seed = random_seed or 1234
self.scheduler = Scheduler(max_batch_size=max_batch_size)

# pylint: disable=C0115,C0116
def generate(
self,
Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/inference/vlm/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ class MockObject:
# Build the nested structure: model.module.language_model.decoder
mock_language_model = MockObject()
mock_language_model.decoder = mock_decoder
mock_language_model.vocab_size = 151936

mock_module = MockObject()
mock_module.language_model = mock_language_model
Expand Down
10 changes: 8 additions & 2 deletions tests/unit_tests/inference/vlm/test_vlm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,21 @@

from unittest.mock import MagicMock

from megatron.core.inference.contexts import StaticInferenceContext

from megatron.bridge.inference.vlm.vlm_engine import VLMEngine


class TestVLMEngine:
def test_generate(self):
mock_controller = MagicMock()
mock_controller.tokenize_prompt.return_value = ([1, 2, 3], "image_dict")
# Fix for TypeError: '>' not supported between instances of 'int' and 'MagicMock'
mock_controller.inference_wrapped_model.context.max_batch_size = 128
# MCoreEngine/StaticInferenceEngine expects inference_context to be a StaticInferenceContext
# (and uses inference_wrapper_config.inference_max_requests for scheduler batch size).
mock_controller.inference_wrapped_model.inference_context = StaticInferenceContext(
max_batch_size=128, max_sequence_length=8192
)
mock_controller.inference_wrapped_model.inference_wrapper_config = MagicMock(inference_max_requests=128)

engine = VLMEngine(mock_controller, max_batch_size=4)
engine.scheduler = MagicMock()
Expand Down