Skip to content

Commit

Permalink
Merge branch 'main' into jpg2p_jun18
Browse files Browse the repository at this point in the history
BuyuanCui authored Jul 19, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
2 parents b6b0510 + ab8988e commit 26868b0
Showing 13 changed files with 858 additions and 87 deletions.
5 changes: 3 additions & 2 deletions examples/multimodal/multimodal_llm/neva/conf/neva_export.yaml
Original file line number Diff line number Diff line change
@@ -6,10 +6,11 @@ infer:
max_input_len: 4096
max_output_len: 256
max_multimodal_len: 3072
vision_max_batch_size: 1 #256 for lita/vita when inference with video dataset

model:
type: neva
type: neva #neva, video-neva, lita, vila, vita
precision: bfloat16
visual_model_path: /path/to/visual.nemo
llm_model_path: /path/to/llm.nemo
llm_model_type: llama
llm_model_type: llama
1 change: 1 addition & 0 deletions examples/multimodal/multimodal_llm/neva/neva_export.py
Original file line number Diff line number Diff line change
@@ -27,6 +27,7 @@ def main(cfg):
tensor_parallel_size=cfg.infer.tensor_parallelism,
max_input_len=cfg.infer.max_input_len,
max_output_len=cfg.infer.max_output_len,
vision_max_batch_size=cfg.infer.vision_max_batch_size,
max_batch_size=cfg.infer.max_batch_size,
max_multimodal_len=cfg.infer.max_multimodal_len,
dtype=cfg.model.precision,
1 change: 1 addition & 0 deletions nemo/collections/common/data/lhotse/cutset.py
Original file line number Diff line number Diff line change
@@ -383,6 +383,7 @@ def read_nemo_manifest(config, is_tarred: bool) -> CutSet:
"lang_field": config.lang_field,
"shuffle_shards": config.shuffle,
"shard_seed": config.shard_seed,
"extra_fields": config.get("extra_fields", None),
}
# The option below is to allow a special case of NeMo manifest iteration as Lhotse CutSet
# without performing any I/O. NeMo manifests typically don't have sampling_rate information required by Lhotse,
160 changes: 154 additions & 6 deletions nemo/collections/common/data/lhotse/nemo_adapters.py
Original file line number Diff line number Diff line change
@@ -16,10 +16,12 @@
import random
import re
import tarfile
from collections.abc import Mapping, Sequence
from io import BytesIO
from pathlib import Path
from typing import Generator, Iterable, List, Literal

import lhotse.serialization
import soundfile
from cytoolz import groupby
from lhotse import AudioSource, Recording, SupervisionSegment
@@ -28,6 +30,7 @@
from lhotse.lazy import LazyIteratorChain, LazyJsonlIterator
from lhotse.serialization import open_best
from lhotse.utils import compute_num_samples

from nemo.collections.common.parts.preprocessing.manifest import get_full_path


