diff --git a/pyproject.toml b/pyproject.toml index cd3ba6399..6e0c5e5a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dependencies = [ "requests>=2.32.3", "openai>=1.81.0", "datasets>=3.6.0", - "transformers", + "transformers<4.52.0", "nest-asyncio>=1.6.0", ] diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 139935d4d..97ba296d4 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -3,7 +3,9 @@ from asyncio import Semaphore from abc import ABC, abstractmethod from copy import deepcopy -from typing import Any, Dict, List, Literal, Tuple, Optional, Union +from typing import Any, Dict, List, Literal, Tuple, Optional, Union, Callable +import base64 +import io import concurrent.futures from transformers.tokenization_utils_base import PreTrainedTokenizerBase @@ -11,10 +13,54 @@ from datasets import Dataset from openai import OpenAI +from PIL import Image + from verifiers import RewardFunc from verifiers.parsers import Parser from verifiers.rubrics import Rubric +def _pil_to_data_url(img: Image.Image, fmt: str | None = None) -> str: + buf = io.BytesIO() + fmt = (fmt or img.format or "PNG").upper() + img.save(buf, format=fmt) + b64 = base64.b64encode(buf.getvalue()).decode("utf-8") + return f"data:image/{fmt.lower()};base64,{b64}" + +def format_oai_chat_msg( + prompts: List[List[Dict[str, Any]]], + images: List[List[Image.Image]] +) -> List[Any]: + formatted_conversations = [] + + for conv_prompts, conv_images in zip(prompts, images): + img_iter = iter(conv_images) + new_conv = [] + + for msg in conv_prompts: + role = msg["role"] + content = msg["content"] + + if isinstance(content, list): + new_parts = [] + for part in content: + if part.get("type") == "image": + img = next(img_iter) + data_url = _pil_to_data_url(img) + new_parts.append({ + "type": "image_url", + "image_url": {"url": data_url} + }) + else: + new_parts.append(part.copy()) + new_conv.append({"role": role, "content": new_parts}) + + else: + new_conv.append({"role": role, "content": content}) + + formatted_conversations.append(new_conv) + + return formatted_conversations + class Environment(ABC): """ Base class for all environments. @@ -31,6 +77,7 @@ def __init__(self, sampling_args: Dict[str, Any] = {}, max_concurrent: int = 128, message_type: Literal['chat', 'completion'] = 'chat', + data_collator: Callable | None = None, **kwargs: Any): self.client = client self.model = model @@ -38,7 +85,8 @@ def __init__(self, self.system_prompt = system_prompt self.few_shot = few_shot self.max_concurrent = max_concurrent - + self.data_collator = data_collator + # Ensure asyncio.to_thread doesn't hit default 32 thread limit try: loop = asyncio.get_running_loop() @@ -66,6 +114,13 @@ def __init__(self, ) self.dataset = dataset self.eval_dataset = eval_dataset + if self.data_collator is not None and self.eval_dataset is not None: + processed_dataset = self.data_collator(list(self.eval_dataset)) + if not processed_dataset: + self.eval_dataset = {} + else: + keys = processed_dataset[0].keys() + self.eval_dataset = {key: [sample.get(key) for sample in processed_dataset] for key in keys} self.parser = parser self.rubric = rubric self.sampling_args = { @@ -134,7 +189,7 @@ def get_dataset(self, n: int = -1, seed: int = 0, **kwargs: Any) -> Dataset | No return self.dataset.shuffle(seed=seed).select(range(n)) # type: ignore return self.dataset - def get_eval_dataset(self, n: int = -1, seed: int = 0, **kwargs: Any) -> Dataset | None: + def get_eval_dataset(self, n: int = -1, seed: int = 0, **kwargs: Any) -> Dataset | dict[Any, list[Any]] | None: if n > 0 and self.eval_dataset is not None: return self.eval_dataset.shuffle(seed=seed).select(range(n)) # type: ignore return self.eval_dataset @@ -346,8 +401,13 @@ def generate(self, results['task'] = ['default'] * len(results['prompt']) if 'info' not in results: results['info'] = [{}] * len(results['prompt']) + + if results.get('images') is not None: + prompts = format_oai_chat_msg(results['prompt'], results['images']) + else: + prompts = results['prompt'] rollouts = self.run_rollouts( - prompts=results['prompt'], + prompts=prompts, answers=results['answer'], tasks=results['task'], infos=results['info'], @@ -378,10 +438,11 @@ def generate(self, def process_chat_format( self, prompt: List[Dict[str, str]], + images: Optional[List[List[Any]]], completion: List[Dict[str, str]], - processing_class: PreTrainedTokenizerBase, + processing_class: Any, mask_env_responses: bool = False - ) -> Tuple[List[int], List[int], List[int], List[int]]: + ) -> Tuple[List[int], List[int], List[int], List[int], dict[str, Any]]: """ Process chat format conversations using incremental prefixes. @@ -393,55 +454,96 @@ def process_chat_format( Returns: prompt_ids, prompt_mask, completion_ids, completion_mask """ - # tokenize just the prompt - prompt_text = processing_class.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) - assert isinstance(prompt_text, str) - prompt_ids = processing_class.encode(prompt_text) - prompt_mask = [1] * len(prompt_ids) - - # track completion tokens and masks by processing incrementally completion_ids = [] completion_mask = [] - - # previous tokenization (starts with just prompt) - prev_ids = prompt_ids - - # process each completion message incrementally - for i, msg in enumerate(completion): - # create conversation prefix: prompt + completion[:i+1] - conversation_prefix = prompt + completion[:i+1] + remaining_inputs = {} + if images: + assert not isinstance(processing_class, PreTrainedTokenizerBase) + prompt_text = processing_class.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) + assert isinstance(prompt_text, str) + inputs = processing_class(text=prompt_text, images=images, return_tensors="pt") + remaining_inputs = { + k: v + for k, v in inputs.items() + if k not in ["input_ids", "attention_mask"] + } + prev_ids = inputs.input_ids[0].tolist() + prompt_ids = prev_ids + prompt_mask = [1] * len(prompt_ids) + + for i, msg in enumerate(completion): + conversation_prefix = prompt + completion[:i+1] + prefix_text = processing_class.apply_chat_template( + conversation_prefix, + tokenize=False, + add_generation_prompt=False, + ) + assert isinstance(prefix_text, str), f"Expected string from apply_chat_template, got {type(prefix_text)}" + current_ids = processing_class(text=prefix_text, images=images, return_tensors="pt").input_ids[0].tolist() + assert current_ids[:len(prev_ids)-1] == prev_ids[:-1], f"Tokenization difference in chat format. Current ids: {current_ids[:len(prev_ids)-1]}, previous ids: {prev_ids[:-1]}" + new_tokens = current_ids[len(prev_ids):] + completion_ids.extend(new_tokens) + + if msg["role"] == "assistant": + msg_mask = [1] * len(new_tokens) + elif msg["role"] != "assistant" and mask_env_responses: + msg_mask = [0] * len(new_tokens) + else: + msg_mask = [1] * len(new_tokens) + + completion_mask.extend(msg_mask) + prev_ids = current_ids + else: + assert isinstance(processing_class, PreTrainedTokenizerBase) + # tokenize just the prompt + prompt_text = processing_class.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) + assert isinstance(prompt_text, str) + prompt_ids = processing_class.encode(prompt_text) + prompt_mask = [1] * len(prompt_ids) - # tokenize the full prefix - prefix_text = processing_class.apply_chat_template( - conversation_prefix, - tokenize=False, - add_generation_prompt=False, - ) - assert isinstance(prefix_text, str), f"Expected string from apply_chat_template, got {type(prefix_text)}" - current_ids = processing_class.encode(prefix_text) - assert current_ids[:len(prev_ids)-1] == prev_ids[:-1], f"Tokenization difference in chat format. Current ids: {current_ids[:len(prev_ids)-1]}, previous ids: {prev_ids[:-1]}" + # track completion tokens and masks by processing incrementally + completion_ids = [] + completion_mask = [] - # add new tokens to completion tokens - new_tokens = current_ids[len(prev_ids):] - assert len(new_tokens) > 0, f"No new tokens in chat format. Current ids: {current_ids}, previous ids: {prev_ids}" - completion_ids.extend(new_tokens) - - # create mask - if msg["role"] == "assistant": - msg_mask = [1] * len(new_tokens) - elif msg["role"] != "assistant" and mask_env_responses: - # mask intermediate 'user' and/or 'tool' messages - msg_mask = [0] * len(new_tokens) - else: - # default to not masking - msg_mask = [1] * len(new_tokens) + # previous tokenization (starts with just prompt) + prev_ids = prompt_ids - completion_mask.extend(msg_mask) - # update previous tokenization for next iteration - prev_ids = current_ids - assert len(completion_ids) == len(completion_mask), f"Length mismatch in chat format. Completion ids: {completion_ids}, completion mask: {completion_mask}" + # process each completion message incrementally + for i, msg in enumerate(completion): + # create conversation prefix: prompt + completion[:i+1] + conversation_prefix = prompt + completion[:i+1] + + # tokenize the full prefix + prefix_text = processing_class.apply_chat_template( + conversation_prefix, + tokenize=False, + add_generation_prompt=False, + ) + assert isinstance(prefix_text, str), f"Expected string from apply_chat_template, got {type(prefix_text)}" + current_ids = processing_class.encode(prefix_text) + assert current_ids[:len(prev_ids)-1] == prev_ids[:-1], f"Tokenization difference in chat format. Current ids: {current_ids[:len(prev_ids)-1]}, previous ids: {prev_ids[:-1]}" + + # add new tokens to completion tokens + new_tokens = current_ids[len(prev_ids):] + assert len(new_tokens) > 0, f"No new tokens in chat format. Current ids: {current_ids}, previous ids: {prev_ids}" + completion_ids.extend(new_tokens) - return prompt_ids, prompt_mask, completion_ids, completion_mask + # create mask + if msg["role"] == "assistant": + msg_mask = [1] * len(new_tokens) + elif msg["role"] != "assistant" and mask_env_responses: + # mask intermediate 'user' and/or 'tool' messages + msg_mask = [0] * len(new_tokens) + else: + # default to not masking + msg_mask = [1] * len(new_tokens) + + completion_mask.extend(msg_mask) + # update previous tokenization for next iteration + prev_ids = current_ids + assert len(completion_ids) == len(completion_mask), f"Length mismatch in chat format. Completion ids: {completion_ids}, completion mask: {completion_mask}" + + return prompt_ids, prompt_mask, completion_ids, completion_mask, remaining_inputs def process_completion_format( self, @@ -467,16 +569,17 @@ def process_completion_format( # Tokenize completion completion_ids = processing_class.encode(completion) completion_mask = [1] * len(completion_ids) - + return prompt_ids, prompt_mask, completion_ids, completion_mask def process_env_results( self, prompts: List[Union[str, List[Dict[str, Any]]]], + images: Optional[List[List[Any]]], completions: List[Union[str, List[Dict[str, Any]]]], states: List[Dict[str, Any]], rewards: List[float], - processing_class: PreTrainedTokenizerBase, + processing_class: Any, max_completion_length: int = -1, mask_truncated_completions: bool = False, mask_env_responses: bool = False, @@ -496,19 +599,25 @@ def process_env_results( all_prompt_masks = [] all_completion_ids = [] all_completion_masks = [] + all_remaining_inputs = [] + + input_images = images or [None] * len(prompts) - for i, (prompt, completion, state, reward) in enumerate(zip(prompts, completions, states, rewards)): + for i, (prompt, images, completion, state, reward) in enumerate(zip(prompts, input_images, completions, states, rewards)): # Format-specific processing if is_chat_format: assert isinstance(prompt, list) and isinstance(completion, list) - prompt_ids, prompt_mask, completion_ids, completion_mask = self.process_chat_format( - prompt, completion, processing_class, mask_env_responses + prompt_ids, prompt_mask, completion_ids, completion_mask, remaining_inputs = self.process_chat_format( + prompt, images, completion, processing_class, mask_env_responses ) else: + if images is not None: + raise NotImplementedError("Multi-modal training is not supported with completion formats yet") assert isinstance(prompt, str) and isinstance(completion, str) prompt_ids, prompt_mask, completion_ids, completion_mask = self.process_completion_format( prompt, completion, processing_class ) + remaining_inputs = [None] * len(prompt_ids) if mask_truncated_completions and max_completion_length > 0 and len(completion_ids) > max_completion_length: completion_ids = completion_ids[:max_completion_length] completion_mask = [0] * len(completion_ids) @@ -516,6 +625,7 @@ def process_env_results( all_prompt_masks.append(prompt_mask) all_completion_ids.append(completion_ids) all_completion_masks.append(completion_mask) + all_remaining_inputs.append(remaining_inputs) return { "prompt_ids": all_prompt_ids, @@ -523,6 +633,7 @@ def process_env_results( "completion_ids": all_completion_ids, "completion_mask": all_completion_masks, "rewards": rewards, + "remaining_inputs": all_remaining_inputs, } # Evaluation and dataset generation @@ -552,7 +663,10 @@ def evaluate(self, else: inputs = self.eval_dataset if num_samples > 0: - inputs = inputs.select(range(num_samples)) + if isinstance(inputs, dict): + inputs = {key: value_list[:num_samples] for key, value_list in inputs.items()} + elif isinstance(inputs, Dataset): + inputs = inputs.select(range(num_samples)) results = self.generate( inputs, client, model, sampling_args, max_concurrent, **kwargs diff --git a/verifiers/examples/docvqa.py b/verifiers/examples/docvqa.py new file mode 100644 index 000000000..6b4c8da29 --- /dev/null +++ b/verifiers/examples/docvqa.py @@ -0,0 +1,130 @@ +import re + +from datasets import load_dataset +from qwen_vl_utils import process_vision_info + +import verifiers as vf + +""" +# install qwen stuff +uv pip install qwen-vl-utils +# inference +CUDA_VISIBLE_DEVICES=0,1,2,3 vf-vllm --model 'Qwen/Qwen2.5-VL-7B-Instruct' --max-model-len 32000 --tensor_parallel_size 4 +# train +CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch --config-file configs/zero3.yaml --num-processes 4 verifiers/examples/docvqa.py +""" + + +def data_collator(batch: list[dict]) -> list[dict]: + processed_samples = [] + for sample in batch: + messages = [] + messages.append({"role": "system", "content": system_prompt}) + content_block = [] + content_block.append({"type": "text", "text": sample["question"]}) + content_block.append( + { + "type": "image", + "image": sample["image"], # only one image in this ds + "resized_height": 768, # XGA resolution + "resized_width": 1024, + } + ) + messages.append({"role": "user", "content": content_block}) + processed_images, *_ = process_vision_info( # process with qwen utils + messages.copy() + ) + sample["prompt"] = messages + sample["images"] = processed_images + sample["answer"] = sample["answers"] + processed_samples.append(sample) + return processed_samples + + +dataset = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation[10%:]") +eval_dataset = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation[:10%]") + +parser = vf.XMLParser(["think", "answer"], answer_field="answer") +system_prompt = f"""Answer the questions. + +Respond in the following format: +{parser.get_format_str()}""" + + +def correctness_reward_func(completion: list[dict[str, str]], **kwargs) -> float: + def get_assistant_messages(messages: list[dict[str, str]]) -> list[dict[str, str]]: + return [msg for msg in messages if msg.get("role") == "assistant"] + + def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None: + pattern = rf"<{tag}>\s*(.*?)\s*" + match = re.search(pattern, text, re.DOTALL) + if match: + content = match.group(1) + return content.strip() if strip else content + return None + + assistant_messages = get_assistant_messages(completion) + if assistant_messages is None: + return 0.0 + msgs_scores = [] + for msg in assistant_messages: + content = msg.get("content", "") + answer = parse_xml_content(content, "answer") + if answer is None: + continue + gt_answers = kwargs["answer"] + mean_gt_len = sum([len(gt_answer) for gt_answer in gt_answers]) / len( + gt_answers + ) + if len(answer) > 0: + diff_from_mean = min(mean_gt_len / len(answer), 1.0) # penalize long answers + else: + diff_from_mean = 0.0 + if answer in gt_answers: + msgs_scores.append(2.0) + elif answer.lower() in [ans.lower() for ans in gt_answers]: + msgs_scores.append(1.0) + elif any(ans.lower() in answer.lower() for ans in gt_answers): + msgs_scores.append(diff_from_mean) + if msgs_scores == []: + return 0.0 + else: + return sum(msgs_scores) / len(msgs_scores) / 2.0 + + +rubric = vf.Rubric( + funcs=[ + parser.get_format_reward_func(), + correctness_reward_func, + ] +) + +vf_env = vf.SingleTurnEnv( + dataset=dataset, + eval_dataset=eval_dataset, + system_prompt=system_prompt, + parser=parser, + rubric=rubric, + data_collator=data_collator, +) + +model_name = "Qwen/Qwen2.5-VL-7B-Instruct" +model, processor = vf.get_model_and_tokenizer(model_name) +run_name = "docvqa_" + model_name.split("/")[-1].lower() + +training_args = vf.grpo_defaults(run_name=run_name) +training_args.learning_rate = 3e-6 +training_args.max_steps = -1 +training_args.eval_strategy = "steps" +training_args.eval_steps = 100 +training_args.gradient_checkpointing_kwargs = { + "use_reentrant": False, +} + +trainer = vf.GRPOTrainer( + model=model, + processing_class=processor, + env=vf_env, + args=training_args, +) +trainer.train() diff --git a/verifiers/inference/vllm_server.py b/verifiers/inference/vllm_server.py index dbe991d50..2d5f29abb 100644 --- a/verifiers/inference/vllm_server.py +++ b/verifiers/inference/vllm_server.py @@ -85,7 +85,7 @@ async def get_next_worker_connection(connections: list[AnyType]) -> tuple[int, A # -------- OpenAI /v1/chat/completions Pydantic Models ---------- # class OAChatMessage(BaseModel): role: str - content: str + content: str | list class OAChatCompletionRequest(BaseModel): model: str diff --git a/verifiers/trainers/async_batch_generator.py b/verifiers/trainers/async_batch_generator.py index 28946cd5e..6e73af494 100644 --- a/verifiers/trainers/async_batch_generator.py +++ b/verifiers/trainers/async_batch_generator.py @@ -11,7 +11,7 @@ class BatchRequest: """Request for batch generation""" batch_id: int - env_inputs: Dict[str, List[Any]] + env_inputs: Dict[str, List[Any] | None] processing_class: Any mask_env_responses: bool max_completion_length: int @@ -239,6 +239,7 @@ def _generate_batch(self, request: BatchRequest) -> BatchResult: # Process results processed_results = self.env.process_env_results( env_results['prompt'], + env_results['images'], env_results['completion'], env_results['state'], env_results['reward'], diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 4fb929e6e..2a721133e 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -1,5 +1,6 @@ # adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py +import inspect import logging from collections import defaultdict, deque from contextlib import nullcontext @@ -11,11 +12,10 @@ from torch.utils.data import DataLoader, Sampler from accelerate.utils import broadcast_object_list, gather_object, is_peft_model from peft import PeftConfig, get_peft_model -from torch.utils.data import DataLoader -from transformers import AutoModelForCausalLM from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.modeling_utils import PreTrainedModel from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from transformers.processing_utils import ProcessorMixin from transformers.trainer import Trainer from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import seed_worker @@ -34,6 +34,7 @@ from verifiers.trainers.async_batch_generator import AsyncBatchGenerator, BatchRequest from verifiers.trainers.async_dataloader_wrapper import AsyncDataLoaderWrapper from verifiers.utils.logging_utils import print_prompt_completions_sample +from verifiers.utils.model_utils import generic_model_loader class RepeatSampler(Sampler): """ @@ -132,6 +133,19 @@ def __len__(self) -> int: return self.num_samples * self.mini_repeat_count * self.repeat_count +def _accepts_logits_to_keep(model) -> bool: + forward = ( + model.get_base_model().forward + if hasattr(model, "get_base_model") + else model.forward + ) + try: + inspect.signature(forward).bind_partial(**{"logits_to_keep": None}) + return True + except TypeError: + return False + + # torch.nanstd doesn't exist, so we define it here def nanstd(tensor: torch.Tensor) -> torch.Tensor: """ @@ -150,53 +164,53 @@ def nanstd(tensor: torch.Tensor) -> torch.Tensor: variance *= count / (count - 1) # Bessel's correction return torch.sqrt(variance) -def split_tensor_dict( - tensor_dict: dict[str, Optional[torch.Tensor]], num_chunks: int -) -> list[dict[str, Optional[torch.Tensor]]]: +def shuffle_data_dict(data_dict: dict[str, Any]) -> dict[str, Any]: """ - Splits a dictionary of tensors along the first dimension into `num_chunks` equal parts. - - Example: - >>> x = torch.arange(12).reshape(6, 2) - >>> y = torch.arange(6).reshape(6, 1) - >>> tensor_dict = {"x": x, "y": y} - >>> split_tensor_dict(tensor_dict, 3) - [ - {"x": tensor([[0, 1], [2, 3]]), "y": tensor([[0], [1]])}, - {"x": tensor([[4, 5], [6, 7]]), "y": tensor([[2], [3]])}, - {"x": tensor([[ 8, 9], [10, 11]]), "y": tensor([[4], [5]])} - ] + Shuffles a dictionary of tensors or lists along the first dimension in unison. """ - first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None) - chunk_size = first_tensor.shape[0] // num_chunks - return [ - { - key: tensor[i * chunk_size : (i + 1) * chunk_size] if tensor is not None else None - for key, tensor in tensor_dict.items() - } - for i in range(num_chunks) - ] + first_item = next(item for item in data_dict.values() if item is not None) + batch_size = len(first_item) + permutation = torch.randperm(batch_size) -def shuffle_tensor_dict(tensor_dict: dict[str, Optional[torch.Tensor]]) -> dict[str, Optional[torch.Tensor]]: - """ - Shuffles a dictionary of tensors along the first dimension in unison. + shuffled_dict = {} + for key, value in data_dict.items(): + if value is None: + shuffled_dict[key] = None + elif isinstance(value, torch.Tensor): + shuffled_dict[key] = value[permutation] + elif isinstance(value, list): + shuffled_dict[key] = [value[i] for i in permutation] + else: + raise TypeError(f"Unsupported type for shuffling: {type(value)}") + return shuffled_dict - Example: - >>> x = torch.arange(6).reshape(3, 2) - >>> y = torch.arange(3).reshape(3, 1) - >>> tensor_dict = {"x": x, "y": y} - >>> shuffle_tensor_dict(tensor_dict) - {'x': tensor([[2, 3], - [0, 1], - [4, 5]]), - 'y': tensor([[1], - [0], - [2]])} +def split_data_dict( + data_dict: dict[str, Any], num_chunks: int +) -> list[dict[str, Any]]: """ - first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None) - batch_size = first_tensor.shape[0] - permutation = torch.randperm(batch_size) - return {key: tensor[permutation] if tensor is not None else None for key, tensor in tensor_dict.items()} + Splits a dictionary of tensors or lists along the first dimension into `num_chunks` equal parts. + """ + first_item = next(item for item in data_dict.values() if item is not None) + # Ensure chunk_size is an integer + chunk_size = len(first_item) // num_chunks + if len(first_item) % num_chunks != 0: + logging.warning( + f"The total number of samples ({len(first_item)}) is not divisible by the number of chunks ({num_chunks}). " + f"The last {len(first_item) % num_chunks} samples will be dropped." + ) + + chunked_list = [] + for i in range(num_chunks): + chunk = {} + start_idx = i * chunk_size + end_idx = (i + 1) * chunk_size + for key, value in data_dict.items(): + if value is None: + chunk[key] = None + else: + chunk[key] = value[start_idx:end_idx] + chunked_list.append(chunk) + return chunked_list def nanmin(tensor: torch.Tensor) -> torch.Tensor: """ @@ -233,7 +247,7 @@ def __init__( model: PreTrainedModel, env: Environment, args: GRPOConfig, - processing_class: PreTrainedTokenizerBase, + processing_class: ProcessorMixin, callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), peft_config: Optional[PeftConfig] = None, @@ -252,9 +266,14 @@ def __init__( # Suppress irrelevant warning model.warnings_issued["estimate_tokens"] = True + self.tokenizer_base = getattr(processing_class, "tokenizer", None) # extract tokenizer in multimodal case # Tokenizer pad token - if processing_class.pad_token is None: # type: ignore - processing_class.pad_token = processing_class.eos_token # type: ignore + if self.tokenizer_base is not None: + if processing_class.tokenizer.pad_token is None: # type: ignore + processing_class.tokenizer.pad_token = processing_class.tokenizer.eos_token # type: ignore + else: + if processing_class.pad_token is None: # type: ignore + processing_class.pad_token = processing_class.eos_token # type: ignore # Training arguments self.per_device_train_batch_size = args.per_device_train_batch_size @@ -293,8 +312,6 @@ def __init__( train_dataset = env.get_dataset() assert train_dataset is not None - eval_dataset = env.get_eval_dataset() - # Filter out prompts that are too long if max_prompt_length is set if self.max_prompt_length is not None: self.logger.info(f"Filtering dataset for prompts with length <= {self.max_prompt_length}") @@ -309,7 +326,12 @@ def filter_by_prompt_length(example): else: # Completion format prompt_text = prompt - prompt_ids = processing_class.encode(prompt_text) # type: ignore + if self.tokenizer_base is not None: + encode = self.tokenizer_base.encode + else: + assert isinstance(processing_class, PreTrainedTokenizerBase) + encode = processing_class.encode + prompt_ids = encode(prompt_text) # type: ignore return len(prompt_ids) <= max_length original_size = len(train_dataset) @@ -319,14 +341,14 @@ def filter_by_prompt_length(example): self.logger.info(f"Filtered dataset from {original_size} to {filtered_size} examples ({original_size - filtered_size} prompts were too long)") # dummy data collator - def data_collator(features): + def default_data_collator(features): return features super().__init__( model=model, args=args, - data_collator=data_collator, + data_collator=env.data_collator if env.data_collator is not None else default_data_collator, train_dataset=train_dataset, - eval_dataset=eval_dataset, + eval_dataset=datasets.Dataset.from_dict({}), # dummy eval ds. This is actually handled by environment processing_class=processing_class, callbacks=callbacks, optimizers=optimizers, @@ -339,7 +361,7 @@ def data_collator(features): elif is_deepspeed_zero3_enabled(): model_id = model.config._name_or_path model_init_kwargs = {"torch_dtype": "auto"} - self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs) + self.ref_model = generic_model_loader(model_id, **model_init_kwargs) elif is_peft_model(model): # If PEFT is used, the reference model is not needed since the adapter can be disabled # to revert to the initial model. @@ -550,16 +572,28 @@ def _get_last_hidden_state(self, unwrapped_model, input_ids, attention_mask, log return last_hidden_state # Get the per-token log probabilities for the completions for the model and the reference model - def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, batch_size=None) -> torch.Tensor: + def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, batch_size=None, **model_kwargs) -> torch.Tensor: batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak all_logps = [] + accepts_logits_to_keep = _accepts_logits_to_keep(model) for i in range(0, input_ids.size(0), batch_size): input_ids_batch = input_ids[i : i + batch_size] attention_mask_batch = attention_mask[i : i + batch_size] - + model_kwargs_batch = {} + for key, value in model_kwargs.items(): + if isinstance(value, list): + # 1. Slice the list to get the tensors for this micro-batch + sub_list = value[i : i + batch_size] + # 2. Batch the tensors in the sub-list together + model_kwargs_batch[key] = torch.cat(sub_list, dim=0).to(self.accelerator.device) + else: + # Handle non-list arguments (like the 'logits_to_keep' we added) + model_kwargs_batch[key] = value + if accepts_logits_to_keep: + model_kwargs_batch["logits_to_keep"] = logits_to_keep + 1 # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded logits = model( - input_ids=input_ids_batch, attention_mask=attention_mask_batch, logits_to_keep=logits_to_keep + 1 + input_ids=input_ids_batch, attention_mask=attention_mask_batch, **model_kwargs_batch ).logits logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred input_ids_batch = input_ids_batch[:, -logits_to_keep:] @@ -670,7 +704,7 @@ def _ids_to_tensors(self, 'mask': mask } - def _gather_batch_data(self, batch_offset: int = 0) -> Tuple[List[Any], List[Any], List[Any], List[Any]]: + def _gather_batch_data(self, batch_offset: int = 0) -> Tuple[List[Any], List[Any] | None, List[Any], List[Any], List[Any]]: """ Gather batch data from all processes. @@ -687,14 +721,18 @@ def _gather_batch_data(self, batch_offset: int = 0) -> Tuple[List[Any], List[Any # Gather batch data from all processes prompts = [x['prompt'] for x in batch] + images = [x['images'] for x in batch if 'images' in x] answers = [x['answer'] for x in batch] tasks = [x.get('task', 'default') for x in batch] infos = [x.get('info', {}) for x in batch] all_prompts = gather_object(prompts) + all_images = gather_object(images) + all_images = all_images if all_images != [] else None all_answers = gather_object(answers) all_tasks = gather_object(tasks) all_infos = gather_object(infos) - return all_prompts, all_answers, all_tasks, all_infos + + return all_prompts, all_images, all_answers, all_tasks, all_infos def _prepare_inputs( # type: ignore self, inputs: list[dict[str, Any]] @@ -741,15 +779,14 @@ def _prepare_inputs( # type: ignore for batch_id in range(self._next_batch_id, target_batch_id + 1): batch_offset = batch_id - batch_id_to_retrieve - all_prompts, all_answers, all_tasks, all_infos = self._gather_batch_data(batch_offset) + all_prompts, all_images, all_answers, all_tasks, all_infos = self._gather_batch_data(batch_offset) local_batch_size = len(all_prompts) // self.accelerator.num_processes - # Submit batch (main process only) if self.accelerator.is_main_process: request = BatchRequest( batch_id=batch_id, - env_inputs={'prompt': all_prompts, 'answer': all_answers, 'task': all_tasks, 'info': all_infos}, + env_inputs={'prompt': all_prompts, 'images': all_images, 'answer': all_answers, 'task': all_tasks, 'info': all_infos}, processing_class=self.processing_class, mask_env_responses=self.mask_env_responses, max_completion_length=self.max_completion_length, @@ -793,6 +830,7 @@ def _prepare_inputs( # type: ignore 'completion_ids': processed_results['completion_ids'], 'completion_mask': processed_results['completion_mask'], 'rewards': processed_results['rewards'], + 'remaining_inputs': processed_results['remaining_inputs'], 'all_reward_dict': batch_result.all_reward_dict if hasattr(batch_result, 'all_reward_dict') else {'reward': processed_results['rewards']}, 'completions': batch_result.completions if hasattr(batch_result, 'completions') else [], 'prompts': batch_result.prompts if hasattr(batch_result, 'prompts') else [], @@ -831,9 +869,14 @@ def _prepare_inputs( # type: ignore completion_mask_list.append(torch.tensor(broadcast_data['completion_mask'][i], device=self.accelerator.device)) # Pad sequences - prompt_ids = pad(prompt_ids_list, padding_value=self.processing_class.pad_token_id, padding_side='left') # type: ignore + if self.tokenizer_base is not None: + pad_token_id = self.tokenizer_base.pad_token_id + else: + assert isinstance(self.processing_class, PreTrainedTokenizerBase) + pad_token_id = self.processing_class.pad_token_id + prompt_ids = pad(prompt_ids_list, padding_value=pad_token_id, padding_side='left') # type: ignore prompt_mask = pad(prompt_mask_list, padding_side='left') # type: ignore - completion_ids = pad(completion_ids_list, padding_value=self.processing_class.pad_token_id, padding_side='right') # type: ignore + completion_ids = pad(completion_ids_list, padding_value=pad_token_id, padding_side='right') # type: ignore completion_mask = pad(completion_mask_list) # Truncate if needed @@ -847,7 +890,10 @@ def _prepare_inputs( # type: ignore # Take this process's slice of advantages advantages = all_advantages[process_slice] - + + # slice remaining inputs + remaining_inputs = broadcast_data['remaining_inputs'][process_slice] + # Log metrics on main process only if self.accelerator.is_main_process: self._log_reward_metrics_primary( @@ -870,7 +916,7 @@ def _prepare_inputs( # type: ignore all_completion_ids=broadcast_data['completion_ids'], all_prompt_mask=broadcast_data['prompt_mask'] ) - + # Concatenate all data for shuffling full_batch = { "prompt_ids": prompt_ids, @@ -879,11 +925,12 @@ def _prepare_inputs( # type: ignore "completion_mask": completion_mask, "old_per_token_logps": None, "advantages": advantages, + "remaining_inputs": remaining_inputs, } # Shuffle and split for gradient accumulation - full_batch = shuffle_tensor_dict(full_batch) - self._buffered_inputs = split_tensor_dict(full_batch, self.gradient_accumulation_steps) + full_batch = shuffle_data_dict(full_batch) + self._buffered_inputs = split_data_dict(full_batch, self.gradient_accumulation_steps) self.accelerator.wait_for_everyone() # Return appropriate slice from buffer result = self._buffered_inputs[self._step % self.gradient_accumulation_steps] @@ -913,17 +960,19 @@ def _compute_advantages( def compute_loss(self, model: PreTrainedModel, - inputs: Dict[str, torch.Tensor], + inputs: Dict[str, Any], return_outputs: bool = False, num_items_in_batch: int | None = None) -> torch.Tensor: mode = "train" # Compute the per-token log probabilities for the model prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + model_kwargs = inputs["remaining_inputs"] input_ids = torch.cat([prompt_ids, completion_ids], dim=1) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) + model_kwargs = {key: [d[key] for d in model_kwargs] for key in model_kwargs[0].keys()} + per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep, batch_size = 1 if model_kwargs != {} else None, **model_kwargs) # Compute the loss advantages = inputs["advantages"] @@ -950,12 +999,12 @@ def compute_loss(self, with torch.no_grad(): if self.ref_model is not None: ref_per_token_logps = self._get_per_token_logps( - self.ref_model, input_ids, attention_mask, logits_to_keep + self.ref_model, input_ids, attention_mask, logits_to_keep, batch_size = 1 if model_kwargs != {} else None, **model_kwargs ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): # type: ignore ref_per_token_logps = self._get_per_token_logps( - self.model, input_ids, attention_mask, logits_to_keep + self.model, input_ids, attention_mask, logits_to_keep, batch_size = 1 if model_kwargs != {} else None, **model_kwargs ) per_token_kl = ( torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 @@ -1031,13 +1080,21 @@ def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval" completions = eval_results['completion'] if isinstance(completions[0], str): # Completion format - directly tokenize strings - completion_lengths = [len(self.processing_class.encode(c)) for c in completions] # type: ignore + if self.tokenizer_base is not None: + encode = self.tokenizer_base.encode + else: + assert isinstance(self.processing_class, PreTrainedTokenizerBase) + encode = self.processing_class.encode + completion_lengths = [len(encode(c)) for c in completions] # type: ignore else: # Chat format - use apply_chat_template completion_lengths = [] for comp in completions: # Apply chat template to get the full text - tokens = self.processing_class.apply_chat_template(comp, tokenize=True, add_generation_prompt=False) # type: ignore + if hasattr(self.processing_class, "tokenizer"): # if multimodal processor, use tokenizer; ow, it expects mm inputs + tokens = self.processing_class.tokenizer.apply_chat_template(comp, tokenize=True, add_generation_prompt=False) # type: ignore + else: + tokens = self.processing_class.apply_chat_template(comp, tokenize=True, add_generation_prompt=False) # type: ignore # Tokenize and count completion_lengths.append(len(tokens)) @@ -1070,10 +1127,18 @@ def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval" # Log to wandb if available if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None: import pandas as pd - + # format prompt for logging + prompt = [] + if prompts: + for messages in prompts: + last_message = messages[-1] + content = last_message.get("content", "") + if isinstance(content, list): + content = content[0]["text"] # extract text only in multimodal case + prompt.append([{'role': 'user', 'content': content}]) table_data = { "step": [str(self.state.global_step)] * len(prompts), - "prompt": prompts, + "prompt": prompt, "completion": completions, } for k, v in reward_dict.items(): @@ -1115,10 +1180,18 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None: import pandas as pd - + # format prompt for logging + prompt = [] + if list(self._textual_logs["prompt"]): + for messages in list(self._textual_logs["prompt"]): + last_message = messages[-1] + content = last_message.get("content", "") + if isinstance(content, list): + content = content[0]["text"] # extract text only in multimodal case + prompt.append([{'role': 'user', 'content': content}]) table = { "step": [str(self.state.global_step)] * len(self._textual_logs["prompt"]), - "prompt": list(self._textual_logs["prompt"]), + "prompt": prompt, "completion": list(self._textual_logs["completion"]), **{k: list(v) for k, v in self._textual_logs["rewards"].items()}, } @@ -1207,8 +1280,13 @@ def _log_completion_metrics_primary( # Check for EOS tokens term_lengths = [] + if self.tokenizer_base is not None: + eos_token_id = self.tokenizer_base.eos_token_id + else: + assert isinstance(self.processing_class, PreTrainedTokenizerBase) + eos_token_id = self.processing_class.eos_token_id for comp_ids, comp_mask in zip(all_completion_ids, all_completion_mask): - has_eos = any(token == self.processing_class.eos_token_id for token, mask in zip(comp_ids, comp_mask) if mask) # type: ignore + has_eos = any(token == eos_token_id for token, mask in zip(comp_ids, comp_mask) if mask) # type: ignore if has_eos: term_lengths.append(sum(comp_mask)) diff --git a/verifiers/utils/__init__.py b/verifiers/utils/__init__.py index f18dacc98..f247922b0 100644 --- a/verifiers/utils/__init__.py +++ b/verifiers/utils/__init__.py @@ -1,5 +1,5 @@ from .data_utils import extract_boxed_answer, extract_hash_answer, load_example_dataset -from .model_utils import get_model, get_tokenizer, get_model_and_tokenizer +from .model_utils import get_model, get_tokenizer, get_model_and_tokenizer, generic_model_loader from .logging_utils import setup_logging, print_prompt_completions_sample __all__ = [ @@ -11,4 +11,5 @@ "get_model_and_tokenizer", "setup_logging", "print_prompt_completions_sample", + "generic_model_loader", ] \ No newline at end of file diff --git a/verifiers/utils/logging_utils.py b/verifiers/utils/logging_utils.py index a36b97f1f..454d07537 100644 --- a/verifiers/utils/logging_utils.py +++ b/verifiers/utils/logging_utils.py @@ -79,6 +79,8 @@ def print_prompt_completions_sample( if prompt: last_message = prompt[-1] content = last_message.get("content", "") + if isinstance(content, list): # multimodal case + content = content[0]["text"] formatted_prompt = Text(content, style="bright_yellow") else: formatted_prompt = Text("") diff --git a/verifiers/utils/model_utils.py b/verifiers/utils/model_utils.py index 12679412b..21894439a 100644 --- a/verifiers/utils/model_utils.py +++ b/verifiers/utils/model_utils.py @@ -1,8 +1,11 @@ +import importlib from importlib.util import find_spec +from importlib import import_module from typing import Dict, Any, Union, Tuple, Callable import torch -from transformers import AutoModelForCausalLM, AutoTokenizer # type: ignore +from transformers import AutoModelForCausalLM, AutoModel, AutoProcessor, AutoConfig, PreTrainedModel # type: ignore +from transformers.models.auto.modeling_auto import AutoModelForSeq2SeqLM, AutoModelForVision2Seq import torch.nn as nn @@ -59,7 +62,37 @@ def on_after_outer_forward(self, wrapper_module: nn.Module, original_module: nn. def is_liger_available() -> bool: return find_spec("liger_kernel") is not None -def get_model(model_name: str, use_liger: bool = True, model_kwargs: Union[Dict[str, Any], None] = None) -> Any: +def generic_model_loader(model_id: str, **model_kwargs) -> PreTrainedModel: + cfg = AutoConfig.from_pretrained(model_id, trust_remote_code=True) + for arch in cfg.architectures or []: + try: + cls = getattr(import_module("transformers"), arch) + return cls.from_pretrained( + model_id, + trust_remote_code=True, + **model_kwargs, + ) + except (AttributeError, ImportError, ValueError): + pass + + for auto_cls in ( + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + AutoModelForVision2Seq, + AutoModel, + ): + try: + return auto_cls.from_pretrained( + model_id, + trust_remote_code=True, + **model_kwargs, + ) + except ValueError: + continue + + raise RuntimeError(f"No suitable loader found for model type {cfg.model_type!r}") + +def get_model(model_name: str, use_liger: bool = True, liger_patch_suffix: str | None = None, model_kwargs: Union[Dict[str, Any], None] = None) -> Any: if model_kwargs is None: model_kwargs = dict( torch_dtype=torch.bfloat16, @@ -68,20 +101,38 @@ def get_model(model_name: str, use_liger: bool = True, model_kwargs: Union[Dict[ ) if is_liger_available() and use_liger: print("Using Liger kernel") - from liger_kernel.transformers import AutoLigerKernelForCausalLM # type: ignore - return AutoLigerKernelForCausalLM.from_pretrained(model_name, **model_kwargs) + try: + from liger_kernel.transformers import AutoLigerKernelForCausalLM # type: ignore + model = AutoLigerKernelForCausalLM.from_pretrained(model_name, **model_kwargs) + return model + except ValueError: # try monkey patch + print(f"Model {model_name} is not supported with AutoLigerKernelForCausalLM. Attempting monkey patch...") + if liger_patch_suffix is None: # try with model tpe + liger_patch_suffix = AutoConfig.from_pretrained(model_name, trust_remote_code=True).model_type + print(f"No liger_patch_suffix provided, attempting with model_type: {liger_patch_suffix}") + patch_func_name = f"apply_liger_kernel_to_{liger_patch_suffix}" + ligermod = importlib.import_module("liger_kernel.transformers") + patch_func = getattr(ligermod, patch_func_name, None) + if callable(patch_func): + patch_func() + model = generic_model_loader(model_name, **model_kwargs) + print(f"Applied Liger-Kernel patch to {model_name}") + return model + else: + raise ValueError(f"Model {model_name} may not be supported with Liger-Kernel in verifiers. Check the Liger-Kernel documentation.") else: - return AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) + return generic_model_loader(model_name, **model_kwargs) -def get_tokenizer(model_name: str) -> Any: - tokenizer = AutoTokenizer.from_pretrained(model_name) +def get_tokenizer(model_name: str, padding_side: str = "left") -> Any: + processor = AutoProcessor.from_pretrained(model_name, padding_side=padding_side) + tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor if not hasattr(tokenizer, "chat_template"): raise ValueError(f"Tokenizer for model {model_name} does not have chat_template attribute, \ and could not find a tokenizer with the same name as the model with suffix \ '-Instruct'. Please provide a tokenizer with the chat_template attribute.") - return tokenizer + return processor -def get_model_and_tokenizer(model_name: str, use_liger: bool = True, model_kwargs: Union[Dict[str, Any], None] = None) -> Tuple[Any, Any]: - model = get_model(model_name, use_liger, model_kwargs) +def get_model_and_tokenizer(model_name: str, use_liger: bool = True, liger_patch_suffix:str | None = None, model_kwargs: Union[Dict[str, Any], None] = None) -> Tuple[Any, Any]: + model = get_model(model_name, use_liger, liger_patch_suffix, model_kwargs) tokenizer = get_tokenizer(model_name) return model, tokenizer \ No newline at end of file