diff --git a/nemo_skills/inference/generate.py b/nemo_skills/inference/generate.py index 1cc75fc182..fa60f9ccfa 100644 --- a/nemo_skills/inference/generate.py +++ b/nemo_skills/inference/generate.py @@ -378,6 +378,7 @@ def setup_llm(self): orig_prompt_filler=self.fill_prompt, # Needed for prompt fillling parallel_thinking=self.cfg.parallel_thinking, main_config=self.cfg, + tokenizer=self.tokenizer, inference_override_config=inference_override_config, ) diff --git a/nemo_skills/inference/model/__init__.py b/nemo_skills/inference/model/__init__.py index 6ecfaa6e25..4851d31f76 100644 --- a/nemo_skills/inference/model/__init__.py +++ b/nemo_skills/inference/model/__init__.py @@ -72,6 +72,7 @@ def get_parallel_thinking_model( model, orig_prompt_filler, parallel_thinking: ParallelThinkingConfig = None, + tokenizer=None, main_config=None, inference_override_config=None, ): @@ -89,7 +90,9 @@ def get_parallel_thinking_model( parallel_thinking_config = ParallelThinkingConfig(**filtered_config) - return ParallelThinkingTask(model=model, orig_prompt_filler=orig_prompt_filler, cfg=parallel_thinking_config) + return ParallelThinkingTask( + model=model, tokenizer=tokenizer, orig_prompt_filler=orig_prompt_filler, cfg=parallel_thinking_config + ) def get_tool_calling_model( diff --git a/nemo_skills/inference/model/parallel_thinking.py b/nemo_skills/inference/model/parallel_thinking.py index f75da72830..6a63713444 100644 --- a/nemo_skills/inference/model/parallel_thinking.py +++ b/nemo_skills/inference/model/parallel_thinking.py @@ -55,6 +55,7 @@ class ParallelThinkingConfig: use_completions_api: bool = False tokenizer: str | None = None chat_template_kwargs: dict | None = None # extra parameters to pass to the tokenizer's apply_chat_template method + start_assistant_response_key: str | None = None # whether to start assistant response with this key # GenSelect vs GenSynthesis mode: str | None = None # genselect or gensynthesis @@ -78,24 +79,21 @@ class ParallelThinkingTask: to choose the best one or synthesize a new solution. """ - def __init__(self, model: BaseModel, orig_prompt_filler, cfg: ParallelThinkingConfig): + def __init__(self, model: BaseModel, tokenizer: str | None, orig_prompt_filler, cfg: ParallelThinkingConfig): self.model = model self.orig_prompt_filler = orig_prompt_filler self.cfg = cfg - if self.cfg.use_completions_api: - tokenizer = self.cfg.tokenizer or self.model.model_name_or_path - else: - tokenizer = None + self.tokenizer = tokenizer # Load GenSelect/GenSynthesis prompt if self.cfg.mode == "genselect": self.parallel_thinking_prompt = get_prompt( - prompt_config=self.cfg.genselect.prompt_config, tokenizer=tokenizer + prompt_config=self.cfg.genselect.prompt_config, tokenizer=self.tokenizer ) elif self.cfg.mode == "gensynthesis": self.parallel_thinking_prompt = get_prompt( - prompt_config=self.cfg.gensynthesis.prompt_config, tokenizer=tokenizer + prompt_config=self.cfg.gensynthesis.prompt_config, tokenizer=self.tokenizer ) else: raise ValueError(f"Invalid parallel thinking mode: {self.cfg.mode}") @@ -190,9 +188,6 @@ def _load_solutions(self, input_dir: str) -> Dict[str, List[Dict]]: return prompt_to_solutions_dict - def _format_solutions_for_parallel_thinking(self, solutions: List[Dict]) -> str: - """Format solutions for parallel thinking prompt.""" - async def _generate_parallel_thinking_contraction( self, prompt: Union[str, List], solutions: List[Dict], **kwargs ) -> Dict: @@ -213,14 +208,21 @@ async def _generate_parallel_thinking_contraction( "max_idx": max_idx, } - parallel_thinking_prompt = self.parallel_thinking_prompt.fill(parallel_thinking_input) + parallel_thinking_prompt = self.parallel_thinking_prompt.fill( + parallel_thinking_input, + start_assistant_response_key=self.cfg.start_assistant_response_key, + chat_template_kwargs=self.cfg.chat_template_kwargs, + ) + + for duplicate_key in ["temperature", "tokens_to_generate", "prompt"]: + kwargs.pop(duplicate_key, None) return await self.model.generate_async( - **kwargs, prompt=parallel_thinking_prompt, # Overriding the tokens_to_generate, temperature tokens_to_generate=self.cfg.tokens_to_generate, temperature=self.cfg.temperature, + **kwargs, ) def _extract_selected_solution(self, generation: str, max_idx: int) -> Optional[int]: @@ -349,6 +351,7 @@ async def generate_async(self, prompt: Union[str, List], **kwargs): if not solutions: return { self.cfg.solution_key: "", + "generation": "", # Required by inference/generate.py "solution_list": [], f"{self.cfg.mode}_comparison": "", f"{self.cfg.mode}_num_generated_tokens": 0, @@ -359,7 +362,7 @@ async def generate_async(self, prompt: Union[str, List], **kwargs): # Step 2: Run GenSelect/GenSynthesis if self.cfg.mode == "genselect": - output_dict = await self._run_genselect(prompt, solutions, local_random) + output_dict = await self._run_genselect(prompt, solutions, local_random, **kwargs) parallel_thinking_result = output_dict["parallel_thinking_result"] result["genselect_comparison"] = parallel_thinking_result["generation"] result["genselect_selection_successful"] = parallel_thinking_result["selection_successful"]