Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
4a5854f
Adding OpenAI Compatible RESTful API
Nov 18, 2023
c0e85d0
Merge branch 'main' into main
PawanOsman Nov 22, 2023
05931b0
Merge branch 'main' into main
PawanOsman Dec 1, 2023
4b3399e
update api_server and fix openai text completion
PawanOsman Dec 11, 2023
4b5e3a8
Merge branch 'main' into main
PawanOsman Dec 11, 2023
b5710e6
Update default top-p and top-k values
PawanOsman Dec 15, 2023
6a8e799
Optimize and Fix minor issues
PawanOsman Dec 15, 2023
644c2ee
remove unused variable
PawanOsman Dec 15, 2023
d1e9f51
Add credits
PawanOsman Dec 15, 2023
cda6225
Merge branch 'main' into main
PawanOsman Dec 15, 2023
994ac31
Merge branch 'main' into main
PawanOsman Dec 16, 2023
b04408d
Adding more args
PawanOsman Dec 16, 2023
0938963
set the deployment-name default
PawanOsman Dec 16, 2023
4a0ee7c
Update __init__.py
mrwyattii Dec 19, 2023
9adf7fe
Update mii/entrypoints/api_server.py
PawanOsman Dec 20, 2023
8c66076
Update mii/entrypoints/data_models.py
PawanOsman Dec 20, 2023
3963177
Update mii/entrypoints/openai_api_server.py
PawanOsman Dec 20, 2023
420b1ba
Update mii/entrypoints/openai_api_server.py
PawanOsman Dec 20, 2023
6cd410d
Merge branch 'main' into main
PawanOsman Dec 26, 2023
72a8c3c
Update api_server.py
PawanOsman Dec 26, 2023
73f82fa
Update openai_api_server.py
PawanOsman Dec 26, 2023
6be85e9
Merge branch 'main' into main
PawanOsman Jan 6, 2024
e35f457
Merge branch 'main' into main
mrwyattii Feb 1, 2024
611e3f9
Merge branch 'main' into main
mrwyattii Feb 1, 2024
b0070d7
Merge branch 'main' into main
PawanOsman Feb 2, 2024
28e0254
Fix linting and formatting issues identified by pre-commit
PawanOsman Feb 2, 2024
c14232e
Merge branch 'main' into main
mrwyattii Feb 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions examples/chat_templates/template_alpaca.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}

{% for message in messages %}
{% if message['role'] == 'user' %}
### Instruction:
{{ message['content']|trim -}}
{% if not loop.last %}


{% endif %}
{% elif message['role'] == 'assistant' %}
### Response:
{{ message['content']|trim -}}
{% if not loop.last %}


{% endif %}
{% elif message['role'] == 'user_context' %}
### Input:
{{ message['content']|trim -}}
{% if not loop.last %}


{% endif %}
{% endif %}
{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}
### Response:
{% endif %}
4 changes: 4 additions & 0 deletions mii/entrypoints/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
217 changes: 217 additions & 0 deletions mii/entrypoints/api_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
# Standard library imports
Comment thread
PawanOsman marked this conversation as resolved.
import json
import time
import grpc
import asyncio
import argparse
import threading
from queue import Queue
from typing import AsyncGenerator
from concurrent.futures import ThreadPoolExecutor

# Third-party imports
import fastapi
import uvicorn
import mii
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse, Response
from mii.grpc_related.proto.modelresponse_pb2_grpc import ModelResponseStub
from mii.grpc_related.proto import modelresponse_pb2
from mii.utils import kwarg_dict_to_proto

# Local module imports
from .data_models import CompletionRequest

app = FastAPI()
load_balancer = "localhost:50050"

@app.post("/generate")
async def generate(request: CompletionRequest) -> Response:
# TODO: Add support for multiple stop tokens, as for now only one is supported
# Check if stop token is a list
if request.stop is not None and isinstance(request.stop, list):
request.stop = request.stop[0]

# Set defaults
if request.max_tokens is None:
request.max_tokens = 128

if request.stream is None:
request.stream = False

if request.prompt is None:
return JSONResponse({"error": "Prompt is required."}, status_code=400)

