Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@
Qwen3Config600M,
Qwen3Model,
Qwen25Config1P5B,
Qwen25Config3B,
Qwen25Config7B,
Qwen25Config14B,
Qwen25Config32B,
Expand Down Expand Up @@ -333,6 +334,7 @@
"Qwen2Config",
"Qwen2Config500M",
"Qwen2Config1P5B",
"Qwen25Config3B",
"Qwen2Config72B",
"Qwen25Config500M",
"Qwen25Config1P5B",
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/llm/gpt/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@
Qwen2Config500M,
Qwen2Model,
Qwen25Config1P5B,
Qwen25Config3B,
Qwen25Config7B,
Qwen25Config14B,
Qwen25Config32B,
Expand Down Expand Up @@ -270,6 +271,7 @@
"Qwen2Config",
"Qwen2Config500M",
"Qwen2Config1P5B",
"Qwen25Config3B",
"Qwen2Config7B",
"Qwen2Config72B",
"Qwen25Config72B",
Expand Down
16 changes: 16 additions & 0 deletions nemo/collections/llm/gpt/model/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,21 @@ class Qwen2Config1P5B(Qwen2Config):
ffn_hidden_size: int = 8960


@dataclass
class Qwen25Config3B(Qwen2Config):
"""
Config for Qwen 2.5 3B: https://huggingface.co/Qwen/Qwen2.5-3B
"""

num_layers: int = 36
hidden_size: int = 2048
num_attention_heads: int = 16
num_query_groups: int = 2
ffn_hidden_size: int = 11008
vocab_size: int = 151936
share_embeddings_and_output_weights: bool = True


