Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion slime/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def _start_router(args):
args.sglang_router_port = find_available_port(random.randint(3000, 4000))

if args.use_slime_router:
from slime_plugins.slime_router.slime_router import run_slime_router as run_router
from slime.router.router import run_router

router_args = args

Expand Down
Empty file added slime/router/__init__.py
Empty file.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -532,23 +532,23 @@ def _print_node(self, node: StringTreeNode, depth: int):
for child in node.children:
self._print_node(child, depth + 1)

def retrieve_from_text(self, text: str, return_logp: bool = False):
def retrieve_from_text(self, text: str, return_logprob: bool = False):
"""
Get tokens from text by looking up in radix tree or using tokenizer.
Also fetches weight version from worker during this operation.
Args:
text: Input text to get tokens for
return_logp: If True, also return log probabilities
return_logprob: If True, also return log probabilities
Returns:
List of token IDs corresponding to the input text if return_logp is False.
Tuple of (token_ids, logp) if return_logp is True.
List of token IDs corresponding to the input text if return_logprob is False.
Tuple of (token_ids, logp) if return_logprob is True.
"""
# Call find_longest_prefix to get the match result
result = self.find_longest_prefix(text)

# If we have a match and it covers the entire text, return the tokens
if result.matched_prefix and result.token_ids:
if return_logp:
if return_logprob:
return (result.token_ids, result.logp)
else:
return result.token_ids
Expand All @@ -562,7 +562,7 @@ def retrieve_from_text(self, text: str, return_logp: bool = False):
# Insert the text and tokens into the tree
self.insert(text, tokens)
# Return the tokens
if return_logp:
if return_logprob:
# Return default logp values (0.0) when using tokenizer
return (tokens, [0.0] * len(tokens))
else:
Expand All @@ -577,7 +577,7 @@ def retrieve_from_text(self, text: str, return_logp: bool = False):
print("Tree structure after retrieve_from_text:")
self.pretty_print()

if return_logp:
if return_logprob:
return (result_tokens, result_logp)
else:
return result_tokens
Expand Down
61 changes: 61 additions & 0 deletions slime/router/middleware_hub/radix_tree_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from fastapi import BaseHTTPMiddleware, FastAPI
from transformers import AutoTokenizer

from .radix_tree import StringRadixTrie


class RadixTreeMiddleware(BaseHTTPMiddleware):
def __init__(self, app: FastAPI, *, router):
super().__init__(app)
self.router = router
self.args = router.args
self.tokenizer = AutoTokenizer.from_pretrained(args.hf_checkpoint, trust_remote_code=True)
self.radix_tree = StringRadixTrie(max_cache_size=10000, tokenizer=self.tokenizer, verbose=False)

async def dispatch(self, request, call_next):
# Example middleware logic using radix tree
path = request.url.path
if path != "/generate":
return await call_next(request)

# pop "text" from request json and get input tokens from self.radix_tree and then use call_next
request_json = await request.json()
input_text = request_json.pop("text", "")
if not input_text:
return await call_next(request)
input_tokens, input_logprobs = self.radix_tree.retrieve_from_text(input_text, return_logprob=True)
request_json["input_tokens"] = input_tokens
request._json = request_json # Update the request json
response = await call_next(request)

# Extract data for radix tree insertion
if "text" in response and "output_ids" in response:
generated_text = response["text"]
generated_token_ids = response["output_ids"]

# Combine input tokens and generated tokens
full_text = input_text + generated_text

# sglang will return the input token ids as well
full_token_ids = generated_token_ids

# Insert the full trajectory into radix tree with current weight version
if full_text and full_token_ids:
try:
if "output_token_logprobs" in response.get("meta_info", {}):
generated_token_logprobs = [item[0] for item in response["meta_info"]["output_token_logprobs"]]
full_logprobs = input_logprobs + generated_token_logprobs
self.radix_tree.insert(
full_text, full_token_ids, full_logprobs, weight_version=self.max_weight_version
)
else:
# Use default log probabilities (0.0) if not provided
self.radix_tree.insert(full_text, full_token_ids, weight_version=self.max_weight_version)

if self.verbose:
print(f"[slime-router] Successfully cached trajectory with {len(full_token_ids)} tokens")
except Exception as e:
if self.verbose:
print(f"[slime-router] Warning: Failed to cache trajectory: {e}")
# Don't fail the request if caching fails
return response
167 changes: 167 additions & 0 deletions slime/router/router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import argparse
import json

import httpx
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from starlette.responses import StreamingResponse

from slime.utils.misc import load_function


def run_router(args):
"""
Run the Slime router with the specified configuration.
"""
# Initialize the router with tokenizer and lazy worker initialization
slime_router = SlimeRouter(args, verbose=False)

