From 5fc7e452a10b05f4a85f40f9758dafdef2ab1377 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 27 Apr 2026 03:52:47 +0000 Subject: [PATCH 1/7] Multi Image GRPO --- unsloth/models/rl_replacements.py | 67 ++++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 20 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 4d36af62c..39a3ccaa2 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -745,6 +745,7 @@ def _get_per_token_logps_and_entropies( kwargs.get("pixel_attention_mask", None), kwargs.get("image_sizes", None), ) + num_images = kwargs.get("num_images", None) # Transformers 5.x needs token_type_ids/mm_token_type_ids for some vision models token_type_ids = kwargs.get("token_type_ids", None) mm_token_type_ids = kwargs.get("mm_token_type_ids", None) @@ -795,13 +796,17 @@ def _get_per_token_logps_and_entropies( else: max_left_pad = 0 - # input_ids_chunks = torch.chunk(input_ids, chunks = B, dim = 0) - attention_mask_chunks = torch.chunk(attention_mask, chunks = B, dim = 0) + def slice_sample_axis(value, start, end): + if value is None: + return None + return value[start:end] - def chunk_optional(tensor, chunks): - if tensor is None: - return [None] * chunks - return torch.chunk(tensor, chunks = chunks, dim = 0) + def to_num_images_list(value): + if value is None: + return None + if isinstance(value, torch.Tensor): + value = value.detach().cpu().reshape(-1).tolist() + return [int(x) for x in value] import math @@ -813,17 +818,34 @@ def chunk_optional(tensor, chunks): pixel_values_chunks = [] image_grid_thw_chunks = [] pixel_attention_mask_chunks = [] + image_sizes_chunks = [] + token_type_ids_chunks = [] + mm_token_type_ids_chunks = [] current_pixel_idx = 0 + current_image_idx = 0 + num_images_list = to_num_images_list(num_images) # TRL 0.23.0 batching logic for start in range(0, total_samples, batch_size): - end = start + batch_size + end = min(start + batch_size, total_samples) input_ids_chunks.append(input_ids[start:end]) attention_mask_chunks.append(attention_mask[start:end]) + image_sizes_chunks.append(slice_sample_axis(image_sizes, start, end)) + token_type_ids_chunks.append(slice_sample_axis(token_type_ids, start, end)) + mm_token_type_ids_chunks.append( + slice_sample_axis(mm_token_type_ids, start, end) + ) if image_grid_thw is not None and pixel_values is not None: - grid_slice = image_grid_thw[start:end] + if num_images_list is None: + grid_slice = image_grid_thw[start:end] + else: + image_count = sum(num_images_list[start:end]) + image_start = current_image_idx + image_end = current_image_idx + image_count + grid_slice = image_grid_thw[image_start:image_end] + current_image_idx = image_end image_grid_thw_chunks.append(grid_slice) batch_pixel_count = grid_slice.prod(dim = -1).sum().item() @@ -836,9 +858,14 @@ def chunk_optional(tensor, chunks): ) if pixel_attention_mask is not None: - pixel_attention_mask_chunks.append( - pixel_attention_mask[start_pixel_idx:end_pixel_idx] - ) + if pixel_attention_mask.shape[0] == pixel_values.shape[0]: + pixel_attention_mask_chunks.append( + pixel_attention_mask[start_pixel_idx:end_pixel_idx] + ) + else: + pixel_attention_mask_chunks.append( + slice_sample_axis(pixel_attention_mask, start, end) + ) else: pixel_attention_mask_chunks.append(None) @@ -849,11 +876,6 @@ def chunk_optional(tensor, chunks): image_grid_thw_chunks.append(None) pixel_attention_mask_chunks.append(None) - if image_sizes is not None and not isinstance(image_sizes, torch.Tensor): - image_sizes_chunks = [[size] for size in image_sizes] - else: - image_sizes_chunks = chunk_optional(image_sizes, B) - temperature = self.temperature logit_softcapping = _unsloth_get_final_logit_softcapping(model.config) logit_scale_multiply = getattr(model.config, "logit_scale", 0) @@ -863,10 +885,6 @@ def chunk_optional(tensor, chunks): if logit_scale_divide is None: logit_scale_divide = 0 - # Transformers 5.x needs token_type_ids/mm_token_type_ids for some vision models - token_type_ids_chunks = chunk_optional(token_type_ids, B) - mm_token_type_ids_chunks = chunk_optional(mm_token_type_ids, B) - zipped_inputs = zip( input_ids_chunks, attention_mask_chunks, @@ -1069,6 +1087,7 @@ def compute_loss( inputs.get("pixel_attention_mask", None), inputs.get("image_sizes", None), ) + num_images = inputs.get("num_images", None) # Transformers 5.x needs token_type_ids/mm_token_type_ids for some vision models token_type_ids = inputs.get("token_type_ids", None) mm_token_type_ids = inputs.get("mm_token_type_ids", None) @@ -1191,6 +1210,9 @@ def compute_loss( input_ids = _input_ids, pixel_values = pixel_values, image_grid_thw = image_grid_thw, + pixel_attention_mask = pixel_attention_mask, + image_sizes = image_sizes, + num_images = num_images, logits_to_keep = logits_to_keep, completion_mask = completion_mask, advantages = advantages, @@ -1222,6 +1244,11 @@ def compute_loss( grpo_accumulated_loss( trainer = self, input_ids = _input_ids, + pixel_values = pixel_values, + image_grid_thw = image_grid_thw, + pixel_attention_mask = pixel_attention_mask, + image_sizes = image_sizes, + num_images = num_images, logits_to_keep = logits_to_keep, completion_mask = completion_mask, advantages = advantages, From 730d3b927d34b948fd56efb12600c66fa308d25f Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 27 Apr 2026 04:21:01 +0000 Subject: [PATCH 2/7] try matching trl semantics --- unsloth/models/rl_replacements.py | 58 +++++++++++++++---------------- 1 file changed, 28 insertions(+), 30 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 39a3ccaa2..9f654652d 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -801,17 +801,26 @@ def slice_sample_axis(value, start, end): return None return value[start:end] - def to_num_images_list(value): - if value is None: - return None - if isinstance(value, torch.Tensor): - value = value.detach().cpu().reshape(-1).tolist() - return [int(x) for x in value] - import math total_samples = input_ids.shape[0] batch_size = math.ceil(total_samples / B) + if isinstance(num_images, torch.Tensor): + num_images = num_images.detach().cpu().reshape(-1).tolist() + if image_grid_thw is not None and pixel_values is not None and num_images is not None: + rows_per_image = image_grid_thw.prod(dim = -1) + rows_per_sample = torch.split(rows_per_image, num_images) + rows_per_sample = torch.stack([s.sum() for s in rows_per_sample]) + cum_rows = torch.cat( + [ + torch.tensor([0], device = rows_per_sample.device), + rows_per_sample.cumsum(0), + ] + ) + cum_imgs = torch.tensor([0] + num_images).cumsum(0) + else: + cum_rows = None + cum_imgs = None input_ids_chunks = [] attention_mask_chunks = [] @@ -823,8 +832,6 @@ def to_num_images_list(value): mm_token_type_ids_chunks = [] current_pixel_idx = 0 - current_image_idx = 0 - num_images_list = to_num_images_list(num_images) # TRL 0.23.0 batching logic for start in range(0, total_samples, batch_size): end = min(start + batch_size, total_samples) @@ -838,39 +845,30 @@ def to_num_images_list(value): ) if image_grid_thw is not None and pixel_values is not None: - if num_images_list is None: + if num_images is None: grid_slice = image_grid_thw[start:end] + batch_pixel_count = grid_slice.prod(dim = -1).sum().item() + start_pixel_idx = current_pixel_idx + end_pixel_idx = current_pixel_idx + batch_pixel_count + current_pixel_idx = end_pixel_idx else: - image_count = sum(num_images_list[start:end]) - image_start = current_image_idx - image_end = current_image_idx + image_count - grid_slice = image_grid_thw[image_start:image_end] - current_image_idx = image_end + start_pixel_idx = cum_rows[start].item() + end_pixel_idx = cum_rows[end].item() + img_start, img_end = cum_imgs[start], cum_imgs[end] + grid_slice = image_grid_thw[img_start:img_end] image_grid_thw_chunks.append(grid_slice) - batch_pixel_count = grid_slice.prod(dim = -1).sum().item() - - start_pixel_idx = current_pixel_idx - end_pixel_idx = current_pixel_idx + batch_pixel_count - pixel_values_chunks.append( pixel_values[start_pixel_idx:end_pixel_idx] ) if pixel_attention_mask is not None: - if pixel_attention_mask.shape[0] == pixel_values.shape[0]: - pixel_attention_mask_chunks.append( - pixel_attention_mask[start_pixel_idx:end_pixel_idx] - ) - else: - pixel_attention_mask_chunks.append( - slice_sample_axis(pixel_attention_mask, start, end) - ) + pixel_attention_mask_chunks.append( + pixel_attention_mask[start:end] + ) else: pixel_attention_mask_chunks.append(None) - current_pixel_idx = end_pixel_idx - else: pixel_values_chunks.append(None) image_grid_thw_chunks.append(None) From 9c9b945c4e773c581f081107096aac23dcf45532 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Mon, 27 Apr 2026 12:46:18 +0000 Subject: [PATCH 3/7] attn mask for multi image grpo --- unsloth/models/rl_replacements.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 9f654652d..cfd1b4e50 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -863,9 +863,14 @@ def slice_sample_axis(value, start, end): ) if pixel_attention_mask is not None: - pixel_attention_mask_chunks.append( - pixel_attention_mask[start:end] - ) + if pixel_attention_mask.shape[0] == pixel_values.shape[0]: + pixel_attention_mask_chunks.append( + pixel_attention_mask[start_pixel_idx:end_pixel_idx] + ) + else: + pixel_attention_mask_chunks.append( + pixel_attention_mask[start:end] + ) else: pixel_attention_mask_chunks.append(None) From da9ad139e0c0fdf498c3badafba1c78483308bd5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 27 Apr 2026 17:14:13 +0000 Subject: [PATCH 4/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- unsloth/models/rl_replacements.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index cfd1b4e50..d409cde24 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -807,7 +807,11 @@ def slice_sample_axis(value, start, end): batch_size = math.ceil(total_samples / B) if isinstance(num_images, torch.Tensor): num_images = num_images.detach().cpu().reshape(-1).tolist() - if image_grid_thw is not None and pixel_values is not None and num_images is not None: + if ( + image_grid_thw is not None + and pixel_values is not None + and num_images is not None + ): rows_per_image = image_grid_thw.prod(dim = -1) rows_per_sample = torch.split(rows_per_image, num_images) rows_per_sample = torch.stack([s.sum() for s in rows_per_sample]) @@ -839,7 +843,9 @@ def slice_sample_axis(value, start, end): input_ids_chunks.append(input_ids[start:end]) attention_mask_chunks.append(attention_mask[start:end]) image_sizes_chunks.append(slice_sample_axis(image_sizes, start, end)) - token_type_ids_chunks.append(slice_sample_axis(token_type_ids, start, end)) + token_type_ids_chunks.append( + slice_sample_axis(token_type_ids, start, end) + ) mm_token_type_ids_chunks.append( slice_sample_axis(mm_token_type_ids, start, end) ) From 256ff74b126d1ad3d6513d57dcf8264d72eaf4b3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 6 May 2026 12:56:11 +0000 Subject: [PATCH 5/7] Fix multi-image GRPO chunking and zoo guard in rl_replacements image_sizes is now sliced on the image axis (img_start:img_end) when the processor emits one row per image and num_images is provided; sample-axis slicing is kept as the fallback. This restores correct per-batch image_sizes alignment for multi-image VLM processors. pixel_attention_mask now uses a three-way layout check: image-axis when shape[0] matches image_grid_thw rows, pixel-row when shape[0] matches pixel_values rows and is distinct from total_samples, otherwise sample-axis. Prevents misalignment with image-axis grid slicing for per-image masks and ambiguity when single-image-per-sample shapes coincide. cum_imgs slice indices materialize via .item to match the existing cum_rows pattern in the same loop and avoid 0-dim tensors flowing into a CUDA-tensor slice. cum_rows is materialized on CPU once after construction; the per-chunk loop uses .item on it, so keeping it on device caused a GPU->CPU sync per iteration. Add a one-time fail-loud guard in compute_loss when num_images is provided but the resolved grpo_accumulated_loss source has no num_images handling, pointing users at the corresponding unsloth_zoo upgrade. The active GRPO path goes through grpo_accumulated_loss (the local _get_per_token_logps and _get_per_token_logps_and_entropies return None on the efficient path), so without this guard a stale unsloth_zoo silently mis-slices multi-image batches. --- unsloth/models/rl_replacements.py | 81 ++++++++++++++++++++++++++----- 1 file changed, 69 insertions(+), 12 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 0611b6336..73ce06dd2 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -1119,17 +1119,32 @@ def slice_sample_axis(value, start, end): rows_per_image = image_grid_thw.prod(dim = -1) rows_per_sample = torch.split(rows_per_image, num_images) rows_per_sample = torch.stack([s.sum() for s in rows_per_sample]) + # why: cum_rows is indexed via .item() inside the per-chunk loop; + # keeping it on CPU avoids per-iteration GPU->CPU sync. cum_rows = torch.cat( [ torch.tensor([0], device = rows_per_sample.device), rows_per_sample.cumsum(0), ] - ) + ).cpu() cum_imgs = torch.tensor([0] + num_images).cumsum(0) else: cum_rows = None cum_imgs = None + def _first_dim_len(value): + if value is None: + return None + if hasattr(value, "shape"): + return value.shape[0] + try: + return len(value) + except TypeError: + return None + + total_images = sum(num_images) if num_images is not None else None + _image_sizes_n = _first_dim_len(image_sizes) + input_ids_chunks = [] attention_mask_chunks = [] pixel_values_chunks = [] @@ -1146,7 +1161,6 @@ def slice_sample_axis(value, start, end): input_ids_chunks.append(input_ids[start:end]) attention_mask_chunks.append(attention_mask[start:end]) - image_sizes_chunks.append(slice_sample_axis(image_sizes, start, end)) token_type_ids_chunks.append( slice_sample_axis(token_type_ids, start, end) ) @@ -1161,10 +1175,12 @@ def slice_sample_axis(value, start, end): start_pixel_idx = current_pixel_idx end_pixel_idx = current_pixel_idx + batch_pixel_count current_pixel_idx = end_pixel_idx + img_start = img_end = None else: start_pixel_idx = cum_rows[start].item() end_pixel_idx = cum_rows[end].item() - img_start, img_end = cum_imgs[start], cum_imgs[end] + img_start = cum_imgs[start].item() + img_end = cum_imgs[end].item() grid_slice = image_grid_thw[img_start:img_end] image_grid_thw_chunks.append(grid_slice) @@ -1172,22 +1188,49 @@ def slice_sample_axis(value, start, end): pixel_values[start_pixel_idx:end_pixel_idx] ) - if pixel_attention_mask is not None: - if pixel_attention_mask.shape[0] == pixel_values.shape[0]: - pixel_attention_mask_chunks.append( - pixel_attention_mask[start_pixel_idx:end_pixel_idx] - ) - else: - pixel_attention_mask_chunks.append( - pixel_attention_mask[start:end] - ) + if image_sizes is None: + image_sizes_chunks.append(None) + elif ( + num_images is not None + and _image_sizes_n == total_images + and img_start is not None + ): + image_sizes_chunks.append(image_sizes[img_start:img_end]) else: + image_sizes_chunks.append( + slice_sample_axis(image_sizes, start, end) + ) + + if pixel_attention_mask is None: pixel_attention_mask_chunks.append(None) + elif ( + num_images is not None + and img_start is not None + and pixel_attention_mask.shape[0] + == image_grid_thw.shape[0] + ): + pixel_attention_mask_chunks.append( + pixel_attention_mask[img_start:img_end] + ) + elif ( + pixel_attention_mask.shape[0] == pixel_values.shape[0] + and pixel_attention_mask.shape[0] != input_ids.shape[0] + ): + pixel_attention_mask_chunks.append( + pixel_attention_mask[start_pixel_idx:end_pixel_idx] + ) + else: + pixel_attention_mask_chunks.append( + pixel_attention_mask[start:end] + ) else: pixel_values_chunks.append(None) image_grid_thw_chunks.append(None) pixel_attention_mask_chunks.append(None) + image_sizes_chunks.append( + slice_sample_axis(image_sizes, start, end) + ) temperature = self.temperature logit_softcapping = _unsloth_get_final_logit_softcapping(model.config) @@ -1518,6 +1561,20 @@ def compute_loss( num_processes = num_processes, ) else: + if num_images is not None and not getattr( + self, "_unsloth_grpo_zoo_checked", False + ): + try: + _zoo_src = inspect.getsource(grpo_accumulated_loss) + except (TypeError, OSError): + _zoo_src = "" + if _zoo_src and "num_images" not in _zoo_src: + raise RuntimeError( + "Multi-image GRPO requires an unsloth_zoo build whose " + "grpo_accumulated_loss handles num_images. Please upgrade " + "unsloth_zoo (see https://github.com/unslothai/unsloth-zoo/pull/613)." + ) + self._unsloth_grpo_zoo_checked = True if hasattr(self.args, "loss_type"): ( loss, From d159e81c5ea618d5095f412ca0141e32f0730356 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 6 May 2026 13:21:59 +0000 Subject: [PATCH 6/7] Tighten multi-image GRPO zoo compatibility guard Only raise the zoo upgrade error when at least one entry in num_images is not 1. Upstream TRL emits num_images=[1,1,...] for any vision batch (one image per sample), and old unsloth_zoo builds chunk those correctly because sample-axis and image-axis slicing coincide for all-ones counts. Restricting the check to batches with a real multi-image sample stops single-image VLM GRPO from being needlessly broken on pre-companion zoo installs. Prefer inspect.signature(grpo_accumulated_loss).parameters for the num_images contract. Fall back to inspect.getsource string matching only when the signature does not declare num_images (e.g. the companion zoo wires it through **kwargs). The previous try/except (TypeError, OSError) over getsource turned the guard into a silent no-op when source files were absent; the new flow raises in that case because the signature check will not have proven support either. --- unsloth/models/rl_replacements.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 73ce06dd2..c285350f5 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -1561,14 +1561,29 @@ def compute_loss( num_processes = num_processes, ) else: - if num_images is not None and not getattr( + def _unsloth_requires_multi_image_zoo(value): + if value is None: + return False + if isinstance(value, torch.Tensor): + counts = value.detach().cpu().reshape(-1).tolist() + else: + counts = list(value) + return any(int(n) != 1 for n in counts) + + if _unsloth_requires_multi_image_zoo(num_images) and not getattr( self, "_unsloth_grpo_zoo_checked", False ): - try: - _zoo_src = inspect.getsource(grpo_accumulated_loss) - except (TypeError, OSError): - _zoo_src = "" - if _zoo_src and "num_images" not in _zoo_src: + _supports_num_images = ( + "num_images" + in inspect.signature(grpo_accumulated_loss).parameters + ) + if not _supports_num_images: + try: + _zoo_src = inspect.getsource(grpo_accumulated_loss) + except (TypeError, OSError): + _zoo_src = "" + _supports_num_images = "num_images" in _zoo_src + if not _supports_num_images: raise RuntimeError( "Multi-image GRPO requires an unsloth_zoo build whose " "grpo_accumulated_loss handles num_images. Please upgrade " From 822bff90a332cd50c0b89ba1306af47735cb2459 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 6 May 2026 13:26:17 +0000 Subject: [PATCH 7/7] Consolidate multi-image GRPO chunking and zoo guard tests --- tests/test_multi_image_grpo_chunking.py | 188 ++++++++++++++++++++++++ 1 file changed, 188 insertions(+) create mode 100644 tests/test_multi_image_grpo_chunking.py diff --git a/tests/test_multi_image_grpo_chunking.py b/tests/test_multi_image_grpo_chunking.py new file mode 100644 index 000000000..fb86819bf --- /dev/null +++ b/tests/test_multi_image_grpo_chunking.py @@ -0,0 +1,188 @@ +"""Static + behavioral checks for the multi-image GRPO chunking and +zoo compatibility guard in unsloth/models/rl_replacements.py.""" + +from __future__ import annotations + +import math +import os +import re + +REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) +SOURCE_PATH = os.path.join(REPO_ROOT, "unsloth", "models", "rl_replacements.py") + + +def _read_source() -> str: + with open(SOURCE_PATH, "r") as fh: + return fh.read() + + +# ---------- Per-chunk slicing fixes (cum_rows, cum_imgs, axes) ---------- + + +def test_cum_rows_materialized_on_cpu(): + src = _read_source() + idx = src.find("cum_rows = torch.cat") + assert idx != -1, "cum_rows assignment must exist" + window = src[idx: idx + 400] + assert "rows_per_sample.cumsum(0)" in window + assert ").cpu()" in window, ( + "cum_rows must be moved to CPU once via .cpu() after construction" + ) + + +def test_cum_imgs_slice_indices_use_item(): + src = _read_source() + assert "cum_imgs[start].item()" in src + assert "cum_imgs[end].item()" in src + + +def test_image_sizes_image_axis_branch_present(): + src = _read_source() + assert "image_sizes[img_start:img_end]" in src + assert "_image_sizes_n" in src and "total_images" in src + + +def test_pixel_attention_mask_three_way_check_present(): + src = _read_source() + assert "pixel_attention_mask[img_start:img_end]" in src + assert "pixel_attention_mask[start_pixel_idx:end_pixel_idx]" in src + assert "pixel_attention_mask[start:end]" in src + assert "image_grid_thw.shape[0]" in src + + +def test_image_sizes_chunked_after_branch_decision(): + src = _read_source() + pattern = re.compile( + r"attention_mask_chunks\.append\(attention_mask\[start:end\]\)\s*\n\s*" + r"image_sizes_chunks\.append\(slice_sample_axis\(image_sizes,\s*start,\s*end\)\)", + ) + assert pattern.search(src) is None, ( + "image_sizes_chunks must not be appended unconditionally on the " + "sample axis above the if/else; the axis is chosen per branch" + ) + + +# ---------- Behavioral simulation of chunk math ---------- + + +def _simulate_chunk_indices(num_images, B): + total_samples = len(num_images) + batch_size = max(1, math.ceil(total_samples / B)) + cum_imgs = [0] + for n in num_images: + cum_imgs.append(cum_imgs[-1] + n) + chunks = [] + for start in range(0, total_samples, batch_size): + end = min(start + batch_size, total_samples) + chunks.append((start, end, cum_imgs[start], cum_imgs[end])) + return chunks + + +def test_simulate_multi_image_chunk_image_axis_correct(): + chunks = _simulate_chunk_indices([2, 1, 3, 1], B=2) + assert chunks == [(0, 2, 0, 3), (2, 4, 3, 7)] + + +def test_simulate_uniform_image_chunking_unchanged(): + chunks = _simulate_chunk_indices([1, 1, 1, 1], B=2) + assert chunks == [(0, 2, 0, 2), (2, 4, 2, 4)] + + +def test_simulate_pixel_attention_mask_axis_decision(): + def select_axis(pam_shape0, pixel_values_shape0, image_grid_thw_shape0, + input_ids_shape0, num_images_provided): + if num_images_provided and pam_shape0 == image_grid_thw_shape0: + return "image" + if pam_shape0 == pixel_values_shape0 and pam_shape0 != input_ids_shape0: + return "pixel" + return "sample" + + assert select_axis(3, 9, 3, 2, True) == "image" + assert select_axis(9, 9, 3, 2, True) == "pixel" + assert select_axis(4, 4, 4, 4, False) == "sample" + assert select_axis(2, 2, 2, 2, False) == "sample" + + +# ---------- Zoo compatibility guard ---------- + + +def test_zoo_guard_branch_present(): + src = _read_source() + assert "_unsloth_grpo_zoo_checked" in src + assert "raise RuntimeError" in src + assert "https://github.com/unslothai/unsloth-zoo/pull/613" in src + assert "Multi-image GRPO" in src + + +def test_guard_helper_skips_all_ones_num_images(): + src = _read_source() + helper_match = re.search( + r"def _unsloth_requires_multi_image_zoo\(value\):.*?return any\(int\(n\) != 1 for n in counts\)", + src, + re.DOTALL, + ) + assert helper_match, "guard helper must compute any(int(n) != 1)" + namespace: dict = {} + + class _FakeTensor: + def __init__(self, values): + self._values = list(values) + + def detach(self): + return self + + def cpu(self): + return self + + def reshape(self, *_args, **_kwargs): + return self + + def tolist(self): + return list(self._values) + + namespace["torch"] = type("torch_stub", (), {"Tensor": _FakeTensor})() + exec(helper_match.group(0), namespace) + helper = namespace["_unsloth_requires_multi_image_zoo"] + + assert helper(None) is False + assert helper([1, 1, 1, 1]) is False + assert helper([2, 1]) is True + assert helper([0, 1, 1]) is True + assert helper(_FakeTensor([1, 1, 1])) is False + assert helper(_FakeTensor([2, 1])) is True + + +def test_guard_prefers_inspect_signature_over_getsource(): + src = _read_source() + helper_idx = src.find("_unsloth_requires_multi_image_zoo") + body = src[helper_idx:] + sig_call = body.find("inspect.signature(grpo_accumulated_loss).parameters") + src_call = body.find("inspect.getsource(grpo_accumulated_loss)") + assert sig_call != -1 + assert src_call != -1 + assert sig_call < src_call, ( + "signature.parameters must run before the getsource fallback" + ) + + +def test_guard_only_raises_when_both_checks_fail(): + src = _read_source() + pattern = re.compile( + r"_supports_num_images\s*=\s*\(\s*\"num_images\"\s*\n?\s*in\s+inspect\.signature.*?" + r"if not _supports_num_images:.*?_supports_num_images\s*=\s*\"num_images\" in _zoo_src.*?" + r"if not _supports_num_images:\s*\n\s*raise RuntimeError", + re.DOTALL, + ) + assert pattern.search(src), ( + "guard flow must be: signature check, source fallback, then raise" + ) + + +def test_guard_introspection_failure_does_not_silent_no_op(): + src = _read_source() + assert "(TypeError, OSError)" in src, ( + "guard must catch inspect.getsource failures explicitly" + ) + assert re.search(r"_zoo_src\s*=\s*['\"]{2}", src), ( + "introspection failure path must default _zoo_src to empty string" + )