@dataclass
class Qwen25Config1P5B(Qwen2Config1P5B):
"""
Expand Down Expand Up @@ -399,6 +414,7 @@ def config(self) -> "HFQwen2Config":
"Qwen2Config",
"Qwen2Config500M",
"Qwen2Config1P5B",
"Qwen25Config3B",
"Qwen2Config7B",
"Qwen2Config72B",
"Qwen25Config500M",
Expand Down
3 changes: 3 additions & 0 deletions nemo/collections/llm/recipes/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Qwen2Config500M,
Qwen2Model,
Qwen25Config1P5B,
Qwen25Config3B,
Qwen25Config7B,
Qwen25Config14B,
Qwen25Config32B,
Expand Down Expand Up @@ -56,6 +57,8 @@ def qwen2_model(version: str) -> run.Config[pl.LightningModule]:
config = run.Config(Qwen2Config1P5B)
elif version == "qwen25_1p5b":
config = run.Config(Qwen25Config1P5B)
elif version == "qwen25_3b":
config = run.Config(Qwen25Config3B)
elif version == "qwen2_7b":
config = run.Config(Qwen2Config7B)
elif version == "qwen25_7b":
Expand Down
21 changes: 19 additions & 2 deletions nemo/collections/vlm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,20 @@
# PEFT
from nemo.collections.vlm.peft import LoRA
from nemo.collections.vlm.qwen2vl.data import Qwen2VLDataConfig, Qwen2VLMockDataModule, Qwen2VLPreloadedDataModule
from nemo.collections.vlm.qwen2vl.model.base import Qwen2VLConfig, Qwen2VLModel, Qwen2VLVisionConfig
from nemo.collections.vlm.qwen2vl.model.qwen2vl import Qwen2VLConfig2B, Qwen2VLConfig7B
from nemo.collections.vlm.qwen2vl.model.base import (
Qwen2VLConfig,
Qwen2VLModel,
Qwen2VLVisionConfig,
Qwen25VLVisionConfig,
)
from nemo.collections.vlm.qwen2vl.model.qwen2vl import (
Qwen2VLConfig2B,
Qwen2VLConfig7B,
Qwen25VLConfig3B,
Qwen25VLConfig7B,
Qwen25VLConfig32B,
Qwen25VLConfig72B,
)

# RECIPES
from nemo.collections.vlm.recipes import *
Expand Down Expand Up @@ -123,6 +135,11 @@
"Qwen2VLConfig7B",
"Qwen2VLVisionConfig",
"Qwen2VLModel",
"Qwen25VLConfig3B",
"Qwen25VLConfig7B",
"Qwen25VLConfig32B",
"Qwen25VLConfig72B",
"Qwen25VLVisionConfig",
"Qwen2VLDataConfig",
"Gemma3VLConfig",
"Gemma3VLConfig4B",
Expand Down
67 changes: 55 additions & 12 deletions nemo/collections/vlm/qwen2vl/data/preloaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from nemo.lightning.pytorch.plugins import MegatronDataSampler


def process_vision(processor, images, videos):
def process_vision(processor, images, videos, fps=None, model_version="qwen2-vl"):
# pylint: disable=C0115,C0116
assert isinstance(processor, Qwen2VLImageProcessor), "processor needs to be Qwen2VLImageProcessor"
if images is not None:
Expand All @@ -58,6 +58,22 @@ def process_vision(processor, images, videos):
if videos is not None:
videos_inputs = processor(images=None, videos=videos, return_tensors='pt')
video_grid_thw = videos_inputs["video_grid_thw"]
if model_version == "qwen25-vl":
if isinstance(fps, (int, float)):
second_per_grid_ts = [processor.temporal_patch_size / fps] * len(video_grid_thw)
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
second_per_grid_ts = [processor.temporal_patch_size / tmp for tmp in fps]
else:
raise ValueError(
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length "
f"of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
)
second_per_grid_ts = torch.tensor(
second_per_grid_ts,
dtype=videos_inputs['pixel_values_videos'].dtype,
device=videos_inputs['pixel_values_videos'].device,
)
videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})
else:
videos_inputs = {}
video_grid_thw = None
Expand Down Expand Up @@ -322,6 +338,7 @@ def __init__(
data_config,
tokenizer,
image_processor,
model_version,
sequence_length=None,
):
super().__init__()
Expand All @@ -348,6 +365,7 @@ def __init__(

self.image_folder = getattr(data_config, "image_folder", None)
self.video_folder = getattr(data_config, "video_folder", None) or self.image_folder
self.model_version = model_version

def __len__(self):
return len(self.list_data_dict)
Expand All @@ -358,7 +376,7 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
conv = copy.deepcopy(supported_conv_templates[self.conv_template])
chatml = self._apply_prompt_templates(conv, source, use_plain=self.conv_template == "plain")

vision_tensors = self._process_vision(source, self.image_folder, self.video_folder)
vision_tensors = self._process_vision(source, self.image_folder, self.video_folder, self.model_version)
tokens, labels = self._tokenize_and_label(conv, chatml, vision_tensors)

data_dict = dict(
Expand Down Expand Up @@ -416,21 +434,26 @@ def _fetch_vision_content(self, images, videos):
for image in images:
image_inputs.append(fetch_image({"image": image}))
video_inputs = []
video_sample_fps_list = []
for video in videos:
video_inputs.append(fetch_video({"video": video}))
video_input, video_sample_fps = fetch_video({"video": video}, return_video_sample_fps=True)
video_sample_fps_list.append(video_sample_fps)
video_inputs.append(video_input)
if len(image_inputs) == 0:
image_inputs = None
if len(video_inputs) == 0:
video_inputs = None
return image_inputs, video_inputs
return image_inputs, video_inputs, video_sample_fps_list

def _process_vision(self, source, image_folder, video_folder):
def _process_vision(self, source, image_folder, video_folder, model_version):
# normalize image and video paths
images, videos = self._normalize_vision_paths(source, image_folder, video_folder)
# leave the I/O and smart_resize to qwen_vl_utils, which is maintained on github by Qwen Team.
image_inputs, video_inputs = self._fetch_vision_content(images, videos)
image_inputs, video_inputs, video_sample_fps_list = self._fetch_vision_content(images, videos)
# call Huggingface processor to get patches and size info, which is maintained by Qwen Team as well.
vision_tensors = process_vision(self.image_processor, image_inputs, video_inputs)
vision_tensors = process_vision(
self.image_processor, image_inputs, video_inputs, video_sample_fps_list, model_version
)
return vision_tensors

def _apply_prompt_templates(self, conv, source, use_plain=False):
Expand Down Expand Up @@ -551,16 +574,17 @@ def __init__(
data_config,
tokenizer,
image_processor,
model_version,
sequence_length=None,
):

if data_path.endswith(".json"):
super().__init__(data_path, data_config, tokenizer, image_processor, sequence_length)
super().__init__(data_path, data_config, tokenizer, image_processor, model_version, sequence_length)
elif data_path.endswith(".jsonl"):
# FIXME: implement support for more data formats
super().__init__(None, data_config, tokenizer, image_processor, sequence_length)
super().__init__(None, data_config, tokenizer, image_processor, model_version, sequence_length)
logging.warning("Loading image inputs from Dataset...")
if data_config.media_type == 'image':
if data_config.image_folder is not None:
image_folder = data_config.image_folder
for line in open(data_path, "r"):
record = json.loads(line)
Expand Down Expand Up @@ -638,8 +662,15 @@ def collate_fn(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
batch['pixel_values_videos'] = None
if 'video_grid_thw' in instances[0]:
batch['video_grid_thw'] = torch.cat([instance['video_grid_thw'] for instance in instances], dim=0)
if self.model_version == "qwen25-vl":
batch['second_per_grid_ts'] = torch.cat(
[instance['second_per_grid_ts'] for instance in instances], dim=0
)
else:
batch['second_per_grid_ts'] = None
else:
batch['video_grid_thw'] = None
batch['second_per_grid_ts'] = None

tokenizer = self.tokenizer

Expand Down Expand Up @@ -681,6 +712,7 @@ class Qwen2VLPreloadedDataModule(pl.LightningDataModule):

def __init__(
self,
model_version,
paths: str | List[str],
weights: Optional[List[float]] = None,
data_config: Optional[Qwen2VLDataConfig] = Qwen2VLDataConfig,
Expand Down Expand Up @@ -708,6 +740,7 @@ def __init__(
# weights must be None if there is only one dataset
weights = None

self.model_version = model_version
self.paths = paths
self.weights = weights
self.data_config = data_config
Expand Down Expand Up @@ -744,10 +777,20 @@ def setup(self, stage: str = "") -> None:
# TODO:
# rng = torch.Generator().manual_seed(self.seed)
self._train_ds = Qwen2VLDataset(
self.paths[0], self.data_config, self.tokenizer, self.image_processor, sequence_length=self.seq_length
self.paths[0],
self.data_config,
self.tokenizer,
self.image_processor,
self.model_version,
sequence_length=self.seq_length,
)
self._validation_ds = Qwen2VLDataset(
self.paths[0], self.data_config, self.tokenizer, self.image_processor, sequence_length=self.seq_length
self.paths[0],
self.data_config,
self.tokenizer,
self.image_processor,
self.model_version,
sequence_length=self.seq_length,
)

def train_dataloader(self) -> TRAIN_DATALOADERS:
Expand Down
23 changes: 21 additions & 2 deletions nemo/collections/vlm/qwen2vl/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.vlm.qwen2vl.model.base import Qwen2VLConfig, Qwen2VLModel, Qwen2VLVisionConfig
from nemo.collections.vlm.qwen2vl.model.qwen2vl import Qwen2VLConfig2B, Qwen2VLConfig7B
from nemo.collections.vlm.qwen2vl.model.base import (
Qwen2VLConfig,
Qwen2VLModel,
Qwen2VLVisionConfig,
Qwen25VLVisionConfig,
)
from nemo.collections.vlm.qwen2vl.model.qwen2vl import (
Qwen2VLConfig2B,
Qwen2VLConfig7B,
Qwen2VLConfig72B,
Qwen25VLConfig3B,
Qwen25VLConfig7B,
Qwen25VLConfig32B,
Qwen25VLConfig72B,
)

__all__ = [
"Qwen2VLVisionConfig",
"Qwen2VLConfig",
"Qwen2VLConfig2B",
"Qwen2VLConfig7B",
"Qwen2VLConfig72B",
"Qwen2VLModel",
"Qwen25VLVisionConfig",
"Qwen25VLConfig3B",
"Qwen25VLConfig7B",
"Qwen25VLConfig32B",
"Qwen25VLConfig72B",
]
39 changes: 34 additions & 5 deletions nemo/collections/vlm/qwen2vl/model/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,51 @@

import lightning.pytorch as pl

from nemo.collections.vlm.qwen2vl.model import Qwen2VLConfig2B, Qwen2VLConfig7B, Qwen2VLConfig72B, Qwen2VLModel
from nemo.collections.vlm.qwen2vl.model import (
Qwen2VLConfig2B,
Qwen2VLConfig7B,
Qwen2VLConfig72B,
Qwen2VLModel,
Qwen25VLConfig3B,
Qwen25VLConfig7B,
Qwen25VLConfig32B,
Qwen25VLConfig72B,
)


def qwen2vl_2b() -> pl.LightningModule:
# pylint: disable=C0115,C0116
return Qwen2VLModel(Qwen2VLConfig2B())
return Qwen2VLModel(Qwen2VLConfig2B(), model_version="qwen2-vl")


def qwen2vl_7b() -> pl.LightningModule:
# pylint: disable=C0115,C0116
return Qwen2VLModel(Qwen2VLConfig7B())
return Qwen2VLModel(Qwen2VLConfig7B(), model_version="qwen2-vl")


def qwen2vl_72b() -> pl.LightningModule:
# pylint: disable=C0115,C0116
return Qwen2VLModel(Qwen2VLConfig72B())
return Qwen2VLModel(Qwen2VLConfig72B(), model_version="qwen2-vl")


__all__ = ["qwen2vl_2b", "qwen2vl_7b", "qwen2vl_72b"]
def qwen25vl_3b() -> pl.LightningModule:
# pylint: disable=C0115,C0116
return Qwen2VLModel(Qwen25VLConfig3B(), model_version="qwen25-vl")


def qwen25vl_7b() -> pl.LightningModule:
# pylint: disable=C0115,C0116
return Qwen2VLModel(Qwen25VLConfig7B(), model_version="qwen25-vl")


def qwen25vl_32b() -> pl.LightningModule:
# pylint: disable=C0115,C0116
return Qwen2VLModel(Qwen25VLConfig32B(), model_version="qwen25-vl")


def qwen25vl_72b() -> pl.LightningModule:
# pylint: disable=C0115,C0116
return Qwen2VLModel(Qwen25VLConfig72B(), model_version="qwen25-vl")


__all__ = ["qwen2vl_2b", "qwen2vl_7b", "qwen2vl_72b", "qwen25vl_3b", "qwen25vl_7b", "qwen25vl_32b", "qwen25vl_72b"]
Loading
Loading