Skip to content

Commit b2b31c4

Browse files
committed
feat: add Agent input validation
1 parent a4c3cca commit b2b31c4

7 files changed

+120
-9
lines changed

README.md

+81
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ The goal of AgentServe is to provide the easiest way to take an local agent to p
2626
- **Framework Agnostic:** AgentServe supports multiple agent frameworks (OpenAI, LangChain, LlamaIndex, and Blank).
2727
- **Dockerized:** The output is a single docker image that you can deploy anywhere.
2828
- **Easy to Use:** AgentServe provides a CLI tool to initialize and setup your AI agent projects.
29+
- **Schema Validation:** Define input schemas for your agents using AgentInput to ensure data consistency and validation.
2930

3031
## Requirements
3132

@@ -161,6 +162,86 @@ Get the result of a task.
161162

162163
- `result`: The result of the task.
163164

165+
## Defining Input Schemas
166+
167+
AgentServe uses AgentInput (an alias for Pydantic's BaseModel) to define and validate the input schemas for your agents. This ensures that the data received by your agents adheres to the expected structure, enhancing reliability and developer experience.
168+
### Subclassing AgentInput
169+
To define a custom input schema for your agent, subclass AgentInput and specify the required fields.
170+
171+
**Example:**
172+
173+
```python
174+
# agents/custom_agent.py
175+
from agentserve.agent import Agent, AgentInput
176+
from typing import Optional, Dict, Any
177+
178+
class CustomTaskSchema(AgentInput):
179+
input_text: str
180+
parameters: Optional[Dict[str, Any]] = None
181+
182+
class CustomAgent(Agent):
183+
input_schema = CustomTaskSchema
184+
185+
def process(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
186+
# Implement your processing logic here
187+
input_text = task_data["input_text"]
188+
parameters = task_data.get("parameters", {})
189+
# Example processing
190+
processed_text = input_text.upper() # Simple example
191+
return {"processed_text": processed_text, "parameters": parameters}
192+
```
193+
194+
### Updating Your Agent
195+
196+
When creating your agent, assign your custom schema to the input_schema attribute. This ensures that all incoming task_data is validated against your defined schema before processing.
197+
198+
**Steps:**
199+
200+
1. Define the Input Schema:
201+
202+
```python
203+
from agentserve.agent import Agent, AgentInput
204+
from typing import Optional, Dict, Any
205+
206+
class MyTaskSchema(AgentInput):
207+
prompt: str
208+
settings: Optional[Dict[str, Any]] = None
209+
```
210+
211+
2. Implement the Agent:
212+
213+
```python
214+
class MyAgent(Agent):
215+
input_schema = MyTaskSchema
216+
217+
def process(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
218+
prompt = task_data["prompt"]
219+
settings = task_data.get("settings", {})
220+
# Your processing logic here
221+
response = {"response": f"Echo: {prompt}", "settings": settings}
222+
return response
223+
```
224+
225+
### Handling Validation Errors
226+
227+
AgentServe will automatically validate incoming task_data against the defined input_schema. If the data does not conform to the schema, a 400 Bad Request error will be returned with details about the validation failure.
228+
229+
**Example Response:**
230+
231+
```json
232+
{
233+
"detail": [
234+
{
235+
"loc": ["body", "prompt"],
236+
"msg": "field required",
237+
"type": "value_error.missing"
238+
}
239+
]
240+
}
241+
```
242+
243+
Ensure that your clients provide data that matches the schema to avoid validation errors.
244+
164245
## CLI Usage
165246

166247
### Init Command (for new projects)

agentserve/agent.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
11
# agentserve/agent.py
2-
from typing import Dict, Any
2+
from typing import Dict, Any, Type
3+
from pydantic import BaseModel
4+
5+
AgentInput = BaseModel # Alias BaseModel to AgentInput
36

47
class Agent:
8+
input_schema: Type[AgentInput] = AgentInput
9+
10+
def _process(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
11+
# Validate task_data against input_schema
12+
validated_data = self.input_schema(**task_data).dict()
13+
return self._process(validated_data)
14+
515
def process(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
616
"""
717
User-defined method to process the incoming task data.

agentserve/agent_server.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
# agentserve/agent_server.py
22

33
from fastapi import FastAPI, HTTPException
4-
from typing import Dict, Any, AsyncGenerator
4+
from typing import Dict, Any
55
from rq import Queue
66
from redis import Redis
7-
from fastapi.responses import StreamingResponse
87
import uuid
98
import os
109

@@ -20,15 +19,17 @@ def _setup_routes(self):
2019
@self.app.post("/task/sync")
2120
async def sync_task(task_data: Dict[str, Any]):
2221
try:
23-
result = self.agent.process(task_data)
22+
result = self.agent._process(task_data)
2423
return {"result": result}
24+
except ValueError as ve:
25+
raise HTTPException(status_code=400, detail=str(ve))
2526
except Exception as e:
2627
raise HTTPException(status_code=500, detail=str(e))
2728

2829
@self.app.post("/task/async")
2930
async def async_task(task_data: Dict[str, Any]):
3031
task_id = str(uuid.uuid4())
31-
job = self.task_queue.enqueue(self.agent.process, task_data, job_id=task_id)
32+
job = self.task_queue.enqueue(self.agent._process, task_data, job_id=task_id)
3233
return {"task_id": task_id}
3334

3435
@self.app.get("/task/status/{task_id}")
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
from agentserve import Agent
1+
from agentserve import Agent, AgentInput
2+
3+
class ExampleInput(AgentInput):
4+
prompt: str
25

36
class ExampleAgent(Agent):
7+
input_schema = ExampleInput
48
def process(self, task_data):
59
return ""
610

agentserve/templates/agents/example_langchain_agent.py.tpl

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1-
from agentserve import Agent
1+
from agentserve import Agent, AgentInput
22
from langchain import OpenAI
33

4+
class ExampleInput(AgentInput):
5+
prompt: str
6+
47
class ExampleAgent(Agent):
8+
input_schema = ExampleInput
9+
510
def __init__(self):
611
self.client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
712

agentserve/templates/agents/example_llamaindex_agent.py.tpl

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1-
from agentserve import Agent
1+
from agentserve import Agent, AgentInput
22
from llama_index import GPTSimpleVectorIndex, SimpleDirectoryReader
33
import os
44

5+
class ExampleInput(AgentInput):
6+
query: str
7+
58
class ExampleAgent(Agent):
9+
input_schema = ExampleInput
10+
611
def process(self, task_data):
712
# Load documents from a directory
813
documents = SimpleDirectoryReader('data').load_data()

agentserve/templates/agents/example_openai_agent.py.tpl

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1-
from agentserve import Agent
1+
from agentserve import Agent, AgentInput
22
from openai import OpenAI
33

4+
class ExampleInput(AgentInput):
5+
prompt: str
6+
47
class ExampleAgent(Agent):
8+
input_schema = ExampleInput
9+
510
def process(self, task_data):
611
client = OpenAI()
712
response = client.chat.completions.create(

0 commit comments

Comments
 (0)