Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nemo_skills/inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def setup_llm(self):
orig_prompt_filler=self.fill_prompt, # Needed for prompt fillling
parallel_thinking=self.cfg.parallel_thinking,
main_config=self.cfg,
tokenizer=self.tokenizer,
inference_override_config=inference_override_config,
)

Expand Down
5 changes: 4 additions & 1 deletion nemo_skills/inference/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def get_parallel_thinking_model(
model,
orig_prompt_filler,
parallel_thinking: ParallelThinkingConfig = None,
tokenizer=None,
main_config=None,
inference_override_config=None,
):
Expand All @@ -89,7 +90,9 @@ def get_parallel_thinking_model(

parallel_thinking_config = ParallelThinkingConfig(**filtered_config)

return ParallelThinkingTask(model=model, orig_prompt_filler=orig_prompt_filler, cfg=parallel_thinking_config)
return ParallelThinkingTask(
model=model, tokenizer=tokenizer, orig_prompt_filler=orig_prompt_filler, cfg=parallel_thinking_config
)


def get_tool_calling_model(
Expand Down
29 changes: 16 additions & 13 deletions nemo_skills/inference/model/parallel_thinking.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class ParallelThinkingConfig:
use_completions_api: bool = False
tokenizer: str | None = None
chat_template_kwargs: dict | None = None # extra parameters to pass to the tokenizer's apply_chat_template method
start_assistant_response_key: str | None = None # whether to start assistant response with this key

# GenSelect vs GenSynthesis
mode: str | None = None # genselect or gensynthesis
Expand All @@ -78,24 +79,21 @@ class ParallelThinkingTask:
to choose the best one or synthesize a new solution.
"""

def __init__(self, model: BaseModel, orig_prompt_filler, cfg: ParallelThinkingConfig):
def __init__(self, model: BaseModel, tokenizer: str | None, orig_prompt_filler, cfg: ParallelThinkingConfig):
self.model = model
self.orig_prompt_filler = orig_prompt_filler
self.cfg = cfg

if self.cfg.use_completions_api:
tokenizer = self.cfg.tokenizer or self.model.model_name_or_path
else:
tokenizer = None
self.tokenizer = tokenizer

# Load GenSelect/GenSynthesis prompt
if self.cfg.mode == "genselect":
self.parallel_thinking_prompt = get_prompt(
prompt_config=self.cfg.genselect.prompt_config, tokenizer=tokenizer
prompt_config=self.cfg.genselect.prompt_config, tokenizer=self.tokenizer
)
elif self.cfg.mode == "gensynthesis":
self.parallel_thinking_prompt = get_prompt(
prompt_config=self.cfg.gensynthesis.prompt_config, tokenizer=tokenizer
prompt_config=self.cfg.gensynthesis.prompt_config, tokenizer=self.tokenizer
)
else:
raise ValueError(f"Invalid parallel thinking mode: {self.cfg.mode}")
Expand Down Expand Up @@ -190,9 +188,6 @@ def _load_solutions(self, input_dir: str) -> Dict[str, List[Dict]]:

return prompt_to_solutions_dict

def _format_solutions_for_parallel_thinking(self, solutions: List[Dict]) -> str:
"""Format solutions for parallel thinking prompt."""

async def _generate_parallel_thinking_contraction(
self, prompt: Union[str, List], solutions: List[Dict], **kwargs
) -> Dict:
Expand All @@ -213,14 +208,21 @@ async def _generate_parallel_thinking_contraction(
"max_idx": max_idx,
}

parallel_thinking_prompt = self.parallel_thinking_prompt.fill(parallel_thinking_input)
parallel_thinking_prompt = self.parallel_thinking_prompt.fill(
parallel_thinking_input,
start_assistant_response_key=self.cfg.start_assistant_response_key,
chat_template_kwargs=self.cfg.chat_template_kwargs,
)

for duplicate_key in ["temperature", "tokens_to_generate", "prompt"]:
kwargs.pop(duplicate_key, None)

return await self.model.generate_async(
**kwargs,
prompt=parallel_thinking_prompt,
# Overriding the tokens_to_generate, temperature
tokens_to_generate=self.cfg.tokens_to_generate,
temperature=self.cfg.temperature,
**kwargs,
)

def _extract_selected_solution(self, generation: str, max_idx: int) -> Optional[int]:
Expand Down Expand Up @@ -349,6 +351,7 @@ async def generate_async(self, prompt: Union[str, List], **kwargs):
if not solutions:
return {
self.cfg.solution_key: "",
"generation": "", # Required by inference/generate.py
"solution_list": [],
f"{self.cfg.mode}_comparison": "",
f"{self.cfg.mode}_num_generated_tokens": 0,
Expand All @@ -359,7 +362,7 @@ async def generate_async(self, prompt: Union[str, List], **kwargs):

# Step 2: Run GenSelect/GenSynthesis
if self.cfg.mode == "genselect":
output_dict = await self._run_genselect(prompt, solutions, local_random)
output_dict = await self._run_genselect(prompt, solutions, local_random, **kwargs)
parallel_thinking_result = output_dict["parallel_thinking_result"]
result["genselect_comparison"] = parallel_thinking_result["generation"]
result["genselect_selection_successful"] = parallel_thinking_result["selection_successful"]
Expand Down