Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 134 additions & 51 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,41 @@ def _pil_to_data_url(img: Image.Image, fmt: str | None = None) -> str:
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
return f"data:image/{fmt.lower()};base64,{b64}"

def format_oai_chat_msg(
prompts: List[List[Dict[str, Any]]],
images: List[List[Image.Image]]
) -> List[List[Dict[str, Any]]]:
formatted_conversations: List[List[Dict[str, Any]]] = []

for conv_prompts, conv_images in zip(prompts, images):
img_iter = iter(conv_images)
new_conv: List[Dict[str, Any]] = []

for msg in conv_prompts:
role = msg["role"]
content = msg["content"]

if isinstance(content, list):
new_parts: List[Dict[str, Any]] = []
for part in content:
if part.get("type") == "image":
img = next(img_iter)
data_url = _pil_to_data_url(img)
new_parts.append({
"type": "image_url",
"image_url": {"url": data_url}
})
else:
new_parts.append(part.copy())
new_conv.append({"role": role, "content": new_parts})

else:
new_conv.append({"role": role, "content": content})

formatted_conversations.append(new_conv)

return formatted_conversations

class Environment(ABC):
"""
Base class for all environments.
Expand Down Expand Up @@ -324,8 +359,12 @@ def generate(self,
results = {col: deepcopy(inputs[col]) for col in inputs.column_names}
else:
results = deepcopy(inputs)
if results.get('images') is not None:
prompts = format_oai_chat_msg(results['prompt'], results['images'])
else:
prompts = results['prompt']
rollouts = self.run_rollouts(
prompts=results['prompt'],
prompts=prompts,
client=client,
model=model,
sampling_args=gen_sampling_args,
Expand All @@ -352,6 +391,7 @@ def generate(self,
def process_chat_format(
self,
prompt: List[Dict[str, str]],
images: Optional[List[List[Any]]],
completion: List[Dict[str, str]],
processing_class: PreTrainedTokenizerBase,
mask_env_responses: bool = False
Expand All @@ -367,59 +407,96 @@ def process_chat_format(
Returns:
prompt_ids, prompt_mask, completion_ids, completion_mask
"""
# tokenize just the prompt
prompt_text = processing_class.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
assert isinstance(prompt_text, str)
if hasattr(processing_class, "tokenizer"):
encode = processing_class.tokenizer.encode
else:
encode = processing_class.encode
prompt_ids = encode(prompt_text)
prompt_mask = [1] * len(prompt_ids)

# track completion tokens and masks by processing incrementally
completion_ids = []
completion_mask = []

# previous tokenization (starts with just prompt)
prev_ids = prompt_ids

# process each completion message incrementally
for i, msg in enumerate(completion):
# create conversation prefix: prompt + completion[:i+1]
conversation_prefix = prompt + completion[:i+1]
remaining_inputs = {}
if images:
prompt_text = processing_class.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
inputs = processing_class(text=prompt_text, images=images, return_tensors="pt")
remaining_inputs = {
k: v
for k, v in inputs.items()
if k not in ["input_ids", "attention_mask"]
}
prev_ids = inputs.input_ids[0].tolist()
prompt_ids = prev_ids
prompt_mask = [1] * len(prompt_ids)

for i, msg in enumerate(completion):
conversation_prefix = prompt + completion[:i+1]
prefix_text = processing_class.apply_chat_template(
conversation_prefix,
tokenize=False,
add_generation_prompt=False,
)
current_ids = processing_class(text=prefix_text, images=images, return_tensors="pt").input_ids[0].tolist()
assert current_ids[:len(prev_ids)] == prev_ids, "Tokenization difference in chat format."
new_tokens = current_ids[len(prev_ids):]
completion_ids.extend(new_tokens)

if msg["role"] == "assistant":
msg_mask = [1] * len(new_tokens)
elif msg["role"] != "assistant" and mask_env_responses:
msg_mask = [0] * len(new_tokens)
else:
msg_mask = [1] * len(new_tokens)

completion_mask.extend(msg_mask)
prev_ids = current_ids
else:
# tokenize just the prompt
prompt_text = processing_class.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
assert isinstance(prompt_text, str)
if hasattr(processing_class, "tokenizer"):
encode = processing_class.tokenizer.encode
else:
encode = processing_class.encode
prompt_ids = encode(prompt_text)
prompt_mask = [1] * len(prompt_ids)

# tokenize the full prefix
prefix_text = processing_class.apply_chat_template(
conversation_prefix,
tokenize=False,
add_generation_prompt=False,
)
assert isinstance(prefix_text, str), f"Expected string from apply_chat_template, got {type(prefix_text)}"
current_ids = encode(prefix_text)
assert current_ids[:len(prev_ids)] == prev_ids, f"Tokenization difference in chat format. Current ids: {current_ids}, previous ids: {prev_ids}"
# track completion tokens and masks by processing incrementally
completion_ids = []
completion_mask = []

# add new tokens to completion tokens
new_tokens = current_ids[len(prev_ids):]
assert len(new_tokens) > 0, f"No new tokens in chat format. Current ids: {current_ids}, previous ids: {prev_ids}"
completion_ids.extend(new_tokens)

# create mask
if msg["role"] == "assistant":
msg_mask = [1] * len(new_tokens)
elif msg["role"] != "assistant" and mask_env_responses:
# mask intermediate 'user' and/or 'tool' messages
msg_mask = [0] * len(new_tokens)
else:
# default to not masking
msg_mask = [1] * len(new_tokens)
# previous tokenization (starts with just prompt)
prev_ids = prompt_ids

