Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from ._terminations import MaxMessageTermination, StopMessageTermination, TextMentionTermination
from ._terminations import MaxMessageTermination, StopMessageTermination, TextMentionTermination, TokenUsageTermination

__all__ = [
"MaxMessageTermination",
"TextMentionTermination",
"StopMessageTermination",
"TokenUsageTermination",
]
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,59 @@ async def __call__(self, messages: Sequence[ChatMessage]) -> StopMessage | None:

async def reset(self) -> None:
self._terminated = False


class TokenUsageTermination(TerminationCondition):
"""Terminate the conversation if a token usage limit is reached.

Args:
max_total_token: The maximum total number of tokens allowed in the conversation.
max_prompt_token: The maximum number of prompt tokens allowed in the conversation.
max_completion_token: The maximum number of completion tokens allowed in the conversation.

Raises:
ValueError: If none of max_total_token, max_prompt_token, or max_completion_token is provided.
"""

def __init__(
self,
max_total_token: int | None = None,
max_prompt_token: int | None = None,
max_completion_token: int | None = None,
) -> None:
if max_total_token is None and max_prompt_token is None and max_completion_token is None:
raise ValueError(
"At least one of max_total_token, max_prompt_token, or max_completion_token must be provided"
)
self._max_total_token = max_total_token
self._max_prompt_token = max_prompt_token
self._max_completion_token = max_completion_token
self._total_token_count = 0
self._prompt_token_count = 0
self._completion_token_count = 0

@property
def terminated(self) -> bool:
return (
(self._max_total_token is not None and self._total_token_count >= self._max_total_token)
or (self._max_prompt_token is not None and self._prompt_token_count >= self._max_prompt_token)
or (self._max_completion_token is not None and self._completion_token_count >= self._max_completion_token)
)

async def __call__(self, messages: Sequence[ChatMessage]) -> StopMessage | None:
if self.terminated:
raise TerminatedException("Termination condition has already been reached")
for message in messages:
if message.model_usage is not None:
self._prompt_token_count += message.model_usage.prompt_tokens
self._completion_token_count += message.model_usage.completion_tokens
self._total_token_count += message.model_usage.prompt_tokens + message.model_usage.completion_tokens
if self.terminated:
content = f"Token usage limit reached, total token count: {self._total_token_count}, prompt token count: {self._prompt_token_count}, completion token count: {self._completion_token_count}."
return StopMessage(content=content, source="TokenUsageTermination")
return None

async def reset(self) -> None:
self._total_token_count = 0
self._prompt_token_count = 0
self._completion_token_count = 0
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import pytest
from autogen_agentchat.messages import StopMessage, TextMessage
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination, TextMentionTermination
from autogen_agentchat.task import (
MaxMessageTermination,
StopMessageTermination,
TextMentionTermination,
TokenUsageTermination,
)
from autogen_core.components.models import RequestUsage


@pytest.mark.asyncio
Expand Down Expand Up @@ -51,6 +57,51 @@ async def test_mention_termination() -> None:
)


@pytest.mark.asyncio
async def test_token_usage_termination() -> None:
termination = TokenUsageTermination(max_total_token=10)
assert await termination([]) is None
await termination.reset()
assert (
await termination(
[
TextMessage(
content="Hello", source="user", model_usage=RequestUsage(prompt_tokens=10, completion_tokens=10)
)
]
)
is not None
)
await termination.reset()
assert (
await termination(
[
TextMessage(
content="Hello", source="user", model_usage=RequestUsage(prompt_tokens=1, completion_tokens=1)
),
TextMessage(
content="World", source="agent", model_usage=RequestUsage(prompt_tokens=1, completion_tokens=1)
),
]
)
is None
)
await termination.reset()
assert (
await termination(
[
TextMessage(
content="Hello", source="user", model_usage=RequestUsage(prompt_tokens=5, completion_tokens=0)
),
TextMessage(
content="stop", source="user", model_usage=RequestUsage(prompt_tokens=0, completion_tokens=5)
),
]
)
is not None
)


@pytest.mark.asyncio
async def test_and_termination() -> None:
termination = MaxMessageTermination(2) & TextMentionTermination("stop")
Expand Down
Loading