diff --git a/tests/test_multi_image_grpo_chunking.py b/tests/test_multi_image_grpo_chunking.py new file mode 100644 index 0000000000..0bd876b664 --- /dev/null +++ b/tests/test_multi_image_grpo_chunking.py @@ -0,0 +1,193 @@ +"""Static + behavioral checks for the multi-image GRPO chunking and +zoo compatibility guard in unsloth/models/rl_replacements.py.""" + +from __future__ import annotations + +import math +import os +import re + +REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) +SOURCE_PATH = os.path.join(REPO_ROOT, "unsloth", "models", "rl_replacements.py") + + +def _read_source() -> str: + with open(SOURCE_PATH, "r") as fh: + return fh.read() + + +# ---------- Per-chunk slicing fixes (cum_rows, cum_imgs, axes) ---------- + + +def test_cum_rows_materialized_on_cpu(): + src = _read_source() + idx = src.find("cum_rows = torch.cat") + assert idx != -1, "cum_rows assignment must exist" + window = src[idx : idx + 400] + assert "rows_per_sample.cumsum(0)" in window + assert ( + ").cpu()" in window + ), "cum_rows must be moved to CPU once via .cpu() after construction" + + +def test_cum_imgs_slice_indices_use_item(): + src = _read_source() + assert "cum_imgs[start].item()" in src + assert "cum_imgs[end].item()" in src + + +def test_image_sizes_image_axis_branch_present(): + src = _read_source() + assert "image_sizes[img_start:img_end]" in src + assert "_image_sizes_n" in src and "total_images" in src + + +def test_pixel_attention_mask_three_way_check_present(): + src = _read_source() + assert "pixel_attention_mask[img_start:img_end]" in src + assert "pixel_attention_mask[start_pixel_idx:end_pixel_idx]" in src + assert "pixel_attention_mask[start:end]" in src + assert "image_grid_thw.shape[0]" in src + + +def test_image_sizes_chunked_after_branch_decision(): + src = _read_source() + pattern = re.compile( + r"attention_mask_chunks\.append\(attention_mask\[start:end\]\)\s*\n\s*" + r"image_sizes_chunks\.append\(slice_sample_axis\(image_sizes,\s*start,\s*end\)\)", + ) + assert pattern.search(src) is None, ( + "image_sizes_chunks must not be appended unconditionally on the " + "sample axis above the if/else; the axis is chosen per branch" + ) + + +# ---------- Behavioral simulation of chunk math ---------- + + +def _simulate_chunk_indices(num_images, B): + total_samples = len(num_images) + batch_size = max(1, math.ceil(total_samples / B)) + cum_imgs = [0] + for n in num_images: + cum_imgs.append(cum_imgs[-1] + n) + chunks = [] + for start in range(0, total_samples, batch_size): + end = min(start + batch_size, total_samples) + chunks.append((start, end, cum_imgs[start], cum_imgs[end])) + return chunks + + +def test_simulate_multi_image_chunk_image_axis_correct(): + chunks = _simulate_chunk_indices([2, 1, 3, 1], B = 2) + assert chunks == [(0, 2, 0, 3), (2, 4, 3, 7)] + + +def test_simulate_uniform_image_chunking_unchanged(): + chunks = _simulate_chunk_indices([1, 1, 1, 1], B = 2) + assert chunks == [(0, 2, 0, 2), (2, 4, 2, 4)] + + +def test_simulate_pixel_attention_mask_axis_decision(): + def select_axis( + pam_shape0, + pixel_values_shape0, + image_grid_thw_shape0, + input_ids_shape0, + num_images_provided, + ): + if num_images_provided and pam_shape0 == image_grid_thw_shape0: + return "image" + if pam_shape0 == pixel_values_shape0 and pam_shape0 != input_ids_shape0: + return "pixel" + return "sample" + + assert select_axis(3, 9, 3, 2, True) == "image" + assert select_axis(9, 9, 3, 2, True) == "pixel" + assert select_axis(4, 4, 4, 4, False) == "sample" + assert select_axis(2, 2, 2, 2, False) == "sample" + + +# ---------- Zoo compatibility guard ---------- + + +def test_zoo_guard_branch_present(): + src = _read_source() + assert "_unsloth_grpo_zoo_checked" in src + assert "raise RuntimeError" in src + assert "https://github.com/unslothai/unsloth-zoo/pull/613" in src + assert "Multi-image GRPO" in src + + +def test_guard_helper_skips_all_ones_num_images(): + src = _read_source() + helper_match = re.search( + r"def _unsloth_requires_multi_image_zoo\(value\):.*?return any\(int\(n\) != 1 for n in counts\)", + src, + re.DOTALL, + ) + assert helper_match, "guard helper must compute any(int(n) != 1)" + namespace: dict = {} + + class _FakeTensor: + def __init__(self, values): + self._values = list(values) + + def detach(self): + return self + + def cpu(self): + return self + + def reshape(self, *_args, **_kwargs): + return self + + def tolist(self): + return list(self._values) + + namespace["torch"] = type("torch_stub", (), {"Tensor": _FakeTensor})() + exec(helper_match.group(0), namespace) + helper = namespace["_unsloth_requires_multi_image_zoo"] + + assert helper(None) is False + assert helper([1, 1, 1, 1]) is False + assert helper([2, 1]) is True + assert helper([0, 1, 1]) is True + assert helper(_FakeTensor([1, 1, 1])) is False + assert helper(_FakeTensor([2, 1])) is True + + +def test_guard_prefers_inspect_signature_over_getsource(): + src = _read_source() + helper_idx = src.find("_unsloth_requires_multi_image_zoo") + body = src[helper_idx:] + sig_call = body.find("inspect.signature(grpo_accumulated_loss).parameters") + src_call = body.find("inspect.getsource(grpo_accumulated_loss)") + assert sig_call != -1 + assert src_call != -1 + assert ( + sig_call < src_call + ), "signature.parameters must run before the getsource fallback" + + +def test_guard_only_raises_when_both_checks_fail(): + src = _read_source() + pattern = re.compile( + r"_supports_num_images\s*=\s*\(\s*\"num_images\"\s*\n?\s*in\s+inspect\.signature.*?" + r"if not _supports_num_images:.*?_supports_num_images\s*=\s*\"num_images\" in _zoo_src.*?" + r"if not _supports_num_images:\s*\n\s*raise RuntimeError", + re.DOTALL, + ) + assert pattern.search( + src + ), "guard flow must be: signature check, source fallback, then raise" + + +def test_guard_introspection_failure_does_not_silent_no_op(): + src = _read_source() + assert ( + "(TypeError, OSError)" in src + ), "guard must catch inspect.getsource failures explicitly" + assert re.search( + r"_zoo_src\s*=\s*['\"]{2}", src + ), "introspection failure path must default _zoo_src to empty string" diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 0f10847282..41c0b4492b 100755 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -1045,6 +1045,7 @@ def _get_per_token_logps_and_entropies( kwargs.get("pixel_attention_mask", None), kwargs.get("image_sizes", None), ) + num_images = kwargs.get("num_images", None) # Transformers 5.x needs token_type_ids/mm_token_type_ids for some vision models token_type_ids = kwargs.get("token_type_ids", None) mm_token_type_ids = kwargs.get("mm_token_type_ids", None) @@ -1099,64 +1100,136 @@ def _get_per_token_logps_and_entropies( else: max_left_pad = 0 - # input_ids_chunks = torch.chunk(input_ids, chunks = B, dim = 0) - attention_mask_chunks = torch.chunk(attention_mask, chunks = B, dim = 0) - - def chunk_optional(tensor, chunks): - if tensor is None: - return [None] * chunks - return torch.chunk(tensor, chunks = chunks, dim = 0) + def slice_sample_axis(value, start, end): + if value is None: + return None + return value[start:end] import math total_samples = input_ids.shape[0] batch_size = math.ceil(total_samples / B) + if isinstance(num_images, torch.Tensor): + num_images = num_images.detach().cpu().reshape(-1).tolist() + if ( + image_grid_thw is not None + and pixel_values is not None + and num_images is not None + ): + rows_per_image = image_grid_thw.prod(dim = -1) + rows_per_sample = torch.split(rows_per_image, num_images) + rows_per_sample = torch.stack([s.sum() for s in rows_per_sample]) + # why: cum_rows is indexed via .item() inside the per-chunk loop; + # keeping it on CPU avoids per-iteration GPU->CPU sync. + cum_rows = torch.cat( + [ + torch.tensor([0], device = rows_per_sample.device), + rows_per_sample.cumsum(0), + ] + ).cpu() + cum_imgs = torch.tensor([0] + num_images).cumsum(0) + else: + cum_rows = None + cum_imgs = None + + def _first_dim_len(value): + if value is None: + return None + if hasattr(value, "shape"): + return value.shape[0] + try: + return len(value) + except TypeError: + return None + + total_images = sum(num_images) if num_images is not None else None + _image_sizes_n = _first_dim_len(image_sizes) input_ids_chunks = [] attention_mask_chunks = [] pixel_values_chunks = [] image_grid_thw_chunks = [] pixel_attention_mask_chunks = [] + image_sizes_chunks = [] + token_type_ids_chunks = [] + mm_token_type_ids_chunks = [] current_pixel_idx = 0 # TRL 0.23.0 batching logic for start in range(0, total_samples, batch_size): - end = start + batch_size + end = min(start + batch_size, total_samples) input_ids_chunks.append(input_ids[start:end]) attention_mask_chunks.append(attention_mask[start:end]) + token_type_ids_chunks.append( + slice_sample_axis(token_type_ids, start, end) + ) + mm_token_type_ids_chunks.append( + slice_sample_axis(mm_token_type_ids, start, end) + ) if image_grid_thw is not None and pixel_values is not None: - grid_slice = image_grid_thw[start:end] + if num_images is None: + grid_slice = image_grid_thw[start:end] + batch_pixel_count = grid_slice.prod(dim = -1).sum().item() + start_pixel_idx = current_pixel_idx + end_pixel_idx = current_pixel_idx + batch_pixel_count + current_pixel_idx = end_pixel_idx + img_start = img_end = None + else: + start_pixel_idx = cum_rows[start].item() + end_pixel_idx = cum_rows[end].item() + img_start = cum_imgs[start].item() + img_end = cum_imgs[end].item() + grid_slice = image_grid_thw[img_start:img_end] image_grid_thw_chunks.append(grid_slice) - batch_pixel_count = grid_slice.prod(dim = -1).sum().item() - - start_pixel_idx = current_pixel_idx - end_pixel_idx = current_pixel_idx + batch_pixel_count - pixel_values_chunks.append( pixel_values[start_pixel_idx:end_pixel_idx] ) - if pixel_attention_mask is not None: + if image_sizes is None: + image_sizes_chunks.append(None) + elif ( + num_images is not None + and _image_sizes_n == total_images + and img_start is not None + ): + image_sizes_chunks.append(image_sizes[img_start:img_end]) + else: + image_sizes_chunks.append( + slice_sample_axis(image_sizes, start, end) + ) + + if pixel_attention_mask is None: + pixel_attention_mask_chunks.append(None) + elif ( + num_images is not None + and img_start is not None + and pixel_attention_mask.shape[0] == image_grid_thw.shape[0] + ): + pixel_attention_mask_chunks.append( + pixel_attention_mask[img_start:img_end] + ) + elif ( + pixel_attention_mask.shape[0] == pixel_values.shape[0] + and pixel_attention_mask.shape[0] != input_ids.shape[0] + ): pixel_attention_mask_chunks.append( pixel_attention_mask[start_pixel_idx:end_pixel_idx] ) else: - pixel_attention_mask_chunks.append(None) - - current_pixel_idx = end_pixel_idx + pixel_attention_mask_chunks.append( + pixel_attention_mask[start:end] + ) else: pixel_values_chunks.append(None) image_grid_thw_chunks.append(None) pixel_attention_mask_chunks.append(None) - - if image_sizes is not None and not isinstance(image_sizes, torch.Tensor): - image_sizes_chunks = [[size] for size in image_sizes] - else: - image_sizes_chunks = chunk_optional(image_sizes, B) + image_sizes_chunks.append( + slice_sample_axis(image_sizes, start, end) + ) temperature = self.temperature logit_softcapping = _unsloth_get_final_logit_softcapping(model.config) @@ -1167,10 +1240,6 @@ def chunk_optional(tensor, chunks): if logit_scale_divide is None: logit_scale_divide = 0 - # Transformers 5.x needs token_type_ids/mm_token_type_ids for some vision models - token_type_ids_chunks = chunk_optional(token_type_ids, B) - mm_token_type_ids_chunks = chunk_optional(mm_token_type_ids, B) - zipped_inputs = zip( input_ids_chunks, attention_mask_chunks, @@ -1375,6 +1444,7 @@ def compute_loss( inputs.get("pixel_attention_mask", None), inputs.get("image_sizes", None), ) + num_images = inputs.get("num_images", None) # Transformers 5.x needs token_type_ids/mm_token_type_ids for some vision models token_type_ids = inputs.get("token_type_ids", None) mm_token_type_ids = inputs.get("mm_token_type_ids", None) @@ -1490,6 +1560,35 @@ def compute_loss( num_processes = num_processes, ) else: + + def _unsloth_requires_multi_image_zoo(value): + if value is None: + return False + if isinstance(value, torch.Tensor): + counts = value.detach().cpu().reshape(-1).tolist() + else: + counts = list(value) + return any(int(n) != 1 for n in counts) + + if _unsloth_requires_multi_image_zoo(num_images) and not getattr( + self, "_unsloth_grpo_zoo_checked", False + ): + _supports_num_images = ( + "num_images" in inspect.signature(grpo_accumulated_loss).parameters + ) + if not _supports_num_images: + try: + _zoo_src = inspect.getsource(grpo_accumulated_loss) + except (TypeError, OSError): + _zoo_src = "" + _supports_num_images = "num_images" in _zoo_src + if not _supports_num_images: + raise RuntimeError( + "Multi-image GRPO requires an unsloth_zoo build whose " + "grpo_accumulated_loss handles num_images. Please upgrade " + "unsloth_zoo (see https://github.com/unslothai/unsloth-zoo/pull/613)." + ) + self._unsloth_grpo_zoo_checked = True if hasattr(self.args, "loss_type"): ( loss, @@ -1504,6 +1603,9 @@ def compute_loss( input_ids = _input_ids, pixel_values = pixel_values, image_grid_thw = image_grid_thw, + pixel_attention_mask = pixel_attention_mask, + image_sizes = image_sizes, + num_images = num_images, logits_to_keep = logits_to_keep, completion_mask = completion_mask, advantages = advantages, @@ -1535,6 +1637,11 @@ def compute_loss( grpo_accumulated_loss( trainer = self, input_ids = _input_ids, + pixel_values = pixel_values, + image_grid_thw = image_grid_thw, + pixel_attention_mask = pixel_attention_mask, + image_sizes = image_sizes, + num_images = num_images, logits_to_keep = logits_to_keep, completion_mask = completion_mask, advantages = advantages,