@@ -56,16 +59,33 @@ class LazyNeMoIterator:
Example::
>>> cuts = lhotse.CutSet(LazyNeMoIterator("nemo_manifests/train.json"))
We allow attaching custom metadata to cuts from files other than the manifest via ``extra_fields`` argument.
In the example below, we'll iterate file "questions.txt" together with the manifest and attach each line
under ``cut.question`` using the field type ``text_iter``::
>>> cuts = lhotse.CutSet(LazyNeMoIterator(
... "nemo_manifests/train.json",
... extra_fields=[{"type": "text_iter", "name": "question", "path": "questions.txt"}],
... ))
We also support random sampling of lines with field type ``text_sample``::
>>> cuts = lhotse.CutSet(LazyNeMoIterator(
... "nemo_manifests/train.json",
... extra_fields=[{"type": "text_sample", "name": "question", "path": "questions.txt"}],
... ))
"""

def __init__(
self,
path: str | Path,
path: str | Path | list[str],
text_field: str = "text",
lang_field: str = "lang",
metadata_only: bool = False,
shuffle_shards: bool = False,
shard_seed: int | Literal["randomized", "trng"] = "trng",
extra_fields: list[dict[str, str]] | None = None,
) -> None:
self.path = path
self.shuffle_shards = shuffle_shards
@@ -80,8 +100,13 @@ def __init__(
self.text_field = text_field
self.lang_field = lang_field
self.metadata_only = metadata_only
self.extra_fields = extra_fields
validate_extra_fields(self.extra_fields)

def __iter__(self) -> Generator[Cut, None, None]:
seed = resolve_seed(self.shard_seed)
# Propagate the random seed
extra_fields = [ExtraField.from_dict({"seed": seed, **field_cfg}) for field_cfg in self.extra_fields or ()]
for data in self.source:
audio_path = get_full_path(str(data.pop("audio_filepath")), str(self.path))
duration = data.pop("duration")
@@ -104,6 +129,8 @@ def __iter__(self) -> Generator[Cut, None, None]:
)
)
cut.custom = data
for extra_field in extra_fields:
extra_field.attach_to(cut)
yield cut

def __len__(self) -> int:
@@ -180,20 +207,39 @@ class LazyNeMoTarredIterator:
Example of CutSet with inter-shard shuffling enabled::
>>> cuts = lhotse.CutSet(LazyNeMoTarredIterator(
... manifest_path="nemo_manifests/train.json",
... manifest_path=["nemo_manifests/sharded_manifests/manifest_0.json", ...],
... tar_paths=["nemo_manifests/audio_0.tar", ...],
... shuffle_shards=True,
... ))
We allow attaching custom metadata to cuts from files other than the manifest via ``extra_fields`` argument.
In the example below, we'll iterate file "questions.txt" together with the manifest and attach each line
under ``cut.question`` using the field type ``text_iter``::
>>> cuts = lhotse.CutSet(LazyNeMoTarredIterator(
... manifest_path=["nemo_manifests/sharded_manifests/manifest_0.json", ...],
... tar_paths=["nemo_manifests/audio_0.tar", ...],
... extra_fields=[{"type": "text_iter", "name": "question", "path": "questions.txt"}],
... ))
We also support random sampling of lines with field type ``text_sample``::
>>> cuts = lhotse.CutSet(LazyNeMoTarredIterator(
... manifest_path=["nemo_manifests/sharded_manifests/manifest_0.json", ...],
... tar_paths=["nemo_manifests/audio_0.tar", ...],
... extra_fields=[{"type": "text_sample", "name": "question", "path": "questions.txt"}],
... ))
"""

def __init__(
self,
manifest_path: str | Path,
manifest_path: str | Path | list[str],
tar_paths: str | list,
shuffle_shards: bool = False,
shard_seed: int | Literal["trng", "randomized"] = "trng",
text_field: str = "text",
lang_field: str = "lang",
extra_fields: list[dict[str, str]] | None = None,
) -> None:
self.shard_id_to_manifest: dict[int, Iterable[dict]]
self.paths = expand_sharded_filepaths(manifest_path)
@@ -235,6 +281,7 @@ def __init__(
self.shard_seed = shard_seed
self.text_field = text_field
self.lang_field = lang_field
self.extra_fields = extra_fields
self._validate()

def to_shards(self) -> List["LazyNeMoTarredIterator"]:
@@ -266,6 +313,7 @@ def _validate(self) -> None:
f"* JSON manifest(s) indicate(s) IDs: {sorted(shard_ids_manifest)}\n"
f"* Tar path(s) indicate(s) IDs: {sorted(shard_ids_tars)}\n"
)
validate_extra_fields(self.extra_fields)

@property
def shard_ids(self) -> List[int]:
@@ -274,10 +322,13 @@ def shard_ids(self) -> List[int]:
def __iter__(self) -> Generator[Cut, None, None]:
shard_ids = self.shard_ids

seed = resolve_seed(self.shard_seed)
if self.shuffle_shards:
seed = resolve_seed(self.shard_seed)
random.Random(seed).shuffle(shard_ids)

# Propagate the random seed
extra_fields = [ExtraField.from_dict({"seed": seed, **field_cfg}) for field_cfg in self.extra_fields or ()]

