diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index e5e40aaf3..29826d3f7 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -26,6 +26,41 @@ def _pil_to_data_url(img: Image.Image, fmt: str | None = None) -> str: b64 = base64.b64encode(buf.getvalue()).decode("utf-8") return f"data:image/{fmt.lower()};base64,{b64}" +def format_oai_chat_msg( + prompts: List[List[Dict[str, Any]]], + images: List[List[Image.Image]] +) -> List[List[Dict[str, Any]]]: + formatted_conversations: List[List[Dict[str, Any]]] = [] + + for conv_prompts, conv_images in zip(prompts, images): + img_iter = iter(conv_images) + new_conv: List[Dict[str, Any]] = [] + + for msg in conv_prompts: + role = msg["role"] + content = msg["content"] + + if isinstance(content, list): + new_parts: List[Dict[str, Any]] = [] + for part in content: + if part.get("type") == "image": + img = next(img_iter) + data_url = _pil_to_data_url(img) + new_parts.append({ + "type": "image_url", + "image_url": {"url": data_url} + }) + else: + new_parts.append(part.copy()) + new_conv.append({"role": role, "content": new_parts}) + + else: + new_conv.append({"role": role, "content": content}) + + formatted_conversations.append(new_conv) + + return formatted_conversations + class Environment(ABC): """ Base class for all environments. @@ -324,8 +359,12 @@ def generate(self, results = {col: deepcopy(inputs[col]) for col in inputs.column_names} else: results = deepcopy(inputs) + if results.get('images') is not None: + prompts = format_oai_chat_msg(results['prompt'], results['images']) + else: + prompts = results['prompt'] rollouts = self.run_rollouts( - prompts=results['prompt'], + prompts=prompts, client=client, model=model, sampling_args=gen_sampling_args, @@ -352,6 +391,7 @@ def generate(self, def process_chat_format( self, prompt: List[Dict[str, str]], + images: Optional[List[List[Any]]], completion: List[Dict[str, str]], processing_class: PreTrainedTokenizerBase, mask_env_responses: bool = False @@ -367,59 +407,96 @@ def process_chat_format( Returns: prompt_ids, prompt_mask, completion_ids, completion_mask """ - # tokenize just the prompt - prompt_text = processing_class.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) - assert isinstance(prompt_text, str) - if hasattr(processing_class, "tokenizer"): - encode = processing_class.tokenizer.encode - else: - encode = processing_class.encode - prompt_ids = encode(prompt_text) - prompt_mask = [1] * len(prompt_ids) - - # track completion tokens and masks by processing incrementally completion_ids = [] completion_mask = [] - - # previous tokenization (starts with just prompt) - prev_ids = prompt_ids - - # process each completion message incrementally - for i, msg in enumerate(completion): - # create conversation prefix: prompt + completion[:i+1] - conversation_prefix = prompt + completion[:i+1] + remaining_inputs = {} + if images: + prompt_text = processing_class.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) + inputs = processing_class(text=prompt_text, images=images, return_tensors="pt") + remaining_inputs = { + k: v + for k, v in inputs.items() + if k not in ["input_ids", "attention_mask"] + } + prev_ids = inputs.input_ids[0].tolist() + prompt_ids = prev_ids + prompt_mask = [1] * len(prompt_ids) + + for i, msg in enumerate(completion): + conversation_prefix = prompt + completion[:i+1] + prefix_text = processing_class.apply_chat_template( + conversation_prefix, + tokenize=False, + add_generation_prompt=False, + ) + current_ids = processing_class(text=prefix_text, images=images, return_tensors="pt").input_ids[0].tolist() + assert current_ids[:len(prev_ids)] == prev_ids, "Tokenization difference in chat format." + new_tokens = current_ids[len(prev_ids):] + completion_ids.extend(new_tokens) + + if msg["role"] == "assistant": + msg_mask = [1] * len(new_tokens) + elif msg["role"] != "assistant" and mask_env_responses: + msg_mask = [0] * len(new_tokens) + else: + msg_mask = [1] * len(new_tokens) + + completion_mask.extend(msg_mask) + prev_ids = current_ids + else: + # tokenize just the prompt + prompt_text = processing_class.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) + assert isinstance(prompt_text, str) + if hasattr(processing_class, "tokenizer"): + encode = processing_class.tokenizer.encode + else: + encode = processing_class.encode + prompt_ids = encode(prompt_text) + prompt_mask = [1] * len(prompt_ids) - # tokenize the full prefix - prefix_text = processing_class.apply_chat_template( - conversation_prefix, - tokenize=False, - add_generation_prompt=False, - ) - assert isinstance(prefix_text, str), f"Expected string from apply_chat_template, got {type(prefix_text)}" - current_ids = encode(prefix_text) - assert current_ids[:len(prev_ids)] == prev_ids, f"Tokenization difference in chat format. Current ids: {current_ids}, previous ids: {prev_ids}" + # track completion tokens and masks by processing incrementally + completion_ids = [] + completion_mask = [] - # add new tokens to completion tokens - new_tokens = current_ids[len(prev_ids):] - assert len(new_tokens) > 0, f"No new tokens in chat format. Current ids: {current_ids}, previous ids: {prev_ids}" - completion_ids.extend(new_tokens) - - # create mask - if msg["role"] == "assistant": - msg_mask = [1] * len(new_tokens) - elif msg["role"] != "assistant" and mask_env_responses: - # mask intermediate 'user' and/or 'tool' messages - msg_mask = [0] * len(new_tokens) - else: - # default to not masking - msg_mask = [1] * len(new_tokens) + # previous tokenization (starts with just prompt) + prev_ids = prompt_ids - completion_mask.extend(msg_mask) - # Update previous tokenization for next iteration - prev_ids = current_ids - assert len(completion_ids) == len(completion_mask), f"Length mismatch in chat format. Completion ids: {completion_ids}, completion mask: {completion_mask}" + # process each completion message incrementally + for i, msg in enumerate(completion): + # create conversation prefix: prompt + completion[:i+1] + conversation_prefix = prompt + completion[:i+1] + + # tokenize the full prefix + prefix_text = processing_class.apply_chat_template( + conversation_prefix, + tokenize=False, + add_generation_prompt=False, + ) + assert isinstance(prefix_text, str), f"Expected string from apply_chat_template, got {type(prefix_text)}" + current_ids = encode(prefix_text) + assert current_ids[:len(prev_ids)] == prev_ids, f"Tokenization difference in chat format. Current ids: {current_ids}, previous ids: {prev_ids}" + + # add new tokens to completion tokens + new_tokens = current_ids[len(prev_ids):] + assert len(new_tokens) > 0, f"No new tokens in chat format. Current ids: {current_ids}, previous ids: {prev_ids}" + completion_ids.extend(new_tokens) - return prompt_ids, prompt_mask, completion_ids, completion_mask + # create mask + if msg["role"] == "assistant": + msg_mask = [1] * len(new_tokens) + elif msg["role"] != "assistant" and mask_env_responses: + # mask intermediate 'user' and/or 'tool' messages + msg_mask = [0] * len(new_tokens) + else: + # default to not masking + msg_mask = [1] * len(new_tokens) + + completion_mask.extend(msg_mask) + # Update previous tokenization for next iteration + prev_ids = current_ids + assert len(completion_ids) == len(completion_mask), f"Length mismatch in chat format. Completion ids: {completion_ids}, completion mask: {completion_mask}" + + return prompt_ids, prompt_mask, completion_ids, completion_mask, remaining_inputs def process_completion_format( self, @@ -455,6 +532,7 @@ def process_completion_format( def process_env_results( self, prompts: List[Union[str, List[Dict[str, Any]]]], + images: Optional[List[List[Any]]], completions: List[Union[str, List[Dict[str, Any]]]], states: List[Dict[str, Any]], rewards: List[float], @@ -478,13 +556,16 @@ def process_env_results( all_prompt_masks = [] all_completion_ids = [] all_completion_masks = [] + all_remaining_inputs = [] + + images = images or [None] * len(prompts) - for i, (prompt, completion, state, reward) in enumerate(zip(prompts, completions, states, rewards)): + for i, (prompt, images, completion, state, reward) in enumerate(zip(prompts, images, completions, states, rewards)): # Format-specific processing if is_chat_format: assert isinstance(prompt, list) and isinstance(completion, list) - prompt_ids, prompt_mask, completion_ids, completion_mask = self.process_chat_format( - prompt, completion, processing_class, mask_env_responses + prompt_ids, prompt_mask, completion_ids, completion_mask, remaining_inputs = self.process_chat_format( + prompt, images, completion, processing_class, mask_env_responses ) else: assert isinstance(prompt, str) and isinstance(completion, str) @@ -498,6 +579,7 @@ def process_env_results( all_prompt_masks.append(prompt_mask) all_completion_ids.append(completion_ids) all_completion_masks.append(completion_mask) + all_remaining_inputs.append(remaining_inputs) return { "prompt_ids": all_prompt_ids, @@ -505,6 +587,7 @@ def process_env_results( "completion_ids": all_completion_ids, "completion_mask": all_completion_masks, "rewards": rewards, + "remaining_inputs": all_remaining_inputs, } # Evaluation and dataset generation diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py index 4da54c15f..e05ec6d0a 100644 --- a/verifiers/examples/docvqa.py +++ b/verifiers/examples/docvqa.py @@ -3,7 +3,7 @@ from datasets import load_dataset import verifiers as vf -from qwen_vl_utils import smart_resize +from qwen_vl_utils import process_vision_info """ # install qwen stuff @@ -15,18 +15,33 @@ """ -def preprocess_docvqa(x): - return { - "question": x["question"], - "images": [x["image"].resize(smart_resize(768, 1024))], # XGA - "answer": x["answers"][0], - } +def data_collator(batch: list[dict]) -> list[dict]: + processed_samples = [] + for sample in batch: + messages = [] + messages.append({"role": "system", "content": system_prompt}) + content_block = [] + content_block.append({"type": "text", "text": sample["question"]}) + content_block.append( + { + "type": "image", + "image": sample["image"], # only one image in this ds + "resized_height": 768, # XGA resolution + "resized_width": 1024, + } + ) + messages.append({"role": "user", "content": content_block}) + processed_images, *_ = process_vision_info( # process with qwen utils + messages.copy() + ) + sample["prompt"] = messages + sample["images"] = processed_images + sample["answer"] = sample["answers"] + processed_samples.append(sample) + return processed_samples dataset = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation") -dataset = dataset.map( - preprocess_docvqa, num_proc=10, remove_columns=dataset.column_names -) parser = vf.XMLParser(["think", "answer"], answer_field="answer") system_prompt = f"""Answer the questions. @@ -97,5 +112,6 @@ def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None: processing_class=processor, env=vf_env, args=training_args, + data_collator=data_collator, ) trainer.train() diff --git a/verifiers/trainers/async_batch_generator.py b/verifiers/trainers/async_batch_generator.py index 28946cd5e..6e73af494 100644 --- a/verifiers/trainers/async_batch_generator.py +++ b/verifiers/trainers/async_batch_generator.py @@ -11,7 +11,7 @@ class BatchRequest: """Request for batch generation""" batch_id: int - env_inputs: Dict[str, List[Any]] + env_inputs: Dict[str, List[Any] | None] processing_class: Any mask_env_responses: bool max_completion_length: int @@ -239,6 +239,7 @@ def _generate_batch(self, request: BatchRequest) -> BatchResult: # Process results processed_results = self.env.process_env_results( env_results['prompt'], + env_results['images'], env_results['completion'], env_results['state'], env_results['reward'], diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 8d2373572..c33146547 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -68,53 +68,53 @@ def nanstd(tensor: torch.Tensor) -> torch.Tensor: variance *= count / (count - 1) # Bessel's correction return torch.sqrt(variance) -def split_tensor_dict( - tensor_dict: dict[str, Optional[torch.Tensor]], num_chunks: int -) -> list[dict[str, Optional[torch.Tensor]]]: +def shuffle_data_dict(data_dict: dict[str, Any]) -> dict[str, Any]: """ - Splits a dictionary of tensors along the first dimension into `num_chunks` equal parts. - - Example: - >>> x = torch.arange(12).reshape(6, 2) - >>> y = torch.arange(6).reshape(6, 1) - >>> tensor_dict = {"x": x, "y": y} - >>> split_tensor_dict(tensor_dict, 3) - [ - {"x": tensor([[0, 1], [2, 3]]), "y": tensor([[0], [1]])}, - {"x": tensor([[4, 5], [6, 7]]), "y": tensor([[2], [3]])}, - {"x": tensor([[ 8, 9], [10, 11]]), "y": tensor([[4], [5]])} - ] + Shuffles a dictionary of tensors or lists along the first dimension in unison. """ - first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None) - chunk_size = first_tensor.shape[0] // num_chunks - return [ - { - key: tensor[i * chunk_size : (i + 1) * chunk_size] if tensor is not None else None - for key, tensor in tensor_dict.items() - } - for i in range(num_chunks) - ] + first_item = next(item for item in data_dict.values() if item is not None) + batch_size = len(first_item) + permutation = torch.randperm(batch_size) + + shuffled_dict = {} + for key, value in data_dict.items(): + if value is None: + shuffled_dict[key] = None + elif isinstance(value, torch.Tensor): + shuffled_dict[key] = value[permutation] + elif isinstance(value, list): + shuffled_dict[key] = [value[i] for i in permutation] + else: + raise TypeError(f"Unsupported type for shuffling: {type(value)}") + return shuffled_dict -def shuffle_tensor_dict(tensor_dict: dict[str, Optional[torch.Tensor]]) -> dict[str, Optional[torch.Tensor]]: +def split_data_dict( + data_dict: dict[str, Any], num_chunks: int +) -> list[dict[str, Any]]: """ - Shuffles a dictionary of tensors along the first dimension in unison. - - Example: - >>> x = torch.arange(6).reshape(3, 2) - >>> y = torch.arange(3).reshape(3, 1) - >>> tensor_dict = {"x": x, "y": y} - >>> shuffle_tensor_dict(tensor_dict) - {'x': tensor([[2, 3], - [0, 1], - [4, 5]]), - 'y': tensor([[1], - [0], - [2]])} + Splits a dictionary of tensors or lists along the first dimension into `num_chunks` equal parts. """ - first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None) - batch_size = first_tensor.shape[0] - permutation = torch.randperm(batch_size) - return {key: tensor[permutation] if tensor is not None else None for key, tensor in tensor_dict.items()} + first_item = next(item for item in data_dict.values() if item is not None) + # Ensure chunk_size is an integer + chunk_size = len(first_item) // num_chunks + if len(first_item) % num_chunks != 0: + logging.warning( + f"The total number of samples ({len(first_item)}) is not divisible by the number of chunks ({num_chunks}). " + f"The last {len(first_item) % num_chunks} samples will be dropped." + ) + + chunked_list = [] + for i in range(num_chunks): + chunk = {} + start_idx = i * chunk_size + end_idx = (i + 1) * chunk_size + for key, value in data_dict.items(): + if value is None: + chunk[key] = None + else: + chunk[key] = value[start_idx:end_idx] + chunked_list.append(chunk) + return chunked_list def nanmin(tensor: torch.Tensor) -> torch.Tensor: """ @@ -155,6 +155,7 @@ def __init__( callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), peft_config: Optional[PeftConfig] = None, + data_collator: Optional[Any] = None, **kwargs, ): self.logger = logging.getLogger(__name__) @@ -245,12 +246,12 @@ def filter_by_prompt_length(example): self.logger.info(f"Filtered dataset from {original_size} to {filtered_size} examples ({original_size - filtered_size} prompts were too long)") # dummy data collator - def data_collator(features): + def default_data_collator(features): return features super().__init__( model=model, args=args, - data_collator=data_collator, + data_collator=data_collator if data_collator is not None else default_data_collator, train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, @@ -471,15 +472,25 @@ def _get_last_hidden_state(self, unwrapped_model, input_ids, attention_mask, log def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, batch_size=None, **model_kwargs) -> torch.Tensor: batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak all_logps = [] - if _accepts_logits_to_keep(model): - model_kwargs["logits_to_keep"] = logits_to_keep + 1 + accepts_logits_to_keep = _accepts_logits_to_keep(model) for i in range(0, input_ids.size(0), batch_size): input_ids_batch = input_ids[i : i + batch_size] attention_mask_batch = attention_mask[i : i + batch_size] - + model_kwargs_batch = {} + for key, value in model_kwargs.items(): + if isinstance(value, list): + # 1. Slice the list to get the tensors for this micro-batch + sub_list = value[i : i + batch_size] + # 2. Batch the tensors in the sub-list together + model_kwargs_batch[key] = torch.cat(sub_list, dim=0).to(self.accelerator.device) + else: + # Handle non-list arguments (like the 'logits_to_keep' we added) + model_kwargs_batch[key] = value + if accepts_logits_to_keep: + model_kwargs_batch["logits_to_keep"] = logits_to_keep + 1 # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded logits = model( - input_ids=input_ids_batch, attention_mask=attention_mask_batch, **model_kwargs + input_ids=input_ids_batch, attention_mask=attention_mask_batch, **model_kwargs_batch ).logits logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred input_ids_batch = input_ids_batch[:, -logits_to_keep:] @@ -590,7 +601,7 @@ def _ids_to_tensors(self, 'mask': mask } - def _gather_batch_data(self, batch_offset: int = 0) -> Tuple[List[Any], List[Any], List[Any]]: + def _gather_batch_data(self, batch_offset: int = 0) -> Tuple[List[Any], List[Any] | None, List[Any], List[Any]]: """ Gather batch data from all processes. @@ -607,14 +618,17 @@ def _gather_batch_data(self, batch_offset: int = 0) -> Tuple[List[Any], List[Any # Gather batch data from all processes prompts = [x['prompt'] for x in batch] + images = [x['images'] for x in batch if 'images' in x] answers = [x['answer'] for x in batch] tasks = [x.get('task', 'default') for x in batch] all_prompts = gather_object(prompts) + all_images = gather_object(images) + all_images = all_images if all_images != [] else None all_answers = gather_object(answers) all_tasks = gather_object(tasks) - return all_prompts, all_answers, all_tasks + return all_prompts, all_images, all_answers, all_tasks def _prepare_inputs( # type: ignore self, inputs: list[dict[str, Any]] @@ -661,14 +675,14 @@ def _prepare_inputs( # type: ignore for batch_id in range(self._next_batch_id, target_batch_id + 1): batch_offset = batch_id - batch_id_to_retrieve - all_prompts, all_answers, all_tasks = self._gather_batch_data(batch_offset) + all_prompts, all_images, all_answers, all_tasks = self._gather_batch_data(batch_offset) local_batch_size = len(all_prompts) // self.accelerator.num_processes # Submit batch (main process only) if self.accelerator.is_main_process: request = BatchRequest( batch_id=batch_id, - env_inputs={'prompt': all_prompts, 'answer': all_answers, 'task': all_tasks}, + env_inputs={'prompt': all_prompts, 'images': all_images, 'answer': all_answers, 'task': all_tasks}, processing_class=self.processing_class, mask_env_responses=self.mask_env_responses, max_completion_length=self.max_completion_length, @@ -712,6 +726,7 @@ def _prepare_inputs( # type: ignore 'completion_ids': processed_results['completion_ids'], 'completion_mask': processed_results['completion_mask'], 'rewards': processed_results['rewards'], + 'remaining_inputs': processed_results['remaining_inputs'], 'all_reward_dict': batch_result.all_reward_dict if hasattr(batch_result, 'all_reward_dict') else {'reward': processed_results['rewards']}, 'completions': batch_result.completions if hasattr(batch_result, 'completions') else [], 'prompts': batch_result.prompts if hasattr(batch_result, 'prompts') else [], @@ -770,7 +785,10 @@ def _prepare_inputs( # type: ignore # Take this process's slice of advantages advantages = all_advantages[process_slice] - + + # slice remaining inputs + remaining_inputs = broadcast_data['remaining_inputs'][process_slice] + # Log metrics on main process only if self.accelerator.is_main_process: self._log_reward_metrics_primary( @@ -793,7 +811,7 @@ def _prepare_inputs( # type: ignore all_completion_ids=broadcast_data['completion_ids'], all_prompt_mask=broadcast_data['prompt_mask'] ) - + # Concatenate all data for shuffling full_batch = { "prompt_ids": prompt_ids, @@ -802,11 +820,12 @@ def _prepare_inputs( # type: ignore "completion_mask": completion_mask, "old_per_token_logps": None, "advantages": advantages, + "remaining_inputs": remaining_inputs, } # Shuffle and split for gradient accumulation - full_batch = shuffle_tensor_dict(full_batch) - self._buffered_inputs = split_tensor_dict(full_batch, self.gradient_accumulation_steps) + full_batch = shuffle_data_dict(full_batch) + self._buffered_inputs = split_data_dict(full_batch, self.gradient_accumulation_steps) self.accelerator.wait_for_everyone() # Return appropriate slice from buffer result = self._buffered_inputs[self._step % self.gradient_accumulation_steps] @@ -843,18 +862,12 @@ def compute_loss(self, # Compute the per-token log probabilities for the model prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + model_kwargs = inputs["remaining_inputs"] input_ids = torch.cat([prompt_ids, completion_ids], dim=1) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - model_kwarg_keys = ( - inspect.signature(model.forward).parameters.keys() - if not hasattr(model, "get_base_model") - else inspect.signature( - model.get_base_model().forward - ).parameters.keys() - ) - model_kwargs = {k: inputs[k] for k in model_kwarg_keys if k in inputs} - per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep, **model_kwargs) + model_kwargs = {key: [d[key] for d in model_kwargs] for key in model_kwargs[0].keys()} + per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep, batch_size = 1 if model_kwargs != {} else None, **model_kwargs) # Compute the loss advantages = inputs["advantages"]