Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 64 additions & 28 deletions unsloth/models/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,7 @@ def _get_per_token_logps_and_entropies(
kwargs.get("pixel_attention_mask", None),
kwargs.get("image_sizes", None),
)
num_images = kwargs.get("num_images", None)
# Transformers 5.x needs token_type_ids/mm_token_type_ids for some vision models
token_type_ids = kwargs.get("token_type_ids", None)
mm_token_type_ids = kwargs.get("mm_token_type_ids", None)
Expand Down Expand Up @@ -1099,65 +1100,95 @@ def _get_per_token_logps_and_entropies(
else:
max_left_pad = 0

# input_ids_chunks = torch.chunk(input_ids, chunks = B, dim = 0)
attention_mask_chunks = torch.chunk(attention_mask, chunks = B, dim = 0)

def chunk_optional(tensor, chunks):
if tensor is None:
return [None] * chunks
return torch.chunk(tensor, chunks = chunks, dim = 0)
def slice_sample_axis(value, start, end):
if value is None:
return None
return value[start:end]

import math

total_samples = input_ids.shape[0]
batch_size = math.ceil(total_samples / B)
if isinstance(num_images, torch.Tensor):
num_images = num_images.detach().cpu().reshape(-1).tolist()
if (
image_grid_thw is not None
and pixel_values is not None
and num_images is not None
):
rows_per_image = image_grid_thw.prod(dim = -1)
rows_per_sample = torch.split(rows_per_image, num_images)
rows_per_sample = torch.stack([s.sum() for s in rows_per_sample])
cum_rows = torch.cat(
[
torch.tensor([0], device = rows_per_sample.device),
rows_per_sample.cumsum(0),
]
)
cum_imgs = torch.tensor([0] + num_images).cumsum(0)
else:
cum_rows = None
cum_imgs = None

input_ids_chunks = []
attention_mask_chunks = []
pixel_values_chunks = []
image_grid_thw_chunks = []
pixel_attention_mask_chunks = []
image_sizes_chunks = []
token_type_ids_chunks = []
mm_token_type_ids_chunks = []

current_pixel_idx = 0
# TRL 0.23.0 batching logic
for start in range(0, total_samples, batch_size):
end = start + batch_size
end = min(start + batch_size, total_samples)

input_ids_chunks.append(input_ids[start:end])
attention_mask_chunks.append(attention_mask[start:end])
image_sizes_chunks.append(slice_sample_axis(image_sizes, start, end))
token_type_ids_chunks.append(
slice_sample_axis(token_type_ids, start, end)
)
mm_token_type_ids_chunks.append(
slice_sample_axis(mm_token_type_ids, start, end)
)

if image_grid_thw is not None and pixel_values is not None:
grid_slice = image_grid_thw[start:end]
if num_images is None:
grid_slice = image_grid_thw[start:end]
batch_pixel_count = grid_slice.prod(dim = -1).sum().item()
start_pixel_idx = current_pixel_idx
end_pixel_idx = current_pixel_idx + batch_pixel_count
current_pixel_idx = end_pixel_idx
else:
start_pixel_idx = cum_rows[start].item()
end_pixel_idx = cum_rows[end].item()
img_start, img_end = cum_imgs[start], cum_imgs[end]
grid_slice = image_grid_thw[img_start:img_end]
image_grid_thw_chunks.append(grid_slice)

batch_pixel_count = grid_slice.prod(dim = -1).sum().item()

start_pixel_idx = current_pixel_idx
end_pixel_idx = current_pixel_idx + batch_pixel_count

pixel_values_chunks.append(
pixel_values[start_pixel_idx:end_pixel_idx]
)

if pixel_attention_mask is not None:
pixel_attention_mask_chunks.append(
pixel_attention_mask[start_pixel_idx:end_pixel_idx]
)
if pixel_attention_mask.shape[0] == pixel_values.shape[0]:
pixel_attention_mask_chunks.append(
pixel_attention_mask[start_pixel_idx:end_pixel_idx]
)
else:
pixel_attention_mask_chunks.append(
pixel_attention_mask[start:end]
)
else:
pixel_attention_mask_chunks.append(None)

current_pixel_idx = end_pixel_idx

else:
pixel_values_chunks.append(None)
image_grid_thw_chunks.append(None)
pixel_attention_mask_chunks.append(None)

if image_sizes is not None and not isinstance(image_sizes, torch.Tensor):
image_sizes_chunks = [[size] for size in image_sizes]
else:
image_sizes_chunks = chunk_optional(image_sizes, B)

temperature = self.temperature
logit_softcapping = _unsloth_get_final_logit_softcapping(model.config)
logit_scale_multiply = getattr(model.config, "logit_scale", 0)
Expand All @@ -1167,10 +1198,6 @@ def chunk_optional(tensor, chunks):
if logit_scale_divide is None:
logit_scale_divide = 0

# Transformers 5.x needs token_type_ids/mm_token_type_ids for some vision models
token_type_ids_chunks = chunk_optional(token_type_ids, B)
mm_token_type_ids_chunks = chunk_optional(mm_token_type_ids, B)

zipped_inputs = zip(
input_ids_chunks,
attention_mask_chunks,
Expand Down Expand Up @@ -1375,6 +1402,7 @@ def compute_loss(
inputs.get("pixel_attention_mask", None),
inputs.get("image_sizes", None),
)
num_images = inputs.get("num_images", None)
# Transformers 5.x needs token_type_ids/mm_token_type_ids for some vision models
token_type_ids = inputs.get("token_type_ids", None)
mm_token_type_ids = inputs.get("mm_token_type_ids", None)
Expand Down Expand Up @@ -1504,6 +1532,9 @@ def compute_loss(
input_ids = _input_ids,
pixel_values = pixel_values,
image_grid_thw = image_grid_thw,
pixel_attention_mask = pixel_attention_mask,
image_sizes = image_sizes,
num_images = num_images,
logits_to_keep = logits_to_keep,
completion_mask = completion_mask,
advantages = advantages,
Expand Down Expand Up @@ -1535,6 +1566,11 @@ def compute_loss(
grpo_accumulated_loss(
trainer = self,
input_ids = _input_ids,
pixel_values = pixel_values,
image_grid_thw = image_grid_thw,
pixel_attention_mask = pixel_attention_mask,
image_sizes = image_sizes,
num_images = num_images,
logits_to_keep = logits_to_keep,
completion_mask = completion_mask,
advantages = advantages,
Expand Down
Loading