Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Updating memory and fixing bugs #394

Merged
merged 1 commit into from
Sep 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/code/targets/6_multi_modal_targets.py
Original file line number Diff line number Diff line change
@@ -92,7 +92,7 @@

# You can use the following to show the image
if image_location != "content blocked":
image_bytes = await azure_sql_memory._storage_io.read_file(image_location) # type: ignore
image_bytes = await azure_sql_memory.storage_io.read_file(image_location) # type: ignore

image_stream = io.BytesIO(image_bytes)
image = Image.open(image_stream)
2 changes: 1 addition & 1 deletion pyrit/common/display_response.py
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@ async def display_response(response_piece: PromptRequestPiece, memory: MemoryInt
and is_in_ipython_session()
):
image_location = response_piece.converted_value
image_bytes = await memory._storage_io.read_file(image_location)
image_bytes = await memory.storage_io.read_file(image_location)

image_stream = io.BytesIO(image_bytes)
image = Image.open(image_stream)
4 changes: 2 additions & 2 deletions pyrit/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingData
from pyrit.memory.memory_models import PromptMemoryEntry, EmbeddingDataEntry
from pyrit.memory.memory_interface import MemoryInterface

from pyrit.memory.azure_sql_memory import AzureSQLMemory
@@ -13,7 +13,7 @@
__all__ = [
"AzureSQLMemory",
"DuckDBMemory",
"EmbeddingData",
"EmbeddingDataEntry",
"MemoryInterface",
"MemoryEmbedding",
"MemoryExporter",
44 changes: 26 additions & 18 deletions pyrit/memory/azure_sql_memory.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@

from pyrit.common import default_values
from pyrit.common.singleton import Singleton
from pyrit.memory.memory_models import EmbeddingData, Base, PromptMemoryEntry, ScoreEntry
from pyrit.memory.memory_models import EmbeddingDataEntry, Base, PromptMemoryEntry, ScoreEntry
from pyrit.memory.memory_interface import MemoryInterface
from pyrit.models.prompt_request_piece import PromptRequestPiece
from pyrit.models.score import Score
@@ -49,8 +49,6 @@ def __init__(
sas_token: Optional[str] = None,
verbose: bool = False,
):
super(AzureSQLMemory, self).__init__()

self._connection_string = default_values.get_required_value(
env_var_name=self.AZURE_SQL_DB_CONNECTION_STRING, passed_value=connection_string
)
@@ -63,8 +61,6 @@ def __init__(
)
except ValueError:
self._sas_token = None # To use delegation SAS
# Handle for Azure Blob Storage when using Azure SQL memory.
self._storage_io = AzureBlobStorageIO(container_url=self._container_url, sas_token=self._sas_token)

self.results_path = self._container_url

@@ -76,6 +72,12 @@ def __init__(
self.SessionFactory = sessionmaker(bind=self.engine)
self._create_tables_if_not_exist()

super(AzureSQLMemory, self).__init__()

def _init_storage_io(self):
# Handle for Azure Blob Storage when using Azure SQL memory.
self.storage_io = AzureBlobStorageIO(container_url=self._container_url, sas_token=self._sas_token)

def _create_auth_token(self) -> AccessToken:
azure_credentials = DefaultAzureCredential()
return azure_credentials.get_token(self.TOKEN_URL)
@@ -137,13 +139,13 @@ def _create_tables_if_not_exist(self):
except Exception as e:
logger.error(f"Error during table creation: {e}")

def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingData]) -> None:
def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingDataEntry]) -> None:
"""
Inserts embedding data into memory storage
"""
self._insert_entries(entries=embedding_data)

def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: str) -> list[Base]:
def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: str) -> list[PromptRequestPiece]:
"""
Retrieves a list of PromptMemoryEntry Base objects that have the specified orchestrator ID.

@@ -152,13 +154,15 @@ def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: str) -> list[Ba
Can be retrieved by calling orchestrator.get_identifier()["id"]

Returns:
list[Base]: A list of PromptMemoryEntry Base objects matching the specified orchestrator ID.
list[PromptRequestPiece]: A list of PromptMemoryEntry Base objects matching the specified orchestrator ID.
"""
try:
sql_condition = text(
"ISJSON(orchestrator_identifier) = 1 AND JSON_VALUE(orchestrator_identifier, '$.id') = :json_id"
).bindparams(json_id=str(orchestrator_id))
result = self.query_entries(PromptMemoryEntry, conditions=sql_condition) # type: ignore
entries = self.query_entries(PromptMemoryEntry, conditions=sql_condition) # type: ignore
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]

