Skip to content

Commit

Permalink
feat(server): add for timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
phodal committed Dec 20, 2023
1 parent dbdbc49 commit 1f9e53c
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 45 deletions.
100 changes: 56 additions & 44 deletions finetunes/deepseek/api-server-python38.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from pydantic import BaseModel
from starlette import status
from starlette.responses import JSONResponse
import async_timeout
import asyncio

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
Expand Down Expand Up @@ -42,52 +44,62 @@ class Message(BaseModel):
content: str


def generate(
chat_history: List[Tuple[str, str]],
class SimpleOpenAIBody(BaseModel):
messages: List[Message]
temperature: float
stream: bool


GENERATION_TIMEOUT_SEC = 60


async def stream_generate(
chat_history: List[Message],
max_new_tokens: int = 512,
temperature: float = 0.1,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1,
) -> Iterator[str]:
global total_count
total_count += 1
if total_count % 50 == 0:
os.system("nvidia-smi")
conversation = []

for message in chat_history:
conversation.append({"role": message.role, "content": message.message})

input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
input_ids = input_ids.to(model.device)

streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=False,
top_p=top_p,
top_k=top_k,
num_beams=1,
temperature=temperature,
repetition_penalty=repetition_penalty,
eos_token_id=32021
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()

outputs = []
for text in streamer:
outputs.append(text)
output = "".join(outputs).replace("<|EOT|>", "")
yield ChatResponse(choices=[MessageInResponseChat(role='assistant', content=output)],
model="deepseek").model_dump_json()

yield 'data: DONE'
):
async with async_timeout.timeout(GENERATION_TIMEOUT_SEC):
try:
global total_count
total_count += 1
if total_count % 50 == 0:
os.system("nvidia-smi")

conversation = chat_history

input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
input_ids = input_ids.to(model.device)

streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=False,
top_p=top_p,
top_k=top_k,
num_beams=1,
temperature=temperature,
repetition_penalty=repetition_penalty,
eos_token_id=32021
)
t = Thread(target=model.stream_generate, kwargs=generate_kwargs)
t.start()

for text in streamer:
yield 'data: ' + ChatResponse(
choices=[MessageInResponseChat(role='assistant', content=text.replace("<|EOT|>", ""))],
model="deepseek").model_dump_json()

yield 'data: DONE'

except asyncio.TimeoutError:
raise HTTPException(status_code=504, detail="Stream timed out")


app = FastAPI()
Expand All @@ -102,9 +114,9 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE


@app.post("/api/chat", response_class=Response)
async def root(msgs: List[Message]) -> StreamingResponse:
return StreamingResponse(generate(msgs), media_type="text/event-stream")

async def root(body: SimpleOpenAIBody) -> StreamingResponse:
return StreamingResponse(stream_generate(body.messages, temperature=body.temperature),
media_type="text/event-stream")

if __name__ == "__main__":
try:
Expand Down
4 changes: 3 additions & 1 deletion finetunes/deepseek/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ accelerate==0.23.0
bitsandbytes==0.41.1
gradio==3.48.0
protobuf==3.20.3
scipy==1.11.2
# scipy==1.11.2
sentencepiece==0.1.99
spaces==0.16.1
torch==2.0.0
transformers==4.34.0
fastapi
uvicorn
asyncio
async_timeout

0 comments on commit 1f9e53c

Please sign in to comment.