Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 1 addition & 20 deletions cosmos_predict2/_src/predict2/inference/get_umt5_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from typing import List, Optional, Union

import ftfy
import pytest
import regex as re
import torch
import torch.distributed.checkpoint as dcp
Expand All @@ -31,7 +30,7 @@
from transformers import AutoTokenizer

from cosmos_predict2._src.imaginaire.checkpointer.s3_filesystem import S3StorageReader
from cosmos_predict2._src.imaginaire.utils import distributed, log, misc
from cosmos_predict2._src.imaginaire.utils import distributed, log
from cosmos_predict2._src.imaginaire.utils.easy_io import easy_io

"""
Expand Down Expand Up @@ -614,21 +613,3 @@ def get_negative_emb():
emb = get_umt5_embedding(neg_prompt).to(dtype=torch.bfloat16).cpu()
print(emb.shape)
easy_io.dump(emb[0], "s3://bucket/cosmos_diffusion_v2/pretrain_weights/umT5_wan_negative_emb.pt")


@pytest.mark.L2
def test_encoder():
with misc.timer("load model"):
model = UMT5EncoderModel(
checkpoint_path="s3://bucket/cosmos_diffusion_v2/pretrain_weights/models_t5_umt5-xxl-enc-bf16.pth"
)
emb = model(texts=["hello world", "hello", "world"])
assert len(emb) == 3
assert emb[0].shape == (512, 4096)
assert emb[1].shape == (512, 4096)
assert emb[2].shape == (512, 4096)


if __name__ == "__main__":
test_encoder()
# get_negative_emb()
2 changes: 1 addition & 1 deletion cosmos_predict2/_src/predict2/networks/minimal_v4_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
CheckpointPolicy = None

from torchvision import transforms
from transformer_engine.pytorch.attention import apply_rotary_pos_emb
from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb

from cosmos_predict2._src.imaginaire.utils import log
from cosmos_predict2._src.predict2.conditioner import DataType
Expand Down
12 changes: 4 additions & 8 deletions cosmos_predict2/_src/reason1/networks/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,11 @@
from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from transformers.generation import GenerationMixin
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available

try:
from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available

if is_flash_attn_available():
from transformers.modeling_flash_attention_utils import _flash_attention_forward, flash_attn_varlen_func
except ImportError:
print("Transformer version too old, flash_attn_supports_top_left_mask is not available.")
is_flash_attn_available = False
if is_flash_attn_available():
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.modeling_flash_attention_utils import _flash_varlen_fn as flash_attn_varlen_func
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.modeling_utils import PreTrainedModel
Expand Down
12 changes: 7 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ dependencies = [
"loguru>=0.7.3",
"mediapy>=1.2.4",
"megatron-core>=0.12.1",
"ml-dtypes==0.5.3",
"modelscope>=1.26.0",
"nltk>=3.9.1",
"numpy>=1.26.4",
Expand All @@ -65,7 +66,7 @@ dependencies = [
"pillow>=11.1.0",
"protobuf>=4.25.3",
"pycocotools>=2.0.10",
"pynvml>=12.0.0",
"pytest>=8.4.2",
"pyyaml>=6.0.2",
"qwen-vl-utils[decord]>=0.0.11",
"retinaface-py>=0.0.2",
Expand All @@ -78,18 +79,19 @@ dependencies = [
"transformers>=4.51.3",
"triton>=3.2.0",
"tyro>=0.9.32",
"wandb>=0.22.2",
"webdataset>=0.2.111",
]

[project.optional-dependencies]
# This must mirror dependency-groups
cu128_torch271 = [
"apex==0.1.0+cu128.torch271",
"flash-attn==2.8.3+cu12torch2.7cxx11abiTRUE",
"flash-attn==2.8.1+cu12torch2.7cxx11abiTRUE",
"natten==0.21.0+cu128.torch271",
"torch==2.7.1+cu128",
"torchvision==0.22.1+cu128",
"transformer-engine==1.13+cu128.torch271",
"transformer-engine==2.8.0+cu128.torch271",
# Torch dependencies
# Dependencies determined from `uv pip install "torch==2.7.1+cu128" "torchvision==0.22.1+cu128" --index-url https://download.pytorch.org/whl`
# Issue: https://github.com/astral-sh/uv/issues/14237
Expand Down Expand Up @@ -124,7 +126,7 @@ cu128_torch271 = [
"natten==0.21.0+cu128.torch271",
"torch==2.7.1+cu128",
"torchvision==0.22.1+cu128",
"transformer-engine==1.13+cu128.torch271",
"transformer-engine==2.8.0+cu128.torch271",
# Torch dependencies
# Dependencies determined from `uv pip install "torch==2.7.1+cu128" "torchvision==0.22.1+cu128" --index-url https://download.pytorch.org/whl`
# Issue: https://github.com/astral-sh/uv/issues/14237
Expand Down Expand Up @@ -171,7 +173,7 @@ no-build-package = [

[tool.uv.sources]
apex = [{ index = "cosmos" }]
flash-attn = [{ url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl"}]
flash-attn = [{ url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.1/flash_attn-2.8.1+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl"}]
natten = [{ index = "cosmos" }]
transformer-engine = [{ index = "cosmos" }]
torch = [{ index = "cosmos" }]
Expand Down
Loading