Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AWS Bedrock support #357

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@

ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", None)

# AWS
AWS_ACCESS_KEY = os.environ.get("AWS_ACCESS_KEY", None)
AWS_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY", None)
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME", "us-west-2")

# Debugging-related

SHOULD_MOCK_AI_RESPONSE = bool(os.environ.get("MOCK", False))
Expand Down
34 changes: 25 additions & 9 deletions backend/evals/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import os

from config import AWS_REGION_NAME
from config import AWS_ACCESS_KEY
from config import AWS_SECRET_ACCESS_KEY
from config import ANTHROPIC_API_KEY

from llm import Llm, stream_claude_response, stream_openai_response
from llm import Llm, stream_claude_response, stream_openai_response, stream_claude_response_aws_bedrock
from prompts import assemble_prompt
from prompts.types import Stack

Expand All @@ -10,20 +14,32 @@ async def generate_code_core(image_url: str, stack: Stack, model: Llm) -> str:
prompt_messages = assemble_prompt(image_url, stack)
openai_api_key = os.environ.get("OPENAI_API_KEY")
anthropic_api_key = ANTHROPIC_API_KEY
aws_access_key = AWS_ACCESS_KEY
aws_secret_access_key = AWS_SECRET_ACCESS_KEY
aws_region_name = AWS_REGION_NAME,
openai_base_url = None

async def process_chunk(content: str):
pass

if model == Llm.CLAUDE_3_SONNET:
if not anthropic_api_key:
raise Exception("Anthropic API key not found")

completion = await stream_claude_response(
prompt_messages,
api_key=anthropic_api_key,
callback=lambda x: process_chunk(x),
)
if not anthropic_api_key and not aws_access_key and not aws_secret_access_key:
raise Exception("Anthropic API key or AWS Access Key not found")

if anthropic_api_key:
completion = await stream_claude_response(
prompt_messages,
api_key=anthropic_api_key,
callback=lambda x: process_chunk(x),
)
else:
completion = await stream_claude_response_aws_bedrock(
prompt_messages,
access_key=aws_access_key,
secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
callback=lambda x: process_chunk(x),
)
else:
if not openai_api_key:
raise Exception("OpenAI API key not found")
Expand Down
187 changes: 186 additions & 1 deletion backend/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk
from config import IS_DEBUG_ENABLED
from debug.DebugFileWriter import DebugFileWriter

import json
import boto3
from typing import List
from botocore.exceptions import ClientError
from utils import pprint_prompt


Expand All @@ -15,6 +18,7 @@ class Llm(Enum):
GPT_4_TURBO_2024_04_09 = "gpt-4-turbo-2024-04-09"
GPT_4O_2024_05_13 = "gpt-4o-2024-05-13"
CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
CLAUDE_3_SONNET_BEDROCK = "anthropic.claude-3-sonnet-20240229-v1:0"
CLAUDE_3_OPUS = "claude-3-opus-20240229"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"

Expand Down Expand Up @@ -128,6 +132,128 @@ async def stream_claude_response(

return response.content[0].text

def initialize_bedrock_client(access_key: str, secret_access_key: str, aws_region_name: str):
try:
# Initialize the Bedrock Runtime client
bedrock_runtime = boto3.client(
service_name='bedrock-runtime',
aws_access_key_id=access_key,
aws_secret_access_key=secret_access_key,
region_name=aws_region_name,
)
return bedrock_runtime
except ClientError as err:
message = err.response["Error"]["Message"]
print(f"A client error occurred: {message}")
except Exception as err:
print("An error occurred!")
raise err

async def stream_bedrock_response(
bedrock_runtime,
messages: List[dict],
system_prompt: str,
model_id: str,
max_tokens: int,
content_type: str,
accept: str,
temperature: float,
callback: Callable[[str], Awaitable[None]],
) -> str:
try:
# Prepare the request body
body = json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": max_tokens,
"messages": messages,
"system": system_prompt,
"temperature": temperature
})

# Invoke the Bedrock Runtime API with response stream
response = bedrock_runtime.invoke_model_with_response_stream(
body=body,
modelId=model_id,
accept=accept,
contentType=content_type,
)
stream = response.get("body")

