Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.DS_Store
lab2/assets/layers/python/
254 changes: 154 additions & 100 deletions lab2/assets/lambda/db_connections/prompt_lambda/prompt_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,66 @@
#########################
# LIBRARIES & LOGGER
#########################
import ast
import json
import logging
import os
import sys
from datetime import datetime, timezone

import boto3
import boto3.dynamodb
import boto3.dynamodb.table
import boto3.dynamodb.types

LOGGER = logging.Logger("DDB LAMBDA", level=logging.DEBUG)
HANDLER = logging.StreamHandler(sys.stdout)
HANDLER.setFormatter(logging.Formatter("%(levelname)s | %(name)s | %(message)s"))
HANDLER.setFormatter(logging.Formatter(
"%(levelname)s | %(name)s | %(message)s"))
LOGGER.addHandler(HANDLER)

MODELS_MAPPING = {
"Bedrock: Amazon Titan": "amazon.titan-text-express-v1",
"Bedrock: Claude V2": "anthropic.claude-v2",
"Bedrock: Claude 3 Sonnet": "anthropic.claude-3-sonnet-20240229-v1:0"
}


def create_bedrock_agent_client():
"""
Creates a Bedrock Agent client using the specified region and configuration.

Returns:
A tuple containing the Bedrock client and the expiration time (which is None).
"""
LOGGER.info("Using bedrock agent client from same account.")
bedrock_client = boto3.client(
service_name="bedrock-agent",
region_name=os.environ["BEDROCK_REGION"]
)
expiration = None
LOGGER.info("Successfully set bedrock agent client")

return bedrock_client, expiration


BEDROCK_AGENT_CLIENT, EXPIRATION = create_bedrock_agent_client()


def verify_bedrock_agent_client():
"""
Verifies the Bedrock client by checking if the token has expired or not.

Returns:
bool: True if the Bedrock client is verified, False otherwise.
"""
if EXPIRATION is not None:
now = datetime.now(timezone.utc)
LOGGER.info(
f"Bedrock token expires in {(EXPIRATION - now).total_seconds()}s")
if (EXPIRATION - now).total_seconds() < 60:
return False
return True