return result
except Exception as e:
logger.exception(
@@ -177,10 +181,14 @@ def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> li
list[PromptRequestPiece]: A list of PromptRequestPieces with the specified conversation ID.
"""
try:
return self.query_entries(
entries = self.query_entries(
PromptMemoryEntry,
conditions=PromptMemoryEntry.conversation_id == conversation_id,
) # type: ignore

result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result

except Exception as e:
logger.exception(f"Failed to retrieve conversation_id {conversation_id} with error {e}")
return []
@@ -206,11 +214,11 @@ def dispose_engine(self):
self.engine.dispose()
logger.info("Engine disposed successfully.")

def get_all_embeddings(self) -> list[EmbeddingData]:
def get_all_embeddings(self) -> list[EmbeddingDataEntry]:
"""
Fetches all entries from the specified table and returns them as model instances.
"""
result = self.query_entries(EmbeddingData)
result = self.query_entries(EmbeddingDataEntry)
return result

def get_all_prompt_pieces(self) -> list[PromptRequestPiece]:
@@ -231,10 +239,12 @@ def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[str]) -> list[Prom
list[PromptRequestPiece]: A list of PromptRequestPiece with the specified conversation ID.
"""
try:
return self.query_entries(
entries = self.query_entries(
PromptMemoryEntry,
conditions=PromptMemoryEntry.id.in_(prompt_ids),
) # type: ignore
)
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result
except Exception as e:
logger.exception(
f"Unexpected error: Failed to retrieve ConversationData with orchestrator {prompt_ids}. {e}"
@@ -266,10 +276,8 @@ def get_prompt_request_piece_by_memory_labels(
# for safe parameter passing, preventing SQL injection
sql_condition = text(conditions).bindparams(**{key: str(value) for key, value in memory_labels.items()})

result: list[PromptRequestPiece] = self.query_entries(
PromptMemoryEntry, conditions=sql_condition
) # type: ignore

entries = self.query_entries(PromptMemoryEntry, conditions=sql_condition)
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result
except Exception as e:
logger.exception(
34 changes: 20 additions & 14 deletions pyrit/memory/duckdb_memory.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@
from sqlalchemy.engine.base import Engine
from contextlib import closing

from pyrit.memory.memory_models import EmbeddingData, PromptMemoryEntry, Base, ScoreEntry
from pyrit.memory.memory_models import EmbeddingDataEntry, PromptMemoryEntry, Base, ScoreEntry
from pyrit.memory.memory_interface import MemoryInterface
from pyrit.common.path import RESULTS_PATH
from pyrit.common.singleton import Singleton
@@ -46,12 +46,15 @@ def __init__(
else:
self.db_path = Path(db_path or Path(RESULTS_PATH, self.DEFAULT_DB_FILE_NAME)).resolve()
self.results_path = str(RESULTS_PATH)
# Handles disk-based storage for DuckDB local memory.
self._storage_io = DiskStorageIO()

self.engine = self._create_engine(has_echo=verbose)
self.SessionFactory = sessionmaker(bind=self.engine)
self._create_tables_if_not_exist()

def _init_storage_io(self):
# Handles disk-based storage for DuckDB local memory.
self.storage_io = DiskStorageIO()

def _create_engine(self, *, has_echo: bool) -> Engine:
"""Creates the SQLAlchemy engine for DuckDB.

@@ -91,11 +94,11 @@ def get_all_prompt_pieces(self) -> list[PromptRequestPiece]:
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result

def get_all_embeddings(self) -> list[EmbeddingData]:
def get_all_embeddings(self) -> list[EmbeddingDataEntry]:
"""
Fetches all entries from the specified table and returns them as model instances.
"""
result: list[EmbeddingData] = self.query_entries(EmbeddingData)
result: list[EmbeddingDataEntry] = self.query_entries(EmbeddingDataEntry)
return result

def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> list[PromptRequestPiece]:
@@ -109,9 +112,10 @@ def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> li
list[PromptRequestPiece]: A list of PromptRequestPieces with the specified conversation ID.
"""
try:
result: list[PromptRequestPiece] = self.query_entries(
entries = self.query_entries(
PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == conversation_id
) # type: ignore
)
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result
except Exception as e:
logger.exception(f"Failed to retrieve conversation_id {conversation_id} with error {e}")
@@ -128,10 +132,12 @@ def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[str]) -> list[Prom
list[PromptRequestPiece]: A list of PromptRequestPiece with the specified conversation ID.
"""
try:
return self.query_entries(
entries = self.query_entries(
PromptMemoryEntry,
conditions=PromptMemoryEntry.id.in_(prompt_ids),
) # type: ignore
)
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result
except Exception as e:
logger.exception(
f"Unexpected error: Failed to retrieve ConversationData with orchestrator {prompt_ids}. {e}"
@@ -156,9 +162,8 @@ def get_prompt_request_piece_by_memory_labels(
try:
conditions = [PromptMemoryEntry.labels.op("->>")(key) == value for key, value in memory_labels.items()]
query_condition = and_(*conditions)
result: list[PromptRequestPiece] = self.query_entries(
PromptMemoryEntry, conditions=query_condition
) # type: ignore
entries = self.query_entries(PromptMemoryEntry, conditions=query_condition)
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result
except Exception as e:
logger.exception(
@@ -178,10 +183,11 @@ def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: str) -> list[Pr
list[PromptRequestPiece]: A list of PromptRequestPiece objects matching the specified orchestrator ID.
"""
try:
result: list[PromptRequestPiece] = self.query_entries(
entries = self.query_entries(
PromptMemoryEntry,
conditions=PromptMemoryEntry.orchestrator_identifier.op("->>")("id") == orchestrator_id,
) # type: ignore
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result
except Exception as e:
logger.exception(
@@ -196,7 +202,7 @@ def add_request_pieces_to_memory(self, *, request_pieces: Sequence[PromptRequest
"""
self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in request_pieces])

def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingData]) -> None:
def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingDataEntry]) -> None:
"""
Inserts embedding data into memory storage
"""
6 changes: 3 additions & 3 deletions pyrit/memory/memory_embedding.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@
from typing import Optional
from pyrit.embedding import AzureTextEmbedding
from pyrit.models import PromptRequestPiece, EmbeddingSupport
from pyrit.memory.memory_models import EmbeddingData
from pyrit.memory.memory_models import EmbeddingDataEntry


class MemoryEmbedding:
@@ -21,7 +21,7 @@ def __init__(self, *, embedding_model: Optional[EmbeddingSupport]):
raise ValueError("embedding_model must be set.")
self.embedding_model = embedding_model

def generate_embedding_memory_data(self, *, prompt_request_piece: PromptRequestPiece) -> EmbeddingData:
def generate_embedding_memory_data(self, *, prompt_request_piece: PromptRequestPiece) -> EmbeddingDataEntry:
"""
Generates metadata for a chat memory entry.

@@ -32,7 +32,7 @@ def generate_embedding_memory_data(self, *, prompt_request_piece: PromptRequestP
ConversationMemoryEntryMetadata: The generated metadata.
"""
if prompt_request_piece.converted_value_data_type == "text":
embedding_data = EmbeddingData(
embedding_data = EmbeddingDataEntry(
embedding=self.embedding_model.generate_text_embedding(text=prompt_request_piece.converted_value)
.data[0]
.embedding,
15 changes: 11 additions & 4 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
group_conversation_request_pieces_by_sequence,
)

from pyrit.memory.memory_models import EmbeddingData
from pyrit.memory.memory_models import EmbeddingDataEntry
from pyrit.memory.memory_embedding import default_memory_embedding_factory, MemoryEmbedding
from pyrit.memory.memory_exporter import MemoryExporter
from pyrit.models.storage_io import StorageIO
@@ -33,13 +33,14 @@ class MemoryInterface(abc.ABC):
"""

memory_embedding: MemoryEmbedding = None
_storage_io: StorageIO = None
storage_io: StorageIO = None
results_path: str = None

def __init__(self, embedding_model=None):
self.memory_embedding = embedding_model
# Initialize the MemoryExporter instance
self.exporter = MemoryExporter()
self._init_storage_io()

def enable_embedding(self, embedding_model=None):
self.memory_embedding = default_memory_embedding_factory(embedding_model=embedding_model)
@@ -54,11 +55,17 @@ def get_all_prompt_pieces(self) -> Sequence[PromptRequestPiece]:
"""

@abc.abstractmethod
def get_all_embeddings(self) -> Sequence[EmbeddingData]:
def get_all_embeddings(self) -> Sequence[EmbeddingDataEntry]:
"""
Loads all EmbeddingData from the memory storage handler.
"""

@abc.abstractmethod
def _init_storage_io(self):
"""
Initialize the storage IO handler storage_io.
"""

@abc.abstractmethod
def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> MutableSequence[PromptRequestPiece]:
"""
@@ -91,7 +98,7 @@ def add_request_pieces_to_memory(self, *, request_pieces: Sequence[PromptRequest
"""

@abc.abstractmethod
def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingData]) -> None:
def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingDataEntry]) -> None:
"""
Inserts embedding data into memory storage
"""
2 changes: 1 addition & 1 deletion pyrit/memory/memory_models.py
Original file line number Diff line number Diff line change
@@ -126,7 +126,7 @@ def __str__(self):
return f": {self.role}: {self.converted_value}"


class EmbeddingData(Base): # type: ignore
class EmbeddingDataEntry(Base): # type: ignore
"""
Represents the embedding data associated with conversation entries in the database.
Each embedding is linked to a specific conversation entry via an id
Loading