Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
3774b56
use flexible model loading
nph4rd Jun 2, 2025
2fbd45a
start example
nph4rd Jun 2, 2025
ad19f9f
use AutoProcessor class
nph4rd Jun 2, 2025
88405ce
fix processor calls and pin transformers
nph4rd Jun 2, 2025
38062e0
gather images
nph4rd Jun 3, 2025
f91df35
format images
nph4rd Jun 3, 2025
85c3f31
fix rich log
nph4rd Jun 3, 2025
d4e9818
update comments
nph4rd Jun 3, 2025
f38eee8
update example
nph4rd Jun 6, 2025
c40152a
fix wandb logs
nph4rd Jun 7, 2025
6eddfd6
resize
nph4rd Jun 8, 2025
3be9f45
model len
nph4rd Jun 8, 2025
65aeb1b
update example
nph4rd Jun 9, 2025
606027a
fix format and remove unused images
nph4rd Jun 9, 2025
bd76eab
fix image unpacking
nph4rd Jun 9, 2025
a89de9b
change format dataset
nph4rd Jun 9, 2025
6c464f9
opt
nph4rd Jun 9, 2025
5036f69
fix format on text-only
nph4rd Jun 9, 2025
f3959f0
fix _gather_batch_data type
nph4rd Jun 9, 2025
b8732f7
relax transformers condition
nph4rd Jun 9, 2025
e461258
update comment / increase lr
nph4rd Jun 9, 2025
58ac1c7
liger monkey patch
nph4rd Jun 9, 2025
2e25655
generic liger patch
nph4rd Jun 10, 2025
0da5888
increase lr
nph4rd Jun 10, 2025
c9eaa02
return to old naming
nph4rd Jun 10, 2025
8aff5d1
remove todos
nph4rd Jun 10, 2025
9bb135c
restore padding side
nph4rd Jun 10, 2025
a0323d7
remove padding side for completion_mask
nph4rd Jun 10, 2025
6b630ba
fix wandb logging
nph4rd Jun 10, 2025
edd9b29
logging format
nph4rd Jun 10, 2025
f3f403e
use data collator
nph4rd Jun 10, 2025
815e000
format oai-api prompts
nph4rd Jun 10, 2025
6ed448d
post-process images
nph4rd Jun 11, 2025
24ed694
fix text position
nph4rd Jun 11, 2025
ffee1d6
process inputs in environment
nph4rd Jun 11, 2025
e0026a8
increase res and lr
nph4rd Jun 12, 2025
f51c665
fix
nph4rd Jun 12, 2025
9e175bc
Merge pull request #1 from nph4rd/mm-kwargs
nph4rd Jun 12, 2025
4dd16b7
fix eval with data collator
nph4rd Jun 16, 2025
020d9d7
transform eval ds once
nph4rd Jun 16, 2025
ac5682b
change eval steps
nph4rd Jun 16, 2025
2538d62
fix batch size in func call
nph4rd Jun 17, 2025
b1b90e4
liger patch suffix opt
nph4rd Jun 17, 2025
5dbc658
load ref with generic_model_loader
nph4rd Jun 17, 2025
ee2e44c
set use_reentrant false
nph4rd Jun 17, 2025
9ff8b79
reset format_dataset func
nph4rd Jun 18, 2025
49c28bb
format stuff
nph4rd Jun 18, 2025
cc69cc6
rase error
nph4rd Jun 18, 2025
eed2fe2
Merge branch 'main' into multimodal
nph4rd Jun 18, 2025
d41d00f
update example
nph4rd Jun 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies = [
"requests>=2.32.3",
"openai>=1.81.0",
"datasets>=3.6.0",
"transformers",
"transformers<4.52.0",
"nest-asyncio>=1.6.0",
]

Expand Down
224 changes: 169 additions & 55 deletions verifiers/envs/environment.py

Large diffs are not rendered by default.

130 changes: 130 additions & 0 deletions verifiers/examples/docvqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import re

from datasets import load_dataset
from qwen_vl_utils import process_vision_info

import verifiers as vf

"""
# install qwen stuff
uv pip install qwen-vl-utils
# inference
CUDA_VISIBLE_DEVICES=0,1,2,3 vf-vllm --model 'Qwen/Qwen2.5-VL-7B-Instruct' --max-model-len 32000 --tensor_parallel_size 4
# train
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch --config-file configs/zero3.yaml --num-processes 4 verifiers/examples/docvqa.py
"""