# Stream the response
final_message = ""
if stream:
for event in stream:
chunk = event.get("chunk")
if chunk:
data = chunk.get("bytes").decode()
chunk_obj = json.loads(data)
if chunk_obj["type"] == "content_block_delta":
text = chunk_obj["delta"]["text"]
await callback(text)
final_message += text

return final_message

except ClientError as err:
message = err.response["Error"]["Message"]
print(f"A client error occurred: {message}")
except Exception as err:
print("An error occurred!")
raise err

async def stream_claude_response_aws_bedrock(
messages: List[dict],
access_key: str,
secret_access_key: str,
aws_region_name: str,
callback: Callable[[str], Awaitable[None]],
) -> str:
bedrock_runtime = initialize_bedrock_client(access_key, secret_access_key, aws_region_name)

# Set model parameters
model_id = Llm.CLAUDE_3_SONNET_BEDROCK.value
max_tokens = 4096
content_type = 'application/json'
accept = '*/*'
temperature = 0.0

# Translate OpenAI messages to Claude messages
system_prompt = cast(str, messages[0].get("content"))
claude_messages = [dict(message) for message in messages[1:]]
for message in claude_messages:
if not isinstance(message["content"], list):
continue

for content in message["content"]: # type: ignore
if content["type"] == "image_url":
content["type"] = "image"

# Extract base64 data and media type from data URL
# Example base64 data URL: data:image/png;base64,iVBOR...
image_data_url = cast(str, content["image_url"]["url"])
media_type = image_data_url.split(";")[0].split(":")[1]
base64_data = image_data_url.split(",")[1]

# Remove OpenAI parameter
del content["image_url"]

content["source"] = {
"type": "base64",
"media_type": media_type,
"data": base64_data,
}

return await stream_bedrock_response(
bedrock_runtime,
claude_messages,
system_prompt,
model_id,
max_tokens,
content_type,
accept,
temperature,
callback,
)

