Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions nemo_skills/inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,12 @@ def setup_llm(self):
# We don't want to override these key variables which overlap with self.cfg
inference_override_config = {
"remove_thinking": self.cfg.parallel_thinking.remove_thinking, # Removing thinking from solutions is important for parallel_thinking. We don't want to override this with the main generation config
"endpoint_type": self.cfg.parallel_thinking.endpoint_type,
# The following are specific to parallel thinking and we want to defend against any future key overlaps with the main generation config
"mode": self.cfg.parallel_thinking.mode,
"window_size": self.cfg.parallel_thinking.window_size,
"solution_key": self.cfg.parallel_thinking.solution_key,
"filter_incomplete_solutions": self.cfg.parallel_thinking.filter_incomplete_solutions,
}

llm = get_parallel_thinking_model(
Expand Down
127 changes: 76 additions & 51 deletions nemo_skills/inference/model/parallel_thinking.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
from dataclasses import field
from typing import Dict, List, Optional, Union

from nemo_skills.prompt.utils import get_prompt
from transformers import AutoTokenizer

from nemo_skills.prompt.utils import get_prompt, get_token_count
from nemo_skills.utils import get_logger_name, nested_dataclass, remove_thinking

from .base import BaseModel, EndpointType
Expand Down Expand Up @@ -52,11 +54,14 @@ class ParallelThinkingConfig:
remove_thinking: bool = True # Remove thinking tokens from the solution key
thinking_begin: str = "<think>"
thinking_end: str = "</think>"
endpoint_type: EndpointType = EndpointType.chat
endpoint_type: EndpointType = EndpointType.text
tokenizer: str | None = None
chat_template_kwargs: dict | None = None # extra parameters to pass to the tokenizer's apply_chat_template method
start_assistant_response_key: str | None = None # whether to start assistant response with this key

# Count the number of tokens in the prompt
count_prompt_tokens: bool = False

# GenSelect vs GenSynthesis
mode: str | None = None # genselect or gensynthesis

Expand Down Expand Up @@ -98,6 +103,11 @@ def __init__(self, model: BaseModel, tokenizer: str | None, orig_prompt_filler,
else:
raise ValueError(f"Invalid parallel thinking mode: {self.cfg.mode}")

if self.cfg.count_prompt_tokens:
self.hf_tokenizer = AutoTokenizer.from_pretrained(self.tokenizer)
if self.hf_tokenizer is None:
raise ValueError("Tokenizer could not be initialized. Needed for counting prompt tokens.")

# Initialize the solutions if input_dir is provided
if self.cfg.generation_dir is not None:
LOG.info("Loading solutions from %s", self.cfg.generation_dir)
Expand Down Expand Up @@ -188,6 +198,46 @@ def _load_solutions(self, input_dir: str) -> Dict[str, List[Dict]]:

return prompt_to_solutions_dict

async def _get_multiple_solutions(
self, prompt: Union[str, List], local_random: random.Random, **kwargs
) -> tuple[List[Dict], int]:
"""Return multiple solutions for the input prompt."""
if self.cfg.generation_dir is not None:
# Already have the solutions in the input directory
# Hashing the prompt to get the key for the solutions
solutions = self.prompt_to_solutions_dict[self.hash_prompt(prompt)]
local_random.shuffle(solutions)
# After shuffling, only take the first window_size solutions
solutions = solutions[: self.cfg.window_size]
else:
# Generate the solutions first
solutions = await self.generate_solutions(prompt, local_random, **kwargs)

# Filter out incomplete solutions if specified
if self.cfg.filter_incomplete_solutions:
# Remove unfinished solutions
filtered_solutions = []
for solution in solutions:
# Check if thinking_begin is in the solution and thinking_end is not in the solution
if (
self.cfg.thinking_begin in solution[self.cfg.solution_key]
and self.cfg.thinking_end not in solution[self.cfg.solution_key]
):
continue
else:
filtered_solutions.append(solution)

if len(filtered_solutions) < len(solutions):
LOG.info(f"Filtered out {len(solutions) - len(filtered_solutions)} incomplete solutions")

solutions = filtered_solutions

total_num_generated_tokens = 0
for solution in solutions:
total_num_generated_tokens += solution["output_dict"].get("num_generated_tokens", 0)

return solutions, total_num_generated_tokens

async def _generate_parallel_thinking_contraction(
self, prompt: Union[str, List], solutions: List[Dict], **kwargs
) -> Dict:
Expand All @@ -214,16 +264,24 @@ async def _generate_parallel_thinking_contraction(
chat_template_kwargs=self.cfg.chat_template_kwargs,
)

for duplicate_key in ["temperature", "tokens_to_generate", "prompt"]:
output_dict = {}
if self.cfg.count_prompt_tokens:
num_input_tokens = get_token_count(tokenizer=self.hf_tokenizer, messages=parallel_thinking_prompt)
output_dict["num_input_tokens"] = num_input_tokens

for duplicate_key in ["temperature", "tokens_to_generate", "prompt", "endpoint_type"]:
kwargs.pop(duplicate_key, None)

return await self.model.generate_async(
prompt=parallel_thinking_prompt,
# Overriding the tokens_to_generate, temperature
tokens_to_generate=self.cfg.tokens_to_generate,
temperature=self.cfg.temperature,
**kwargs,
output_dict.update(
await self.model.generate_async(
prompt=parallel_thinking_prompt,
# Overriding the tokens_to_generate, temperature
tokens_to_generate=self.cfg.tokens_to_generate,
temperature=self.cfg.temperature,
**kwargs,
)
)
return output_dict

def _extract_selected_solution(self, generation: str, max_idx: int) -> Optional[int]:
"""Extract the selected solutions index from the GenSelect generation."""
Expand Down Expand Up @@ -298,46 +356,6 @@ async def _run_gensynthesis(
"parallel_thinking_result": gensynthesis_result,
}

async def _get_multiple_solutions(
self, prompt: Union[str, List], local_random: random.Random, **kwargs
) -> tuple[List[Dict], int]:
"""Return multiple solutions for the input prompt."""
if self.cfg.generation_dir is not None:
# Already have the solutions in the input directory
# Hashing the prompt to get the key for the solutions
solutions = self.prompt_to_solutions_dict[self.hash_prompt(prompt)]
local_random.shuffle(solutions)
# After shuffling, only take the first window_size solutions
solutions = solutions[: self.cfg.window_size]
else:
# Generate the solutions first
solutions = await self.generate_solutions(prompt, local_random, **kwargs)

# Filter out incomplete solutions if specified
if self.cfg.filter_incomplete_solutions:
# Remove unfinished solutions
filtered_solutions = []
for solution in solutions:
# Check if thinking_begin is in the solution and thinking_end is not in the solution
if (
self.cfg.thinking_begin in solution[self.cfg.solution_key]
and self.cfg.thinking_end not in solution[self.cfg.solution_key]
):
continue
else:
filtered_solutions.append(solution)

if len(filtered_solutions) < len(solutions):
LOG.info(f"Filtered out {len(solutions) - len(filtered_solutions)} incomplete solutions")

solutions = filtered_solutions

total_num_generated_tokens = 0
for solution in solutions:
total_num_generated_tokens += solution["output_dict"].get("num_generated_tokens", 0)

return solutions, total_num_generated_tokens

async def generate_async(self, prompt: Union[str, List], **kwargs):
"""Generate a single solution using parallel thinking."""

Expand All @@ -349,9 +367,8 @@ async def generate_async(self, prompt: Union[str, List], **kwargs):
result["total_solution_generated_tokens"] = total_num_generated_tokens

if not solutions:
return {
output_dict = {
self.cfg.solution_key: "",
"generation": "", # Required by inference/generate.py
"solution_list": [],
f"{self.cfg.mode}_comparison": "",
f"{self.cfg.mode}_num_generated_tokens": 0,
Expand All @@ -360,6 +377,12 @@ async def generate_async(self, prompt: Union[str, List], **kwargs):
"num_best_solution_generated_tokens": 0,
}

# Required by inference/generate.py
output_dict["generation"] = ""
if self.cfg.count_prompt_tokens:
# The input doesn't make sense for such cases where there are no solutions
output_dict["num_input_tokens"] = None

# Step 2: Run GenSelect/GenSynthesis
if self.cfg.mode == "genselect":
output_dict = await self._run_genselect(prompt, solutions, local_random, **kwargs)
Expand All @@ -382,6 +405,8 @@ async def generate_async(self, prompt: Union[str, List], **kwargs):
# TODO: Decide what count of generated tokens do we want to report - the total or the best solution?
# Current implementation returns the total number of generated tokens
result["num_generated_tokens"] = total_gen_tokens
if self.cfg.count_prompt_tokens:
result["num_input_tokens"] = parallel_thinking_result["num_input_tokens"]

result[self.cfg.solution_key] = output_dict[self.cfg.solution_key]
result["solution_list"] = [solution[self.cfg.solution_key] for solution in solutions]
Expand Down