for sid in shard_ids:
manifest_path = self.paths[sid] if len(self.paths) > 1 else self.paths[0]
shard_manifest = {data["audio_filepath"]: data for data in self.shard_id_to_manifest[sid]}
@@ -314,6 +365,8 @@ def __iter__(self) -> Generator[Cut, None, None]:
)
)
cut.custom = _to_custom_attr_dict(data)
for extra_field in extra_fields:
extra_field.attach_to(cut)
yield cut

def __len__(self) -> int:
@@ -323,11 +376,106 @@ def __add__(self, other):
return LazyIteratorChain(self, other)


def expand_sharded_filepaths(path: str | Path) -> list[str]:
class ExtraField:
TYPE = None
SUPPORTED_TYPES = {}

def attach_to(self, cut):
raise NotImplementedError()

def __init_subclass__(cls, **kwargs):
if cls.__name__ not in ExtraField.SUPPORTED_TYPES:
ExtraField.SUPPORTED_TYPES[cls.TYPE] = cls
super().__init_subclass__(**kwargs)

@staticmethod
def from_dict(data: dict) -> "ExtraField":
assert data["type"] in ExtraField.SUPPORTED_TYPES, f"Unknown transform type: {data['type']}"
return ExtraField.SUPPORTED_TYPES[data["type"]](**{k: v for k, v in data.items() if k != 'type'})

@classmethod
def is_supported(cls, field_type: str) -> bool:
return field_type in cls.SUPPORTED_TYPES

@classmethod
def supported_types(cls) -> list[str]:
return list(cls.SUPPORTED_TYPES)


class TextIteratorExtraField(ExtraField):
TYPE = "text_iter"

def __init__(self, name: str, path: str, seed=None):
self.name = name
self.path = path
self.iterator = None

def _maybe_init(self):
if self.iterator is None:
self.iterator = iter(map(str.strip, open_best(self.path)))

def attach_to(self, cut):
self._maybe_init()
try:
attached_value = next(self.iterator)
except StopIteration:
raise RuntimeError(f"Not enough lines in file {self.path} to attach to cuts under field {self.name}.")
setattr(cut, self.name, attached_value)
return cut


class TextSampleExtraField(ExtraField):
TYPE = "text_sample"

def __init__(self, name: str, path: str, seed: int | str):
self.name = name
self.path = path
self.seed = seed
self.population = None
self.rng = None

def _maybe_init(self):
if self.population is None:
self.population = list(map(str.strip, open_best(self.path)))
self.rng = random.Random(resolve_seed(self.seed))

def attach_to(self, cut):
self._maybe_init()
attached_value = self.rng.choice(self.population)
setattr(cut, self.name, attached_value)
return cut


def validate_extra_fields(extra_fields):
if extra_fields is None:
return
assert isinstance(
extra_fields, Sequence
), f"The argument provided to 'extra_fields' must be a list of dicts. We received {extra_fields=}"
for field in extra_fields:
assert isinstance(
field, Mapping
), f"Each item in 'extra_fields' must be a dict. We received {field=} in {extra_fields=}"
field_type = field.get("type")
assert ExtraField.is_supported(field_type), (
f"Each item in 'extra_fields' must contain a 'type' field with one of "
f"the supported values ({ExtraField.supported_types()}). "
f"We got {field_type=} in {extra_fields=}"
)
assert "name" in field, (
f"Each item in 'extra_fields' must contain a 'name' field so that the field is available under cut.<name>."
f"We found {field=} in {extra_fields=}"
)


def expand_sharded_filepaths(paths: str | Path | list[str]) -> list[str]:
# local import to avoid circular imports
from nemo.collections.asr.data.audio_to_text import expand_sharded_filepaths as _expand_sharded_filepaths

return _expand_sharded_filepaths(str(path), shard_strategy="replicate", world_size=1, global_rank=0)
if isinstance(paths, Path):
paths = str(paths)

return _expand_sharded_filepaths(paths, shard_strategy="replicate", world_size=1, global_rank=0)


