diff --git a/nemo_skills/inference/generate.py b/nemo_skills/inference/generate.py index 8498287731..91cb2a08c8 100644 --- a/nemo_skills/inference/generate.py +++ b/nemo_skills/inference/generate.py @@ -397,6 +397,12 @@ def setup_llm(self): # We don't want to override these key variables which overlap with self.cfg inference_override_config = { "remove_thinking": self.cfg.parallel_thinking.remove_thinking, # Removing thinking from solutions is important for parallel_thinking. We don't want to override this with the main generation config + "endpoint_type": self.cfg.parallel_thinking.endpoint_type, + # The following are specific to parallel thinking and we want to defend against any future key overlaps with the main generation config + "mode": self.cfg.parallel_thinking.mode, + "window_size": self.cfg.parallel_thinking.window_size, + "solution_key": self.cfg.parallel_thinking.solution_key, + "filter_incomplete_solutions": self.cfg.parallel_thinking.filter_incomplete_solutions, } llm = get_parallel_thinking_model( diff --git a/nemo_skills/inference/model/parallel_thinking.py b/nemo_skills/inference/model/parallel_thinking.py index 959f22b05a..a5eb191bbb 100644 --- a/nemo_skills/inference/model/parallel_thinking.py +++ b/nemo_skills/inference/model/parallel_thinking.py @@ -24,7 +24,9 @@ from dataclasses import field from typing import Dict, List, Optional, Union -from nemo_skills.prompt.utils import get_prompt +from transformers import AutoTokenizer + +from nemo_skills.prompt.utils import get_prompt, get_token_count from nemo_skills.utils import get_logger_name, nested_dataclass, remove_thinking from .base import BaseModel, EndpointType @@ -52,11 +54,14 @@ class ParallelThinkingConfig: remove_thinking: bool = True # Remove thinking tokens from the solution key thinking_begin: str = "" thinking_end: str = "" - endpoint_type: EndpointType = EndpointType.chat + endpoint_type: EndpointType = EndpointType.text tokenizer: str | None = None chat_template_kwargs: dict | None = None # extra parameters to pass to the tokenizer's apply_chat_template method start_assistant_response_key: str | None = None # whether to start assistant response with this key + # Count the number of tokens in the prompt + count_prompt_tokens: bool = False + # GenSelect vs GenSynthesis mode: str | None = None # genselect or gensynthesis @@ -98,6 +103,11 @@ def __init__(self, model: BaseModel, tokenizer: str | None, orig_prompt_filler, else: raise ValueError(f"Invalid parallel thinking mode: {self.cfg.mode}") + if self.cfg.count_prompt_tokens: + self.hf_tokenizer = AutoTokenizer.from_pretrained(self.tokenizer) + if self.hf_tokenizer is None: + raise ValueError("Tokenizer could not be initialized. Needed for counting prompt tokens.") + # Initialize the solutions if input_dir is provided if self.cfg.generation_dir is not None: LOG.info("Loading solutions from %s", self.cfg.generation_dir) @@ -188,6 +198,46 @@ def _load_solutions(self, input_dir: str) -> Dict[str, List[Dict]]: return prompt_to_solutions_dict + async def _get_multiple_solutions( + self, prompt: Union[str, List], local_random: random.Random, **kwargs + ) -> tuple[List[Dict], int]: + """Return multiple solutions for the input prompt.""" + if self.cfg.generation_dir is not None: + # Already have the solutions in the input directory + # Hashing the prompt to get the key for the solutions + solutions = self.prompt_to_solutions_dict[self.hash_prompt(prompt)] + local_random.shuffle(solutions) + # After shuffling, only take the first window_size solutions + solutions = solutions[: self.cfg.window_size] + else: + # Generate the solutions first + solutions = await self.generate_solutions(prompt, local_random, **kwargs) + + # Filter out incomplete solutions if specified + if self.cfg.filter_incomplete_solutions: + # Remove unfinished solutions + filtered_solutions = [] + for solution in solutions: + # Check if thinking_begin is in the solution and thinking_end is not in the solution + if ( + self.cfg.thinking_begin in solution[self.cfg.solution_key] + and self.cfg.thinking_end not in solution[self.cfg.solution_key] + ): + continue + else: + filtered_solutions.append(solution) + + if len(filtered_solutions) < len(solutions): + LOG.info(f"Filtered out {len(solutions) - len(filtered_solutions)} incomplete solutions") + + solutions = filtered_solutions + + total_num_generated_tokens = 0 + for solution in solutions: + total_num_generated_tokens += solution["output_dict"].get("num_generated_tokens", 0) + + return solutions, total_num_generated_tokens + async def _generate_parallel_thinking_contraction( self, prompt: Union[str, List], solutions: List[Dict], **kwargs ) -> Dict: @@ -214,16 +264,24 @@ async def _generate_parallel_thinking_contraction( chat_template_kwargs=self.cfg.chat_template_kwargs, ) - for duplicate_key in ["temperature", "tokens_to_generate", "prompt"]: + output_dict = {} + if self.cfg.count_prompt_tokens: + num_input_tokens = get_token_count(tokenizer=self.hf_tokenizer, messages=parallel_thinking_prompt) + output_dict["num_input_tokens"] = num_input_tokens + + for duplicate_key in ["temperature", "tokens_to_generate", "prompt", "endpoint_type"]: kwargs.pop(duplicate_key, None) - return await self.model.generate_async( - prompt=parallel_thinking_prompt, - # Overriding the tokens_to_generate, temperature - tokens_to_generate=self.cfg.tokens_to_generate, - temperature=self.cfg.temperature, - **kwargs, + output_dict.update( + await self.model.generate_async( + prompt=parallel_thinking_prompt, + # Overriding the tokens_to_generate, temperature + tokens_to_generate=self.cfg.tokens_to_generate, + temperature=self.cfg.temperature, + **kwargs, + ) ) + return output_dict def _extract_selected_solution(self, generation: str, max_idx: int) -> Optional[int]: """Extract the selected solutions index from the GenSelect generation.""" @@ -298,46 +356,6 @@ async def _run_gensynthesis( "parallel_thinking_result": gensynthesis_result, } - async def _get_multiple_solutions( - self, prompt: Union[str, List], local_random: random.Random, **kwargs - ) -> tuple[List[Dict], int]: - """Return multiple solutions for the input prompt.""" - if self.cfg.generation_dir is not None: - # Already have the solutions in the input directory - # Hashing the prompt to get the key for the solutions - solutions = self.prompt_to_solutions_dict[self.hash_prompt(prompt)] - local_random.shuffle(solutions) - # After shuffling, only take the first window_size solutions - solutions = solutions[: self.cfg.window_size] - else: - # Generate the solutions first - solutions = await self.generate_solutions(prompt, local_random, **kwargs) - - # Filter out incomplete solutions if specified - if self.cfg.filter_incomplete_solutions: - # Remove unfinished solutions - filtered_solutions = [] - for solution in solutions: - # Check if thinking_begin is in the solution and thinking_end is not in the solution - if ( - self.cfg.thinking_begin in solution[self.cfg.solution_key] - and self.cfg.thinking_end not in solution[self.cfg.solution_key] - ): - continue - else: - filtered_solutions.append(solution) - - if len(filtered_solutions) < len(solutions): - LOG.info(f"Filtered out {len(solutions) - len(filtered_solutions)} incomplete solutions") - - solutions = filtered_solutions - - total_num_generated_tokens = 0 - for solution in solutions: - total_num_generated_tokens += solution["output_dict"].get("num_generated_tokens", 0) - - return solutions, total_num_generated_tokens - async def generate_async(self, prompt: Union[str, List], **kwargs): """Generate a single solution using parallel thinking.""" @@ -349,9 +367,8 @@ async def generate_async(self, prompt: Union[str, List], **kwargs): result["total_solution_generated_tokens"] = total_num_generated_tokens if not solutions: - return { + output_dict = { self.cfg.solution_key: "", - "generation": "", # Required by inference/generate.py "solution_list": [], f"{self.cfg.mode}_comparison": "", f"{self.cfg.mode}_num_generated_tokens": 0, @@ -360,6 +377,12 @@ async def generate_async(self, prompt: Union[str, List], **kwargs): "num_best_solution_generated_tokens": 0, } + # Required by inference/generate.py + output_dict["generation"] = "" + if self.cfg.count_prompt_tokens: + # The input doesn't make sense for such cases where there are no solutions + output_dict["num_input_tokens"] = None + # Step 2: Run GenSelect/GenSynthesis if self.cfg.mode == "genselect": output_dict = await self._run_genselect(prompt, solutions, local_random, **kwargs) @@ -382,6 +405,8 @@ async def generate_async(self, prompt: Union[str, List], **kwargs): # TODO: Decide what count of generated tokens do we want to report - the total or the best solution? # Current implementation returns the total number of generated tokens result["num_generated_tokens"] = total_gen_tokens + if self.cfg.count_prompt_tokens: + result["num_input_tokens"] = parallel_thinking_result["num_input_tokens"] result[self.cfg.solution_key] = output_dict[self.cfg.solution_key] result["solution_list"] = [solution[self.cfg.solution_key] for solution in solutions]