def remove_dynamodb_type_descriptors(item):
return {k: list(v.values())[0] for k, v in item.items()}
Expand All @@ -33,8 +80,9 @@ def lambda_handler(event, context):
Lambda handler
"""
LOGGER.info("Starting execution of lambda_handler()")
LOGGER.info(f"Boto version: {boto3.__version__}")

### PREPARATIONS
# PREPARATIONS
# Convert the 'body' string to a dictionary
LOGGER.info(f"The incoming payload event:{event}")

Expand All @@ -52,20 +100,53 @@ def lambda_handler(event, context):
table_name = os.environ["TABLE_NAME"]
LOGGER.info("boto3 dynamo established!")

if not verify_bedrock_agent_client():
LOGGER.info("Bedrock agent client expired, will refresh token.")
global BEDROCK_AGENT_CLIENT, EXPIRATION
BEDROCK_AGENT_CLIENT, EXPIRATION = create_bedrock_agent_client()

if payload_type == "PUT":
# Extract the 'item' dictionary from the body
item = body.get("item", {})
item = body.get("item")
# get fixed model params
MODEL_ID = MODELS_MAPPING[item.get("model")]
LOGGER.info(f"MODEL_ID: {MODEL_ID}")
# create the prompt in bedrock
created_prompt = BEDROCK_AGENT_CLIENT.create_prompt(
name=f"{item.get('session_id')}",
variants=[
{
"name": "variant-001",
"modelId": MODEL_ID,
"templateType": "TEXT",
"inferenceConfiguration": {
"text": {
"temperature": float(item.get("temperature")),
"maxTokens": int(item.get("answer_length")),
}
},
"templateConfiguration": {
"text": {
"text": item.get("Prompt Template")
}
}
}
],
defaultVariant="variant-001"
)
prompt_id = created_prompt.get("id")
# Create a version
versioned_prompt = BEDROCK_AGENT_CLIENT.create_prompt_version(
promptIdentifier=prompt_id
)

# Construct the item data as a dictionary
item_data = {
"session_id": {"S": item.get("session_id", "")},
"user_id": {"S": item.get("user_id", "")},
"timestamp": {"S": item.get("timestamp", "")},
"model": {"S": item.get("model", "")},
"answer_length": {"N": item.get("answer_length", "")},
"temperature": {"N": item.get("temperature", "")},
"Prompt Template": {"S": item.get("Prompt Template", "")},
"Prompt": {"S": item.get("Prompt", "")},
"Output": {"S": item.get("Output", "")},
"user_id": {"S": item.get("user_id")},
"prompt_id": {"S": prompt_id},
"prompt_version": {"S": versioned_prompt.get("version")},
"prompt": {"S": item.get("Prompt")},
"output": {"S": item.get("Output")}
}
LOGGER.info(f"Item data: {item_data}")

Expand All @@ -82,113 +163,86 @@ def lambda_handler(event, context):
if payload_type == "GET":
# Extract filter parameters
filter_params = body.get("filter_params", {})
print(filter_params)
user_ids_filter_str = filter_params.get("user_ids_filter", "")
if user_ids_filter_str != "":
user_ids_filter = ast.literal_eval(user_ids_filter_str)
else:
user_ids_filter = []

# retrieving filter values
industry_filter = filter_params.get("industry_filter")
language_filter = filter_params.get("language_filter")
task_filter = filter_params.get("task_filter")
technique_filter = filter_params.get("technique_filter")
ai_model_filter = filter_params.get("ai_model_filter")

user_id = filter_params.get("user_id")
user_ids_filter.append(user_id)
# Create FilterExpression and ExpressionAttributeValues
filter_expressions = []
expression_attribute_values = {}
# Use an expression attribute names dictionary to handle reserved words
expression_attribute_names = {}

# Incorporating user_ids_filter for possiblle filter expansion in UI
if user_ids_filter:
user_id_placeholders = [f":user_id_{i}" for i, _ in enumerate(user_ids_filter)]
filter_expressions.append(f"user_id IN ({', '.join(user_id_placeholders)})")
for placeholder, user_id in zip(user_id_placeholders, user_ids_filter):
expression_attribute_values[placeholder] = {"S": user_id}

if industry_filter and industry_filter != "ALL":
filter_expressions.append("Industry = :industry")
expression_attribute_values[":industry"] = {"S": industry_filter}

# Incorporating language_filter
if language_filter and language_filter != "ALL":
filter_expressions.append("#Language = :language")
expression_attribute_values[":language"] = {"S": language_filter}
expression_attribute_names["#Language"] = "Language"

# Incorporating task_filter
if task_filter and task_filter != "ALL":
filter_expressions.append("Task = :task")
expression_attribute_values[":task"] = {"S": task_filter}

# Incorporating technique_filter
if technique_filter and technique_filter != "ALL":
filter_expressions.append("Technique = :technique")
expression_attribute_values[":technique"] = {"S": technique_filter}

# Incorporating ai_model_filter
if ai_model_filter and ai_model_filter != "ALL":
filter_expressions.append("model = :model")
expression_attribute_values[":model"] = {"S": ai_model_filter}

filter_expression_string = " AND ".join(filter_expressions)
print(f"FilterExpressions = {filter_expression_string}")
print(f"Expression Attributes: {expression_attribute_values}")
ai_model_filter = filter_params.get("ai_model_filter")
MODEL_ID = MODELS_MAPPING[ai_model_filter]
LOGGER.info(f"user_id: {user_id}, ai_model_filter: {MODEL_ID}")

try:
if expression_attribute_names != {}:
response = dynamodb.scan(
TableName=table_name,
Limit=1000,
FilterExpression=filter_expression_string if filter_expressions else None,
ExpressionAttributeValues=expression_attribute_values if expression_attribute_values else None,
ExpressionAttributeNames=expression_attribute_names if expression_attribute_names else None,
)
else:
# Use scan with FilterExpression to get items. Limit to 1000 items.
response = dynamodb.scan(
TableName=table_name,
Limit=1000,
FilterExpression=filter_expression_string if filter_expressions else None,
ExpressionAttributeValues=expression_attribute_values if expression_attribute_values else None,
)

items = response.get("Items", {})
# query all items from dynamodb using the user_id partition
response = dynamodb.query(
TableName=table_name,
KeyConditionExpression="user_id = :user_id",
ExpressionAttributeValues={
":user_id": {"S": user_id}
}
)

items = response.get("Items")
LOGGER.info(f"Retrieved {len(items)} items")
clean_item_list = []
for item in items:
clean_item = remove_dynamodb_type_descriptors(item)
clean_item_list.append(clean_item)
LOGGER.info(f"Items {items}")

# iterate through prompt_ids, user_ids, prompt_versions to get the prompts and put them in a json object
prompts = []
for item in items:
# get the prompt from bedrock agent using the prompt_id and version
prompt_id = item.get("prompt_id").get("S")
prompt_version = item.get("prompt_version").get("S")
bedrock_prompt_catalog = BEDROCK_AGENT_CLIENT.get_prompt(
promptIdentifier=prompt_id, promptVersion=prompt_version)
LOGGER.info(
f"bedrock_prompt_catalog: {json.dumps(bedrock_prompt_catalog, indent=4, sort_keys=True, default=str)}")
defaultVariants = [variant for variant in bedrock_prompt_catalog['variants'] if variant["name"]
== bedrock_prompt_catalog["defaultVariant"] and variant["modelId"] == MODEL_ID]
# when array is 0 continue to next else extract default variant
if len(defaultVariants) == 0:
continue
defaultVariant = defaultVariants[0]
prompt_template_string = defaultVariant['templateConfiguration']['text']['text']
prompts.append({
"Prompt Template": prompt_template_string,
"Prompt": item.get("prompt").get("S"),
"Output": item.get("output").get("S"),
"model": ai_model_filter,
"answer_length": defaultVariant['inferenceConfiguration']['text']['maxTokens'],
"temperature": defaultVariant['inferenceConfiguration']['text']['temperature'],
})

LOGGER.info(
f"bedrock_prompt: {json.dumps(prompts)}")
# Return items in response
return {"statusCode": 200, "body": json.dumps(clean_item_list)}
return {"statusCode": 200, "body": json.dumps(prompts)}
except Exception as e:
LOGGER.error(f"Error retrieving items: {e}")
return {"statusCode": 500, "body": json.dumps(f"Error retrieving items from DynamoDB: {str(e)}")}

elif payload_type == "DELETE":
# Extract the 'session_id' from the body
print("made it into the delete section! ")
session_id = body["item"]["session_id"]
LOGGER.debug(f"session_id {session_id}, {type(session_id)}")
user_id = body["item"]["user_id"]
prompt_id = body["item"]["prompt_id"]
prompt_version = body["item"]["prompt_version"]

if not session_id:
LOGGER.error("session_id is required for DELETE operation")
return {"statusCode": 400, "body": json.dumps("session_id is required for DELETE operation")}
if not user_id:
LOGGER.error("user_id is required for DELETE operation")
return {"statusCode": 400, "body": json.dumps("user_id is required for DELETE operation")}

try:
response = dynamodb.delete_item(
bedrock_prompt_catalog = BEDROCK_AGENT_CLIENT.delete_prompt_version(
promptId=prompt_id,
version=prompt_version
)
dynamodb_response = dynamodb.delete_item(
TableName=table_name,
Key={"session_id": {"S": session_id}}, # "S" indicates that the datatype is a string.
# "S" indicates that the datatype is a string.
Key={"user_id": {"S": user_id}, "prompt_id": {"S": prompt_id}},
)

LOGGER.info(f"Deleted item with session_id: {session_id}")
LOGGER.info(
f"Deleted item with user_id: {user_id} and prompt_id {prompt_id}")

return {"statusCode": 200, "body": json.dumps(f"Item with session_id {session_id} successfully deleted")}
return {"statusCode": 200, "body": json.dumps(f"Item with user_id {user_id} and prompt_id {prompt_id} successfully deleted")}
except Exception as e:
LOGGER.error(f"Error deleting item: {e}")
return {"statusCode": 500, "body": json.dumps(f"Error deleting item from DynamoDB: {str(e)}")}
Expand Down
25 changes: 12 additions & 13 deletions lab2/infra/bdrk_reinvent_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@

from typing import Any, Dict

from aws_cdk import Aws
from aws_cdk import CfnOutput as output
from aws_cdk import RemovalPolicy, Stack, Tags

from aws_cdk import Stack, Tags
from constructs import Construct

from infra.constructs.bdrk_reinvent_api import bdrk_reinventAPIConstructs
from infra.constructs.bdrk_reinvent_layers import bdrk_reinventLambdaLayers
from infra.stacks.bdrk_reinvent_streamlit import bdrk_reinventStreamlitStack
Expand All @@ -25,27 +22,28 @@ class bdrkReinventStack(Stack):
def __init__(self, scope: Construct, stack_name: str, config: Dict[str, Any], **kwargs) -> None: # noqa: C901
super().__init__(scope, stack_name, **kwargs)

## **************** Lambda layers ****************
# **************** Lambda layers ****************

self.layers = bdrk_reinventLambdaLayers(self, f"{stack_name}-layers", stack_name=stack_name)
self.layers = bdrk_reinventLambdaLayers(
self, f"{stack_name}-layers", stack_name=stack_name)

## ********** Bedrock configs ***********
# ********** Bedrock configs ***********
bedrock_region = kwargs["env"].region
bedrock_role_arn = None
bedrock_role_arn = ""

if "bedrock" in config:
if "region" in config["bedrock"]:
bedrock_region = (
kwargs["env"].region if config["bedrock"]["region"] == "None" else config["bedrock"]["region"]
)

## ********** Authentication configs ***********
# ********** Authentication configs ***********
mfa_enabled = True
if "authentication" in config:
if "MFA" in config["authentication"]:
mfa_enabled = config["authentication"]["MFA"]

## **************** API Constructs ****************
# **************** API Constructs ****************
self.api_constructs = bdrk_reinventAPIConstructs(
self,
f"{stack_name}-API",
Expand All @@ -56,7 +54,7 @@ def __init__(self, scope: Construct, stack_name: str, config: Dict[str, Any], **
mfa_enabled=mfa_enabled,
)

## **************** Streamlit NestedStack ****************
# **************** Streamlit NestedStack ****************
if config["streamlit"]["deploy_streamlit"]:
self.streamlit_constructs = bdrk_reinventStreamlitStack(
self,
Expand All @@ -70,7 +68,8 @@ def __init__(self, scope: Construct, stack_name: str, config: Dict[str, Any], **
cover_image_login_url=config["streamlit"]["cover_image_login_url"],
assistant_avatar=config["streamlit"]["assistant_avatar"],
open_to_public_internet=config["streamlit"]["open_to_public_internet"],
ip_address_allowed=config["streamlit"].get("ip_address_allowed"),
ip_address_allowed=config["streamlit"].get(
"ip_address_allowed"),
custom_header_name=config["cloudfront"]["custom_header_name"],
custom_header_value=config["cloudfront"]["custom_header_value"],
)
Expand All @@ -87,6 +86,6 @@ def __init__(self, scope: Construct, stack_name: str, config: Dict[str, Any], **
value=self.streamlit_constructs.cloudfront.domain_name,
)

## **************** Tags ****************
# **************** Tags ****************
Tags.of(self).add("StackName", stack_name)
Tags.of(self).add("Team", "Bedrock Workshop")
Loading