def _to_custom_attr_dict(d: dict, _excluded_fields: set[str] = {"duration", "audio_filepath"}) -> dict:
Original file line number Diff line number Diff line change
@@ -470,11 +470,10 @@ def __init__(
def create_vision_encoder_and_processor(self, mm_cfg):
# Initialize vision encoder and freeze it
if mm_cfg.vision_encoder.get("from_hf", False):
if (
"clip" in mm_cfg.vision_encoder.from_pretrained
or "vit" in mm_cfg.vision_encoder.from_pretrained
or "clip" in mm_cfg.vision_encoder.get("model_type", "")
):
from transformers import AutoConfig

config = AutoConfig.from_pretrained(mm_cfg.vision_encoder.from_pretrained)
if config.architectures[0] == "CLIPVisionModel":
vision_encoder = CLIPVisionModel.from_pretrained(
mm_cfg.vision_encoder.from_pretrained,
torch_dtype=torch.bfloat16,
@@ -484,9 +483,7 @@ def create_vision_encoder_and_processor(self, mm_cfg):
for param in vision_encoder.parameters():
param.requires_grad = False
vision_encoder = vision_encoder.eval()
elif "siglip" in mm_cfg.vision_encoder.from_pretrained or "siglip" in mm_cfg.vision_encoder.get(
"model_type", ""
):
elif config.architectures[0] == "SiglipVisionModel":
vision_encoder = SiglipVisionModel.from_pretrained(
mm_cfg.vision_encoder.from_pretrained,
torch_dtype=torch.bfloat16,
13 changes: 5 additions & 8 deletions nemo/collections/multimodal/parts/utils.py
Original file line number Diff line number Diff line change
@@ -534,17 +534,14 @@ def expand2square(pil_img, background_color):

def create_image_processor(mm_cfg):
if mm_cfg.vision_encoder.get("from_hf", False):
if (
"clip" in mm_cfg.vision_encoder.from_pretrained
or "vit" in mm_cfg.vision_encoder.from_pretrained
or "clip" in mm_cfg.vision_encoder.get("model_type", "")
):
from transformers import AutoConfig

config = AutoConfig.from_pretrained(mm_cfg.vision_encoder.from_pretrained)
if config.architectures[0] == "CLIPVisionModel":
image_processor = CLIPImageProcessor.from_pretrained(
mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16
)
elif "siglip" in mm_cfg.vision_encoder.from_pretrained or "siglip" in mm_cfg.vision_encoder.get(
"model_type", ""
):
elif config.architectures[0] == "SiglipVisionModel":
image_processor = SiglipImageProcessor.from_pretrained(
mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16
)
21 changes: 20 additions & 1 deletion nemo/deploy/multimodal/query_multimodal.py
Original file line number Diff line number Diff line change
@@ -56,12 +56,31 @@ def setup_media(self, input_media):
vr = VideoReader(input_media)
frames = [f.asnumpy() for f in vr]
return np.array(frames)
elif self.model_type == "neva":
elif self.model_type == "lita" or self.model_type == "vita":
vr = VideoReader(input_media)
frames = [f.asnumpy() for f in vr]
subsample_len = self.frame_len(frames)
sub_frames = self.get_subsampled_frames(frames, subsample_len)
return np.array(sub_frames)
elif self.model_type == "neva" or self.model_type == "vila":
media = Image.open(input_media).convert('RGB')
return np.expand_dims(np.array(media), axis=0)
else:
raise RuntimeError(f"Invalid model type {self.model_type}")

def frame_len(self, frames):
max_frames = 256
if len(frames) <= max_frames:
return len(frames)
else:
subsample = int(np.ceil(float(len(frames)) / max_frames))
return int(np.round(float(len(frames)) / subsample))

def get_subsampled_frames(self, frames, subsample_len):
idx = np.round(np.linspace(0, len(frames) - 1, subsample_len)).astype(int)
sub_frames = [frames[i] for i in idx]
return sub_frames

def query(
self,
input_text,
Loading

0 comments on commit 26868b0

Please sign in to comment.