diff --git a/python/pyproject.toml b/python/pyproject.toml index 24e23ec8628a..8056443d1d7c 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -138,6 +138,9 @@ realtime = [ "websockets >= 13, < 16", "aiortc>=1.9.0", ] +sql = [ + "pyodbc >= 5.2" +] [tool.uv] prerelease = "if-necessary-or-explicit" diff --git a/python/samples/concepts/memory/complex_memory.py b/python/samples/concepts/memory/complex_memory.py index 4816cda886dc..211213f88cfd 100644 --- a/python/samples/concepts/memory/complex_memory.py +++ b/python/samples/concepts/memory/complex_memory.py @@ -29,6 +29,7 @@ from semantic_kernel.connectors.memory.postgres import PostgresCollection from semantic_kernel.connectors.memory.qdrant import QdrantCollection from semantic_kernel.connectors.memory.redis import RedisHashsetCollection, RedisJsonCollection +from semantic_kernel.connectors.memory.sql_server import SqlServerCollection from semantic_kernel.connectors.memory.weaviate import WeaviateCollection from semantic_kernel.data import ( VectorizableTextSearchMixin, @@ -120,7 +121,7 @@ class DataModelList: # Depending on the vector database, the index kind and distance function may need to be adjusted # since not all combinations are supported by all databases. # The values below might need to be changed for your collection to work. -distance_function = DistanceFunction.EUCLIDEAN_SQUARED_DISTANCE +distance_function = DistanceFunction.COSINE_DISTANCE index_kind = IndexKind.FLAT DataModel = get_data_model("array", index_kind, distance_function) @@ -147,11 +148,14 @@ class DataModelList: # The chroma collection is currently only available for in-memory versions # Client-Server mode and Chroma Cloud are not yet supported. # More info on Chroma here: https://docs.trychroma.com/docs/overview/introduction +# - faiss: Faiss - in-memory with optimized indexes. +# - pinecone: Pinecone +# - sql_server: SQL Server, can connect to any SQL Server compatible database, like Azure SQL. # This is represented as a mapping from the collection name to a # function which returns the collection. # Using a function allows for lazy initialization of the collection, # so that settings for unused collections do not cause validation errors. -collections: dict[str, Callable[[], VectorStoreRecordCollection[str, DataModel]]] = { +collections: dict[str, Callable[[], VectorStoreRecordCollection]] = { "ai_search": lambda: AzureAISearchCollection[str, DataModel]( data_model_type=DataModel, ), @@ -204,6 +208,10 @@ class DataModelList: collection_name=collection_name, data_model_type=DataModel, ), + "sql_server": lambda: SqlServerCollection[str, DataModel]( + data_model_type=DataModel, + collection_name=collection_name, + ), } diff --git a/python/semantic_kernel/connectors/memory/sql_server.py b/python/semantic_kernel/connectors/memory/sql_server.py new file mode 100644 index 000000000000..684f230b8708 --- /dev/null +++ b/python/semantic_kernel/connectors/memory/sql_server.py @@ -0,0 +1,1037 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import json +import logging +import re +import struct +import sys +from collections.abc import AsyncIterable, Sequence +from contextlib import contextmanager +from io import StringIO +from itertools import chain +from typing import TYPE_CHECKING, Any, ClassVar, Final, Generic, TypeVar + +from azure.identity.aio import DefaultAzureCredential +from pydantic import SecretStr, ValidationError, field_validator + +from semantic_kernel.data.const import DISTANCE_FUNCTION_DIRECTION_HELPER, DistanceFunction, IndexKind +from semantic_kernel.data.record_definition import ( + VectorStoreRecordDataField, + VectorStoreRecordDefinition, + VectorStoreRecordKeyField, + VectorStoreRecordVectorField, +) +from semantic_kernel.data.text_search import AnyTagsEqualTo, EqualTo, KernelSearchResults +from semantic_kernel.data.vector_search import ( + VectorizedSearchMixin, + VectorSearchFilter, + VectorSearchOptions, + VectorSearchResult, +) +from semantic_kernel.data.vector_storage import VectorStore, VectorStoreRecordCollection +from semantic_kernel.exceptions import VectorStoreOperationException +from semantic_kernel.exceptions.vector_store_exceptions import ( + VectorSearchExecutionException, + VectorStoreInitializationException, +) +from semantic_kernel.kernel_pydantic import KernelBaseSettings +from semantic_kernel.kernel_types import OneOrMany +from semantic_kernel.utils.feature_stage_decorator import experimental + +if sys.version_info >= (3, 12): + from typing import override # pragma: no cover +else: + from typing_extensions import override # pragma: no cover +if sys.version_info >= (3, 11): + from typing import Self # pragma: no cover +else: + from typing_extensions import Self # pragma: no cover + +if TYPE_CHECKING: + from pyodbc import Connection + + +logger = logging.getLogger(__name__) + +TKey = TypeVar("TKey", str, int) +TModel = TypeVar("TModel") + +# maximum number of parameters for SQL Server +# The actual limit is 2100, but we leave some space +SQL_PARAMETER_SAFETY_MAX_COUNT: Final[int] = 2000 +SQL_PARAMETER_MAX_COUNT: Final[int] = 2100 +SCORE_FIELD_NAME: Final[str] = "_vector_distance_value" +DISTANCE_FUNCTION_MAP = { + DistanceFunction.COSINE_DISTANCE: "cosine", + DistanceFunction.EUCLIDEAN_DISTANCE: "euclidean", + DistanceFunction.DOT_PROD: "dot", +} + +__all__ = ["SqlServerCollection", "SqlServerStore"] + +# region: Settings + + +@experimental +class SqlSettings(KernelBaseSettings): + """SQL settings. + + The settings are first loaded from environment variables with + the prefix 'SQL_SERVER_'. + If the environment variables are not found, the settings can + be loaded from a .env file with the encoding 'utf-8'. + If the settings are not found in the .env file, the settings + are ignored; however, validation will fail alerting that the + settings are missing. + + Required settings for prefix 'SQL_SERVER_': + - connection_string: str - The connection string of the SQL Server, including for Azure SQL. + For SQL Server: the connection string should include the server name, database name, user ID, and password. + For example: "Driver={ODBC Driver 18 for SQL Server};Server=server_name;Database=database_name;UID=user_id;PWD=password;" + For Azure SQL: This value can be found in the Keys & Endpoint section when examining + your resource from the Azure portal. + The advice is to use a password-less setup, see + https://learn.microsoft.com/en-us/azure/azure-sql/database/azure-sql-passwordless-migration-python?view=azuresql&preserve-view=true&tabs=sign-in-azure-cli%2Cazure-portal-create%2Cazure-portal-assign%2Capp-service-identity#update-the-local-connection-configuration for more info. + (Env var name: SQL_SERVER_CONNECTION_STRING) + """ # noqa: E501 + + env_prefix: ClassVar[str] = "SQL_SERVER_" + + connection_string: SecretStr + + @field_validator("connection_string", mode="before") + @classmethod + def validate_connection_string(cls, value: str) -> str: + """Validate the connection string. + + The LongAsMax=yes option is added to the connection string if it is not present. + This is needed to supply vectors as query parameters. + + """ + if "LongAsMax=yes" not in value: + if value.endswith(";"): + value = value[:-1] + return f"{value};LongAsMax=yes;" + return value + + +# region: SQL Command and Query Builder + + +@experimental +class QueryBuilder: + """A class that helps you build strings for SQL queries.""" + + def __init__(self, initial_string: "QueryBuilder | str | None" = None): + """Initialize the StringBuilder with an empty StringIO object.""" + self._file_str = StringIO() + if initial_string: + self._file_str.write(str(initial_string)) + + def append(self, string: str, suffix: str | None = None): + """Append a string to the StringBuilder.""" + self._file_str.write(string) + if suffix: + self._file_str.write(suffix) + + def append_list(self, strings: Sequence[str], sep: str = ", ", suffix: str | None = None): + """Append a list of strings to the StringBuilder. + + Optionally set the separator (default: `, `) and a suffix (default is None). + """ + if not strings: + return + for string in strings[:-1]: + self.append(string, suffix=sep) + self.append(strings[-1], suffix=suffix) + + def append_table_name( + self, schema: str, table_name: str, prefix: str = "", suffix: str | None = None, newline: bool = False + ) -> None: + """Append a table name to the StringBuilder with schema. + + This includes square brackets around the schema and table name. + And spaces around the table name. + + Args: + schema: The schema name. + table_name: The table name. + prefix: Optional prefix to add before the table name. + suffix: Optional suffix to add after the table name. + newline: Whether to add a newline after the table name or suffix. + """ + self.append(f"{prefix} [{schema}].[{table_name}] {suffix or ''}", suffix="\n" if newline else "") + + def remove_last(self, number_of_chars: int): + """Remove the last number_of_chars from the StringBuilder.""" + current_pos = self._file_str.tell() + if current_pos >= number_of_chars: + self._file_str.seek(current_pos - number_of_chars) + self._file_str.truncate() + + @contextmanager + def in_parenthesis(self, prefix: str | None = None, suffix: str | None = None): + """Context manager to add parentheses around a block of strings. + + Args: + prefix: Optional prefix to add before the opening parenthesis. + suffix: Optional suffix to add after the closing parenthesis. + + """ + self.append(f"{prefix or ''} (") + yield + self.append(f") {suffix or ''}") + + @contextmanager + def in_logical_group(self): + """Create a logical group with BEGIN and END.""" + self.append("BEGIN", suffix="\n") + yield + self.append("\nEND", suffix="\n") + + def __str__(self): + """Return the string representation of the StringBuilder.""" + return self._file_str.getvalue() + + +@experimental +class SqlCommand: + """A class that represents a SQL command with parameters.""" + + def __init__( + self, + query: QueryBuilder | str | None = None, + ): + """Initialize the SqlCommand. + + This only allows for creation of the query string, use the add_parameter + and add_parameters methods to add parameters to the command. + + Args: + query: The SQL command string or QueryBuilder object. + + """ + self.query = QueryBuilder(query) + self.parameters: list[str] = [] + + def add_parameter(self, value: str) -> None: + """Add a parameter to the SqlCommand.""" + if (len(self.parameters) + 1) > SQL_PARAMETER_MAX_COUNT: + raise VectorStoreOperationException("The maximum number of parameters is 2100.") + self.parameters.append(value) + + def add_parameters(self, values: Sequence[str] | tuple[str, ...]) -> None: + """Add multiple parameters to the SqlCommand.""" + if (len(self.parameters) + len(values)) > SQL_PARAMETER_MAX_COUNT: + raise VectorStoreOperationException(f"The maximum number of parameters is {SQL_PARAMETER_MAX_COUNT}.") + self.parameters.extend(values) + + def __str__(self): + """Return the string representation of the SqlCommand.""" + if self.parameters: + logger.debug("This command has parameters.") + return str(self.query) + + def to_execute(self) -> tuple[str, tuple[str, ...]]: + """Return the command and parameters for execute.""" + return str(self.query), tuple(self.parameters) + + +async def _get_mssql_connection(settings: SqlSettings) -> "Connection": + """Get a connection to the SQL Server database, optionally with Entra Auth.""" + from pyodbc import connect + + mssql_connection_string = settings.connection_string.get_secret_value() + if any(s in mssql_connection_string.lower() for s in ["uid"]): + attrs_before: dict | None = None + else: + async with DefaultAzureCredential(exclude_interactive_browser_credential=False) as credential: + # Get the access token + token_bytes = (await credential.get_token("https://database.windows.net/.default")).token.encode( + "UTF-16-LE" + ) + token_struct = struct.pack(f" Self: + # If the connection pool was not provided, create a new one. + if not self.connection: + if not self.settings: # pragma: no cover + # this should never happen, but just in case + raise VectorStoreInitializationException("No connection or settings provided.") + self.connection = await _get_mssql_connection(self.settings) + self.connection.__enter__() + return self + + @override + async def __aexit__(self, *args): + # Only close the connection if it was created by the collection. + if self.managed_client and self.connection: + self.connection.close() + self.connection = None + + @override + async def _inner_upsert( + self, + records: Sequence[dict[str, Any]], + **kwargs: Any, + ) -> Sequence[TKey]: + """Upsert records into the database. + + Args: + records: The records, the format is specific to the store. + **kwargs: Additional arguments, to be passed to the store. + + Returns: + The keys of the upserted records. + """ + if self.connection is None: + raise VectorStoreOperationException("connection is not available, use the collection as a context manager.") + if not records: + return [] + data_fields = [ + field + for field in self.data_model_definition.fields.values() + if isinstance(field, VectorStoreRecordDataField) + ] + vector_fields = self.data_model_definition.vector_fields + schema, table = self._get_schema_and_table() + # Check how many parameters are likely to be passed + # to the command, if it exceeds the maximum, split the records + # into smaller chunks + max_records = SQL_PARAMETER_SAFETY_MAX_COUNT // len(self.data_model_definition.fields) + batches = [] + for i in range(0, len(records), max_records): + batches.append(records[i : i + max_records]) + keys = [] + for batch in batches: + command = _build_merge_query( + schema, table, self.data_model_definition.key_field, data_fields, vector_fields, batch + ) + with self.connection.cursor() as cur: + cur.execute(*command.to_execute()) + while cur.nextset(): + keys.extend([row[0] for row in cur.fetchall()]) + if not keys: + raise VectorStoreOperationException("No keys were returned from the merge query.") + return keys + + @override + async def _inner_get(self, keys: Sequence[TKey], **kwargs: Any) -> OneOrMany[dict[str, Any]] | None: + """Get records from the database. + + Args: + keys: The keys to get. + **kwargs: Additional arguments. + + Returns: + The records from the store, not deserialized. + """ + if not keys: + return None + query = _build_select_query( + *self._get_schema_and_table(), + self.data_model_definition.key_field, + [ + field + for field in self.data_model_definition.fields.values() + if isinstance(field, VectorStoreRecordDataField) + ], + self.data_model_definition.vector_fields if kwargs.get("include_vectors", True) else None, + keys, + ) + records = [record async for record in self._fetch_records(query)] + return records if records else None + + @override + async def _inner_delete(self, keys: Sequence[TKey], **kwargs: Any) -> None: + """Delete the records with the given keys. + + Args: + keys: The keys. + **kwargs: Additional arguments. + """ + if self.connection is None: + raise VectorStoreOperationException("connection is not available, use the collection as a context manager.") + + if not keys: + return + query = _build_delete_query( + *self._get_schema_and_table(), + self.data_model_definition.key_field, + keys, + ) + with self.connection.cursor() as cur: + cur.execute(*query.to_execute()) + + @override + def _serialize_dicts_to_store_models(self, records: Sequence[dict[str, Any]], **kwargs: Any) -> Sequence[Any]: + """Serialize a list of dicts of the data to the store model. + + Pass the records through without modification. + """ + return records + + @override + def _deserialize_store_models_to_dicts(self, records: Sequence[Any], **kwargs: Any) -> Sequence[dict[str, Any]]: + """Deserialize the store models to a list of dicts. + + Pass the records through without modification. + """ + return records + + @override + async def create_collection( + self, *, create_if_not_exists: bool = True, queries: list[str] | None = None, **kwargs: Any + ) -> None: + """Create a SQL table based on the data model. + + Alternatively, you can pass a list of queries to execute. + If supplied, only the queries will be executed. + + Args: + create_if_not_exists: Whether to create the table if it does not exist, default is True. + This means, that by default the table will only be created if it does not exist. + So if there is a existing table with the same name, it will not be created or modified. + queries: A list of SQL queries to execute. + **kwargs: Additional arguments. + + """ + if self.connection is None: + raise VectorStoreOperationException("Connection is not available, use the collection as a context manager.") + + if queries: + with self.connection.cursor() as cursor: + for query in queries: + cursor.execute(query) + return + + data_fields = [ + field + for field in self.data_model_definition.fields.values() + if isinstance(field, VectorStoreRecordDataField) + ] + create_table_query = _build_create_table_query( + *self._get_schema_and_table(), + key_field=self.data_model_definition.key_field, + data_fields=data_fields, + vector_fields=self.data_model_definition.vector_fields, + if_not_exists=create_if_not_exists, + ) + with self.connection.cursor() as cursor: + cursor.execute(*create_table_query.to_execute()) + logger.info(f"SqlServer table '{self.collection_name}' created successfully.") + + def _get_schema_and_table(self) -> tuple[str, str]: + """Get the schema and table name from the collection name.""" + if "." in self.collection_name: + schema, table = self.collection_name.split(".", maxsplit=1) + else: + schema = "dbo" + table = self.collection_name + return schema, table + + @override + async def does_collection_exist(self, **kwargs: Any) -> bool: + """Check if the collection exists.""" + if self.connection is None: + raise VectorStoreOperationException("connection is not available, use the collection as a context manager.") + + with self.connection.cursor() as cursor: + cursor.execute(*_build_select_table_name_query(*self._get_schema_and_table()).to_execute()) + row = cursor.fetchone() + return bool(row) + + @override + async def delete_collection(self, **kwargs: Any) -> None: + """Delete the collection.""" + if self.connection is None: + raise VectorStoreOperationException("connection is not available, use the collection as a context manager.") + + with self.connection.cursor() as cur: + cur.execute(*_build_delete_table_query(*self._get_schema_and_table()).to_execute()) + logger.debug(f"SqlServer table '{self.collection_name}' deleted successfully.") + + @override + async def _inner_search( + self, + options: VectorSearchOptions, + search_text: str | None = None, + vectorizable_text: str | None = None, + vector: list[float | int] | None = None, + **kwargs: Any, + ) -> KernelSearchResults[VectorSearchResult[TModel]]: + if vector is not None: + query = _build_search_query( + *self._get_schema_and_table(), + self.data_model_definition.key_field, + [ + field + for field in self.data_model_definition.fields.values() + if isinstance(field, VectorStoreRecordDataField) + ], + self.data_model_definition.vector_fields, + vector, + options, + ) + elif search_text: + raise VectorSearchExecutionException("Text search not supported.") + elif vectorizable_text: + raise VectorSearchExecutionException("Vectorizable text search not supported.") + + return KernelSearchResults( + results=self._get_vector_search_results_from_results(self._fetch_records(query), options), + total_count=None, + ) + + async def _fetch_records(self, query: SqlCommand) -> AsyncIterable[dict[str, Any]]: + if self.connection is None: + raise VectorStoreOperationException("connection is not available, use the collection as a context manager.") + with self.connection.cursor() as cur: + cur.execute(*query.to_execute()) + col_names = [desc[0] for desc in cur.description] + for row in cur: + record = { + col: ( + json.loads(row.__getattribute__(col)) + if col in self.data_model_definition.vector_field_names + else row.__getattribute__(col) + ) + for col in col_names + } + yield record + await asyncio.sleep(0) + + @override + def _get_record_from_result(self, result: dict[str, Any]) -> dict[str, Any]: + return result + + @override + def _get_score_from_result(self, result: Any) -> float | None: + return result.pop(SCORE_FIELD_NAME, None) + + +# region: SQL Server Store + + +@experimental +class SqlServerStore(VectorStore): + """SQL Store implementation. + + This class is used to store and retrieve data from an SQL database. + It uses the SqlServerCollection class to perform the actual operations. + """ + + connection: Any | None = None + settings: SqlSettings | None = None + + def __init__( + self, + connection_string: str | None = None, + connection: "Connection | None" = None, + env_file_path: str | None = None, + env_file_encoding: str | None = None, + **kwargs: Any, + ): + """Initialize the SQL Store. + + Args: + connection_string: The connection string to the database. + connection: The connection, make sure to set the `LongAsMax=yes` option on the construction string used. + env_file_path: Use the environment settings file as a fallback to environment variables. + env_file_encoding: The encoding of the environment settings file. + **kwargs: Additional arguments. + """ + managed_client = not connection + settings = None + if not connection: + try: + settings = SqlSettings.create( + connection_string=connection_string, + env_file_path=env_file_path, + env_file_encoding=env_file_encoding, + ) + except ValidationError as e: + raise VectorStoreInitializationException( + "Invalid settings provided. Please check the connection string." + ) from e + + super().__init__(settings=settings, connection=connection, managed_client=managed_client, **kwargs) + + @override + async def __aenter__(self) -> Self: + # If the connection was not provided, create a new one. + if not self.connection: + if not self.settings: # pragma: no cover + # this should never happen, but just in case + raise VectorStoreInitializationException("Settings must be provided to establish a connection.") + self.connection = await _get_mssql_connection(self.settings) + self.connection.__enter__() + return self + + @override + async def __aexit__(self, *args): + # Only close the connection if it was created by the store. + if self.managed_client and self.connection: + self.connection.close() + self.connection = None + + @override + async def list_collection_names(self, **kwargs) -> Sequence[str]: + """List the collection names in the database. + + Args: + **kwargs: Additional arguments. + + Returns: + A list of collection names. + """ + if self.connection is None: + raise VectorStoreOperationException("connection is not available, use the store as a context manager.") + with self.connection.cursor() as cur: + cur.execute(*_build_select_table_names_query(schema=kwargs.get("schema")).to_execute()) + rows = cur.fetchall() + return [row[0] for row in rows] + + @override + def get_collection( + self, + collection_name: str, + data_model_type: type[object], + data_model_definition: VectorStoreRecordDefinition | None = None, + **kwargs: Any, + ) -> "VectorStoreRecordCollection": + self.vector_record_collections[collection_name] = SqlServerCollection( + collection_name=collection_name, + data_model_type=data_model_type, + data_model_definition=data_model_definition, + connection=self.connection, + settings=self.settings, + **kwargs, + ) + return self.vector_record_collections[collection_name] + + +# region: Query Build Functions + + +def _python_type_to_sql(python_type_str: str | None, is_key: bool = False) -> str | None: + """Convert a string representation of a Python type to a SQL data type. + + Args: + python_type_str: The string representation of the Python type (e.g., "int", "List[str]"). + is_key: Whether the type is a key field. + + Returns: + Corresponding SQL data type as a string, if found. If the type is not found, return None. + """ + if python_type_str is None: + raise VectorStoreOperationException("property type cannot be None") + # Basic type mapping from Python types (in string form) to SQL types + type_mapping = { + "str": "nvarchar(max)" if not is_key else "nvarchar(255)", + "int": "int", + "float": "float", + "bool": "bit", + "dict": "json", + "datetime": "datetime2", + "bytes": "binary", + } + + # Regular expression to detect lists, e.g., "List[str]" or "List[int]" + list_pattern = re.compile(r"(?i)List\[(.*)\]") + + # Check if the type is a list + match = list_pattern.match(python_type_str) + if match: + # Extract the inner type of the list and convert it to a SQL array type + element_type_str = match.group(1) + sql_element_type = _python_type_to_sql(element_type_str) + return f"{sql_element_type}[]" + + # Handle basic types + if python_type_str in type_mapping: + return type_mapping[python_type_str] + + return None + + +def _cast_value(value: Any) -> str: + """Add a cast check to the value.""" + if value is None: + return "NULL" + match value: + case str(): + return value + case bool(): + return "1" if value else "0" + case int() | float(): + return f"{value!s}" + case list() | dict(): + return f"{json.dumps(value)}" + case bytes(): + return f"CONVERT(VARBINARY(MAX), '{value.hex()}')" + case _: + raise VectorStoreOperationException(f"Unsupported type: {type(value)}") + + +def _add_cast_check(placeholder: str, value: Any) -> str: + """Add a cast check to the value.""" + if isinstance(value, bytes): + return f"CONVERT(VARBINARY(MAX), {placeholder})" + return placeholder + + +def _build_create_table_query( + schema: str, + table: str, + key_field: VectorStoreRecordKeyField, + data_fields: list[VectorStoreRecordDataField], + vector_fields: list[VectorStoreRecordVectorField], + if_not_exists: bool = False, +) -> SqlCommand: + """Build the CREATE TABLE query based on the data model.""" + command = SqlCommand() + if if_not_exists: + command.query.append_table_name( + schema, table, prefix="IF OBJECT_ID(N'", suffix="', N'U') IS NULL", newline=True + ) + with command.query.in_logical_group(): + command.query.append_table_name(schema, table, prefix="CREATE TABLE", newline=True) + with command.query.in_parenthesis(suffix=";"): + # add the key field + command.query.append( + f'"{key_field.name}" {_python_type_to_sql(key_field.property_type, is_key=True)} NOT NULL,\n' + ) + # add the data fields + [ + command.query.append(f'"{field.name}" {_python_type_to_sql(field.property_type)} NULL,\n') + for field in data_fields + ] + # add the vector fields + for field in vector_fields: + if field.dimensions is None: + raise VectorStoreOperationException(f"Vector dimensions are not defined for field '{field.name}'") + if field.index_kind is not None and field.index_kind != IndexKind.FLAT: + # Only FLAT index kind is supported + # None is also accepted, which means no explicit index kind + # is set, so implicit default is used + raise VectorStoreOperationException( + f"Index kind '{field.index_kind}' is not supported for field '{field.name}'" + ) + command.query.append(f'"{field.name}" VECTOR({field.dimensions}) NULL,\n') + # set the primary key + with command.query.in_parenthesis("PRIMARY KEY", "\n"): + command.query.append(key_field.name) + return command + + +def _build_delete_table_query( + schema: str, + table: str, +) -> SqlCommand: + """Build the DELETE TABLE query based on the data model.""" + command = SqlCommand("DROP TABLE IF EXISTS") + command.query.append_table_name(schema, table, suffix=";") + return command + + +def _build_select_table_names_query( + schema: str | None = None, +) -> SqlCommand: + """Build the SELECT TABLE NAMES query based on the data model.""" + command = SqlCommand() + if schema: + command.query.append( + "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES " + "WHERE TABLE_TYPE = 'BASE TABLE' " + "AND (@schema is NULL or TABLE_SCHEMA = ?);" + ) + command.add_parameter(schema) + else: + command.query.append("SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE = 'BASE TABLE';") + return command + + +def _build_select_table_name_query( + schema: str, + table: str, +) -> SqlCommand: + """Build the SELECT TABLE NAMES query based on the data model.""" + command = SqlCommand( + "SELECT TABLE_NAME" + " FROM INFORMATION_SCHEMA.TABLES" + " WHERE TABLE_TYPE = 'BASE TABLE'" + " AND (@schema is NULL or TABLE_SCHEMA = ?)" + " AND TABLE_NAME = ?" + ) + command.add_parameter(schema) + command.add_parameter(table) + return command + + +def _add_field_names( + command: SqlCommand, + key_field: VectorStoreRecordKeyField, + data_fields: list[VectorStoreRecordDataField], + vector_fields: list[VectorStoreRecordVectorField] | None, + table_identifier: str | None = None, +) -> None: + """Add the field names to the query builder. + + Args: + command: The SqlCommand object to add the field names to. + key_field: The key field. + data_fields: The data fields. + vector_fields: The vector fields. + table_identifier: The table identifier to prefix the field names with, if not given, + the field name is used as is. + If passed, then it is used with a dot separating the table name and field name. + + """ + fields = chain([key_field], data_fields, vector_fields or []) + if table_identifier: + strings = [f"{table_identifier}.{field.name}" for field in fields] + else: + strings = [field.name for field in fields] + command.query.append_list(strings) + + +def _build_merge_query( + schema: str, + table: str, + key_field: VectorStoreRecordKeyField, + data_fields: list[VectorStoreRecordDataField], + vector_fields: list[VectorStoreRecordVectorField], + records: Sequence[dict[str, Any]], +) -> SqlCommand: + """Build the MERGE TABLE query based on the data model.""" + command = SqlCommand() + # Declare a temp table to store the keys that are updated + command.query.append( + "DECLARE @UpsertedKeys TABLE (KeyColumn " + f"{_python_type_to_sql(key_field.property_type or 'str', is_key=True)});\n" + ) + # start the MERGE statement + command.query.append_table_name(schema, table, prefix="MERGE INTO", suffix="AS t", newline=True) + # add the USING VALUES clause + with command.query.in_parenthesis(prefix="USING"): + command.query.append(" VALUES ") + for record in records: + with command.query.in_parenthesis(suffix=",\n"): + query_list = [] + param_list = [] + for field in chain([key_field], data_fields, vector_fields): + value = record.get(field.name) + # add the field name to the query list + query_list.append(_add_cast_check("?", value)) + # add the field value to the parameter list + param_list.append(_cast_value(value)) + command.query.append_list(query_list) + command.add_parameters(param_list) + command.query.remove_last(2) # remove the last comma and newline + # with the table column names + with command.query.in_parenthesis("AS s", " "): + _add_field_names(command, key_field, data_fields, vector_fields) + # add the ON clause + with command.query.in_parenthesis("ON", "\n"): + command.query.append(f"t.{key_field.name} = s.{key_field.name}") + # Set the Matched clause + command.query.append("WHEN MATCHED THEN\n") + command.query.append("UPDATE SET ") + command.query.append_list( + [f"t.{field.name} = s.{field.name}" for field in chain(data_fields, vector_fields)], suffix="\n" + ) + # Set the Not Matched clause + command.query.append("WHEN NOT MATCHED THEN\n") + with command.query.in_parenthesis("INSERT", " "): + _add_field_names(command, key_field, data_fields, vector_fields) + # add the closing parenthesis + with command.query.in_parenthesis("VALUES", " \n"): + _add_field_names(command, key_field, data_fields, vector_fields, table_identifier="s") + # add the closing parenthesis + command.query.append(f"OUTPUT inserted.{key_field.name} INTO @UpsertedKeys (KeyColumn);\n") + command.query.append("SELECT KeyColumn FROM @UpsertedKeys;\n") + return command + + +def _build_select_query( + schema: str, + table: str, + key_field: VectorStoreRecordKeyField, + data_fields: list[VectorStoreRecordDataField], + vector_fields: list[VectorStoreRecordVectorField] | None, + keys: Sequence[Any], +) -> SqlCommand: + """Build the SELECT query based on the data model.""" + command = SqlCommand() + # start the SELECT statement + command.query.append("SELECT\n") + # add the data and vector fields + _add_field_names(command, key_field, data_fields, vector_fields) + # add the FROM clause + command.query.append_table_name(schema, table, prefix=" FROM", newline=True) + # add the WHERE clause + if keys: + command.query.append(f"WHERE {key_field.name} IN\n") + with command.query.in_parenthesis(): + # add the keys + command.query.append_list(["?"] * len(keys)) + command.add_parameters([_cast_value(key) for key in keys]) + command.query.append(";") + return command + + +def _build_delete_query( + schema: str, + table: str, + key_field: VectorStoreRecordKeyField, + keys: Sequence[Any], +) -> SqlCommand: + """Build the DELETE query based on the data model.""" + command = SqlCommand("DELETE FROM") + # start the DELETE statement + command.query.append_table_name(schema, table) + # add the WHERE clause + command.query.append(f"WHERE [{key_field.name}] IN") + with command.query.in_parenthesis(): + # add the keys + command.query.append_list(["?"] * len(keys)) + command.add_parameters([_cast_value(key) for key in keys]) + command.query.append(";") + return command + + +def _build_filter(command: SqlCommand, filters: VectorSearchFilter): + """Build the filter query based on the data model.""" + if not filters.filters: + return + command.query.append("WHERE ") + for filter in filters.filters: + match filter: + case EqualTo(): + command.query.append(f"[{filter.field_name}] = ? AND\n") + command.add_parameter(_cast_value(filter.value)) + case AnyTagsEqualTo(): + command.query.append(f"? IN [{filter.field_name}] AND\n") + command.add_parameter(_cast_value(filter.value)) + # remove the last AND + command.query.remove_last(4) + command.query.append("\n") + + +def _build_search_query( + schema: str, + table: str, + key_field: VectorStoreRecordKeyField, + data_fields: list[VectorStoreRecordDataField], + vector_fields: list[VectorStoreRecordVectorField], + vector: list[float], + options: VectorSearchOptions, +) -> SqlCommand: + """Build the SELECT query based on the data model.""" + # start the SELECT statement + command = SqlCommand("SELECT ") + # add the data and vector fields + _add_field_names(command, key_field, data_fields, vector_fields if options.include_vectors else None) + # add the vector search clause + vector_field: VectorStoreRecordVectorField | None = None + if options.vector_field_name: + vector_field = next( + (field for field in vector_fields if field.name == options.vector_field_name), + None, + ) + elif len(vector_fields) == 1: + vector_field = vector_fields[0] + if not vector_field: + raise VectorStoreOperationException("Vector field not specified.") + + asc: bool = True + if vector_field.distance_function: + distance_function = DISTANCE_FUNCTION_MAP.get(vector_field.distance_function) + if not distance_function: + raise VectorStoreOperationException(f"Distance function '{vector_field.distance_function}' not supported.") + asc = DISTANCE_FUNCTION_DIRECTION_HELPER[vector_field.distance_function](0, 1) + else: + distance_function = "cosine" + + command.query.append( + f", VECTOR_DISTANCE('{distance_function}', {vector_field.name}, CAST(? AS VECTOR({vector_field.dimensions}))) as {SCORE_FIELD_NAME}\n", # noqa: E501 + ) + command.add_parameter(_cast_value(vector)) + # add the FROM clause + command.query.append_table_name(schema, table, prefix=" FROM", newline=True) + # add the WHERE clause + _build_filter(command, options.filter) + # add the ORDER BY clause + command.query.append(f"ORDER BY {SCORE_FIELD_NAME} {'ASC' if asc else 'DESC'}\n") + command.query.append(f"OFFSET {options.skip} ROWS FETCH NEXT {options.top} ROWS ONLY;") + return command diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 7a5c31e02914..f88a3aa753b9 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -321,7 +321,7 @@ class MyDataModel: property_type=vector_property_type, ), ] = None - id: Annotated[str, VectorStoreRecordKeyField()] = field(default_factory=lambda: str(uuid4())) + id: Annotated[str, VectorStoreRecordKeyField(property_type="str")] = field(default_factory=lambda: str(uuid4())) content: Annotated[ str, VectorStoreRecordDataField(has_embedding=True, embedding_property_name="vector", property_type="str") ] = "content1" @@ -362,10 +362,11 @@ def data_model_definition( ) -> VectorStoreRecordDefinition: return VectorStoreRecordDefinition( fields={ - "id": VectorStoreRecordKeyField(), + "id": VectorStoreRecordKeyField(property_type="str"), "content": VectorStoreRecordDataField( has_embedding=True, embedding_property_name="vector", + property_type="str", ), "vector": VectorStoreRecordVectorField( dimensions=dimensions, diff --git a/python/tests/unit/connectors/memory/conftest.py b/python/tests/unit/connectors/memory/conftest.py index 95a9d6e7d42f..325723a3939d 100644 --- a/python/tests/unit/connectors/memory/conftest.py +++ b/python/tests/unit/connectors/memory/conftest.py @@ -111,3 +111,27 @@ def pinecone_unit_test_env(monkeypatch, exclude_list, override_env_param_dict): monkeypatch.delenv(key, raising=False) return env_vars + + +@fixture +def sql_server_unit_test_env(monkeypatch, exclude_list, override_env_param_dict): + """Fixture to set environment variables for SQL Server.""" + if exclude_list is None: + exclude_list = [] + + if override_env_param_dict is None: + override_env_param_dict = {} + + env_vars = { + "SQL_SERVER_CONNECTION_STRING": "Driver={ODBC Driver 18 for SQL Server};Server=localhost;Database=testdb;User Id=testuser;Password=example;" # noqa: E501 + } + + env_vars.update(override_env_param_dict) + + for key, value in env_vars.items(): + if key not in exclude_list: + monkeypatch.setenv(key, value) + else: + monkeypatch.delenv(key, raising=False) + + return env_vars diff --git a/python/tests/unit/connectors/memory/sql_server/test_sql_server.py b/python/tests/unit/connectors/memory/sql_server/test_sql_server.py new file mode 100644 index 000000000000..fae623f95c42 --- /dev/null +++ b/python/tests/unit/connectors/memory/sql_server/test_sql_server.py @@ -0,0 +1,552 @@ +# Copyright (c) Microsoft. All rights reserved. + +import json +import sys +from dataclasses import dataclass +from typing import NamedTuple +from unittest.mock import AsyncMock, MagicMock, NonCallableMagicMock, patch + +from pytest import fixture, mark, param, raises + +from semantic_kernel.connectors.memory.sql_server import ( + QueryBuilder, + SqlCommand, + SqlServerCollection, + SqlServerStore, + _build_create_table_query, + _build_delete_query, + _build_delete_table_query, + _build_merge_query, + _build_search_query, + _build_select_query, + _build_select_table_names_query, +) +from semantic_kernel.data.const import DistanceFunction, IndexKind +from semantic_kernel.data.record_definition import ( + VectorStoreRecordDataField, + VectorStoreRecordKeyField, + VectorStoreRecordVectorField, +) +from semantic_kernel.data.vector_search import VectorSearchFilter, VectorSearchOptions +from semantic_kernel.exceptions.vector_store_exceptions import ( + VectorStoreInitializationException, + VectorStoreOperationException, +) + + +class TestQueryBuilder: + def test_query_builder_append(self): + qb = QueryBuilder() + qb.append("SELECT * FROM") + qb.append(" table", suffix=";") + result = str(qb).strip() + assert result == "SELECT * FROM table;" + + def test_query_builder_append_list(self): + qb = QueryBuilder() + qb.append_list(["id", "name", "age"], sep=", ", suffix=";") + result = str(qb).strip() + assert result == "id, name, age;" + + def test_query_builder_append_table_name(self): + qb = QueryBuilder() + qb.append_table_name("dbo", "Users", prefix="SELECT * FROM", suffix=";", newline=False) + result = str(qb).strip() + assert result == "SELECT * FROM [dbo].[Users] ;" + + def test_query_builder_remove_last(self): + qb = QueryBuilder("SELECT * FROM table;") + qb.remove_last(1) # remove trailing semicolon + result = str(qb).strip() + assert result == "SELECT * FROM table" + + def test_query_builder_in_parenthesis(self): + qb = QueryBuilder("INSERT INTO table") + with qb.in_parenthesis(): + qb.append("id, name, age") + result = str(qb).strip() + assert result == "INSERT INTO table (id, name, age)" + + def test_query_builder_in_parenthesis_with_prefix_suffix(self): + qb = QueryBuilder() + with qb.in_parenthesis(prefix="VALUES", suffix=";"): + qb.append_list(["1", "'John'", "30"]) + result = str(qb).strip() + assert result == "VALUES (1, 'John', 30) ;" + + def test_query_builder_in_logical_group(self): + qb = QueryBuilder() + with qb.in_logical_group(): + qb.append("UPDATE Users SET name = 'John'") + result = str(qb).strip() + lines = result.splitlines() + assert lines[0] == "BEGIN" + assert lines[1] == "UPDATE Users SET name = 'John'" + assert lines[2] == "END" + + +class TestSqlCommand: + def test_sql_command_initial_query(self): + cmd = SqlCommand("SELECT 1") + assert str(cmd.query) == "SELECT 1" + + def test_sql_command_add_parameter(self): + cmd = SqlCommand("SELECT * FROM Test WHERE id = ?") + cmd.add_parameter("42") + assert cmd.parameters[0] == "42" + + def test_sql_command_add_parameters(self): + cmd = SqlCommand("SELECT * FROM Test WHERE id = ?") + cmd.add_parameters(["42", "43"]) + assert cmd.parameters[0] == "42" + assert cmd.parameters[1] == "43" + + def test_parameter_limit(self): + cmd = SqlCommand() + cmd.add_parameters(["42"] * 2100) + with raises(VectorStoreOperationException): + cmd.add_parameter("43") + with raises(VectorStoreOperationException): + cmd.add_parameters(["43", "44"]) + + +class TestQueryBuildFunctions: + def test_build_create_table_query(self): + schema = "dbo" + table = "Test" + key_field = VectorStoreRecordKeyField(name="id", property_type="str") + data_fields = [ + VectorStoreRecordDataField(name="name", property_type="str"), + VectorStoreRecordDataField(name="age", property_type="int"), + ] + vector_fields = [ + VectorStoreRecordVectorField(name="embedding", property_type="float", dimensions=1536), + ] + cmd = _build_create_table_query(schema, table, key_field, data_fields, vector_fields) + assert not cmd.parameters + cmd_str = str(cmd.query) + assert ( + cmd_str + == 'BEGIN\nCREATE TABLE [dbo].[Test] \n ("id" nvarchar(255) NOT NULL,\n"name" nvarchar(max) NULL,\n"age" ' + 'int NULL,\n"embedding" VECTOR(1536) NULL,\nPRIMARY KEY (id) \n) ;\nEND\n' + ) + + def test_delete_table_query(self): + schema = "dbo" + table = "Test" + cmd = _build_delete_table_query(schema, table) + assert str(cmd.query) == f"DROP TABLE IF EXISTS [{schema}].[{table}] ;" + + @mark.parametrize("schema", ["dbo", None]) + def test_build_select_table_names_query(self, schema): + cmd = _build_select_table_names_query(schema) + if schema: + assert cmd.parameters == [schema] + assert str(cmd) == ( + "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES " + "WHERE TABLE_TYPE = 'BASE TABLE' " + "AND (@schema is NULL or TABLE_SCHEMA = ?);" + ) + else: + assert str(cmd) == "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE = 'BASE TABLE';" + + def test_build_merge_query(self): + schema = "dbo" + table = "Test" + key_field = VectorStoreRecordKeyField(name="id", property_type="str") + data_fields = [ + VectorStoreRecordDataField(name="name", property_type="str"), + VectorStoreRecordDataField(name="age", property_type="int"), + ] + vector_fields = [ + VectorStoreRecordVectorField(name="embedding", property_type="float", dimensions=5), + ] + records = [ + { + "id": "test", + "name": "name", + "age": 50, + "embedding": [0.1, 0.2, 0.3, 0.4, 0.5], + } + ] + cmd = _build_merge_query(schema, table, key_field, data_fields, vector_fields, records) + assert cmd.parameters[0] == records[0]["id"] + assert cmd.parameters[1] == records[0]["name"] + assert cmd.parameters[2] == str(records[0]["age"]) + assert cmd.parameters[3] == json.dumps(records[0]["embedding"]) + str_cmd = str(cmd) + assert str_cmd == ( + "DECLARE @UpsertedKeys TABLE (KeyColumn nvarchar(255));\nMERGE INTO [dbo].[Test] AS t\nUSING ( " + "VALUES (?, ?, ?, ?) ) AS s (id, name, age, embedding) ON (t.id = s.id) \nWHEN MATCHED THEN\nUPDATE " + "SET t.name = s.name, t.age = s.age, t.embedding = s.embedding\nWHEN NOT MATCHED THEN\nINSERT " + "(id, name, age, embedding) VALUES (s.id, s.name, s.age, s.embedding) \nOUTPUT inserted.id " + "INTO @UpsertedKeys (KeyColumn);\nSELECT KeyColumn FROM @UpsertedKeys;\n" + ) + + def test_build_select_query(self): + schema = "dbo" + table = "Test" + key_field = VectorStoreRecordKeyField(name="id", property_type="str") + data_fields = [ + VectorStoreRecordDataField(name="name", property_type="str"), + VectorStoreRecordDataField(name="age", property_type="int"), + ] + vector_fields = [ + VectorStoreRecordVectorField(name="embedding", property_type="float", dimensions=5), + ] + keys = ["test"] + cmd = _build_select_query(schema, table, key_field, data_fields, vector_fields, keys) + assert cmd.parameters == ["test"] + str_cmd = str(cmd) + assert str_cmd == "SELECT\nid, name, age, embedding FROM [dbo].[Test] \nWHERE id IN\n (?) ;" + + def test_build_delete_query(self): + schema = "dbo" + table = "Test" + key_field = VectorStoreRecordKeyField(name="id", property_type="str") + keys = ["test"] + cmd = _build_delete_query(schema, table, key_field, keys) + str_cmd = str(cmd) + assert cmd.parameters[0] == "test" + assert str_cmd == "DELETE FROM [dbo].[Test] WHERE [id] IN (?) ;" + + def test_build_search_query(self): + schema = "dbo" + table = "Test" + key_field = VectorStoreRecordKeyField(name="id", property_type="str") + data_fields = [ + VectorStoreRecordDataField(name="name", property_type="str"), + VectorStoreRecordDataField(name="age", property_type="int"), + ] + vector_fields = [ + VectorStoreRecordVectorField( + name="embedding", + property_type="float", + dimensions=5, + distance_function=DistanceFunction.COSINE_DISTANCE, + ), + ] + vector = [0.1, 0.2, 0.3, 0.4, 0.5] + options = VectorSearchOptions( + vector_field_name="embedding", + filter=VectorSearchFilter.equal_to("age", "30").any_tag_equal_to("name", "test"), + ) + cmd = _build_search_query(schema, table, key_field, data_fields, vector_fields, vector, options) + assert cmd.parameters[0] == json.dumps(vector) + assert cmd.parameters[1] == "30" + assert cmd.parameters[2] == "test" + str_cmd = str(cmd) + assert ( + str_cmd == "SELECT id, name, age, VECTOR_DISTANCE('cosine', embedding, CAST(? AS VECTOR(5))) as " + "_vector_distance_value\n FROM [dbo].[Test] \nWHERE [age] = ? AND\n? IN [name] \nORDER BY " + "_vector_distance_value ASC\nOFFSET 0 ROWS FETCH NEXT 3 ROWS ONLY;" + ) + + +@fixture +async def mock_connection(*args, **kwargs): + return MagicMock() + + +@mark.parametrize( + "connection_string", + [ + param( + "Driver={ODBC Driver 18 for SQL Server};Server=localhost;Database=testdb;uid=testuserLongAsMax=yes;", + id="with uid", + ), + param( + "Driver={ODBC Driver 18 for SQL Server};Server=localhost;Database=testdb;LongAsMax=yes;", id="credential" + ), + ], +) +async def test_get_mssql_connection(connection_string): + mock_pyodbc = NonCallableMagicMock() + sys.modules["pyodbc"] = mock_pyodbc + + with patch("pyodbc.connect") as patched_connection: + from azure.identity.aio import DefaultAzureCredential + + from semantic_kernel.connectors.memory.sql_server import SqlSettings, _get_mssql_connection + + token = MagicMock() + token.token.return_value = "test_token" + token.token.encode.return_value = b"test_token" + credential = AsyncMock(spec=DefaultAzureCredential) + credential.__aenter__.return_value = credential + credential.get_token.return_value = token + + settings = SqlSettings.create(connection_string=connection_string) + with patch("semantic_kernel.connectors.memory.sql_server.DefaultAzureCredential", return_value=credential): + connection = await _get_mssql_connection(settings) + assert connection is not None + assert isinstance(connection, MagicMock) + if "uid" in connection_string: + assert patched_connection.call_args.kwargs["attrs_before"] is None + else: + assert patched_connection.call_args.kwargs["attrs_before"] == { + 1256: b"\n\x00\x00\x00test_token", + } + + +class TestSqlServerStore: + async def test_create_store(self, sql_server_unit_test_env): + store = SqlServerStore() + assert store is not None + assert store.settings is not None + assert store.settings.connection_string is not None + assert "LongAsMax=yes;" in store.settings.connection_string.get_secret_value() + + with patch("semantic_kernel.connectors.memory.sql_server._get_mssql_connection") as mock_get_connection: + mock_get_connection.return_value = AsyncMock() + await store.__aenter__() + assert store.connection is not None + + @mark.parametrize( + "override_env_param_dict", + [ + { + "SQL_SERVER_CONNECTION_STRING": "Driver={ODBC Driver 18 for SQL Server};Server=localhost;Database=testdb;User Id=testuser;Password=example;LongAsMax=yes;" # noqa: E501 + } + ], + indirect=True, + ) + def test_create_store_with_long_as_max(self, sql_server_unit_test_env): + store = SqlServerStore() + assert store is not None + assert store.settings is not None + assert store.settings.connection_string is not None + + @mark.parametrize("exclude_list", ["SQL_SERVER_CONNECTION_STRING"], indirect=True) + def test_create_without_connection_string(self, sql_server_unit_test_env): + with raises(VectorStoreInitializationException): + SqlServerStore(env_file_path="test.env") + + def test_get_collection(self, sql_server_unit_test_env, data_model_definition): + store = SqlServerStore() + collection = store.get_collection("test", data_model_type=dict, data_model_definition=data_model_definition) + assert collection is not None + + async def test_list_collection_names(self, sql_server_unit_test_env, mock_connection): + async with SqlServerStore(connection=mock_connection) as store: + mock_connection.cursor.return_value.__enter__.return_value.fetchall.return_value = [ + ["Test1"], + ["Test2"], + ] + collection_names = await store.list_collection_names() + assert collection_names == ["Test1", "Test2"] + + async def test_no_connection(self, sql_server_unit_test_env): + store = SqlServerStore() + with raises(VectorStoreOperationException): + await store.list_collection_names() + + +class TestSqlServerCollection: + @mark.parametrize("exclude_list", ["SQL_SERVER_CONNECTION_STRING"], indirect=True) + def test_create_without_connection_string(self, sql_server_unit_test_env, data_model_definition): + with raises(VectorStoreInitializationException): + SqlServerCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=data_model_definition, + env_file_path="test.env", + ) + + async def test_create(self, sql_server_unit_test_env, data_model_definition): + collection = SqlServerCollection( + collection_name="test", data_model_type=dict, data_model_definition=data_model_definition + ) + assert collection is not None + assert collection.collection_name == "test" + assert collection.settings is not None + assert collection.settings.connection_string is not None + + with patch("semantic_kernel.connectors.memory.sql_server._get_mssql_connection") as mock_get_connection: + mock_get_connection.return_value = AsyncMock() + await collection.__aenter__() + assert collection.connection is not None + + async def test_upsert( + self, + sql_server_unit_test_env, + mock_connection, + data_model_definition, + ): + collection = SqlServerCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=data_model_definition, + connection=mock_connection, + ) + record = {"id": "1", "content": "test", "vector": [0.1, 0.2, 0.3, 0.4, 0.5]} + mock_connection.cursor.return_value.__enter__.return_value.nextset.side_effect = [True, False] + mock_connection.cursor.return_value.__enter__.return_value.fetchall.return_value = [ + ["1"], + ] + await collection.upsert(record) + mock_connection.cursor.return_value.__enter__.return_value.execute.assert_called_with( + ( + "DECLARE @UpsertedKeys TABLE (KeyColumn nvarchar(255));\nMERGE INTO [dbo].[test] AS t\nUSING ( VALUES" + " (?, ?, ?) ) AS s (id, content, vector) ON (t.id = s.id) \nWHEN MATCHED THEN\nUPDATE SET t.content" + " = s.content, t.vector = s.vector\nWHEN NOT MATCHED THEN\nINSERT (id, content, vector) VALUES (s.id, " + "s.content, s.vector) \nOUTPUT inserted.id INTO @UpsertedKeys (KeyColumn);\nSELECT KeyColumn " + "FROM @UpsertedKeys;\n" + ), + ("1", "test", json.dumps([0.1, 0.2, 0.3, 0.4, 0.5])), + ) + + async def test_get( + self, + sql_server_unit_test_env, + mock_connection, + data_model_definition, + ): + class MockRow(NamedTuple): + id: str + content: str + vector: str + + mock_cursor = MagicMock() + mock_connection.cursor.return_value.__enter__.return_value = mock_cursor + + collection = SqlServerCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=data_model_definition, + connection=mock_connection, + ) + key = "1" + + row = MockRow("1", "test", "[0.1, 0.2, 0.3, 0.4, 0.5]") + mock_cursor.description = [["id"], ["content"], ["vector"]] + + mock_cursor.__iter__.return_value = [row] + record = await collection.get(key) + mock_cursor.execute.assert_called_with( + "SELECT\nid, content, vector FROM [dbo].[test] \nWHERE id IN\n (?) ;", ("1",) + ) + assert record["id"] == "1" + assert record["content"] == "test" + assert record["vector"] == [0.1, 0.2, 0.3, 0.4, 0.5] + + async def test_delete( + self, + sql_server_unit_test_env, + mock_connection, + data_model_definition, + ): + collection = SqlServerCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=data_model_definition, + connection=mock_connection, + ) + key = "1" + await collection.delete(key) + mock_connection.cursor.return_value.__enter__.return_value.execute.assert_called_with( + "DELETE FROM [dbo].[test] WHERE [id] IN (?) ;", ("1",) + ) + + async def test_search( + self, + sql_server_unit_test_env, + mock_connection, + data_model_definition, + ): + mock_cursor = MagicMock() + mock_connection.cursor.return_value.__enter__.return_value = mock_cursor + data_model_definition.fields["vector"].distance_function = DistanceFunction.COSINE_DISTANCE + collection = SqlServerCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=data_model_definition, + connection=mock_connection, + ) + vector = [0.1, 0.2, 0.3, 0.4, 0.5] + options = VectorSearchOptions(vector_field_name="vector", filter=VectorSearchFilter.equal_to("content", "test")) + + @dataclass + class MockRow: + id: str + content: str + _vector_distance_value: float + + row = MockRow("1", "test", 0.1) + mock_cursor.description = [["id"], ["content"], ["_vector_distance_value"]] + + mock_cursor.__iter__.return_value = [row] + search_result = await collection.vectorized_search(vector, options) + async for record in search_result.results: + assert record.record["id"] == "1" + assert record.record["content"] == "test" + assert record.score == 0.1 + mock_cursor.execute.assert_called_with( + ( + "SELECT id, content, VECTOR_DISTANCE('cosine', vector, CAST(? AS VECTOR(5))) as " + "_vector_distance_value\n FROM [dbo].[test] \nWHERE [content] = ? \nORDER BY _vector_distance_value " + "ASC\nOFFSET 0 ROWS FETCH NEXT 3 ROWS ONLY;" + ), + (json.dumps(vector), "test"), + ) + + async def test_create_collection( + self, + sql_server_unit_test_env, + mock_connection, + data_model_definition, + ): + data_model_definition.fields["vector"].index_kind = IndexKind.FLAT + collection = SqlServerCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=data_model_definition, + connection=mock_connection, + ) + await collection.create_collection() + mock_connection.cursor.return_value.__enter__.return_value.execute.assert_called_with( + ( + "IF OBJECT_ID(N' [dbo].[test] ', N'U') IS NULL\nBEGIN\nCREATE TABLE [dbo].[test] \n (\"id\" nvarchar" + '(255) NOT NULL,\n"content" nvarchar(max) NULL,\n"vector" VECTOR(5) NULL,\nPRIMARY KEY (id) \n) ;' + "\nEND\n" + ), + (), + ) + + async def test_delete_collection( + self, + sql_server_unit_test_env, + mock_connection, + data_model_definition, + ): + collection = SqlServerCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=data_model_definition, + connection=mock_connection, + ) + await collection.delete_collection() + mock_connection.cursor.return_value.__enter__.return_value.execute.assert_called_with( + "DROP TABLE IF EXISTS [dbo].[test] ;", () + ) + + async def test_no_connection(self, sql_server_unit_test_env, data_model_definition): + collection = SqlServerCollection( + collection_name="test", + data_model_type=dict, + data_model_definition=data_model_definition, + ) + with raises(VectorStoreOperationException): + await collection.create_collection() + with raises(VectorStoreOperationException): + await collection.delete_collection() + with raises(VectorStoreOperationException): + await collection.does_collection_exist() + with raises(VectorStoreOperationException): + await collection.upsert({"id": "1", "content": "test", "vector": [0.1, 0.2, 0.3, 0.4, 0.5]}) + with raises(VectorStoreOperationException): + await collection.get("1") + with raises(VectorStoreOperationException): + await collection.delete("1") + with raises(VectorStoreOperationException): + await collection.vectorized_search([0.1, 0.2, 0.3, 0.4, 0.5], VectorSearchOptions()) diff --git a/python/uv.lock b/python/uv.lock index 7cb66e367035..c0928c23f77a 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -1485,7 +1485,7 @@ wheels = [ [[package]] name = "google-cloud-aiplatform" -version = "1.85.0" +version = "1.86.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "docstring-parser", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -1501,9 +1501,9 @@ dependencies = [ { name = "shapely", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/36/44/395e85e58d6cc266213f5679f785b78eedca8b4d257048da19fa56b1e18b/google_cloud_aiplatform-1.85.0.tar.gz", hash = "sha256:8f4845f02c0fe77903342d250c7522f870333e2f9738ccd197c8e7cc3e95f11d", size = 8773739 } +sdist = { url = "https://files.pythonhosted.org/packages/56/78/a9fd14966ff5c44db1d36dd5db9c44513a106adafd2570d928589754d049/google_cloud_aiplatform-1.86.0.tar.gz", hash = "sha256:45fff84c75c6f66105efa1c6caf0ea87fddc85298c834ee38f4163cf793510c4", size = 8999957 } wheels = [ - { url = "https://files.pythonhosted.org/packages/7c/7b/b308cdc8d6cc58acc656c7e9d4eb753992b277b29d885ab1c39d3f14c650/google_cloud_aiplatform-1.85.0-py3-none-any.whl", hash = "sha256:5064519a81aa355bcba708795c0b122b08af90612eec1b2a3b892dbdb98a81b4", size = 7334994 }, + { url = "https://files.pythonhosted.org/packages/e5/01/a204dce0e40e1bdfb7de7a28f9db56e9bdf738e31f3cdf89f8601171b76a/google_cloud_aiplatform-1.86.0-py2.py3-none-any.whl", hash = "sha256:fcd155a0e77fdc12a5c477af92fa0b7e8ea1b1d1fcece35ad07f160008dedc7e", size = 7519175 }, ] [[package]] @@ -4307,6 +4307,38 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7d/64/11d87df61cdca4fef90388af592247e17f3d31b15a909780f186d2739592/pymongo-4.11.3-cp313-cp313t-win_amd64.whl", hash = "sha256:07d40b831590bc458b624f421849c2b09ad2b9110b956f658b583fe01fe01c01", size = 987855 }, ] +[[package]] +name = "pyodbc" +version = "5.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a0/36/a1ac7d23a1611e7ccd4d27df096f3794e8d1e7faa040260d9d41b6fc3185/pyodbc-5.2.0.tar.gz", hash = "sha256:de8be39809c8ddeeee26a4b876a6463529cd487a60d1393eb2a93e9bcd44a8f5", size = 116908 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/30/01/05c4a4ec122c4a8a37fa1be5bdbf6fb23724a2ee3b1b771bb46f710158a9/pyodbc-5.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eb0850e3e3782f57457feed297e220bb20c3e8fd7550d7a6b6bb96112bd9b6fe", size = 72483 }, + { url = "https://files.pythonhosted.org/packages/73/22/ba718cc5508bdfbb53e1906018d7f597be37241c769dda8a48f52af96fe3/pyodbc-5.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0dae0fb86078c87acf135dbe5afd3c7d15d52ab0db5965c44159e84058c3e2fb", size = 71794 }, + { url = "https://files.pythonhosted.org/packages/24/e4/9d859ea3642059c10a6644a00ccb1f8b8e02c1e4f49ab34250db1273c2c5/pyodbc-5.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6493b9c7506ca964b80ad638d0dc82869df7058255d71f04fdd1405e88bcb36b", size = 332850 }, + { url = "https://files.pythonhosted.org/packages/b9/a7/98c3555c10cfeb343ec7eea69ecb17476aa3ace72131ea8a4a1f8250318c/pyodbc-5.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e04de873607fb960e71953c164c83e8e5d9291ce0d69e688e54947b254b04902", size = 336009 }, + { url = "https://files.pythonhosted.org/packages/24/c1/d5b16dd62eb70f281bc90cdc1e3c46af7acda3f0f6afb34553206506ccb2/pyodbc-5.2.0-cp310-cp310-win32.whl", hash = "sha256:74135cb10c1dcdbd99fe429c61539c232140e62939fa7c69b0a373cc552e4a08", size = 62407 }, + { url = "https://files.pythonhosted.org/packages/f5/12/22c83669abee4ca5915aa89172cf1673b58ca05f44dabeb8b0bac9b7fecc/pyodbc-5.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:d287121eeaa562b9ab3d4c52fa77c793dfedd127049273eb882a05d3d67a8ce8", size = 68874 }, + { url = "https://files.pythonhosted.org/packages/8f/a2/5907ce319a571eb1e271d6a475920edfeacd92da1021bb2a15ed1b7f6ac1/pyodbc-5.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4627779f0a608b51ce2d2fe6d1d395384e65ca36248bf9dbb6d7cf2c8fda1cab", size = 72536 }, + { url = "https://files.pythonhosted.org/packages/e1/b8/bd438ab2bb9481615142784b0c9778079a87ae1bca7a0fe8aabfc088aa9f/pyodbc-5.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4d997d3b6551273647825c734158ca8a6f682df269f6b3975f2499c01577ddec", size = 71825 }, + { url = "https://files.pythonhosted.org/packages/8b/82/cf71ae99b511a7f20c380ce470de233a0291fa3798afa74e0adc8fad1675/pyodbc-5.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5102007a8c78dd2fc1c1b6f6147de8cfc020f81013e4b46c33e66aaa7d1bf7b1", size = 342304 }, + { url = "https://files.pythonhosted.org/packages/43/ea/03fe042f4a390df05e753ddd21c6cab006baae1eee71ce230f6e2a883944/pyodbc-5.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e3cbc7075a46c411b531ada557c4aef13d034060a70077717124cabc1717e2d", size = 346186 }, + { url = "https://files.pythonhosted.org/packages/f9/80/48178bb50990147adb72ec9e377e94517a0dfaf2f2a6e3fe477d9a33671f/pyodbc-5.2.0-cp311-cp311-win32.whl", hash = "sha256:de1ee7ec2eb326b7be5e2c4ce20d472c5ef1a6eb838d126d1d26779ff5486e49", size = 62418 }, + { url = "https://files.pythonhosted.org/packages/7c/6b/f0ad7d8a535d58f35f375ffbf367c68d0ec54452a431d23b0ebee4cd44c6/pyodbc-5.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:113f904b9852c12f10c7a3288f5a3563ecdbbefe3ccc829074a9eb8255edcd29", size = 68871 }, + { url = "https://files.pythonhosted.org/packages/26/26/104525b728fedfababd3143426b9d0008c70f0d604a3bf5d4773977d83f4/pyodbc-5.2.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:be43d1ece4f2cf4d430996689d89a1a15aeb3a8da8262527e5ced5aee27e89c3", size = 73014 }, + { url = "https://files.pythonhosted.org/packages/4f/7d/bb632488b603bcd2a6753b858e8bc7dd56146dd19bd72003cc09ae6e3fc0/pyodbc-5.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9f7badd0055221a744d76c11440c0856fd2846ed53b6555cf8f0a8893a3e4b03", size = 72515 }, + { url = "https://files.pythonhosted.org/packages/ab/38/a1b9bfe5a7062672268553c2d6ff93676173b0fb4bd583e8c4f74a0e296f/pyodbc-5.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad633c52f4f4e7691daaa2278d6e6ebb2fe4ae7709e610e22c7dd1a1d620cf8b", size = 348561 }, + { url = "https://files.pythonhosted.org/packages/71/82/ddb1c41c682550116f391aa6cab2052910046a30d63014bbe6d09c4958f4/pyodbc-5.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97d086a8f7a302b74c9c2e77bedf954a603b19168af900d4d3a97322e773df63", size = 353962 }, + { url = "https://files.pythonhosted.org/packages/e5/29/fec0e739d0c1cab155843ed71d0717f5e1694effe3f28d397168f48bcd92/pyodbc-5.2.0-cp312-cp312-win32.whl", hash = "sha256:0e4412f8e608db2a4be5bcc75f9581f386ed6a427dbcb5eac795049ba6fc205e", size = 63050 }, + { url = "https://files.pythonhosted.org/packages/21/7f/3a47e022a97b017ffb73351a1061e4401bcb5aa4fc0162d04f4e5452e4fc/pyodbc-5.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:b1f5686b142759c5b2bdbeaa0692622c2ebb1f10780eb3c174b85f5607fbcf55", size = 69485 }, + { url = "https://files.pythonhosted.org/packages/90/be/e5f8022ec57a7ea6aa3717a3f307a44c3b012fce7ad6ec91aad3e2a56978/pyodbc-5.2.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:26844d780045bbc3514d5c2f0d89e7fda7df7db0bd24292eb6902046f5730885", size = 72982 }, + { url = "https://files.pythonhosted.org/packages/5c/0e/71111e4f53936b0b99731d9b6acfc8fc95660533a1421447a63d6e519112/pyodbc-5.2.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:26d2d8fd53b71204c755abc53b0379df4e23fd9a40faf211e1cb87e8a32470f0", size = 72515 }, + { url = "https://files.pythonhosted.org/packages/a5/09/3c06bbc1ebb9ae15f53cefe10774809b67da643883287ba1c44ba053816a/pyodbc-5.2.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a27996b6d27e275dfb5fe8a34087ba1cacadfd1439e636874ef675faea5149d9", size = 347470 }, + { url = "https://files.pythonhosted.org/packages/a4/35/1c7efd4665e7983169d20175014f68578e0edfcbc4602b0bafcefa522c4a/pyodbc-5.2.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaf42c4bd323b8fd01f1cd900cca2d09232155f9b8f0b9bcd0be66763588ce64", size = 353025 }, + { url = "https://files.pythonhosted.org/packages/6d/c9/736d07fa33572abdc50d858fd9e527d2c8281f3acbb90dff4999a3662edd/pyodbc-5.2.0-cp313-cp313-win32.whl", hash = "sha256:207f16b7e9bf09c591616429ebf2b47127e879aad21167ac15158910dc9bbcda", size = 63052 }, + { url = "https://files.pythonhosted.org/packages/73/2a/3219c8b7fa3788fc9f27b5fc2244017223cf070e5ab370f71c519adf9120/pyodbc-5.2.0-cp313-cp313-win_amd64.whl", hash = "sha256:96d3127f28c0dacf18da7ae009cd48eac532d3dcc718a334b86a3c65f6a5ef5c", size = 69486 }, +] + [[package]] name = "pyopenssl" version = "25.0.0" @@ -5238,6 +5270,9 @@ redis = [ { name = "redisvl", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "types-redis", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] +sql = [ + { name = "pyodbc", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, +] usearch = [ { name = "pyarrow", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "usearch", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -5283,7 +5318,7 @@ requires-dist = [ { name = "defusedxml", specifier = "~=0.7" }, { name = "faiss-cpu", marker = "extra == 'faiss'", specifier = ">=1.10.0" }, { name = "flask-dapr", marker = "extra == 'dapr'", specifier = ">=1.14.0" }, - { name = "google-cloud-aiplatform", marker = "extra == 'google'", specifier = "==1.85.0" }, + { name = "google-cloud-aiplatform", marker = "extra == 'google'", specifier = "==1.86.0" }, { name = "google-generativeai", marker = "extra == 'google'", specifier = "~=0.8" }, { name = "ipykernel", marker = "extra == 'notebooks'", specifier = "~=6.29" }, { name = "jinja2", specifier = "~=3.1" }, @@ -5305,15 +5340,16 @@ requires-dist = [ { name = "psycopg", extras = ["binary", "pool"], marker = "extra == 'postgres'", specifier = "~=3.2" }, { name = "pyarrow", marker = "extra == 'usearch'", specifier = ">=12.0,<20.0" }, { name = "pybars4", specifier = "~=0.9" }, - { name = "pydantic", specifier = ">=2.0,!=2.10.0,!=2.10.1,!=2.10.2,!=2.10.3,<2.11" }, + { name = "pydantic", specifier = ">=2.0,!=2.10.0,!=2.10.1,!=2.10.2,!=2.10.3,<2.12" }, { name = "pydantic-settings", specifier = "~=2.0" }, { name = "pymilvus", marker = "extra == 'milvus'", specifier = ">=2.3,<2.6" }, { name = "pymongo", marker = "extra == 'mongo'", specifier = ">=4.8.0,<4.12" }, + { name = "pyodbc", marker = "extra == 'sql'", specifier = ">=5.2" }, { name = "qdrant-client", marker = "extra == 'qdrant'", specifier = "~=1.9" }, { name = "redis", extras = ["hiredis"], marker = "extra == 'redis'", specifier = "~=5.0" }, { name = "redisvl", marker = "extra == 'redis'", specifier = ">=0.3.6" }, { name = "scipy", specifier = ">=1.15.1" }, - { name = "sentence-transformers", marker = "extra == 'hugging-face'", specifier = ">=2.2,<4.0" }, + { name = "sentence-transformers", marker = "extra == 'hugging-face'", specifier = ">=2.2,<5.0" }, { name = "torch", marker = "extra == 'hugging-face'", specifier = "==2.6.0" }, { name = "transformers", extras = ["torch"], marker = "extra == 'hugging-face'", specifier = "~=4.28" }, { name = "types-redis", marker = "extra == 'redis'", specifier = "~=4.6.0.20240425" }, @@ -5323,7 +5359,7 @@ requires-dist = [ { name = "websockets", specifier = ">=13,<16" }, { name = "websockets", marker = "extra == 'realtime'", specifier = ">=13,<16" }, ] -provides-extras = ["anthropic", "autogen", "aws", "azure", "chroma", "dapr", "faiss", "google", "hugging-face", "milvus", "mistralai", "mongo", "notebooks", "ollama", "onnx", "pandas", "pinecone", "postgres", "qdrant", "realtime", "redis", "usearch", "weaviate"] +provides-extras = ["anthropic", "autogen", "aws", "azure", "chroma", "dapr", "faiss", "google", "hugging-face", "milvus", "mistralai", "mongo", "notebooks", "ollama", "onnx", "pandas", "pinecone", "postgres", "qdrant", "realtime", "redis", "sql", "usearch", "weaviate"] [package.metadata.requires-dev] dev = [