Skip to content

Commit

Permalink
FEAT: Add Centralized DB Support Using Azure (Azure#379)
Browse files Browse the repository at this point in the history
Co-authored-by: rdheekonda <[email protected]>
  • Loading branch information
rdheekonda and rdheekonda authored Sep 23, 2024
1 parent 739d896 commit a069c88
Show file tree
Hide file tree
Showing 66 changed files with 2,350 additions and 465 deletions.
4 changes: 3 additions & 1 deletion .env_example
Original file line number Diff line number Diff line change
Expand Up @@ -129,5 +129,7 @@ AZURE_SQL_SERVER_CONNECTION_STRING="<Provide DB Azure SQL Server connection stri
# Crucible API Key. You can get yours at: https://crucible.dreadnode.io/login
CRUCIBLE_API_KEY = "<Provide Crucible API key here>"

# Azure SQL Server Connection String
# Azure SQL Server Connection String and Azure Storage Account for storing blob objects
AZURE_SQL_DB_CONNECTION_STRING = "<Provide Azure SQL DB connection string here in SQLAlchemy format>"
AZURE_STORAGE_ACCOUNT_RESULTS_CONTAINER_URL="<Azure Storage Account results container url>"
AZURE_STORAGE_ACCOUNT_RESULTS_SAS_TOKEN="<Azure Storage Account Results Container SAS URL>"
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ repos:
exclude: NOTICE.txt
- id: check-yaml
- id: check-added-large-files
args: ["--maxkb=3072"] # Set limit to 3072 KB (3 MB) for displaying images in notebooks
- id: detect-private-key

# https://black.readthedocs.io/en/stable/integrations/source_version_control.html
Expand Down
120 changes: 120 additions & 0 deletions doc/code/memory/1_duck_db_memory.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "ec85323a",
"metadata": {},
"source": [
"# Memory\n",
"\n",
"The memory DuckDB database can be thought of as a normalized source of truth. The memory module is the primary way pyrit keeps track of requests and responses to targets and scores. Most of this is done automatically. All Prompt Targets write to memory for later retrieval. All scorers also write to memory when scoring.\n",
"\n",
"The schema is found in `memory_models.py` and can be programatically viewed as follows"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "cb83267e",
"metadata": {
"execution": {
"iopub.execute_input": "2024-09-22T20:26:45.105804Z",
"iopub.status.busy": "2024-09-22T20:26:45.105804Z",
"iopub.status.idle": "2024-09-22T20:26:56.398392Z",
"shell.execute_reply": "2024-09-22T20:26:56.392296Z"
},
"lines_to_next_cell": 2
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\rdheekonda\\AppData\\Local\\anaconda3\\envs\\pyrit-dev\\lib\\site-packages\\duckdb_engine\\__init__.py:565: SAWarning: Did not recognize type 'list' of column 'embedding'\n",
" columns = self._get_columns_info(rows, domains, enums, schema) # type: ignore[attr-defined]\n",
"C:\\Users\\rdheekonda\\AppData\\Local\\anaconda3\\envs\\pyrit-dev\\lib\\site-packages\\duckdb_engine\\__init__.py:180: DuckDBEngineWarning: duckdb-engine doesn't yet support reflection on indices\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Schema for EmbeddingData:\n",
" Column id (UUID)\n",
" Column embedding (NULL)\n",
" Column embedding_type_name (VARCHAR)\n",
"Schema for PromptMemoryEntries:\n",
" Column id (UUID)\n",
" Column role (VARCHAR)\n",
" Column conversation_id (VARCHAR)\n",
" Column sequence (INTEGER)\n",
" Column timestamp (TIMESTAMP)\n",
" Column labels (VARCHAR)\n",
" Column prompt_metadata (VARCHAR)\n",
" Column converter_identifiers (VARCHAR)\n",
" Column prompt_target_identifier (VARCHAR)\n",
" Column orchestrator_identifier (VARCHAR)\n",
" Column response_error (VARCHAR)\n",
" Column original_value_data_type (VARCHAR)\n",
" Column original_value (VARCHAR)\n",
" Column original_value_sha256 (VARCHAR)\n",
" Column converted_value_data_type (VARCHAR)\n",
" Column converted_value (VARCHAR)\n",
" Column converted_value_sha256 (VARCHAR)\n",
"Schema for ScoreEntries:\n",
" Column id (UUID)\n",
" Column score_value (VARCHAR)\n",
" Column score_value_description (VARCHAR)\n",
" Column score_type (VARCHAR)\n",
" Column score_category (VARCHAR)\n",
" Column score_rationale (VARCHAR)\n",
" Column score_metadata (VARCHAR)\n",
" Column scorer_class_identifier (VARCHAR)\n",
" Column prompt_request_response_id (UUID)\n",
" Column timestamp (TIMESTAMP)\n",
" Column task (VARCHAR)\n"
]
}
],
"source": [
"from pyrit.memory import DuckDBMemory\n",
"\n",
"memory = DuckDBMemory()\n",
"memory.print_schema()\n",
"\n",
"memory.dispose_engine()"
]
},
{
"cell_type": "markdown",
"id": "db45c4c0",
"metadata": {},
"source": []
}
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all"
},
"kernelspec": {
"display_name": "pyrit-311",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
File renamed without changes.
231 changes: 231 additions & 0 deletions doc/code/memory/6_azure_sql_memory.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "7b7fb6ee",
"metadata": {},
"source": [
"# Azure SQL Memory\n",
"\n",
"The memory AzureSQL database can be thought of as a normalized source of truth. The memory module is the primary way pyrit keeps track of requests and responses to targets and scores. Most of this is done automatically. All orchestrators write to memory for later retrieval. All scorers also write to memory when scoring.\n",
"\n",
"The schema is found in `memory_models.py` and can be programatically viewed as follows\n",
"\n",
"### Azure Login\n",
"\n",
"PyRIT `AzureSQLMemory` supports only **Azure Entra ID authentication** at this time. User ID/password-based login is not available.\n",
"\n",
"Please log in to your Azure account before running this notebook:\n",
"\n",
"- Use the default login:\n",
" ```bash\n",
" az login\n",
" ```\n",
"- Or, use device code login\n",
" ```bash\n",
" az login --use-device-code\n",
" ```\n",
"\n",
"### Environment Variables\n",
"\n",
"Please set the following environment variables to run AzureSQLMemory interactions:\n",
"\n",
"- `AZURE_SQL_DB_CONNECTION_STRING` = \"<Azure SQL DB connection string here in SQLAlchemy format>\"\n",
"- `AZURE_STORAGE_ACCOUNT_RESULTS_CONTAINER_URL` = \"<Azure Storage Account results container URL>\" (which uses delegation SAS) but needs login to Azure.\n",
"\n",
"To use regular key-based authentication, please also set:\n",
"\n",
"- `AZURE_STORAGE_ACCOUNT_RESULTS_SAS_TOKEN`\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "b63619c4",
"metadata": {
"execution": {
"iopub.execute_input": "2024-09-22T20:29:01.611052Z",
"iopub.status.busy": "2024-09-22T20:29:01.611052Z",
"iopub.status.idle": "2024-09-22T20:29:21.644356Z",
"shell.execute_reply": "2024-09-22T20:29:21.641848Z"
},
"lines_to_next_cell": 2
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Schema for EmbeddingData:\n",
" Column id (UNIQUEIDENTIFIER)\n",
" Column embedding (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column embedding_type_name (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
"Schema for PromptMemoryEntries:\n",
" Column id (UNIQUEIDENTIFIER)\n",
" Column role (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column conversation_id (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column sequence (INTEGER)\n",
" Column timestamp (DATETIME)\n",
" Column labels (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column prompt_metadata (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column converter_identifiers (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column prompt_target_identifier (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column orchestrator_identifier (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column response_error (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column original_value_data_type (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column original_value (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column original_value_sha256 (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column converted_value_data_type (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column converted_value (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column converted_value_sha256 (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
"Schema for Prompts:\n",
" Column id (UNIQUEIDENTIFIER)\n",
" Column value (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column data_type (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column name (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column dataset_name (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column harm_categories (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column description (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column authors (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column groups (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column source (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column date_added (DATETIME)\n",
" Column added_by (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column prompt_metadata (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column parameters (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column prompt_group_id (UNIQUEIDENTIFIER)\n",
" Column sequence (INTEGER)\n",
"Schema for ScoreEntries:\n",
" Column id (UNIQUEIDENTIFIER)\n",
" Column score_value (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column score_value_description (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column score_type (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column score_category (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column score_rationale (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column score_metadata (VARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column scorer_class_identifier (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n",
" Column prompt_request_response_id (UNIQUEIDENTIFIER)\n",
" Column timestamp (DATETIME)\n",
" Column task (NVARCHAR COLLATE \"SQL_Latin1_General_CP1_CI_AS\")\n"
]
}
],
"source": [
"from pyrit.memory import AzureSQLMemory\n",
"\n",
"\n",
"memory = AzureSQLMemory()\n",
"\n",
"memory.print_schema()"
]
},
{
"cell_type": "markdown",
"id": "68e7c0fe",
"metadata": {},
"source": [
"## Basic Azure SQL Memory Programming Usage\n",
"\n",
"The `pyrit.memory.azure_sql_memory` module provides functionality to keep track of the conversation history, scoring, data, and more using Azure SQL. You can use memory to read and write data. Here is an example that retrieves a normalized conversation:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "de1a2abb",
"metadata": {
"execution": {
"iopub.execute_input": "2024-09-22T20:29:21.650035Z",
"iopub.status.busy": "2024-09-22T20:29:21.650035Z",
"iopub.status.idle": "2024-09-22T20:29:23.983721Z",
"shell.execute_reply": "2024-09-22T20:29:23.983721Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
": user: Hi, chat bot! This is my initial prompt.\n",
": assistant: Nice to meet you! This is my response.\n",
": user: Wonderful! This is my second prompt to the chat bot!\n"
]
}
],
"source": [
"from uuid import uuid4\n",
"from pyrit.memory import AzureSQLMemory\n",
"from pyrit.models import PromptRequestPiece, PromptRequestResponse\n",
"\n",
"\n",
"conversation_id = str(uuid4())\n",
"\n",
"message_list = [\n",
" PromptRequestPiece(\n",
" role=\"user\", original_value=\"Hi, chat bot! This is my initial prompt.\", conversation_id=conversation_id\n",
" ),\n",
" PromptRequestPiece(\n",
" role=\"assistant\", original_value=\"Nice to meet you! This is my response.\", conversation_id=conversation_id\n",
" ),\n",
" PromptRequestPiece(\n",
" role=\"user\",\n",
" original_value=\"Wonderful! This is my second prompt to the chat bot!\",\n",
" conversation_id=conversation_id,\n",
" ),\n",
"]\n",
"\n",
"memory = AzureSQLMemory()\n",
"\n",
"memory.add_request_response_to_memory(request=PromptRequestResponse([message_list[0]]))\n",
"memory.add_request_response_to_memory(request=PromptRequestResponse([message_list[1]]))\n",
"memory.add_request_response_to_memory(request=PromptRequestResponse([message_list[2]]))\n",
"\n",
"\n",
"entries = memory.get_conversation(conversation_id=conversation_id)\n",
"\n",
"for entry in entries:\n",
" print(entry)\n",
"\n",
"\n",
"# Cleanup memory resources\n",
"memory.dispose_engine()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "72f20a96",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"\n"
]
}
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading

0 comments on commit a069c88

Please sign in to comment.