Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions nemo_skills/code_execution/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ async def execute_code(
traceback_verbosity='plain', # could be plain, context, verbose, or minimal
) -> Tuple[Dict, str]:
traceback_verbosity = traceback_verbosity.capitalize()

if session_id is None and language == "ipython": # creating a new session with empty state
session_id = uuid.uuid4()
self.sessions[session_id] = []
Expand Down Expand Up @@ -267,7 +266,7 @@ def strip_ansi_codes(text):
output = {"process_status": "timeout", "stdout": "", "stderr": "Timed out\n"}
# removing last state to not re-execute code with errors
if session_id is not None:
if output['stderr'] or 'Traceback (most recent call last)' in output['stdout']:
if output['process_status'] != "completed":
self.sessions[session_id] = self.sessions[session_id][:-1]
return output, session_id

Expand Down
51 changes: 29 additions & 22 deletions nemo_skills/inference/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import abc
import logging
import os

import httpx
import litellm

Expand All @@ -38,6 +39,7 @@ class BaseModel:
ssh_key_path: Optional[str] = None - Path to the ssh key for tunneling.
Can also be specified through NEMO_SKILLS_SSH_KEY_PATH env var.
"""

# Litellm provider name
MODEL_PROVIDER = "openai"

Expand Down Expand Up @@ -94,18 +96,17 @@ def __init__(
api_key=api_key,
base_url=base_url,
)
httpx_limits = httpx.Limits(
max_keepalive_connections=2048, max_connections=2048
)
httpx_limits = httpx.Limits(max_keepalive_connections=2048, max_connections=2048)
litellm.client_session = httpx.Client(limits=httpx_limits)
litellm.aclient_session = httpx.AsyncClient(limits=httpx_limits)


def __del__(self):
if self._tunnel:
self._tunnel.stop()

def _maybe_apply_stop_phrase_removal(self, result: dict, remove_stop_phrases: bool, stop_phrases: list[str] | None) -> None:
def _maybe_apply_stop_phrase_removal(
self, result: dict, remove_stop_phrases: bool, stop_phrases: list[str] | None
) -> None:
if remove_stop_phrases:
result['generation'] = trim_after_stop_phrases(result['generation'], stop_phrases)

Expand All @@ -129,7 +130,7 @@ async def generate_async(
random_seed: int = None,
stop_phrases: list[str] | None = None,
top_logprobs: int | None = None,
timeout: float | int | None = 10000, # None is 10min
timeout: float | int | None = 10000, # None is 10min
remove_stop_phrases: bool = True,
stream: bool = False,
reasoning_effort: str | None = None,
Expand Down Expand Up @@ -170,10 +171,10 @@ async def generate_async(
result = self._parse_completion_response(response, include_response=include_response, **kwargs)
else:
raise TypeError(f"Unsupported prompt type: {type(prompt)}")

self._maybe_apply_stop_phrase_removal(result, remove_stop_phrases, stop_phrases)
return result

def generate_sync(
self,
prompt: str | list,
Expand All @@ -186,7 +187,7 @@ def generate_sync(
random_seed: int = None,
stop_phrases: list[str] | None = None,
top_logprobs: int | None = None,
timeout: float | int | None = 10000, # None is 10min
timeout: float | int | None = 10000, # None is 10min
remove_stop_phrases: bool = True,
stream: bool = False,
reasoning_effort: str | None = None,
Expand All @@ -213,7 +214,7 @@ def generate_sync(
'tools': tools,
'extra_body': extra_body,
}

if isinstance(prompt, list):
request_params = self._build_chat_request_params(messages=prompt, stream=stream, **kwargs)
response = litellm.completion(**request_params, **self.litellm_kwargs)
Expand All @@ -224,6 +225,7 @@ def generate_sync(

elif isinstance(prompt, str):
request_params = self._build_completion_request_params(prompt=prompt, stream=stream, **kwargs)
request_params['skip_special_tokens'] = False
response = litellm.text_completion(**request_params, **self.litellm_kwargs)
if stream:
result = self._stream_completion_chunks_sync(response)
Expand All @@ -235,7 +237,9 @@ def generate_sync(
self._maybe_apply_stop_phrase_removal(result, remove_stop_phrases, stop_phrases)
return result

def _parse_completion_response(self, response: "openai.types.Completion", include_response: bool = False, **kwargs) -> dict:
def _parse_completion_response(
self, response: "openai.types.Completion", include_response: bool = False, **kwargs
) -> dict:
choice = response.choices[0]
output = choice.text
if output is None:
Expand Down Expand Up @@ -268,11 +272,11 @@ def _parse_chat_completion_response(self, response, include_response: bool = Fal
if output is None:
output = ""
result = {'generation': output, 'num_generated_tokens': response.usage.completion_tokens}

# Add reasoning_content if available
if hasattr(choice.message, 'reasoning_content') and choice.message.reasoning_content:
result['reasoning_content'] = choice.message.reasoning_content

if getattr(choice, 'logprobs', None) and choice.logprobs.content:
result['logprobs'] = [tok.logprob for tok in choice.logprobs.content]
result['tokens'] = [tok.token for tok in choice.logprobs.content]
Expand All @@ -295,20 +299,20 @@ def _process_completion_chunk(self, chunk, emitted_so_far: list):
"""Process a single completion chunk and return data to yield."""
cur_delta = chunk.choices[0].text
emitted_so_far.append(cur_delta)

results_to_yield = []
if cur_delta:
results_to_yield.append({"generation": cur_delta})

# vllm variant
stop_reason = getattr(chunk.choices[0], "stop_reason", None)
# sglang variant
matched_stop = getattr(chunk.choices[0], "matched_stop", None)

# vllm variant - emit stop_reason as is and finish
if stop_reason and isinstance(stop_reason, str):
results_to_yield.append({"generation": stop_reason})

# sglang variant - emit only not-yet-sent part of matched_stop
if matched_stop and isinstance(matched_stop, str):
remaining = matched_stop
Expand All @@ -322,26 +326,30 @@ def _process_completion_chunk(self, chunk, emitted_so_far: list):
break
if remaining:
results_to_yield.append({"generation": remaining})

return results_to_yield

def _process_chat_chunk(self, chunk):
"""Process a single chat chunk and return data to yield."""
if hasattr(chunk.choices[0], "delta"):
cur_delta = chunk.choices[0].delta.content
# Check for reasoning_content in delta
reasoning_delta = getattr(chunk.choices[0].delta, 'reasoning_content', None) if hasattr(chunk.choices[0].delta, 'reasoning_content') else None
reasoning_delta = (
getattr(chunk.choices[0].delta, 'reasoning_content', None)
if hasattr(chunk.choices[0].delta, 'reasoning_content')
else None
)
else:
cur_delta = chunk.choices[0].text
reasoning_delta = None

finish_reason = getattr(chunk.choices[0], "finish_reason", None)
result = {"generation": cur_delta}

# Add reasoning_content to result if available
if reasoning_delta:
result["reasoning_content"] = reasoning_delta

if finish_reason:
result["finish_reason"] = finish_reason
if not cur_delta:
Expand Down Expand Up @@ -378,4 +386,3 @@ async def _stream_chat_chunks_async(self, response):
results = self._process_chat_chunk(chunk)
for result in results:
yield result

16 changes: 8 additions & 8 deletions nemo_skills/inference/model/code_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ async def _generate_single(
if request['timeout'] <= 0:
break

output_dict = self.model.generate_sync(**request, remove_stop_phrases=False)
output_dict = await self.model.generate_async(**request, remove_stop_phrases=False)

output, num_generated_tokens = output_dict['generation'], output_dict.get('num_generated_tokens', 0)
# no need to do anything with this as the code below should just exit, so that's only for logging
Expand Down Expand Up @@ -175,7 +175,7 @@ async def _generate_single(
# .rfind(code_end, 0, -1) searches for the second-to-last occurrence of code_end and checks
# that the last code_begin is not closed to ensure that we are inside the code block
if output.endswith(code_end) and output.rfind(code_begin) > output.rfind(code_end, 0, -1):
code_execution_time_start, execution_dict = await self.execute_generated_code(
code_execution_time_start, execution_dict, session_id = await self.execute_generated_code(
prompt, code_begin, code_end, output, session_id
)
remaining_code_executions = None
Expand Down Expand Up @@ -225,7 +225,7 @@ async def execute_generated_code(self, input_prompt, code_begin, code_end, outpu
traceback_verbosity=self.config.sandbox_traceback_verbosity,
)

return code_execution_time_start, execution_dict
return code_execution_time_start, execution_dict, session_id

async def generate_async(
self,
Expand Down Expand Up @@ -256,7 +256,7 @@ async def generate_async(
"""
if top_logprobs is not None: # TODO: add this
raise NotImplementedError("top_logprobs is not supported yet.")

kwargs = {
'code_begin': code_begin,
'code_end': code_end,
Expand All @@ -276,15 +276,15 @@ async def generate_async(
"stream": stream,
"extra_body": extra_body,
}

request = {key: value for key, value in kwargs.items()}
request['prompt'] = prompt

output = await self._generate_single(**request)
self.model._maybe_apply_stop_phrase_removal(output, remove_stop_phrases, stop_phrases)

return output

async def _stream_single(
self,
prompt: str,
Expand Down