# Start the server
uvicorn.run(slime_router.app, host=args.sglang_router_ip, port=args.sglang_router_port, log_level="info")


class SlimeRouter:
def __init__(self, args, verbose=False):
"""Initialize the slime-router with SGLang router address"""
self.args = args
self.verbose = verbose

self.app = FastAPI()

# Worker information
self.worker_urls: dict[str, int] = {}
self.max_weight_version = None

# TODO: remove this hardcode
self.client = httpx.AsyncClient(
limits=httpx.Limits(max_connections=16384),
timeout=httpx.Timeout(None, connect=5.0),
)

self._setup_routes()

for middleware_path in args.slime_router_middleware_paths or []:
if self.verbose:
print(f"[slime-router] Loading middleware from: {middleware_path}")
middleware = load_function(middleware_path)
self.app.add_middleware(middleware, router=self)

def _update_weight_version_from_response(self, output):
"""
Update weight version from SGLang response meta_info.
This is the correct way to get weight version - from the generate response.
"""
if "meta_info" not in output or "weight_version" not in output["meta_info"]:
return

current_weight_version = output["meta_info"]["weight_version"]

# Update max_weight_version
if self.max_weight_version is None or current_weight_version > self.max_weight_version:
self.max_weight_version = current_weight_version
if self.verbose:
print(f"[slime-router] Updated max weight version to: {self.max_weight_version}")
elif self.verbose:
print(f"[slime-router] Current weight version {current_weight_version} <= max {self.max_weight_version}")

def _setup_routes(self):
"""Setup all the HTTP routes"""
# sglang-router api
self.app.post("/add_worker")(self.add_worker)
self.app.get("/list_workers")(self.list_workers)
# Catch-all route for proxying to SGLang - must be registered LAST
self.app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])(self.proxy)

async def health_check(self, request: Request):
# TODO: do health check in background
pass

async def proxy(self, request: Request, path: str):
"""Proxy all other requests to the SGLang router"""
# Forward all other paths to SGLang router
worker_url = self._use_url()
url = f"{worker_url}/{path}"

# Get request body and headers
body = await request.body()
headers = dict(request.headers)

try:
response = await self.client.request(request.method, url, content=body, headers=headers)
return StreamingResponse(
response.aiter_bytes(),
status_code=response.status_code,
headers=response.headers,
media_type=response.headers.get("content-type"),
)

finally:
self._finish_url(worker_url)

async def add_worker(self, request: Request):
"""Add a new worker to the router.
Supports providing the URL via query string or JSON body.
Examples:
- POST /add_worker?url=http://127.0.0.1:10090
- POST /add_worker with body {"url": "http://127.0.0.1:10090"}
"""
# 1) Prefer query param
worker_url = request.query_params.get("url") or request.query_params.get("worker_url")

# 2) Fallback to JSON body
if not worker_url:
body = await request.body()
payload = json.loads(body) if body else {}
worker_url = payload.get("url") or payload.get("worker_url")

if not worker_url:
return JSONResponse(
status_code=400, content={"error": "worker_url is required (use query ?url=... or JSON body)"}
)

# Add if new, keep a simple request count per worker
if worker_url not in self.worker_urls:
self.worker_urls[worker_url] = 0
if self.verbose:
print(f"[slime-router] Added new worker: {worker_url}")

return {"status": "success", "worker_urls": self.worker_urls}

async def list_workers(self, request: Request):
"""List all registered workers"""
return {"urls": list(self.worker_urls.keys())}

def _use_url(self):
"""Select a worker URL using round-robin strategy"""
assert len(self.worker_urls) > 0, "No workers available"

# get the url with mininal count
url = min(self.worker_urls, key=self.worker_urls.get)
self.worker_urls[url] += 1
return url

def _finish_url(self, url):
"""Mark the request to the given URL as finished"""
assert url in self.worker_urls, f"URL {url} not recognized"
self.worker_urls[url] -= 1
assert self.worker_urls[url] >= 0, f"URL {url} count went negative"


if __name__ == "__main__":
import argparse

import uvicorn

parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=30000)
parser.add_argument("--sglang-host", type=str, required=True)
parser.add_argument("--sglang-port", type=int, required=True)
parser.add_argument("--tokenizer-name", type=str, help="Name of the tokenizer to use for tokenization")
parser.add_argument("--verbose", action="store_true", help="Enable verbose output")

args = parser.parse_args()

# Run the router
run_router(args)
6 changes: 6 additions & 0 deletions slime/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,12 @@ def add_router_arguments(parser):
default=False,
help="Whether to use SlimeRouter for text-based routing instead of SGLang token-based routing",
)
parser.add_argument(
"--slime-router-middleware-paths",
type=str,
nargs="+",
default=None,
)
return parser

# wandb
Expand Down
Loading
Loading