def data_collator(batch: list[dict]) -> list[dict]:
processed_samples = []
for sample in batch:
messages = []
messages.append({"role": "system", "content": system_prompt})
content_block = []
content_block.append({"type": "text", "text": sample["question"]})
content_block.append(
{
"type": "image",
"image": sample["image"], # only one image in this ds
"resized_height": 768, # XGA resolution
"resized_width": 1024,
}
)
messages.append({"role": "user", "content": content_block})
processed_images, *_ = process_vision_info( # process with qwen utils
messages.copy()
)
sample["prompt"] = messages
sample["images"] = processed_images
sample["answer"] = sample["answers"]
processed_samples.append(sample)
return processed_samples


dataset = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation[10%:]")
eval_dataset = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation[:10%]")

parser = vf.XMLParser(["think", "answer"], answer_field="answer")
system_prompt = f"""Answer the questions.

Respond in the following format:
{parser.get_format_str()}"""


def correctness_reward_func(completion: list[dict[str, str]], **kwargs) -> float:
def get_assistant_messages(messages: list[dict[str, str]]) -> list[dict[str, str]]:
return [msg for msg in messages if msg.get("role") == "assistant"]

def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None:
pattern = rf"<{tag}>\s*(.*?)\s*</{tag}>"
match = re.search(pattern, text, re.DOTALL)
if match:
content = match.group(1)
return content.strip() if strip else content
return None

assistant_messages = get_assistant_messages(completion)
if assistant_messages is None:
return 0.0
msgs_scores = []
for msg in assistant_messages:
content = msg.get("content", "")
answer = parse_xml_content(content, "answer")
if answer is None:
continue
gt_answers = kwargs["answer"]
mean_gt_len = sum([len(gt_answer) for gt_answer in gt_answers]) / len(
gt_answers
)
if len(answer) > 0:
diff_from_mean = min(mean_gt_len / len(answer), 1.0) # penalize long answers
else:
diff_from_mean = 0.0
if answer in gt_answers:
msgs_scores.append(2.0)
elif answer.lower() in [ans.lower() for ans in gt_answers]:
msgs_scores.append(1.0)
elif any(ans.lower() in answer.lower() for ans in gt_answers):
msgs_scores.append(diff_from_mean)
if msgs_scores == []:
return 0.0
else:
return sum(msgs_scores) / len(msgs_scores) / 2.0


rubric = vf.Rubric(
funcs=[
parser.get_format_reward_func(),
correctness_reward_func,
]
)

vf_env = vf.SingleTurnEnv(
dataset=dataset,
eval_dataset=eval_dataset,
system_prompt=system_prompt,
parser=parser,
rubric=rubric,
data_collator=data_collator,
)

model_name = "Qwen/Qwen2.5-VL-7B-Instruct"
model, processor = vf.get_model_and_tokenizer(model_name)
run_name = "docvqa_" + model_name.split("/")[-1].lower()

training_args = vf.grpo_defaults(run_name=run_name)
training_args.learning_rate = 3e-6
training_args.max_steps = -1
training_args.eval_strategy = "steps"
training_args.eval_steps = 100
training_args.gradient_checkpointing_kwargs = {
"use_reentrant": False,
}

trainer = vf.GRPOTrainer(
model=model,
processing_class=processor,
env=vf_env,
args=training_args,
)
trainer.train()
2 changes: 1 addition & 1 deletion verifiers/inference/vllm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ async def get_next_worker_connection(connections: list[AnyType]) -> tuple[int, A
# -------- OpenAI /v1/chat/completions Pydantic Models ---------- #
class OAChatMessage(BaseModel):
role: str
content: str
content: str | list

class OAChatCompletionRequest(BaseModel):
model: str
Expand Down
3 changes: 2 additions & 1 deletion verifiers/trainers/async_batch_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class BatchRequest:
"""Request for batch generation"""
batch_id: int
env_inputs: Dict[str, List[Any]]
env_inputs: Dict[str, List[Any] | None]
processing_class: Any
mask_env_responses: bool
max_completion_length: int
Expand Down Expand Up @@ -239,6 +239,7 @@ def _generate_batch(self, request: BatchRequest) -> BatchResult:
# Process results
processed_results = self.env.process_env_results(
env_results['prompt'],
env_results['images'],
env_results['completion'],
env_results['state'],
env_results['reward'],
Expand Down
Loading