Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 189 additions & 0 deletions examples/runtime/engine/fastapi_engine_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
"""
FastAPI server example for text generation using SGLang Engine and demonstrating client usage.

Starts the server, sends requests to it, and prints responses.

Usage:
python fastapi_engine_inference.py --model-path Qwen/Qwen2.5-0.5B-Instruct --tp_size 1 --host 127.0.0.1 --port 8000
"""

import os
import subprocess
import time
from contextlib import asynccontextmanager

import requests
from fastapi import FastAPI, Request

import sglang as sgl
from sglang.utils import terminate_process

engine = None


# Use FastAPI's lifespan manager to initialize/shutdown the engine
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manages SGLang engine initialization during server startup."""
global engine
# Initialize the SGLang engine when the server starts
# Adjust model_path and other engine arguments as needed
print("Loading SGLang engine...")
engine = sgl.Engine(
model_path=os.getenv("MODEL_PATH"), tp_size=int(os.getenv("TP_SIZE"))
)
print("SGLang engine loaded.")
yield
# Clean up engine resources when the server stops (optional, depends on engine needs)
print("Shutting down SGLang engine...")
# engine.shutdown() # Or other cleanup if available/necessary
print("SGLang engine shutdown.")


app = FastAPI(lifespan=lifespan)


@app.post("/generate")
async def generate_text(request: Request):
"""FastAPI endpoint to handle text generation requests."""
global engine
if not engine:
return {"error": "Engine not initialized"}, 503

try:
data = await request.json()
prompt = data.get("prompt")
max_new_tokens = data.get("max_new_tokens", 128)
temperature = data.get("temperature", 0.7)

if not prompt:
return {"error": "Prompt is required"}, 400

# Use async_generate for non-blocking generation
state = await engine.async_generate(
prompt,
sampling_params={
"max_new_tokens": max_new_tokens,
"temperature": temperature,
},
# Add other parameters like stop, top_p etc. as needed
)

return {"generated_text": state["text"]}
except Exception as e:
return {"error": str(e)}, 500


# Helper function to start the server
def start_server(args, timeout=60):
"""Starts the Uvicorn server as a subprocess and waits for it to be ready."""
base_url = f"http://{args.host}:{args.port}"
command = [
"python",
"-m",
"uvicorn",
"fastapi_engine_inference:app",
f"--host={args.host}",
f"--port={args.port}",
]

process = subprocess.Popen(command, stdout=None, stderr=None)

start_time = time.time()
with requests.Session() as session:
while time.time() - start_time < timeout:
try:
# Check the /docs endpoint which FastAPI provides by default
response = session.get(
f"{base_url}/docs", timeout=5
) # Add a request timeout
if response.status_code == 200:
print(f"Server {base_url} is ready (responded on /docs)")
return process
except requests.ConnectionError:
# Specific exception for connection refused/DNS error etc.
pass
except requests.Timeout:
# Specific exception for request timeout
print(f"Health check to {base_url}/docs timed out, retrying...")
pass
except requests.RequestException as e:
# Catch other request exceptions
print(f"Health check request error: {e}, retrying...")
pass
# Use a shorter sleep interval for faster startup detection
time.sleep(1)

# If loop finishes, raise the timeout error
# Attempt to terminate the failed process before raising
if process:
print(
"Server failed to start within timeout, attempting to terminate process..."
)
terminate_process(process) # Use the imported terminate_process
raise TimeoutError(
f"Server failed to start at {base_url} within the timeout period."
)


def send_requests(server_url, prompts, max_new_tokens, temperature):
"""Sends generation requests to the running server for a list of prompts."""
# Iterate through prompts and send requests
for i, prompt in enumerate(prompts):
print(f"\n[{i+1}/{len(prompts)}] Sending prompt: '{prompt}'")
payload = {
"prompt": prompt,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
}

try:
response = requests.post(f"{server_url}/generate", json=payload, timeout=60)

result = response.json()

print(f"Prompt: {prompt}\nResponse: {result['generated_text']}")

except requests.exceptions.Timeout:
print(f" Error: Request timed out for prompt '{prompt}'")
except requests.exceptions.RequestException as e:
print(f" Error sending request for prompt '{prompt}': {e}")


if __name__ == "__main__":
"""Main entry point for the script."""

import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="127.0.0.1")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model-path", type=str, default="Qwen/Qwen2.5-0.5B-Instruct")
parser.add_argument("--tp_size", type=int, default=1)
args = parser.parse_args()

# Pass the model to the child uvicorn process via an env var
os.environ["MODEL_PATH"] = args.model_path
os.environ["TP_SIZE"] = str(args.tp_size)

# Start the server
process = start_server(args)

# Define the prompts and sampling parameters
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
max_new_tokens = 64
temperature = 0.1

# Define server url
server_url = f"http://{args.host}:{args.port}"

# Send requests to the server
send_requests(server_url, prompts, max_new_tokens, temperature)

# Terminate the server process
terminate_process(process)
5 changes: 5 additions & 0 deletions examples/runtime/engine/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ SGLang provides a direct inference engine without the need for an HTTP server. T
1. **Offline Batch Inference**
2. **Embedding Generation**
3. **Custom Server on Top of the Engine**
4. **Inference Using FastAPI**

## Examples

Expand Down Expand Up @@ -47,3 +48,7 @@ This will send both non-streaming and streaming requests to the server.
### [Token-In-Token-Out for RLHF](../token_in_token_out)

In this example, we launch an SGLang engine, feed tokens as input and generate tokens as output.

### [Inference Using FastAPI](fastapi_engine_inference.py)

This example demonstrates how to create a FastAPI server that uses the SGLang engine for text generation.
Loading