diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1fcf653..95b9cd5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,9 +20,10 @@ repos: hooks: - id: autoflake args: ['--in-place', '--remove-all-unused-imports'] + language_version: python3.8 - repo: https://github.com/ambv/black rev: 22.10.0 hooks: - id: black - language_version: python3.9 + language_version: python3.8 diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 4b6e3e7..2435714 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -13,6 +13,7 @@ import vecs PYTEST_DB = "postgresql://postgres:password@localhost:5611/vecs_db" +PYTEST_SCHEMA = "test_schema" @pytest.fixture(scope="session") @@ -95,6 +96,7 @@ def clean_db(maybe_start_pg: None) -> Generator[str, None, None]: eng = create_engine(PYTEST_DB) with eng.begin() as connection: connection.execute(text("drop schema if exists vecs cascade;")) + connection.execute(text(f"drop schema if exists {PYTEST_SCHEMA} cascade;")) yield PYTEST_DB eng.dispose() diff --git a/src/tests/test_client.py b/src/tests/test_client.py index 6e8694f..db88c85 100644 --- a/src/tests/test_client.py +++ b/src/tests/test_client.py @@ -29,11 +29,17 @@ def test_get_collection(client: vecs.Client) -> None: def test_list_collections(client: vecs.Client) -> None: + """ + Test list_collections returns appropriate results for default schema (vecs) and custom schema + """ assert len(client.list_collections()) == 0 client.get_or_create_collection(name="docs", dimension=384) client.get_or_create_collection(name="books", dimension=1586) + client.get_or_create_collection(name="movies", schema="test_schema", dimension=384) collections = client.list_collections() + collections_test_schema = client.list_collections(schema="test_schema") assert len(collections) == 2 + assert len(collections_test_schema) == 1 def test_delete_collection(client: vecs.Client) -> None: diff --git a/src/tests/test_collection.py b/src/tests/test_collection.py index e7c4b38..be8ec38 100644 --- a/src/tests/test_collection.py +++ b/src/tests/test_collection.py @@ -815,3 +815,69 @@ def test_hnsw_unavailable_error(client: vecs.Client) -> None: bar = client.get_or_create_collection(name="bar", dimension=dim) with pytest.raises(ArgError): bar.create_index(method=IndexMethod.hnsw) + + +def test_get_or_create_with_schema(client: vecs.Client): + """ + Test that get_or_create_collection works when specifying custom schema + """ + + dim = 384 + + collection_1 = client.get_or_create_collection( + name="collection_1", schema="test_schema", dimension=dim + ) + collection_2 = client.get_or_create_collection( + name="collection_1", schema="test_schema", dimension=dim + ) + + assert collection_1.schema == "test_schema" + assert collection_1.schema == collection_2.schema + assert collection_1.name == collection_2.name + + +def test_upsert_with_schema(client: vecs.Client) -> None: + n_records = 100 + dim = 384 + + movies1 = client.get_or_create_collection( + name="ping", schema="test_schema", dimension=dim + ) + movies2 = client.get_or_create_collection(name="ping", schema="vecs", dimension=dim) + + # collection initially empty + assert len(movies1) == 0 + assert len(movies2) == 0 + + records = [ + ( + f"vec{ix}", + vec, + { + "genre": random.choice(["action", "rom-com", "drama"]), + "year": int(50 * random.random()) + 1970, + }, + ) + for ix, vec in enumerate(np.random.random((n_records, dim))) + ] + + # insert works + movies1.upsert(records) + assert len(movies1) == n_records + + movies2.upsert(records) + assert len(movies2) == n_records + + # upserting overwrites + new_record = ("vec0", np.zeros(384), {}) + movies1.upsert([new_record]) + db_record = movies1["vec0"] + db_record[0] == new_record[0] + db_record[1] == new_record[1] + db_record[2] == new_record[2] + + movies2.upsert([new_record]) + db_record = movies2["vec0"] + db_record[0] == new_record[0] + db_record[1] == new_record[1] + db_record[2] == new_record[2] diff --git a/src/vecs/__init__.py b/src/vecs/__init__.py index 55f80e5..c226523 100644 --- a/src/vecs/__init__.py +++ b/src/vecs/__init__.py @@ -24,5 +24,8 @@ def create_client(connection_string: str) -> Client: - """Creates a client from a Postgres connection string""" - return Client(connection_string) + """ + Creates a client from a Postgres connection string and optional schema. + Defaults to `vecs` schema. + """ + return Client(connection_string=connection_string) diff --git a/src/vecs/client.py b/src/vecs/client.py index 89bb3e3..be9430a 100644 --- a/src/vecs/client.py +++ b/src/vecs/client.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, List, Optional from deprecated import deprecated -from sqlalchemy import MetaData, create_engine, text +from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker from vecs.adapter import Adapter @@ -53,12 +53,10 @@ def __init__(self, connection_string: str): Args: connection_string (str): A string representing the database connection information. - Returns: None """ self.engine = create_engine(connection_string) - self.meta = MetaData(schema="vecs") self.Session = sessionmaker(self.engine) with self.Session() as sess: @@ -84,6 +82,7 @@ def get_or_create_collection( self, name: str, *, + schema: str = "vecs", dimension: Optional[int] = None, adapter: Optional[Adapter] = None, ) -> Collection: @@ -113,6 +112,7 @@ def get_or_create_collection( dimension=dimension or adapter_dimension, # type: ignore client=self, adapter=adapter, + schema=schema, ) return collection._create_if_not_exists() @@ -182,18 +182,18 @@ def get_collection(self, name: str) -> Collection: self, ) - def list_collections(self) -> List["Collection"]: + def list_collections(self, *, schema: str = "vecs") -> List["Collection"]: """ - List all vector collections. + List all vector collections by database schema. Returns: list[Collection]: A list of all collections. """ from vecs.collection import Collection - return Collection._list_collections(self) + return Collection._list_collections(self, schema) - def delete_collection(self, name: str) -> None: + def delete_collection(self, name: str, *, schema: str = "vecs") -> None: """ Delete a vector collection. @@ -201,13 +201,14 @@ def delete_collection(self, name: str) -> None: Args: name (str): The name of the collection. + schema (str): Optional, the database schema. Defaults to `vecs`. Returns: None """ from vecs.collection import Collection - Collection(name, -1, self)._drop() + Collection(name, -1, self, schema=schema)._drop() return def disconnect(self) -> None: diff --git a/src/vecs/collection.py b/src/vecs/collection.py index af35538..2442b28 100644 --- a/src/vecs/collection.py +++ b/src/vecs/collection.py @@ -159,6 +159,7 @@ def __init__( dimension: int, client: Client, adapter: Optional[Adapter] = None, + schema: Optional[str] = "vecs", ): """ Initializes a new instance of the `Collection` class. @@ -174,7 +175,12 @@ def __init__( self.client = client self.name = name self.dimension = dimension - self.table = build_table(name, client.meta, dimension) + self._schema = schema + self.schema = self.client.engine.dialect.identifier_preparer.quote_schema( + self._schema + ) + self.meta = MetaData(schema=self.schema) + self.table = build_table(name, self.meta, dimension) self._index: Optional[str] = None self.adapter = adapter or Adapter(steps=[NoOp(dimension=dimension)]) @@ -195,6 +201,10 @@ def __init__( "Dimensions reported by adapter, dimension, and collection do not match" ) + with self.client.Session() as sess: + with sess.begin(): + sess.execute(text(f"create schema if not exists {self.schema};")) + def __repr__(self): """ Returns a string representation of the `Collection` instance. @@ -235,7 +245,7 @@ def _create_if_not_exists(self): join pg_attribute pa on pc.oid = pa.attrelid where - pc.relnamespace = 'vecs'::regnamespace + pc.relnamespace = '{self.schema}'::regnamespace and pc.relkind = 'r' and pa.attname = 'vec' and not pc.relname ^@ '_' @@ -285,11 +295,12 @@ def _create(self): unique_string = str(uuid.uuid4()).replace("-", "_")[0:7] with self.client.Session() as sess: + sess.execute(text(f"create schema if not exists {self.schema};")) sess.execute( text( f""" create index ix_meta_{unique_string} - on vecs."{self.table.name}" + on {self.schema}."{self.table.name}" using gin ( metadata jsonb_path_ops ) """ ) @@ -562,7 +573,7 @@ def query( return sess.execute(stmt).fetchall() or [] @classmethod - def _list_collections(cls, client: "Client") -> List["Collection"]: + def _list_collections(cls, client: "Client", schema: str) -> List["Collection"]: """ PRIVATE @@ -570,13 +581,14 @@ def _list_collections(cls, client: "Client") -> List["Collection"]: Args: client (Client): The database client. + schema (str): The database schema to query. Returns: - List[Collection]: A list of all existing collections. + List[Collection]: A list of all existing collections within the specified schema. """ query = text( - """ + f""" select relname as table_name, atttypmod as embedding_dim @@ -585,7 +597,7 @@ def _list_collections(cls, client: "Client") -> List["Collection"]: join pg_attribute pa on pc.oid = pa.attrelid where - pc.relnamespace = 'vecs'::regnamespace + pc.relnamespace = '{schema}'::regnamespace and pc.relkind = 'r' and pa.attname = 'vec' and not pc.relname ^@ '_' @@ -636,13 +648,13 @@ def index(self) -> Optional[str]: if self._index is None: query = text( - """ + f""" select relname as table_name from pg_class pc where - pc.relnamespace = 'vecs'::regnamespace + pc.relnamespace = '{self.schema}'::regnamespace and relname ilike 'ix_vector%' and pc.relkind = 'i' """ @@ -760,7 +772,9 @@ def create_index( with sess.begin(): if self.index is not None: if replace: - sess.execute(text(f'drop index vecs."{self.index}";')) + sess.execute( + text(f'drop index "{self.schema}"."{self.index}";') + ) self._index = None else: raise ArgError("replace is set to False but an index exists") @@ -787,7 +801,7 @@ def create_index( text( f""" create index ix_{ops}_ivfflat_nl{n_lists}_{unique_string} - on vecs."{self.table.name}" + on {self.schema}."{self.table.name}" using ivfflat (vec {ops}) with (lists={n_lists}) """ ) @@ -806,7 +820,7 @@ def create_index( text( f""" create index ix_{ops}_hnsw_m{m}_efc{ef_construction}_{unique_string} - on vecs."{self.table.name}" + on {self.schema}."{self.table.name}" using hnsw (vec {ops}) WITH (m={m}, ef_construction={ef_construction}); """ )