Skip to content

Commit

Permalink
🪙 Calculate Max Tokens (#867)
Browse files Browse the repository at this point in the history
* 🪙 Calculate Max tokens

* 🐛 Fix mypy issues
* 🐛 Fix prompt token calc
* 🐛 Fix tests
* 🐛 Fix tests
  • Loading branch information
awtkns committed Jun 27, 2023
1 parent 16657f4 commit 66aabcf
Show file tree
Hide file tree
Showing 10 changed files with 96 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
from reworkd_platform.services.tokenizer.service import TokenService


def get_tokenizer(request: Request) -> TokenService:
def get_token_service(request: Request) -> TokenService:
return TokenService(request.app.state.token_encoding)
2 changes: 1 addition & 1 deletion platform/reworkd_platform/services/tokenizer/lifetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
ENCODING_NAME = "cl100k_base" # gpt-4, gpt-3.5-turbo, text-embedding-ada-002


async def init_tokenizer(app: FastAPI) -> None: # pragma: no cover
def init_tokenizer(app: FastAPI) -> None: # pragma: no cover
"""
Initialize tokenizer.
Expand Down
9 changes: 1 addition & 8 deletions platform/reworkd_platform/services/tokenizer/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,5 @@ def tokenize(self, text: str) -> list[int]:
def detokenize(self, tokens: list[int]) -> str:
return self.encoding.decode(tokens)

def token_count(self, text: str) -> int:
def count(self, text: str) -> int:
return len(self.tokenize(text))

def get_context_space(
self, prompt: str, max_tokens: int, reserved_response_tokens: int
) -> int:
prompt_tokens = self.tokenize(prompt)

return max_tokens - len(prompt_tokens) - reserved_response_tokens
44 changes: 26 additions & 18 deletions platform/reworkd_platform/tests/agent/agent/create_model_test.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,35 @@
import itertools

import pytest

from reworkd_platform.schemas import ModelSettings
from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import (
create_model,
)
from reworkd_platform.schemas import ModelSettings, UserBase
from reworkd_platform.web.api.agent.model_settings import create_model


@pytest.mark.parametrize(
"settings",
[
ModelSettings(
customTemperature=0.222,
customModelName="gpt-4",
maxTokens=1234,
),
ModelSettings(),
],
"settings, streaming",
list(
itertools.product(
[
ModelSettings(
customTemperature=0.222,
customModelName="gpt-4",
maxTokens=1234,
),
ModelSettings(),
],
[True, False],
)
),
)
def test_create_model(
settings: ModelSettings,
):
model = create_model(settings)
def test_create_model(settings: ModelSettings, streaming: bool):
model = create_model(
settings,
UserBase(id="", email="[email protected]"),
streaming=streaming,
)

assert model.temperature == settings.temperature
assert model.model_name == settings.model
assert model.model_name.startswith(settings.model)
assert model.max_tokens == settings.max_tokens
assert model.streaming == streaming
1 change: 1 addition & 0 deletions platform/reworkd_platform/tests/test_reworkd_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from starlette import status


@pytest.mark.skip(reason="Mysql needs to be mocked")
@pytest.mark.anyio
async def test_health(client: AsyncClient, fastapi_app: FastAPI) -> None:
"""
Expand Down
11 changes: 1 addition & 10 deletions platform/reworkd_platform/tests/test_token_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,8 @@ def test_nothing():
validate_tokenize_and_detokenize(service, text, 0)


def test_context_space():
prompt = "You're a wizard, Harry. Write a book based on the context below:"
max_tokens = 800

service = TokenService(encoding)
get_context_space = service.get_context_space(prompt, max_tokens, 500)
assert 0 < get_context_space < 800 - 500


def validate_tokenize_and_detokenize(service, text, expected_token_count):
tokens = service.tokenize(text)
assert text == service.detokenize(tokens)
assert len(tokens) == service.token_count(text)
assert len(tokens) == service.count(text)
assert len(tokens) == expected_token_count
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from fastapi import Depends

from reworkd_platform.schemas import AgentRun, UserBase
from reworkd_platform.services.tokenizer.dependencies import get_token_service
from reworkd_platform.services.tokenizer.service import TokenService
from reworkd_platform.settings import settings
from reworkd_platform.web.api.agent.agent_service.agent_service import AgentService
from reworkd_platform.web.api.agent.agent_service.mock_agent_service import (
Expand All @@ -27,11 +29,14 @@ def func(
run: AgentRun = Depends(validator),
user: UserBase = Depends(get_current_user),
agent_memory: AgentMemory = Depends(get_agent_memory),
token_service: TokenService = Depends(get_token_service),
) -> AgentService:
if settings.ff_mock_mode_enabled:
return MockAgentService()

model = create_model(run.model_settings, user, streaming=streaming)
return OpenAIAgentService(model, run.model_settings.language, agent_memory)
return OpenAIAgentService(
model, run.model_settings.language, agent_memory, token_service
)

return func
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from typing import List, Optional

from lanarky.responses import StreamingResponse
from langchain.chat_models.base import BaseChatModel
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate
from loguru import logger
from pydantic import ValidationError

from reworkd_platform.services.tokenizer.service import TokenService
from reworkd_platform.web.api.agent.agent_service.agent_service import AgentService
from reworkd_platform.web.api.agent.analysis import Analysis, AnalysisArguments
from reworkd_platform.web.api.agent.helpers import (
call_model_with_handling,
openai_error_handler,
parse_with_handling,
)
from reworkd_platform.web.api.agent.model_settings import WrappedChatOpenAI
from reworkd_platform.web.api.agent.prompts import (
analyze_task_prompt,
create_tasks_prompt,
Expand All @@ -34,15 +35,28 @@
class OpenAIAgentService(AgentService):
def __init__(
self,
model: BaseChatModel,
model: WrappedChatOpenAI,
language: str,
agent_memory: AgentMemory,
token_service: TokenService,
):
self.model = model
self.agent_memory = agent_memory
self.language = language
self.token_service = token_service

async def start_goal_agent(self, *, goal: str) -> List[str]:
prompt = ChatPromptTemplate.from_messages(
[SystemMessagePromptTemplate(prompt=start_goal_prompt)]
)

self.model.max_tokens -= self.token_service.count(
prompt.format_prompt(
goal=goal,
language=self.language,
).to_string(),
)

completion = await call_model_with_handling(
self.model,
ChatPromptTemplate.from_messages(
Expand All @@ -63,14 +77,22 @@ async def start_goal_agent(self, *, goal: str) -> List[str]:
async def analyze_task_agent(
self, *, goal: str, task: str, tool_names: List[str]
) -> Analysis:
functions = list(map(get_tool_function, get_user_tools(tool_names)))
prompt = analyze_task_prompt.format_prompt(
goal=goal,
task=task,
language=self.language,
)

self.model.max_tokens -= self.token_service.count(prompt.to_string())

# TODO: We are over counting tokens here
self.model.max_tokens -= self.token_service.count(str(functions))

message = await openai_error_handler(
func=self.model.apredict_messages,
messages=analyze_task_prompt.format_prompt(
goal=goal,
task=task,
language=self.language,
).to_messages(),
functions=list(map(get_tool_function, get_user_tools(tool_names))),
messages=prompt.to_messages(),
functions=functions,
)

function_call = message.additional_kwargs.get("function_call", {})
Expand All @@ -93,7 +115,9 @@ async def execute_task_agent(
task: str,
analysis: Analysis,
) -> StreamingResponse:
print("Execution analysis:", analysis)
# TODO: More mature way of calculating max_tokens
if self.model.max_tokens > 3000:
self.model.max_tokens = max(self.model.max_tokens - 1000, 3000)

tool_class = get_tool_from_name(analysis.action)
return await tool_class(self.model, self.language).call(
Expand All @@ -109,18 +133,26 @@ async def create_tasks_agent(
result: str,
completed_tasks: Optional[List[str]] = None,
) -> List[str]:
prompt = ChatPromptTemplate.from_messages(
[SystemMessagePromptTemplate(prompt=create_tasks_prompt)]
)

args = {
"goal": goal,
"language": self.language,
"tasks": "\n".join(tasks),
"lastTask": last_task,
"result": result,
}

self.model.max_tokens -= self.token_service.count(
prompt.format_prompt(**args).to_string(),
)

completion = await call_model_with_handling(
self.model,
ChatPromptTemplate.from_messages(
[SystemMessagePromptTemplate(prompt=create_tasks_prompt)]
),
{
"goal": goal,
"language": self.language,
"tasks": "\n".join(tasks),
"lastTask": last_task,
"result": result,
},
prompt,
args,
)

previous_tasks = (completed_tasks or []) + tasks
Expand Down
8 changes: 6 additions & 2 deletions platform/reworkd_platform/web/api/agent/model_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
openai.api_base = settings.openai_api_base


class WrappedChatOpenAI(ChatOpenAI):
max_tokens: int


def create_model(
model_settings: ModelSettings, user: UserBase, streaming: bool = False
) -> ChatOpenAI:
) -> WrappedChatOpenAI:
if model_settings.custom_api_key != "":
api_key = model_settings.custom_api_key
else:
Expand All @@ -20,7 +24,7 @@ def create_model(
model=model_settings.model,
)

return ChatOpenAI(
return WrappedChatOpenAI(
client=None, # Meta private value but mypy will complain its missing
openai_api_key=api_key,
temperature=model_settings.temperature,
Expand Down
2 changes: 2 additions & 0 deletions platform/reworkd_platform/web/lifetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from reworkd_platform.db.models import load_all_models
from reworkd_platform.db.utils import create_engine
from reworkd_platform.services.pinecone.lifetime import init_pinecone
from reworkd_platform.services.tokenizer.lifetime import init_tokenizer
from reworkd_platform.services.vecs.lifetime import (
init_supabase_vecs,
shutdown_supabase_vecs,
Expand Down Expand Up @@ -59,6 +60,7 @@ def register_startup_event(
async def _startup() -> None: # noqa: WPS430
_setup_db(app)
init_pinecone()
init_tokenizer(app)
init_supabase_vecs(
app
) # create pg_connection connection pool at startup as its expensive
Expand Down

1 comment on commit 66aabcf

@vercel
Copy link

@vercel vercel bot commented on 66aabcf Jun 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

docs – ./docs

docs.reworkd.ai
docs-reworkd.vercel.app
docs-git-main-reworkd.vercel.app

Please sign in to comment.