Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle streamed function calls #1118

Merged
merged 11 commits into from
Jan 8, 2024
34 changes: 27 additions & 7 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,22 +287,42 @@ def yes_or_no_filter(context, response):

def _completions_create(self, client, params):
completions = client.chat.completions if "messages" in params else client.completions
# If streaming is enabled, has messages, and does not have functions, then
# iterate over the chunks of the response
if params.get("stream", False) and "messages" in params and "functions" not in params:
# If streaming is enabled and has messages, then iterate over the chunks of the response.
if params.get("stream", False) and "messages" in params:
response_contents = [""] * params.get("n", 1)
finish_reasons = [""] * params.get("n", 1)
completion_tokens = 0

# Set the terminal text color to green
print("\033[32m", end="")

# Prepare for potential function call
full_function_call = None
# Send the chat completion request to OpenAI's API and process the response in chunks
for chunk in completions.create(**params):
if chunk.choices:
for choice in chunk.choices:
content = choice.delta.content
function_call_chunk = choice.delta.function_call
finish_reasons[choice.index] = choice.finish_reason

# Handle function call
if function_call_chunk:
if hasattr(function_call_chunk, "name") and function_call_chunk.name:
if full_function_call is None:
full_function_call = {"name": "", "arguments": ""}
full_function_call["name"] += function_call_chunk.name
completion_tokens += 1
bitnom marked this conversation as resolved.
Show resolved Hide resolved
if hasattr(function_call_chunk, "arguments") and function_call_chunk.arguments:
full_function_call["arguments"] += function_call_chunk.arguments
completion_tokens += 1
if choice.finish_reason == "function_call":
# Need something here? I don't think so.
pass
if not content:
continue
# End handle function call

# If content is present, print it to the terminal and update response variables
if content is not None:
print(content, end="", flush=True)
Expand Down Expand Up @@ -336,7 +356,7 @@ def _completions_create(self, client, params):
index=i,
finish_reason=finish_reasons[i],
message=ChatCompletionMessage(
role="assistant", content=response_contents[i], function_call=None
role="assistant", content=response_contents[i], function_call=full_function_call
),
logprobs=None,
)
Expand All @@ -346,17 +366,17 @@ def _completions_create(self, client, params):
index=i,
finish_reason=finish_reasons[i],
message=ChatCompletionMessage(
role="assistant", content=response_contents[i], function_call=None
role="assistant", content=response_contents[i], function_call=full_function_call
),
)

response.choices.append(choice)
else:
# If streaming is not enabled or using functions, send a regular chat completion request
# Functions are not supported, so ensure streaming is disabled
# If streaming is not enabled, send a regular chat completion request
params = params.copy()
params["stream"] = False
davorrunje marked this conversation as resolved.
Show resolved Hide resolved
response = completions.create(**params)

return response

def _update_usage_summary(self, response: ChatCompletion | Completion, use_cache: bool) -> None:
Expand Down
Loading