if isinstance(request.prompt, str):
request.prompt = [request.prompt]

# Set up the generation arguments
generate_args = {
"ignore_eos": False,
"do_sample": True,
"return_full_text": False
}

# Set optional generation arguments
if request.max_length is not None:
generate_args["max_length"] = request.max_length

if request.min_tokens is not None:
generate_args["min_new_tokens"] = request.min_tokens

if request.max_tokens is not None:
generate_args["max_new_tokens"] = request.max_tokens

if request.top_p is not None:
generate_args["top_p"] = request.top_p

if request.top_k is not None:
generate_args["top_k"] = request.top_k

if request.temperature is not None:
generate_args["temperature"] = request.temperature

if request.stop is not None:
generate_args["stop"] = request.stop

if request.stream:
generate_args["stream"] = True

channel = grpc.aio.insecure_channel(load_balancer)
stub = ModelResponseStub(channel)
requestData = modelresponse_pb2.MultiStringRequest(
request=request.prompt,
query_kwargs=kwarg_dict_to_proto(generate_args),
)

# Streaming case
if request.stream:
Comment thread
PawanOsman marked this conversation as resolved.
async def StreamResults() -> AsyncGenerator[bytes, None]:
# Send an empty chunk to start the stream and prevent timeout
yield ""
async for response_chunk in stub.GeneratorReplyStream(requestData):
# Send the response chunk
responses = [obj.response for obj in response_chunk.response]
dataOut = {"text": responses}
yield f"data: {json.dumps(dataOut)}\n\n"
yield f"data: [DONE]\n\n"
return StreamingResponse(StreamResults(), media_type="text/event-stream")

# Non-streaming case
responseData = await stub.GeneratorReply(requestData)
responses = [obj.response for obj in responseData.response]
result = {"text": responses}
return JSONResponse(result)

@app.get("/health")
async def health() -> Response:
"""Health check."""
return JSONResponse({"status": "ok"}, status_code=200)

if __name__ == "__main__":
parser = argparse.ArgumentParser("DeepSpeed-MII Simple Text Generation RESRful API Server")
parser.add_argument(
"--model",
type=str,
default="mistralai/Mistral-7B-Instruct-v0.1",
help="model name or path to model directory (defaults to mistralai/Mistral-7B-Instruct-v0.1)"
)
parser.add_argument(
'--deployment-name',
type=str,
default="deepspeed-mii",
help='A unique identifying string for the persistent model (defaults to f"deepspeed-mii")'
)
parser.add_argument(
"--load-balancer",
type=str,
default=None,
help="load balancer address (defaults to None)"
)
parser.add_argument(
"--max-length",
type=int,
default=32768,
help="maximum token length (defaults to 32768)"
)
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="host address (defaults to 0.0.0.0)"
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="port (defaults to 8000)"
)
parser.add_argument(
"--allow-credentials",
action="store_true",\
help="allow credentials"
)
parser.add_argument(
"--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins"
)
parser.add_argument(
"--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods"
)
parser.add_argument(
"--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers"
)
parser.add_argument(
'--max_length',
type=int,
default=None,
help='Sets the default maximum token length for the prompt + response (defaults to maximum sequence length in model config)'
)
parser.add_argument(
'--tensor-parallel',
type=int,
default=1,
help='Number of GPUs to split the model across (defaults to 1)'
)
parser.add_argument(
'--replica-num',
type=int,
default=1,
help='The number of model replicas to stand up (defaults to 1)'
)

args = parser.parse_args()

# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=args.allowed_origins,
allow_credentials=args.allow_credentials,
allow_methods=args.allowed_methods,
allow_headers=args.allowed_headers,
)

# Check if a load balancer is specified else start the DeepSpeed-MII instance
if args.load_balancer is not None:
# Set the load balancer
load_balancer = args.load_balancer
else:
# Initialize the DeepSpeed-MII instance
mii.serve(args.model, deployment_name=args.deployment_name, tensor_parallel=args.tensor_parallel, replica_num=args.replica_num, max_length=args.max_length)

# Start the server
uvicorn.run(app,
host=args.host,
port=args.port,
log_level="info",
timeout_keep_alive=300)
Loading