Skip to content

Commit

Permalink
smart prompt longest prefix matching to avoid sending the same text t…
Browse files Browse the repository at this point in the history
…hrough the NN again. speeds up prefill significantly
  • Loading branch information
AlexCheema committed Jul 31, 2024
1 parent 94ac946 commit 5c67e24
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 5 deletions.
23 changes: 21 additions & 2 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from aiohttp import web
import aiohttp_cors
from exo import DEBUG, VERSION
from exo.helpers import terminal_link
from exo.helpers import terminal_link, PrefixDict
from exo.inference.shard import Shard
from exo.orchestration import Node

Expand Down Expand Up @@ -49,6 +49,7 @@
}



class Message:
def __init__(self, role: str, content: Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]):
self.role = role
Expand Down Expand Up @@ -234,13 +235,19 @@ def parse_chat_request(data: dict):
data.get("temperature", 0.0),
)

class PromptSession:
def __init__(self, request_id: str, timestamp: int, prompt: str):
self.request_id = request_id
self.timestamp = timestamp
self.prompt = prompt

class ChatGPTAPI:
def __init__(self, node: Node, inference_engine_classname: str, response_timeout_secs: int = 90):
self.node = node
self.inference_engine_classname = inference_engine_classname
self.response_timeout_secs = response_timeout_secs
self.app = web.Application(client_max_size=100 * 1024 * 1024) # 100MB to support image upload
self.prompts: PrefixDict[str, PromptSession] = PrefixDict()
self.prev_token_lens: Dict[str, int] = {}
self.stream_tasks: Dict[str, asyncio.Task] = {}
cors = aiohttp_cors.setup(self.app)
Expand Down Expand Up @@ -293,12 +300,24 @@ async def handle_post_chat_completions(self, request):
{"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
status=400,
)
request_id = str(uuid.uuid4())

tokenizer = await resolve_tokenizer(shard.model_id)
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")

prompt, image_str = build_prompt(tokenizer, chat_request.messages)
request_id = None
match = self.prompts.find_longest_prefix(prompt)
if match:
if DEBUG >= 2:
print(f"Prompt for request starts with previous prompt {len(match[1].prompt)} of {len(prompt)}: {match[1].prompt}")
request_id = match[1].request_id
self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))
# remove the matching prefix from the prompt
prompt = prompt[len(match[1].prompt):]
else:
request_id = str(uuid.uuid4())
self.prompts.add(prompt, PromptSession(request_id=request_id, timestamp=int(time.time()), prompt=prompt))

callback_id = f"chatgpt-api-wait-response-{request_id}"
callback = self.node.on_token.register(callback_id)

Expand Down
25 changes: 22 additions & 3 deletions exo/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import asyncio
from typing import Any, Callable, Coroutine, TypeVar, Optional, Dict, Generic, Tuple
from typing import Any, Callable, TypeVar, Optional, Dict, Generic, Tuple, List
from collections import defaultdict
import socket
import random
import platform
Expand Down Expand Up @@ -97,8 +98,6 @@ def terminal_link(uri, label=None):

T = TypeVar("T")
K = TypeVar("K")


class AsyncCallback(Generic[T]):
def __init__(self) -> None:
self.condition: asyncio.Condition = asyncio.Condition()
Expand Down Expand Up @@ -147,3 +146,23 @@ def trigger(self, name: K, *args: T) -> None:
def trigger_all(self, *args: T) -> None:
for callback in self.callbacks.values():
callback.set(*args)


K = TypeVar('K', bound=str)
V = TypeVar('V')
class PrefixDict(Generic[K, V]):
def __init__(self):
self.items: Dict[K, V] = {}

def add(self, key: K, value: V) -> None:
self.items[key] = value

def find_prefix(self, argument: str) -> List[Tuple[K, V]]:
return [(key, value) for key, value in self.items.items() if argument.startswith(key)]

def find_longest_prefix(self, argument: str) -> Optional[Tuple[K, V]]:
matches = self.find_prefix(argument)
if len(matches) == 0:
return None

return max(matches, key=lambda x: len(x[0]))

0 comments on commit 5c67e24

Please sign in to comment.