Skip to content

Commit

Permalink
feat(server): try to saimple response
Browse files Browse the repository at this point in the history
  • Loading branch information
phodal committed Dec 20, 2023
1 parent 3cf8175 commit 46f9863
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 13 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
<a href="https://central.sonatype.com/artifact/cc.unitmesh/unit-picker">
<img src="https://img.shields.io/maven-central/v/cc.unitmesh/unit-picker" alt="Maven"/>
</a>
<a href="https://openbayes.com/console/signup?r=phodal_uVxU">
<img src="https://openbayes.com/img/badge-open-in-openbayes.svg" alt="Open In OpenBayes" />
</a>
<a href="https://openbayes.com/console/signup?r=phodal_uVxU">
<img src="https://openbayes.com/img/badge-built-with-openbayes.svg" alt="Built with OpenBayes" />
</a>
</p>

> LLM benchmark/evaluation tools with fine-tuning data engineering, specifically tailored for Unit Mesh tools such as
Expand Down
34 changes: 21 additions & 13 deletions finetunes/deepseek/api-server-python38.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from starlette.responses import JSONResponse
import async_timeout
import asyncio
import time

MAX_MAX_NEW_TOKENS = 2048
MAX_MAX_NEW_TOKENS = 4096
DEFAULT_MAX_NEW_TOKENS = 1024
total_count = 0
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
Expand All @@ -29,28 +30,27 @@
tokenizer.use_default_system_prompt = False


class MessageInResponseChat(BaseModel):
class Message(BaseModel):
role: str
content: str


class MessageInResponseChat(BaseModel):
message: Message


class ChatResponse(BaseModel):
choices: List[MessageInResponseChat]
model: str


class Message(BaseModel):
role: str
content: str


class SimpleOpenAIBody(BaseModel):
messages: List[Message]
temperature: float
stream: bool


GENERATION_TIMEOUT_SEC = 60
GENERATION_TIMEOUT_SEC = 480


async def stream_generate(
Expand All @@ -75,7 +75,7 @@ async def stream_generate(
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
input_ids = input_ids.to(model.device)

streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
Expand All @@ -91,12 +91,19 @@ async def stream_generate(
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()

result = ""
outputs = []
for text in streamer:
yield 'data: ' + ChatResponse(
choices=[MessageInResponseChat(role='assistant', content=text.replace("<|EOT|>", ""))],
model="deepseek").model_dump_json()
outputs.append(text)
result = "".join(outputs).replace("<|EOT|>", "")
yield 'data:' + ChatResponse(
choices=[MessageInResponseChat(message=Message(role='assistant', content=result))],
model="autodev-deepseek").model_dump_json()

yield 'data: DONE'
yield '\n\n'
time.sleep(0.2)
yield 'data:[DONE]'
print(result)

except asyncio.TimeoutError:
raise HTTPException(status_code=504, detail="Stream timed out")
Expand All @@ -118,6 +125,7 @@ async def root(body: SimpleOpenAIBody) -> StreamingResponse:
return StreamingResponse(stream_generate(body.messages, temperature=body.temperature),
media_type="text/event-stream")


if __name__ == "__main__":
try:
meta = requests.get('http://localhost:21999/gear-status', timeout=5).json()
Expand Down

0 comments on commit 46f9863

Please sign in to comment.