Skip to content

Commit

Permalink
feat(integration + agents-api): removed browserbase changes + added h…
Browse files Browse the repository at this point in the history
…otfix for groq models
  • Loading branch information
Vedantsahai18 committed Nov 27, 2024
1 parent 6207d58 commit 181dc96
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 64 deletions.
12 changes: 9 additions & 3 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,16 @@ async def prompt_step(context: StepContext) -> StepOutcome:
},
}
formatted_tools.append(tool)

# For non-Claude models, we don't need to send tools
# FIXME: Enable formatted_tools once format-tools PR is merged.
if not is_claude_model:
# HOTFIX: for groq calls, litellm expects tool_calls_id not to be in the messages
formatted_tools = None

# HOTFIX: for groq calls, litellm expects tool_calls_id not to be in the messages
# FIXME: This is a temporary fix. We need to update the agent-api to use the new tool calling format
# FIXME: Enable formatted_tools once format-tools PR is merged.
is_groq_model = agent_model.lower().startswith("llama-3.1")
if is_groq_model:
prompt = [
{
k: v
Expand All @@ -178,7 +185,6 @@ async def prompt_step(context: StepContext) -> StepOutcome:
}
for message in prompt
]
formatted_tools = None

# Use litellm for other models
completion_data: dict = {
Expand Down
12 changes: 8 additions & 4 deletions agents-api/agents_api/routers/sessions/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,15 @@ async def chat(
}
formatted_tools.append(tool)

# If not using Claude model,

# If not using Claude model
# FIXME: Enable formatted_tools once format-tools PR is merged.
if not is_claude_model:
# HOTFIX: for groq calls, litellm expects tool_calls_id not to be in the messages
formatted_tools = None

# HOTFIX: for groq calls, litellm expects tool_calls_id not to be in the messages
# FIXME: This is a temporary fix. We need to update the agent-api to use the new tool calling format
is_groq_model = settings["model"].lower().startswith("llama-3.1")
if is_groq_model:
messages = [
{
k: v
Expand All @@ -180,7 +185,6 @@ async def chat(
}
for message in messages
]
formatted_tools = None

# Use litellm for other models
model_response = await litellm.acompletion(
Expand Down
1 change: 0 additions & 1 deletion integrations-service/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ x--shared-environment: &shared-environment
CLOUDINARY_API_SECRET: ${CLOUDINARY_API_SECRET}
CLOUDINARY_CLOUD_NAME: ${CLOUDINARY_CLOUD_NAME}
MAILGUN_PASSWORD: ${MAILGUN_PASSWORD}
MAX_POOL_SIZE: ${MAX_POOL_SIZE}

services:
integrations:
Expand Down
1 change: 0 additions & 1 deletion integrations-service/integrations/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,3 @@
cloudinary_api_secret = env.str("CLOUDINARY_API_SECRET", default=None)
cloudinary_cloud_name = env.str("CLOUDINARY_CLOUD_NAME", default=None)
mailgun_password = env.str("MAILGUN_PASSWORD", default=None)
max_pool_size = env.int("MAX_POOL_SIZE", default=5)
45 changes: 19 additions & 26 deletions integrations-service/integrations/utils/integrations/browserbase.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import asyncio
import os
import tempfile
from functools import lru_cache
from typing import Optional

import httpx
from beartype import beartype
Expand Down Expand Up @@ -40,19 +37,21 @@
from ...models.browserbase import BrowserbaseExtensionOutput


# Cache client instances
@lru_cache(maxsize=50)
def get_browserbase_client(
api_key: str,
project_id: str,
api_url: Optional[str] = None,
connect_url: Optional[str] = None,
) -> Browserbase:
def get_browserbase_client(setup: BrowserbaseSetup) -> Browserbase:
setup.api_key = (
browserbase_api_key if setup.api_key == "DEMO_API_KEY" else setup.api_key
)
setup.project_id = (
browserbase_project_id
if setup.project_id == "DEMO_PROJECT_ID"
else setup.project_id
)

return Browserbase(
api_key=api_key,
project_id=project_id,
api_url=api_url,
connect_url=connect_url,
api_key=setup.api_key,
project_id=setup.project_id,
api_url=setup.api_url,
connect_url=setup.connect_url,
)


Expand All @@ -65,18 +64,12 @@ def get_browserbase_client(
async def list_sessions(
setup: BrowserbaseSetup, arguments: BrowserbaseListSessionsArguments
) -> BrowserbaseListSessionsOutput:
client = get_browserbase_client(
api_key=setup.api_key
if setup.api_key != "DEMO_API_KEY"
else browserbase_api_key,
project_id=setup.project_id
if setup.project_id != "DEMO_PROJECT_ID"
else browserbase_project_id,
api_url=setup.api_url,
connect_url=setup.connect_url,
)
client = get_browserbase_client(setup)

# FIXME: Implement status filter
# Run the list_sessions method
sessions: list[Session] = client.list_sessions()

sessions = await asyncio.to_thread(client.list_sessions)
return BrowserbaseListSessionsOutput(sessions=sessions)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,43 +13,16 @@
Keyboard,
Mouse,
Page,
PlaywrightContextManager,
async_playwright,
)
from tenacity import retry, stop_after_attempt, wait_exponential

from ...autogen.Tools import RemoteBrowserArguments, RemoteBrowserSetup
from ...env import max_pool_size
from ...models import RemoteBrowserOutput

CURSOR_PATH = Path(__file__).parent / "assets" / "cursor-small.png"

# Add connection pooling
_browser_pool = {}


async def cleanup_browser_pool():
"""Remove disconnected browsers from the pool."""
for connect_url, browser in list(_browser_pool.items()):
if not browser.is_connected():
await browser.close() # Ensure the browser is closed
del _browser_pool[connect_url]


async def get_browser(connect_url: str):
await cleanup_browser_pool() # Clean up before getting a new browser

if connect_url not in _browser_pool:
if len(_browser_pool) >= max_pool_size:
# Remove the oldest entry to make space
oldest_url = next(iter(_browser_pool))
await _browser_pool[oldest_url].close()
del _browser_pool[oldest_url]

p = await async_playwright().start()
_browser_pool[connect_url] = await p.chromium.connect_over_cdp(connect_url)

return _browser_pool[connect_url]


class PlaywrightActions:
"""Class to handle browser automation actions using Playwright."""
Expand Down Expand Up @@ -407,10 +380,12 @@ async def perform_action(
async def perform_action(
setup: RemoteBrowserSetup, arguments: RemoteBrowserArguments
) -> RemoteBrowserOutput:
p: PlaywrightContextManager = await async_playwright().start()
connect_url = setup.connect_url if setup.connect_url else arguments.connect_url
browser = await get_browser(connect_url)
browser = await p.chromium.connect_over_cdp(connect_url)

automation = PlaywrightActions(browser, width=setup.width, height=setup.height)

await automation.initialize()

return await automation.perform_action(
Expand Down

0 comments on commit 181dc96

Please sign in to comment.