Skip to content

Commit

Permalink
Fixed bug: semaphore not releasing
Browse files Browse the repository at this point in the history
  • Loading branch information
ABR177 committed Aug 27, 2023
1 parent 41d666b commit 6ce743e
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 77 deletions.
54 changes: 18 additions & 36 deletions llama_api/modules/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ def llm_model(self) -> "LlamaCppModel":
assert self._llm_model is not None
return self._llm_model

@property
def eos_token(self) -> int:
assert self.client is not None, "Llama is not initialized"
try:
return self.client.token_eos()
except Exception:
return llama_cpp.llama_token_eos() # type: ignore

@classmethod
def from_pretrained(
cls, llm_model: "LlamaCppModel"
Expand Down Expand Up @@ -116,44 +124,17 @@ def generate_text(
assert client is not None, "Llama is not initialized"
self.llm_model.max_total_tokens = client.n_ctx()
assert client.ctx is not None, "Llama context is not initialized"
n_ctx = client.n_ctx()
tokens = (llama_cpp.llama_token * n_ctx)()
n_tokens = llama_cpp.llama_tokenize(
client.ctx,
b" " + prompt.encode("utf-8"),
tokens,
llama_cpp.c_int(n_ctx),
llama_cpp.c_bool(True),
)
if n_tokens < 0:
n_tokens = abs(n_tokens)
tokens = (llama_cpp.llama_token * n_tokens)()
n_tokens = llama_cpp.llama_tokenize(
client.ctx,
b" " + prompt.encode("utf-8"),
tokens,
llama_cpp.c_int(n_tokens),
llama_cpp.c_bool(True),
)
if n_tokens < 0:
raise RuntimeError(
f'Failed to tokenize: text="{prompt}" n_tokens={n_tokens}'
)
input_ids = array("i", tokens[:n_tokens]) # type: array[int]

input_ids = array(
"i",
client.tokenize(prompt.encode("utf-8"))
if prompt != ""
else [client.token_bos()],
) # type: array[int]
self.accept_settings(
prompt=prompt, prompt_tokens=len(input_ids), settings=settings
)
yield from self._generate_text(client, input_ids, settings)

@property
def eos_token(self) -> int:
assert self.client is not None, "Llama is not initialized"
try:
return self.client.token_eos()
except Exception:
return llama_cpp.llama_token_eos() # type: ignore

def _generate_text(
self,
client: llama_cpp.Llama,
Expand Down Expand Up @@ -223,9 +204,10 @@ def _generate_text(
),
):
# Check if the token is a stop token
if self.check_interruption(completion_status):
break
if token_id == eos_token_id:
if (
self.check_interruption(completion_status)
or token_id == eos_token_id
):
break

# Update the generated id
Expand Down
1 change: 0 additions & 1 deletion llama_api/schemas/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ class LlamaCppModel(BaseLLMModel):
},
)
use_mmap: bool = True # Whether to use memory mapping for the model.
streaming: bool = True # Whether to stream the results, token by token.
cache: bool = (
False # The size of the cache in bytes. Only used if cache is True.
)
Expand Down
84 changes: 47 additions & 37 deletions llama_api/server/routers/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
Use same format as OpenAI API"""


from asyncio import CancelledError, Task, create_task, gather, sleep
from contextlib import asynccontextmanager
from asyncio import Task, create_task, gather, sleep
from dataclasses import dataclass, field
from functools import partial
from queue import Queue
Expand All @@ -20,14 +19,16 @@
Union,
)

from anyio import Semaphore, create_memory_object_stream
from anyio import (
Semaphore,
create_memory_object_stream,
get_cancelled_exc_class,
)
from fastapi import APIRouter, Request
from fastapi.concurrency import iterate_in_threadpool, run_in_threadpool
from orjson import dumps
from sse_starlette.sse import EventSourceResponse

from llama_api.shared.config import MainCliArgs

from ...schemas.api import (
ChatCompletion,
Completion,
Expand All @@ -38,6 +39,7 @@
ModelData,
ModelList,
)
from ...shared.config import MainCliArgs
from ...utils.concurrency import (
get_queue_and_event,
run_in_processpool_with_wix,
Expand All @@ -51,7 +53,6 @@
get_model_names,
)


logger = ApiLogger(__name__)
router = APIRouter(prefix="/v1", route_class=RouteErrorHandler)
max_workers = int(MainCliArgs.max_workers.value or 1)
Expand Down Expand Up @@ -93,11 +94,10 @@ def get_worker_rank(meta: WixMetadata, request_key: Optional[str]) -> int:
) # return the number of slots in use


@asynccontextmanager
async def get_wix_with_semaphore(
request: Request,
request_key: Optional[str] = None,
) -> AsyncGenerator[int, None]:
) -> int:
"""Get the worker index (wix) for the key and acquire the semaphore"""
global wix_metas

Expand All @@ -118,13 +118,13 @@ async def get_wix_with_semaphore(
wix_meta = wix_metas[choice(candidates)]

# Acquire the semaphore for the worker index (wix)
async with wix_meta.semaphore:
# If client is already gone, then ignore the request
if await request.is_disconnected():
return
# Reserve the worker, it is now processing the request
wix_meta.processed_key = request_key
yield wix_meta.wix
await wix_meta.semaphore.acquire()
# If client is already gone, then ignore the request
if await request.is_disconnected():
raise get_cancelled_exc_class()()
# Reserve the worker, it is now processing the request
wix_meta.processed_key = request_key
return wix_meta.wix


def validate_item_type(item: Any, type: Type[T]) -> T:
Expand Down Expand Up @@ -162,7 +162,8 @@ async def create_chat_completion_or_completion(
If the body is a chat completion, then create a chat completion.
If the body is a completion, then create a completion.
If streaming is enabled, then return an EventSourceResponse."""
async with get_wix_with_semaphore(request, body.model) as wix:
wix: int = await get_wix_with_semaphore(request, body.model)
try:
queue, interrupt_signal = get_queue_and_event()
task: "Task[None]" = create_task(
run_in_processpool_with_wix(
Expand Down Expand Up @@ -213,7 +214,7 @@ async def check_disconnection():
if task.done():
break
if await request.is_disconnected():
raise CancelledError("Request disconnected.")
raise get_cancelled_exc_class()()

try:
result, _ = await gather(
Expand All @@ -223,6 +224,9 @@ async def check_disconnection():
finally:
interrupt_signal.set()
task.cancel()
finally:
# Release the semaphore for the worker index (wix)
wix_metas[wix].semaphore.release()


@router.post("/chat/completions")
Expand All @@ -248,31 +252,34 @@ async def create_embedding(
if MainCliArgs.no_embed.value:
raise PermissionError("Embeddings endpoint is disabled")
assert body.model is not None, "Model is required"
async with get_wix_with_semaphore(request, body.model) as wix:
queue, interrupt_signal = get_queue_and_event()
task: Task["None"] = create_task(
run_in_processpool_with_wix(
partial(
generate_embeddings,
body=body,
queue=queue,
),
wix=wix,
)
wix: int = await get_wix_with_semaphore(request, body.model)
queue, interrupt_signal = get_queue_and_event()
task: Task["None"] = create_task(
run_in_processpool_with_wix(
partial(
generate_embeddings,
body=body,
queue=queue,
),
wix=wix,
)
try:
return validate_item_type(
await run_in_threadpool(queue.get),
type=dict, # type: ignore
)
finally:
interrupt_signal.set()
task.cancel()
)
try:
return validate_item_type(
await run_in_threadpool(queue.get),
type=dict, # type: ignore
)
finally:
# Release the semaphore for the worker index (wix)
interrupt_signal.set()
wix_metas[wix].semaphore.release()
task.cancel()


@router.get("/models")
async def get_models(request: Request) -> ModelList:
async with get_wix_with_semaphore(request) as wix:
wix: int = await get_wix_with_semaphore(request)
try:
return ModelList(
object="list",
data=[
Expand All @@ -288,3 +295,6 @@ async def get_models(request: Request) -> ModelList:
)
],
)
finally:
# Release the semaphore for the worker index (wix)
wix_metas[wix].semaphore.release()
6 changes: 3 additions & 3 deletions llama_api/utils/errors.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from asyncio import CancelledError
from functools import cached_property
from pathlib import Path
from re import Match, Pattern, compile
from typing import Callable, Coroutine, Dict, Optional, Tuple, Union
from anyio import get_cancelled_exc_class

from fastapi import Request, Response
from fastapi.responses import JSONResponse
Expand Down Expand Up @@ -221,9 +221,9 @@ async def custom_route_handler(self, request: Request) -> Response:
status_code=401,
)
return await super().get_route_handler()(request)
except CancelledError:
except get_cancelled_exc_class():
# Client has disconnected
return Response(status_code=499)
raise
except Exception as error:
json_body = await request.json()
try:
Expand Down
Empty file modified run_server.sh
100644 → 100755
Empty file.

0 comments on commit 6ce743e

Please sign in to comment.