diff --git a/README.md b/README.md
index 7466706..db04556 100644
--- a/README.md
+++ b/README.md
@@ -17,7 +17,7 @@
**Fara-7B** is Microsoft's first **agentic small language model (SLM)** designed specifically for computer use. With only 7 billion parameters, Fara-7B is an ultra-compact Computer Use Agent (CUA) that achieves state-of-the-art performance within its size class and is competitive with larger, more resource-intensive agentic systems.
-Try Fara-7B locally as follows (see [Installation](##Installation) for detailed instructions) or via Magentic-UI:
+Try Fara-7B locally as follows (see [Installation](#Installation) for detailed instructions on Windows ) or via Magentic-UI:
```bash
# 1. Clone repository
@@ -44,7 +44,7 @@ To try Fara-7B inside Magentic-UI, please follow the instructions here [Magentic
Notes:
-- If you're using Windows, we highly recommend using WSL2 (Windows Subsystem for Linux).
+- If you're using Windows, we highly recommend using WSL2 (Windows Subsystem for Linux). Please the Windows instructions in the [Installation](#Installation) section.
- You might need to do `--tensor-parallel-size 2` with vllm command if you run out of memory
@@ -156,27 +156,45 @@ Our evaluation setup leverages:
---
-## Installation
+# Installation
-Install the package using either UV or pip:
-```bash
-uv sync --all-extras
-```
+## Linux
+
+The following instructions are for Linux systems, see the Windows section below for Windows instructions.
-or
+Install the package using pip and set up the environment with Playwright:
```bash
-pip install -e .
+# 1. Clone repository
+git clone https://github.com/microsoft/fara.git
+cd fara
+
+# 2. Setup environment
+python3 -m venv .venv
+source .venv/bin/activate
+pip install -e .[vllm]
+playwright install
```
-Then install Playwright browsers:
+Note: If you plan on hosting with Azure Foundry only, you can skip the `[vllm]` and just do `pip install -e .`
+
+
+## Windows
+
+For Windows, we highly recommend using WSL2 (Windows Subsystem for Linux) to provide a Linux-like environment. However, if you prefer to run natively on Windows, follow these steps:
```bash
-playwright install
-```
+# 1. Clone repository
+git clone https://github.com/microsoft/fara.git
+cd fara
----
+# 2. Setup environment
+python3 -m venv .venv
+.venv\Scripts\activate
+pip install -e .
+python3 -m playwright install
+```
## Hosting the Model
@@ -189,11 +207,10 @@ Deploy Fara-7B on [Azure Foundry](https://ai.azure.com/explore/models/Fara-7B/ve
**Setup:**
1. Deploy the Fara-7B model on Azure Foundry and obtain your endpoint URL and API key
-2. Add your endpoint details to the existing `endpoint_configs/` directory (example configs are already provided):
-```bash
-# Edit one of the existing config files or create a new one
-# endpoint_configs/fara-7b-hosting-ansrz.json (example format):
+Then create a endpoint configuration JSON file (e.g., `azure_foundry_config.json`):
+
+```json
{
"model": "Fara-7B",
"base_url": "https://your-endpoint.inference.ml.azure.com/",
@@ -201,62 +218,55 @@ Deploy Fara-7B on [Azure Foundry](https://ai.azure.com/explore/models/Fara-7B/ve
}
```
-3. Run the Fara agent:
+Then you can run Fara-7B using this endpoint configuration.
+
+2. Run the Fara agent:
```bash
-fara-cli --task "how many pages does wikipedia have" --start_page "https://www.bing.com"
+fara-cli --task "how many pages does wikipedia have" --endpoint_config azure_foundry_config.json [--headful]
```
-That's it! No GPU or model downloads required.
+Note: you can also specify the endpoint config with the args `--base_url [your_base_url] --api_key [your_api_key] --model [your_model_name]` instead of using a config JSON file.
-### Self-hosting with VLLM
-
-If you have access to GPU resources, you can self-host Fara-7B using VLLM. This requires a GPU machine with sufficient VRAM.
-
-All that is required is to run the following command to start the VLLM server:
+Note: If you see an error that the `fara-cli` command is not found, then try:
```bash
-vllm serve "microsoft/Fara-7B" --port 5000 --dtype auto
+python -m fara.run_fara --task "what is the weather in new york now"
```
-### Testing the Fara Agent
+That's it! No GPU or model downloads required.
-Run the test script to see Fara in action:
+### Self-hosting with vLLM or LM Studio / Ollama
+
+**If you have access to GPU resources, you can self-host Fara-7B using vLLM. This requires a GPU machine with sufficient VRAM (e.g., 24GB or more).**
+
+Only on Linux: all that is required is to run the following command to start the VLLM server:
```bash
-fara-cli --task "how many pages does wikipedia have" --start_page "https://www.bing.com" --endpoint_config endpoint_configs/azure_foundry_config.json [--headful] [--downloads_folder "/path/to/downloads"] [--save_screenshots] [--max_rounds 100] [--browserbase]
+vllm serve "microsoft/Fara-7B" --port 5000 --dtype auto
```
+For quantized models or lower VRAM GPUs, please see [Fara-7B GGUF on HuggingFace](https://huggingface.co/bartowski/microsoft_Fara-7B-GGUF).
-In self-hosting scenario the `endpoint_config` points to `endpoint_configs/vllm_config.json` from the VLLM server above.
+** For Windows/Mac, vLLM is not natively supported. You can use WSL2 on Windows to run the above command or LM Studio / Ollama as described below. **
-If you set `--browserbase`, export environment variables for the API key and project ID.
+Otherwise, you can use [LM Studio](https://lmstudio.ai/) or [Ollama](https://ollama.com/) to host the model locally. We currently recommend the following GGUF versions of our models [Fara-7B GGUF on HuggingFace](https://huggingface.co/bartowski/microsoft_Fara-7B-GGUF) for use with LM Studio or Ollama. Select the largest model that fits your GPU. Please ensure that context length is set to at least 15000 tokens and temperature to 0 for best results.
-#### Expected Output
-
-```
-Initializing Browser...
-Browser Running... Starting Fara Agent...
-##########################################
-Task: how many pages does wikipedia have
-##########################################
-Running Fara...
+Then you can run Fara-7B pointing to your local server:
+Run the test script to see Fara in action:
-Thought #1: To find the current number of Wikipedia pages, I'll search for the latest Wikipedia page count statistics.
-Action #1: executing tool 'web_search' with arguments {"action": "web_search", "query": "Wikipedia total number of articles"}
-Observation#1: I typed 'Wikipedia total number of articles' into the browser search bar.
+```bash
+fara-cli --task "what is the weather in new york now"
+```
-Thought #2: Wikipedia currently has 7,095,446 articles.
-Action #2: executing tool 'terminate' with arguments {"action": "terminate", "status": "success"}
-Observation#2: Wikipedia currently has 7,095,446 articles.
+If you didn't use vLLM to host, please specify the correct `--base_url [your_base_url] --api_key [your_api_key] --model [your_model_name]`
-Final Answer: Wikipedia currently has 7,095,446 articles.
+If you see an error that the `fara-cli` command is not found, then try:
-Enter another task (or press Enter to exit):
+```bash
+python -m fara.run_fara --task "what is the weather in new york now"
```
----
-
# Reproducibility
We provide a framework in `webeval/` to reproduce our results on WebVoyager and OnlineMind2Web.
diff --git a/pyproject.toml b/pyproject.toml
index 8e858d9..6566c44 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -32,14 +32,20 @@ dependencies = [
"pyyaml",
"jsonschema",
"browserbase",
- "vllm>=0.10.0"
]
+
+
+
[project.urls]
Homepage = "https://github.com/microsoft/fara"
Repository = "https://github.com/microsoft/fara"
Issues = "https://github.com/microsoft/fara/issues"
+[project.optional-dependencies]
+vllm = ["vllm>=0.10.0"]
+lmstudio = ["lmstudio"]
+ollama = ["ollama"]
[project.scripts]
fara-cli = "fara.run_fara:main"
diff --git a/src/fara/browser/browser_bb.py b/src/fara/browser/browser_bb.py
index d0303ca..d4a2e6f 100644
--- a/src/fara/browser/browser_bb.py
+++ b/src/fara/browser/browser_bb.py
@@ -6,7 +6,7 @@
import subprocess
import time
from typing import Any, Dict, Optional, Callable
-
+import platform
import browserbase
from browserbase import Browserbase
from playwright.async_api import (
@@ -48,7 +48,7 @@ def __init__(
self.single_tab_mode = single_tab_mode
self.use_browser_base = use_browser_base
self.logger = logger or logging.getLogger("browser_manager")
-
+ self.is_linux = platform.system() == "Linux"
self._viewport_height = viewport_height
self._viewport_width = viewport_width
@@ -194,7 +194,8 @@ async def delayed_resume():
async def _init_regular_browser(self, channel: str = "chromium") -> None:
"""Initialize regular browser according to the specified channel."""
- if not self.headless:
+ if not self.headless and self.is_linux:
+ print("STARTING XVFB")
self.start_xvfb()
launch_args: Dict[str, Any] = {"headless": self.headless}
@@ -218,7 +219,7 @@ async def _init_regular_browser(self, channel: str = "chromium") -> None:
async def _init_persistent_browser(self) -> None:
"""Initialize persistent browser with data directory."""
- if not self.headless:
+ if not self.headless and self.is_linux:
self.start_xvfb()
launch_args: Dict[str, Any] = {"headless": self.headless}
diff --git a/src/fara/fara_agent.py b/src/fara/fara_agent.py
index 09d52b5..ab6b953 100644
--- a/src/fara/fara_agent.py
+++ b/src/fara/fara_agent.py
@@ -15,7 +15,7 @@
import asyncio
from .browser.playwright_controller import PlaywrightController
from ._prompts import get_computer_use_system_prompt
-from .types import (
+from .fara_types import (
LLMMessage,
SystemMessage,
UserMessage,
@@ -379,15 +379,20 @@ async def run(self, user_message: str) -> Tuple:
thoughts, action_dict = self._parse_thoughts_and_action(raw_response)
action_args = action_dict.get("arguments", {})
action = action_args["action"]
- self.logger.info(f"\nThought #{i+1}: {thoughts}\nAction #{i+1}: executing tool '{action}' with arguments {json.dumps(action_args)}")
-
+ self.logger.debug(
+ f"\nThought #{i+1}: {thoughts}\nAction #{i+1}: executing tool '{action}' with arguments {json.dumps(action_args)}"
+ )
+ print(
+ f"\nThought #{i+1}: {thoughts}\nAction #{i+1}: executing tool '{action}' with arguments {json.dumps(action_args)}"
+ )
(
is_stop_action,
new_screenshot,
action_description,
) = await self.execute_action(function_call)
all_observations.append(action_description)
- self.logger.info(f"Observation#{i+1}: {action_description}")
+ self.logger.debug(f"Observation#{i+1}: {action_description}")
+ print(f"Observation#{i+1}: {action_description}")
if is_stop_action:
final_answer = thoughts
break
@@ -564,7 +569,7 @@ async def execute_action(
elif args["action"] == "pause_and_memorize_fact":
fact = str(args.get("fact"))
self._facts.append(fact)
- action_description= f"I memorized the following fact: {fact}"
+ action_description = f"I memorized the following fact: {fact}"
elif args["action"] == "stop" or args["action"] == "terminate":
action_description = args.get("thoughts")
is_stop_action = True
diff --git a/src/fara/types.py b/src/fara/fara_types.py
similarity index 100%
rename from src/fara/types.py
rename to src/fara/fara_types.py
diff --git a/src/fara/run_fara.py b/src/fara/run_fara.py
index 3e1e207..9525828 100755
--- a/src/fara/run_fara.py
+++ b/src/fara/run_fara.py
@@ -1,8 +1,8 @@
import asyncio
import argparse
import os
-from fara import FaraAgent
-from fara.browser.browser_bb import BrowserBB
+from .fara_agent import FaraAgent
+from .browser.browser_bb import BrowserBB
import logging
from typing import Dict
from pathlib import Path
@@ -11,8 +11,8 @@
# Configure logging to only show logs from fara.fara_agent
logging.basicConfig(
- level=logging.CRITICAL,
- format="%(message)s",
+ level=logging.CRITICAL,
+ format="%(message)s",
)
# Enable INFO level only for fara.fara_agent
@@ -159,21 +159,51 @@ def main():
default=None,
help="Path to the endpoint configuration JSON file. By default, tries local vllm on 5000 port",
)
+ parser.add_argument(
+ "--api_key",
+ type=str,
+ default=None,
+ help="API key for the model endpoint (overrides endpoint_config)",
+ )
+ parser.add_argument(
+ "--base_url",
+ type=str,
+ default=None,
+ help="Base URL for the model endpoint (overrides endpoint_config)",
+ )
+ parser.add_argument(
+ "--model",
+ type=str,
+ default=None,
+ help="Model name to use (overrides endpoint_config)",
+ )
args = parser.parse_args()
if args.browserbase:
- assert os.environ.get("BROWSERBASE_API_KEY"), (
- "BROWSERBASE_API_KEY environment variable must be set to use browserbase"
- )
- assert os.environ.get("BROWSERBASE_PROJECT_ID"), (
- "BROWSERBASE_API_KEY and BROWSERBASE_PROJECT_ID environment variables must be set to use browserbase"
- )
+ assert os.environ.get(
+ "BROWSERBASE_API_KEY"
+ ), "BROWSERBASE_API_KEY environment variable must be set to use browserbase"
+ assert os.environ.get(
+ "BROWSERBASE_PROJECT_ID"
+ ), "BROWSERBASE_API_KEY and BROWSERBASE_PROJECT_ID environment variables must be set to use browserbase"
endpoint_config = DEFAULT_ENDPOINT_CONFIG
if args.endpoint_config:
with open(args.endpoint_config, "r") as f:
endpoint_config = json.load(f)
+ assert (
+ "api_key" in endpoint_config
+ and "base_url" in endpoint_config
+ and "model" in endpoint_config
+ ), "endpoint_config file must contain api_key, base_url, and model fields"
+ # Override with command-line arguments if provided
+ if args.api_key:
+ endpoint_config["api_key"] = args.api_key
+ if args.base_url:
+ endpoint_config["base_url"] = args.base_url
+ if args.model:
+ endpoint_config["model"] = args.model
asyncio.run(
run_fara_agent(
diff --git a/src/fara/vllm/az_vllm.py b/src/fara/vllm/az_vllm.py
index d34b9ac..01ceada 100644
--- a/src/fara/vllm/az_vllm.py
+++ b/src/fara/vllm/az_vllm.py
@@ -28,10 +28,15 @@
def _is_azure_blob_url(model_path: str) -> bool:
- return model_path.startswith(("https://", "http://")) and "blob.core.windows.net" in model_path
+ return (
+ model_path.startswith(("https://", "http://"))
+ and "blob.core.windows.net" in model_path
+ )
-def _download_model_from_hf(output_dir: Path, model_id: str = DEFAULT_HF_MODEL_ID) -> str:
+def _download_model_from_hf(
+ output_dir: Path, model_id: str = DEFAULT_HF_MODEL_ID
+) -> str:
"""Download model from HuggingFace Hub if not already present."""
if snapshot_download is None:
raise ImportError(
@@ -63,13 +68,15 @@ def _download_model_from_hf(output_dir: Path, model_id: str = DEFAULT_HF_MODEL_I
def _extract_model_name(model_url: str) -> str:
"""Extract model name from URL for consistent naming."""
- url_parts = model_url.rstrip('/').split('/')
+ url_parts = model_url.rstrip("/").split("/")
return url_parts[-1] if url_parts else model_url
def _cache_model(model_url: str) -> str:
if AzFolder is None:
- raise RuntimeError("Azure support not available. Install aztool or run without --cache.")
+ raise RuntimeError(
+ "Azure support not available. Install aztool or run without --cache."
+ )
cache_root = Path(args.cache_dir or os.path.expanduser("~/.cache/vllm_models"))
cache_root.mkdir(parents=True, exist_ok=True)
@@ -120,8 +127,18 @@ def _prepare_cached_model(model_url: str) -> str:
raise FileNotFoundError(f"Local model directory not found: {model_url}")
return str(model_path.resolve())
+
class AzVllm:
- def __init__(self, model_url, port, device_id, max_n_images, dtype='auto', enforce_eager=False, use_external_endpoint=False):
+ def __init__(
+ self,
+ model_url,
+ port,
+ device_id,
+ max_n_images,
+ dtype="auto",
+ enforce_eager=False,
+ use_external_endpoint=False,
+ ):
self.model_az = None
self.local_model_path = None
self.vllm = None
@@ -141,7 +158,9 @@ def __init__(self, model_url, port, device_id, max_n_images, dtype='auto', enfor
if not model_path.exists():
# Auto-download from HuggingFace if path doesn't exist
logging.warning(f"Local model directory not found: {model_url}")
- logging.info(f"Attempting to download {DEFAULT_HF_MODEL_ID} from HuggingFace...")
+ logging.info(
+ f"Attempting to download {DEFAULT_HF_MODEL_ID} from HuggingFace..."
+ )
self.local_model_path = _download_model_from_hf(model_path)
else:
self.local_model_path = str(model_path.resolve())
@@ -150,7 +169,7 @@ def __init__(self, model_url, port, device_id, max_n_images, dtype='auto', enfor
def __enter__(self):
# No-op if using external endpoint
if self.use_external_endpoint:
- print('Using external endpoint, skipping VLLM startup')
+ print("Using external endpoint, skipping VLLM startup")
return self
if self.model_az:
@@ -162,33 +181,35 @@ def __enter__(self):
for file in files:
print(f"\t{os.path.join(root, file)}")
self.vllm = VLLM(
- model_path = self.context.path,
- port = self.port,
- device_id = self.device_id,
- max_n_images = self.max_n_images,
- dtype = self.dtype,
- enforce_eager = self.enforce_eager
+ model_path=self.context.path,
+ port=self.port,
+ device_id=self.device_id,
+ max_n_images=self.max_n_images,
+ dtype=self.dtype,
+ enforce_eager=self.enforce_eager,
)
self.vllm.start()
- print('VLLM has started')
+ print("VLLM has started")
elif self.local_model_path:
- print(f"VLLM using on-disk model at path {self.local_model_path}, contents:")
+ print(
+ f"VLLM using on-disk model at path {self.local_model_path}, contents:"
+ )
### sometimes need to ls the directory or else huggingface will complain a config.json doesn't exist
for root, dirs, files in os.walk(self.local_model_path):
for file in files:
print(f"\t{os.path.join(root, file)}")
self.vllm = VLLM(
- model_path = self.local_model_path,
- port = self.port,
- device_id = self.device_id,
- max_n_images = self.max_n_images,
- dtype = self.dtype,
- enforce_eager = self.enforce_eager
+ model_path=self.local_model_path,
+ port=self.port,
+ device_id=self.device_id,
+ max_n_images=self.max_n_images,
+ dtype=self.dtype,
+ enforce_eager=self.enforce_eager,
)
self.vllm.start()
- print('VLLM has started')
+ print("VLLM has started")
return self
-
+
def __exit__(self, exc_type, exc_val, exc_tb):
if self.vllm:
if self.vllm and (self.vllm.status == Status.Running):
@@ -196,6 +217,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
if self.context:
self.context.unmount()
+
@asynccontextmanager
async def lifespan(app: FastAPI):
cached_vllm: Optional[VLLM] = None
@@ -214,17 +236,18 @@ async def lifespan(app: FastAPI):
device_id=args.device_id,
max_n_images=args.max_n_images,
dtype=args.dtype,
- enforce_eager=args.enforce_eager
+ enforce_eager=args.enforce_eager,
)
cached_vllm.start()
else:
az_vllm = AzVllm(
- model_url = args.model_url,
- port = args.vllm_port,
- device_id = args.device_id,
- max_n_images = args.max_n_images,
- dtype = args.dtype,
- enforce_eager = args.enforce_eager)
+ model_url=args.model_url,
+ port=args.vllm_port,
+ device_id=args.device_id,
+ max_n_images=args.max_n_images,
+ dtype=args.dtype,
+ enforce_eager=args.enforce_eager,
+ )
az_vllm.__enter__()
app.state.resolved_model_path = args.model_url
app.state.model_name = _extract_model_name(args.model_url)
@@ -239,7 +262,7 @@ async def lifespan(app: FastAPI):
app.state.model_name = None
-app = FastAPI(lifespan = lifespan)
+app = FastAPI(lifespan=lifespan)
@app.post("/v1/chat/completions")
@@ -247,21 +270,18 @@ async def post_v1_chat_completions(request: Request):
body = await request.body()
async with httpx.AsyncClient() as client:
resp = await client.post(
- f'http://localhost:{args.vllm_port}/v1/chat/completions',
+ f"http://localhost:{args.vllm_port}/v1/chat/completions",
content=body,
headers=dict(request.headers),
- timeout=None
+ timeout=None,
)
return Response(
- content=resp.content,
- status_code=resp.status_code,
- headers=resp.headers
+ content=resp.content, status_code=resp.status_code, headers=resp.headers
)
@app.get("/model")
async def get_model():
-
return {"model": _extract_model_name(args.model_url), "model_url": args.model_url}
@@ -271,12 +291,37 @@ async def get_model():
parser.add_argument("--port", type=int, default=5000, help="port")
parser.add_argument("--vllm_port", type=int, default=5001, help="vllm port")
parser.add_argument("--device_id", type=str, default="0", help="device id")
- parser.add_argument("--max_n_images", type=int, default=3, help="Maximum number of images to process")
- parser.add_argument('--dtype', type=str, choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], default='auto', help='Data type for VLLM model (default: auto)')
- parser.add_argument('--enforce_eager', action='store_true', help='Enforce eager execution mode for compatibility')
- parser.add_argument('--cache', action='store_true', help='Enable caching / local path serving instead of Azure mount')
- parser.add_argument('--cache_dir', type=str, default=None, help='Directory to cache downloaded models (default: ~/.cache/vllm_models)')
+ parser.add_argument(
+ "--max_n_images",
+ type=int,
+ default=3,
+ help="Maximum number of images to process",
+ )
+ parser.add_argument(
+ "--dtype",
+ type=str,
+ choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
+ default="auto",
+ help="Data type for VLLM model (default: auto)",
+ )
+ parser.add_argument(
+ "--enforce_eager",
+ action="store_true",
+ help="Enforce eager execution mode for compatibility",
+ )
+ parser.add_argument(
+ "--cache",
+ action="store_true",
+ help="Enable caching / local path serving instead of Azure mount",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="Directory to cache downloaded models (default: ~/.cache/vllm_models)",
+ )
args = parser.parse_args()
import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=args.port)
\ No newline at end of file
+
+ uvicorn.run(app, host="0.0.0.0", port=args.port)
diff --git a/src/fara/vllm/vllm_facade.py b/src/fara/vllm/vllm_facade.py
index 73cb8d4..ae745e6 100644
--- a/src/fara/vllm/vllm_facade.py
+++ b/src/fara/vllm/vllm_facade.py
@@ -5,32 +5,39 @@
import threading
import time
+
class Status(Enum):
NotStarted = 0
Running = 1
Stopped = 2
+
class VLLM:
- cmd_template = ' '.join([
- "python -O -u -m vllm.entrypoints.openai.api_server",
- "--host={host}",
- "--port={port}",
- "--model={model_dir}",
- "--served-model-name {model_name}",
- "--tensor-parallel-size {tensor_parallel_size}",
- "--gpu-memory-utilization 0.95",
- "--trust-remote-code",
- "--dtype {dtype}"
- ])
- def __init__(self,
- model_path,
- max_n_images,
- device_id = "0",
- host = "0.0.0.0",
- port = 5000,
- model_name = "gpt-4o-mini-2024-07-18",
- dtype = "auto",
- enforce_eager = False):
+ cmd_template = " ".join(
+ [
+ "python -O -u -m vllm.entrypoints.openai.api_server",
+ "--host={host}",
+ "--port={port}",
+ "--model={model_dir}",
+ "--served-model-name {model_name}",
+ "--tensor-parallel-size {tensor_parallel_size}",
+ "--gpu-memory-utilization 0.95",
+ "--trust-remote-code",
+ "--dtype {dtype}",
+ ]
+ )
+
+ def __init__(
+ self,
+ model_path,
+ max_n_images,
+ device_id="0",
+ host="0.0.0.0",
+ port=5000,
+ model_name="gpt-4o-mini-2024-07-18",
+ dtype="auto",
+ enforce_eager=False,
+ ):
self.model_path = model_path
self.device_id = device_id
self.host = host
@@ -43,10 +50,12 @@ def __init__(self,
# new versions of vllm require dictionary-like arguments for this
# see https://docs.vllm.ai/en/latest/configuration/engine_args.html#multimodalconfig
self.cmd += f" --limit-mm-per-prompt.image {self.max_n_images}"
- if enforce_eager: # Most helpful for float32 cases when attention backends are incompatible
+ if (
+ enforce_eager
+ ): # Most helpful for float32 cases when attention backends are incompatible
self.cmd += " --enforce-eager"
self.model_name = model_name
- self.tensor_parallel_size = len(str(device_id).split(','))
+ self.tensor_parallel_size = len(str(device_id).split(","))
self.status = Status.NotStarted
self.process = None
self.logs = []
@@ -57,12 +66,13 @@ def endpoint(self):
def start(self):
def _drain(pipe):
- for line in iter(pipe.readline, ''):
+ for line in iter(pipe.readline, ""):
self.logs.append(line)
- print(line, end='')
+ print(line, end="")
+
env = os.environ.copy()
- env['CUDA_VISIBLE_DEVICES'] = self.device_id
- env['NCCL_DEBUG'] = "TRACE"
+ env["CUDA_VISIBLE_DEVICES"] = self.device_id
+ env["NCCL_DEBUG"] = "TRACE"
self.process = subprocess.Popen(
self.cmd.format(
host=self.host,
@@ -70,13 +80,13 @@ def _drain(pipe):
model_dir=self.model_path,
model_name=self.model_name,
tensor_parallel_size=self.tensor_parallel_size,
- dtype=self.dtype
+ dtype=self.dtype,
).split(),
- stdout = subprocess.PIPE,
- stderr = subprocess.STDOUT,
- text = True,
- shell = False,
- env = env
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ text=True,
+ shell=False,
+ env=env,
)
t = threading.Thread(target=_drain, args=(self.process.stdout,), daemon=True)
t.start()
@@ -88,9 +98,9 @@ def _drain(pipe):
if "Application startup complete." in line:
logging.info("VLLM process started successfully.")
self.status = Status.Running
- return True
-
+ return True
+
def stop(self):
if self.process:
self.process.terminate()
- self.status = Status.Stopped
\ No newline at end of file
+ self.status = Status.Stopped