async def stream_claude_response_native(
system_prompt: str,
Expand Down Expand Up @@ -216,3 +342,62 @@ async def stream_claude_response_native(
raise Exception("No HTML response found in AI response")
else:
return response.content[0].text

async def stream_claude_response_native_aws_bedrock(
system_prompt: str,
messages: list[Any],
access_key: str,
secret_access_key: str,
aws_region_name: str,
callback: Callable[[str], Awaitable[None]],
include_thinking: bool = False,
model: Llm = Llm.CLAUDE_3_SONNET_BEDROCK,
) -> str:
bedrock_runtime = initialize_bedrock_client(access_key, secret_access_key, aws_region_name)

# Set model parameters
model_id = Llm.CLAUDE_3_SONNET_BEDROCK.value
max_tokens = 4096
content_type = 'application/json'
accept = '*/*'
temperature = 0.0

# Multi-pass flow
current_pass_num = 1
max_passes = 2

prefix = "<thinking>"
response = None

while current_pass_num <= max_passes:
current_pass_num += 1

# Set up message depending on whether we have a <thinking> prefix
messages_to_send = (
messages + [{"role": "assistant", "content": prefix}]
if include_thinking
else messages
)

response_text = await stream_bedrock_response(
bedrock_runtime,
messages_to_send,
system_prompt,
model_id,
max_tokens,
content_type,
accept,
temperature,
callback,
)

# Set up messages array for next pass
messages += [
{"role": "assistant", "content": str(prefix) + response_text},
{
"role": "user",
"content": "You've done a good job with a first draft. Improve this further based on the original instructions so that the app is fully functional and looks like the original video of the app we're trying to replicate.",
},
]

return response_text
81 changes: 50 additions & 31 deletions backend/routes/generate_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
import traceback
from fastapi import APIRouter, WebSocket
import openai
from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE
from config import ANTHROPIC_API_KEY, IS_PROD, SHOULD_MOCK_AI_RESPONSE, AWS_ACCESS_KEY, AWS_SECRET_ACCESS_KEY, AWS_REGION_NAME
from custom_types import InputMode
from llm import (
Llm,
convert_frontend_str_to_llm,
stream_claude_response,
stream_claude_response_native,
stream_openai_response,
stream_openai_response, stream_claude_response_aws_bedrock, stream_claude_response_native_aws_bedrock,
)
from openai.types.chat import ChatCompletionMessageParam
from mock_llm import mock_completion
Expand All @@ -25,7 +25,6 @@
from video.utils import extract_tag_content, assemble_claude_prompt_video
from ws.constants import APP_ERROR_WEB_SOCKET_CODE # type: ignore


router = APIRouter()


Expand Down Expand Up @@ -55,7 +54,7 @@ async def stream_code(websocket: WebSocket):
print("Incoming websocket connection...")

async def throw_error(
message: str,
message: str,
):
await websocket.send_json({"type": "error", "value": message})
await websocket.close(APP_ERROR_WEB_SOCKET_CODE)
Expand Down Expand Up @@ -230,33 +229,53 @@ async def process_chunk(content: str):
else:
try:
if validated_input_mode == "video":
if not anthropic_api_key:
await throw_error(
"Video only works with Anthropic models. No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env or in the settings dialog"
)
raise Exception("No Anthropic key")

completion = await stream_claude_response_native(
system_prompt=VIDEO_PROMPT,
messages=prompt_messages, # type: ignore
api_key=anthropic_api_key,
callback=lambda x: process_chunk(x),
model=Llm.CLAUDE_3_OPUS,
include_thinking=True,
)
exact_llm_version = Llm.CLAUDE_3_OPUS
elif code_generation_model == Llm.CLAUDE_3_SONNET:
if not anthropic_api_key:
await throw_error(
"No Anthropic API key found. Please add the environment variable ANTHROPIC_API_KEY to backend/.env or in the settings dialog"
)
raise Exception("No Anthropic key")

completion = await stream_claude_response(
prompt_messages, # type: ignore
api_key=anthropic_api_key,
callback=lambda x: process_chunk(x),
)
if not anthropic_api_key and not AWS_ACCESS_KEY and not AWS_SECRET_ACCESS_KEY:
await throw_error(
"Video only works with Anthropic models. Neither Anthropic API key or AWS Access Key found. Please add the environment variable ANTHROPIC_API_KEY or AWS_ACCESS_KEY/AWS_SECRET_ACCESS_KEY to backend/.env or in the settings dialog"
)
raise Exception("No Anthropic key")

if anthropic_api_key:
completion = await stream_claude_response_native(
system_prompt=VIDEO_PROMPT,
messages=prompt_messages, # type: ignore
api_key=anthropic_api_key,
callback=lambda x: process_chunk(x),
model=Llm.CLAUDE_3_OPUS,
include_thinking=True,
)
else:
completion = await stream_claude_response_native_aws_bedrock(
system_prompt=VIDEO_PROMPT,
messages=prompt_messages, # type: ignore
access_key=AWS_ACCESS_KEY,
secret_access_key=AWS_SECRET_ACCESS_KEY,
aws_region_name=AWS_REGION_NAME,
callback=lambda x: process_chunk(x),
model=Llm.CLAUDE_3_OPUS,
include_thinking=True,
)
exact_llm_version = Llm.CLAUDE_3_OPUS
elif code_generation_model == Llm.CLAUDE_3_SONNET:
if not anthropic_api_key and not AWS_ACCESS_KEY and not AWS_SECRET_ACCESS_KEY:
await throw_error(
"No Anthropic API key or AWS Access Key found. Please add the environment variable ANTHROPIC_API_KEY or AWS_ACCESS_KEY/AWS_SECRET_ACCESS_KEY to backend/.env or in the settings dialog"
)
raise Exception("No Anthropic key")
if anthropic_api_key:
completion = await stream_claude_response(
prompt_messages, # type: ignore
api_key=anthropic_api_key,
callback=lambda x: process_chunk(x),
)
else:
completion = await stream_claude_response_aws_bedrock(
prompt_messages, # type: ignore
access_key=AWS_ACCESS_KEY,
secret_access_key=AWS_SECRET_ACCESS_KEY,
aws_region_name=AWS_REGION_NAME,
callback=lambda x: process_chunk(x),
)
exact_llm_version = code_generation_model
else:
completion = await stream_openai_response(
Expand Down