Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 45 additions & 13 deletions benchmark/mmmu/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def set_seed(seed_value):


def prepare_samples(eval_args: EvalArgs):
print("preparing samples...")
print("Preparing samples...")
# Build prompts
set_seed(eval_args.seed)

Expand All @@ -105,15 +105,40 @@ def prepare_samples(eval_args: EvalArgs):
assert len(value) == 1, "key {} has more than one value".format(key)
eval_args.config[key] = value[0]

# run for each subject
# run for each subject in parallel
sub_dataset_list = []
subjects = list(CAT_SHORT2LONG.values()) # Get a fixed list of subjects

for subject in tqdm(CAT_SHORT2LONG.values()):
sub_dataset = load_dataset(
eval_args.dataset_path, subject, split=eval_args.split
)
sub_dataset_list.append(sub_dataset)
# break
print(f"Loading datasets for {len(subjects)} subjects...")
with ThreadPoolExecutor() as executor:
# Submit all load_dataset tasks
future_to_subject = {
executor.submit(
load_dataset, eval_args.dataset_path, subject, split=eval_args.split
): subject
for subject in subjects
}

# Collect results as they complete
results = {}
for future in tqdm(
as_completed(future_to_subject),
total=len(subjects),
desc="Loading datasets",
):
subject = future_to_subject[future]
try:
results[subject] = future.result()
except Exception as exc:
print(f"{subject} generated an exception: {exc}")

# Ensure datasets are added in the original order for consistency
for subject in subjects:
if subject in results:
sub_dataset_list.append(results[subject])
else:
# Handle cases where a dataset failed to load (optional, depends on desired behavior)
print(f"Warning: Dataset for subject '{subject}' could not be loaded.")

# merge all dataset
dataset = concatenate_datasets(sub_dataset_list)
Expand All @@ -133,28 +158,35 @@ def process_sample(i, sample):
width, height = image.size
if width * height >= eval_args.image_pixels_limit:
return None, True
image_path = f"{images_path}/image_{i}.png"
# Use a unique identifier for the image path to avoid potential collisions if indices reset
image_path = f"{images_path}/image_{sample['id']}.png"
if not os.path.exists(image_path):
image.save(image_path)
sample["image_path"] = image_path
return sample, False

print("Processing samples...")
with ThreadPoolExecutor() as executor:
# Pass the sample itself to process_sample, index is less reliable now
futures = [
executor.submit(process_sample, i, sample)
executor.submit(
process_sample, i, sample
) # Keep index i for tqdm maybe? Or remove it. Let's keep it for now.
for i, sample in enumerate(dataset)
]
for future in tqdm(as_completed(futures), total=len(futures)):
for future in tqdm(
as_completed(futures), total=len(dataset), desc="Processing samples"
):
sample, skipped = future.result()
if skipped:
skip_count += 1
elif sample:
samples.append(sample)

print(
f"skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset"
f"Skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset"
)
print("samples have been prepared")
print("Samples have been prepared")
return samples


Expand Down
13 changes: 6 additions & 7 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,14 @@ def __init__(
)

if enable_multimodal is None:
if self.hf_config.architectures[0] == "Llama4ForConditionalGeneration":
mm_disabled_models = [
"Gemma3ForConditionalGeneration",
"Llama4ForConditionalGeneration",
]
if self.hf_config.architectures[0] in mm_disabled_models:
enable_multimodal = False
logger.info(
"Multimodal is disabled for Llama4. To enable it, set --enable-llama4-multimodal."
)
elif self.hf_config.architectures[0] == "Gemma3ForConditionalGeneration":
enable_multimodal = False
logger.info(
"Multimodal is disabled for Gemma3. To enable it, set --enable-gemma3-multimodal."
f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
)
else:
enable_multimodal = True
Expand Down
264 changes: 150 additions & 114 deletions python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,127 +877,163 @@ def forward(
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key

# Copied from https://github.com/huggingface/transformers/blob/c8e0e603de9b3d49161a15fe6e8ea84badfb5d02/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1439
@staticmethod
def get_input_positions(
input_tokens: List[int],
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
def get_rope_index(
spatial_merge_size: int,
image_token_id: int,
video_token_id: int,
vision_start_token_id: int,
vision_end_token_id: int,
spatial_merge_size: int,
context_len: int = 0,
seq_len: Optional[int] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
model_type: str,
tokens_per_second: Optional[int] = None,
) -> Tuple[List[List[int]], int]:
"""
Get mrope input positions and delta value.

:arg
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.

"""

if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
if isinstance(video_grid_thw, torch.Tensor):
video_grid_thw = video_grid_thw.tolist()

input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id
).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = []

st = 0
remain_images, remain_videos = image_nums, video_nums

image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
second_per_grid_t = 0
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
if second_per_grid_ts is not None:
second_per_grid_t = second_per_grid_ts[video_index]
else:
second_per_grid_t = 1.0
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
text_len = ed - st

st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)

t_index = (
torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
* second_per_grid_t
* tokens_per_second
).flatten()

h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
mrope_position_deltas = []
if input_ids is not None and (
image_grid_thw is not None or video_grid_thw is not None
):
total_input_ids = input_ids
position_ids = torch.ones(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=input_ids.dtype,
device=input_ids.device,
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w

if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
image_index, video_index = 0, 0
for i, input_ids in enumerate(total_input_ids):
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(
input_ids == vision_start_token_id
).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
second_per_grid_t = 0
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
if second_per_grid_ts is not None:
second_per_grid_t = second_per_grid_ts[video_index]
else:
second_per_grid_t = 1.0
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
text_len = ed - st

st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)

if model_type == "qwen2_5_vl":
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
expanded_range = range_tensor.expand(
-1, llm_grid_h * llm_grid_w
)

time_tensor = (
expanded_range * second_per_grid_t * tokens_per_second
)

time_tensor_long = time_tensor.long()
t_index = time_tensor_long.flatten()
elif model_type == "qwen2_vl":
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
else:
raise RuntimeError("Unimplemented")
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w

if st < len(input_tokens):
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)

llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, :] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(
llm_positions.max() + 1 - len(total_input_ids[i])
)
mrope_position_deltas = torch.tensor(
mrope_position_deltas, device=input_ids.device
).unsqueeze(1)
return position_ids, mrope_position_deltas
else:
s = input_ids.shape[1]
position_ids = torch.arange(s)
position_ids = (
position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
)

llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]

return llm_positions.tolist(), mrope_position_delta
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
-1, keepdim=True
)[0]
mrope_position_deltas = max_position_ids + 1 - s
return position_ids, mrope_position_deltas

@staticmethod
def get_next_input_positions(
Expand Down
Loading
Loading