Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"assisted-service-client>=2.41.0.post3",
"fastmcp>=2.8.0",
"fastmcp>=2.12.3",
"netaddr>=1.3.0",
"requests>=2.32.3",
"retry>=0.9.2",
Expand Down
66 changes: 40 additions & 26 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
import uvicorn
from pydantic import Field
from assisted_service_client import models
from mcp.server.fastmcp import FastMCP

from fastmcp import FastMCP
from fastmcp.server.dependencies import get_http_headers
from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext
from fastmcp.tools.tool import ToolResult

from metrics import metrics, track_tool_usage, initiate_metrics
from service_client import InventoryClient
Expand All @@ -29,10 +33,7 @@
)


transport_type = os.environ.get("TRANSPORT", "sse").lower()
use_stateless_http = transport_type == "streamable-http"

mcp = FastMCP("AssistedService", host="0.0.0.0", stateless_http=use_stateless_http)
mcp = FastMCP("AssistedService")


def format_presigned_url(presigned_url: models.PresignedUrl) -> dict[str, Any]:
Expand Down Expand Up @@ -87,12 +88,11 @@ def get_offline_token() -> str:
log.debug("Found offline token in environment variables")
return token

request = mcp.get_context().request_context.request
if request is not None:
token = request.headers.get("OCM-Offline-Token")
if token:
log.debug("Found offline token in request headers")
return token
headers = get_http_headers(include_all=True)
token = headers.get("ocm-offline-token")
if token:
log.debug("Found offline token in request headers")
return token

log.error("No offline token found in environment or request headers")
raise RuntimeError("No offline token found in environment or request headers")
Expand All @@ -114,14 +114,13 @@ def get_access_token() -> str:
"""
log.debug("Attempting to retrieve access token")
# First try to get the token from the authorization header:
request = mcp.get_context().request_context.request
if request is not None:
header = request.headers.get("Authorization")
if header is not None:
parts = header.split()
if len(parts) == 2 and parts[0].lower() == "bearer":
log.debug("Found access token in authorization header")
return parts[1]
headers = get_http_headers(include_all=True)
header = headers.get("authorization")
if header is not None:
parts = header.split()
if len(parts) == 2 and parts[0].lower() == "bearer":
log.debug("Found access token in authorization header")
return parts[1]

# Now try to get the offline token, and generate a new access token from it:
log.debug("Generating new access token from offline token")
Expand Down Expand Up @@ -939,18 +938,33 @@ def list_tools() -> list[str]:
"""List all MCP tools."""

async def mcp_list_tools() -> list[str]:
return [t.name for t in await mcp.list_tools()]
return list((await mcp.get_tools()).keys())

return asyncio.run(mcp_list_tools())


class StripSessionIDMiddleware(Middleware):
"""Middleware to fix llama-stack behavior

For some reason it injects session_id as an arg to every tool call
and it isn't configurable to disable.
"""

async def on_call_tool(
self,
context: MiddlewareContext,
call_next: CallNext,
) -> ToolResult:
"""Strip session_id from tool calls"""
if "session_id" in context.message.arguments:
del context.message.arguments["session_id"]

return await call_next(context)


if __name__ == "__main__":
if transport_type == "streamable-http":
app = mcp.streamable_http_app()
log.info("Using StreamableHTTP transport (stateless)")
else:
app = mcp.sse_app()
log.info("Using SSE transport (stateful)")
mcp.add_middleware(StripSessionIDMiddleware())
app = mcp.http_app(transport="streamable-http")

initiate_metrics(list_tools())
app.add_route("/metrics", metrics)
Expand Down
Loading