completion_mask.extend(msg_mask)
# Update previous tokenization for next iteration
prev_ids = current_ids
assert len(completion_ids) == len(completion_mask), f"Length mismatch in chat format. Completion ids: {completion_ids}, completion mask: {completion_mask}"
# process each completion message incrementally
for i, msg in enumerate(completion):
# create conversation prefix: prompt + completion[:i+1]
conversation_prefix = prompt + completion[:i+1]

# tokenize the full prefix
prefix_text = processing_class.apply_chat_template(
conversation_prefix,
tokenize=False,
add_generation_prompt=False,
)
assert isinstance(prefix_text, str), f"Expected string from apply_chat_template, got {type(prefix_text)}"
current_ids = encode(prefix_text)
assert current_ids[:len(prev_ids)] == prev_ids, f"Tokenization difference in chat format. Current ids: {current_ids}, previous ids: {prev_ids}"

# add new tokens to completion tokens
new_tokens = current_ids[len(prev_ids):]
assert len(new_tokens) > 0, f"No new tokens in chat format. Current ids: {current_ids}, previous ids: {prev_ids}"
completion_ids.extend(new_tokens)

return prompt_ids, prompt_mask, completion_ids, completion_mask
# create mask
if msg["role"] == "assistant":
msg_mask = [1] * len(new_tokens)
elif msg["role"] != "assistant" and mask_env_responses:
# mask intermediate 'user' and/or 'tool' messages
msg_mask = [0] * len(new_tokens)
else:
# default to not masking
msg_mask = [1] * len(new_tokens)

completion_mask.extend(msg_mask)
# Update previous tokenization for next iteration
prev_ids = current_ids
assert len(completion_ids) == len(completion_mask), f"Length mismatch in chat format. Completion ids: {completion_ids}, completion mask: {completion_mask}"

return prompt_ids, prompt_mask, completion_ids, completion_mask, remaining_inputs

def process_completion_format(
self,
Expand Down Expand Up @@ -455,6 +532,7 @@ def process_completion_format(
def process_env_results(
self,
prompts: List[Union[str, List[Dict[str, Any]]]],
images: Optional[List[List[Any]]],
completions: List[Union[str, List[Dict[str, Any]]]],
states: List[Dict[str, Any]],
rewards: List[float],
Expand All @@ -478,13 +556,16 @@ def process_env_results(
all_prompt_masks = []
all_completion_ids = []
all_completion_masks = []
all_remaining_inputs = []

images = images or [None] * len(prompts)

for i, (prompt, completion, state, reward) in enumerate(zip(prompts, completions, states, rewards)):
for i, (prompt, images, completion, state, reward) in enumerate(zip(prompts, images, completions, states, rewards)):
# Format-specific processing
if is_chat_format:
assert isinstance(prompt, list) and isinstance(completion, list)
prompt_ids, prompt_mask, completion_ids, completion_mask = self.process_chat_format(
prompt, completion, processing_class, mask_env_responses
prompt_ids, prompt_mask, completion_ids, completion_mask, remaining_inputs = self.process_chat_format(
prompt, images, completion, processing_class, mask_env_responses
)
else:
assert isinstance(prompt, str) and isinstance(completion, str)
Expand All @@ -498,13 +579,15 @@ def process_env_results(
all_prompt_masks.append(prompt_mask)
all_completion_ids.append(completion_ids)
all_completion_masks.append(completion_mask)
all_remaining_inputs.append(remaining_inputs)

return {
"prompt_ids": all_prompt_ids,
"prompt_mask": all_prompt_masks,
"completion_ids": all_completion_ids,
"completion_mask": all_completion_masks,
"rewards": rewards,
"remaining_inputs": all_remaining_inputs,
}

# Evaluation and dataset generation
Expand Down
36 changes: 26 additions & 10 deletions verifiers/examples/docvqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datasets import load_dataset

import verifiers as vf
from qwen_vl_utils import smart_resize
from qwen_vl_utils import process_vision_info

"""
# install qwen stuff
Expand All @@ -15,18 +15,33 @@
"""


def preprocess_docvqa(x):
return {
"question": x["question"],
"images": [x["image"].resize(smart_resize(768, 1024))], # XGA
"answer": x["answers"][0],
}
def data_collator(batch: list[dict]) -> list[dict]:
processed_samples = []
for sample in batch:
messages = []
messages.append({"role": "system", "content": system_prompt})
content_block = []
content_block.append({"type": "text", "text": sample["question"]})
content_block.append(
{
"type": "image",
"image": sample["image"], # only one image in this ds
"resized_height": 768, # XGA resolution
"resized_width": 1024,
}
)
messages.append({"role": "user", "content": content_block})
processed_images, *_ = process_vision_info( # process with qwen utils
messages.copy()
)
sample["prompt"] = messages
sample["images"] = processed_images
sample["answer"] = sample["answers"]
processed_samples.append(sample)
return processed_samples


dataset = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation")
dataset = dataset.map(
preprocess_docvqa, num_proc=10, remove_columns=dataset.column_names
)

parser = vf.XMLParser(["think", "answer"], answer_field="answer")
system_prompt = f"""Answer the questions.
Expand Down Expand Up @@ -97,5 +112,6 @@ def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None:
processing_class=processor,
env=vf_env,
args=training_args,
data_collator=data_collator,
)
trainer.train()
3 changes: 2 additions & 1 deletion verifiers/trainers/async_batch_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class BatchRequest:
"""Request for batch generation"""
batch_id: int
env_inputs: Dict[str, List[Any]]
env_inputs: Dict[str, List[Any] | None]
processing_class: Any
mask_env_responses: bool
max_completion_length: int
Expand Down Expand Up @@ -239,6 +239,7 @@ def _generate_batch(self, request: BatchRequest) -> BatchResult:
# Process results
processed_results = self.env.process_env_results(
env_results['prompt'],
env_results['images'],
env_results['completion'],
env_results['state'],
env_results['reward'],
Expand Down
Loading