Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion trinity/explorer/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ async def chat_completions(request: Request):
content=f"Error forwarding request to model at {url}: {traceback.format_exc()}",
)
resp_data = resp.json()
await request.app.state.service.record_experience(resp_data)
await request.app.state.service.record_experience(
resp_data, session_id=body.get("session_id", None)
)
return JSONResponse(content=resp_data)


Expand All @@ -52,6 +54,26 @@ async def metrics(request: Request):
return JSONResponse(content=metrics)


@app.get("/allocate")
async def allocate(request: Request):
"""Allocate a new session."""
return JSONResponse(content={"session_id": request.app.state.service.allocate_session()})


@app.post("/feedback")
async def feedback(request: Request):
"""Receive feedback for the current session."""
body = await request.json()
session_id = body.get("session_id", None)
reward = body.get("reward", None)
if session_id is None or reward is None:
return JSONResponse(
status_code=400, content={"error": "session_id and reward are required"}
)
await request.app.state.service.explorer.record_feedback(session_id, reward)
return JSONResponse(content={"status": "success"})


async def serve_http(app: FastAPI, host: str, port: int = None):
config = uvicorn.Config(app, host=host, port=port)
server = uvicorn.Server(config)
Expand Down
56 changes: 43 additions & 13 deletions trinity/explorer/api/service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import time
from collections import deque
from typing import Dict, List
from typing import Dict, List, Optional

import torch

Expand All @@ -13,6 +13,8 @@


class ExplorerService:
"""Manages the lifecycle and operations of the Explorer API service."""

def __init__(self, explorer: Explorer, listen_address: str = "localhost", port: int = 8010):
self.logger = get_logger(__name__)
self.explorer = explorer
Expand All @@ -27,10 +29,13 @@ def __init__(self, explorer: Explorer, listen_address: str = "localhost", port:
self.running_models: deque[int] = deque() # indices of running models
self.sync_task_map: Dict[asyncio.Future, int] = {} # sync task -> model index
self.latest_model_version = 0
self.experience_queue = asyncio.Queue()
self.experience_queue: deque[Experience] = deque()
self.session_level_experience_queue: Dict[int, deque[Experience]] = {}
self.queue_lock = asyncio.Lock()
self.experience_count = 0
self.session_count = 0

async def serve(self):
async def serve(self) -> None:
from trinity.explorer.api.api import run_app

if self.running:
Expand All @@ -48,7 +53,7 @@ async def serve(self):
)
self.sync_model_weights_task = asyncio.create_task(self.model_weights_sync_loop())

async def model_weights_sync_loop(self):
async def model_weights_sync_loop(self) -> None:
self.logger.info("Starting model weights synchronization loop.")
while self.running:
for idx in list(self.running_models):
Expand All @@ -71,7 +76,7 @@ def set_latest_model_version(self, version: int) -> None:
self.latest_model_version = version
self.logger.info(f"Updated latest model version to {version}.")

async def _wait_for_sync_start(self, index: int):
async def _wait_for_sync_start(self, index: int) -> None:
start_time = time.time()
while time.time() - start_time < self.max_timeout:
current_load = await self.models[index].get_current_load()
Expand All @@ -85,7 +90,7 @@ async def _wait_for_sync_start(self, index: int):
f"Timeout waiting for model {index} to be free for synchronization. Current load: {current_load}"
)

async def _sync_model_weights(self, task: asyncio.Future):
async def _sync_model_weights(self, task: asyncio.Future) -> None:
index = self.sync_task_map.pop(task)
latest_version = self.latest_model_version # capture the latest version
if task.cancelled():
Expand Down Expand Up @@ -121,7 +126,7 @@ async def check_requiring_sync_models(self):
*[self._sync_model_weights(idx) for idx in list(self.requiring_sync_models)]
)

async def record_experience(self, response):
async def record_experience(self, response, session_id: Optional[int] = None):
experiences = []
for choice in response["choices"]:
exp = Experience(
Expand All @@ -137,14 +142,39 @@ async def record_experience(self, response):
)
experiences.append(exp)
self.experience_count += len(experiences)
for exp in experiences:
await self.experience_queue.put(exp)

# Store experiences in session-level queue if session_id is provided
if session_id is not None:
async with self.queue_lock:
if session_id not in self.session_level_experience_queue:
self.session_level_experience_queue[session_id] = deque()
self.session_level_experience_queue[session_id].extend(experiences)
else:
async with self.queue_lock:
self.experience_queue.extend(experiences)

async def get_all_experiences(self) -> List:
experiences = []
while not self.experience_queue.empty():
experiences.append(await self.experience_queue.get())
return experiences
async with self.queue_lock:
experiences = list(self.experience_queue)
self.experience_queue.clear()
return experiences

def allocate_session(self) -> int:
self.session_count += 1
return self.session_count

async def record_feedback(self, session_id: int, reward: float):
exps = []
async with self.queue_lock:
if session_id in self.session_level_experience_queue:
exps = list(self.session_level_experience_queue.pop(session_id))
if not exps:
self.logger.warning(f"No experiences found for session_id {session_id}.")
return
for exp in exps:
exp.reward = reward
async with self.queue_lock:
self.experience_queue.extend(exps)

async def shutdown(self):
if not self.running:
Expand Down
20 changes: 11 additions & 9 deletions trinity/explorer/explorer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,20 @@


class ExplorerClient:
def __init__(self, base_url: str):
self.base_url = base_url
def __init__(self, explorer_api_url: str):
self.explorer_api_url = explorer_api_url
self.openai_base_url = f"{self.explorer_api_url}/v1"
self.feedback_url = f"{self.explorer_api_url}/feedback"
self.session_id = self.init_session()

def init_session(self) -> str:
response = requests.post(f"{self.base_url}/allocate")
response = requests.post(f"{self.explorer_api_url}/allocate")
data = response.json()
return data["session_id"]

def get_openai_client(self) -> openai.OpenAI:
client = openai.OpenAI(
base_url=self.base_url + "/v1",
base_url=self.openai_base_url,
api_key="EMPTY",
)
client.chat.completions.create = partial(
Expand All @@ -27,23 +29,23 @@ def get_openai_client(self) -> openai.OpenAI:

def get_openai_async_client(self) -> openai.AsyncOpenAI:
client = openai.AsyncOpenAI(
base_url=self.base_url + "/v1",
base_url=self.openai_base_url,
api_key="EMPTY",
)
client.chat.completions.create = partial(
client.chat.completions.create, extra_body={"session_id": self.session_id}
)
return client

def feedback(self, reward: float):
def feedback(self, reward: float) -> dict:
response = requests.post(
f"{self.base_url}/feedback", json={"session_id": self.session_id, "reward": reward}
self.feedback_url, json={"session_id": self.session_id, "reward": reward}
)
return response.json()

async def feedback_async(self, reward: float):
async def feedback_async(self, reward: float) -> dict:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/feedback", json={"session_id": self.session_id, "reward": reward}
self.feedback_url, json={"session_id": self.session_id, "reward": reward}
)
return response.json()