Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add support for initializing vecs client with custom schema #63

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import vecs

PYTEST_DB = "postgresql://postgres:password@localhost:5611/vecs_db"

PYTEST_SCHEMA = "test_schema"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the best way to test that escaping is done correctly in all places is to use a crazy schema name.

I tested basic operations with the schema name "esCape Me!" and the test suite fails.

To reproduce that, try:

foo: Collection = client.get_or_create_collection(name="foo", schema="esCape Me!", dimension=5)

and you'll get

sqlalchemy.exc.ProgrammingError: (psycopg2.errors.InvalidSchemaName) schema ""esCape Me!"" does not exist


@pytest.fixture(scope="session")
def maybe_start_pg() -> Generator[None, None, None]:
Expand Down Expand Up @@ -94,12 +94,12 @@ def maybe_start_pg() -> Generator[None, None, None]:
def clean_db(maybe_start_pg: None) -> Generator[str, None, None]:
eng = create_engine(PYTEST_DB)
with eng.begin() as connection:
connection.execute(text("drop schema if exists vecs cascade;"))
connection.execute(text(f"drop schema if exists {PYTEST_SCHEMA} cascade;"))
yield PYTEST_DB
eng.dispose()


@pytest.fixture(scope="function")
def client(clean_db: str) -> Generator[vecs.Client, None, None]:
client_ = vecs.create_client(clean_db)
client_ = vecs.create_client(clean_db, PYTEST_SCHEMA)
yield client_
7 changes: 6 additions & 1 deletion src/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import pytest

import vecs

def test_create_client(clean_db) -> None:
client = vecs.create_client(clean_db)
assert client.schema == "vecs"

client = vecs.create_client(clean_db, "my_schema")
assert client.schema == "my_schema"

def test_extracts_vector_version(client: vecs.Client) -> None:
# pgvector version is sucessfully extracted
Expand Down
4 changes: 2 additions & 2 deletions src/vecs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@
]


def create_client(connection_string: str) -> Client:
def create_client(connection_string: str, schema: str="vecs") -> Client:
"""Creates a client from a Postgres connection string"""
return Client(connection_string)
return Client(connection_string=connection_string, schema=schema)
13 changes: 7 additions & 6 deletions src/vecs/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,24 @@ class Client:
vx.disconnect()
"""

def __init__(self, connection_string: str):
def __init__(self, connection_string: str, schema: str):
"""
Initialize a Client instance.

Args:
connection_string (str): A string representing the database connection information.

schema (str): A string representing the database schema to connect to.
Returns:
None
"""
self.schema = schema
self.engine = create_engine(connection_string)
self.meta = MetaData(schema="vecs")
self.meta = MetaData(schema=self.schema)
self.Session = sessionmaker(self.engine)

with self.Session() as sess:
with sess.begin():
sess.execute(text("create schema if not exists vecs;"))
sess.execute(text(f"create schema if not exists {self.schema};"))
sess.execute(text("create extension if not exists vector;"))
self.vector_version: str = sess.execute(
text(
Expand Down Expand Up @@ -105,7 +106,7 @@ def get_or_create_collection(
CollectionAlreadyExists: If a collection with the same name already exists
"""
from vecs.collection import Collection

adapter_dimension = adapter.exported_dimension if adapter else None

collection = Collection(
Expand Down Expand Up @@ -162,7 +163,7 @@ def get_collection(self, name: str) -> Collection:
join pg_attribute pa
on pc.oid = pa.attrelid
where
pc.relnamespace = 'vecs'::regnamespace
pc.relnamespace = '{self.schema}'::regnamespace
and pc.relkind = 'r'
and pa.attname = 'vec'
and not pc.relname ^@ '_'
Expand Down
18 changes: 9 additions & 9 deletions src/vecs/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _create_if_not_exists(self):
join pg_attribute pa
on pc.oid = pa.attrelid
where
pc.relnamespace = 'vecs'::regnamespace
pc.relnamespace = '{self.client.schema}'::regnamespace
and pc.relkind = 'r'
and pa.attname = 'vec'
and not pc.relname ^@ '_'
Expand Down Expand Up @@ -289,7 +289,7 @@ def _create(self):
text(
f"""
create index ix_meta_{unique_string}
on vecs."{self.table.name}"
on {self.client.schema}."{self.table.name}"
using gin ( metadata jsonb_path_ops )
"""
)
Expand Down Expand Up @@ -576,7 +576,7 @@ def _list_collections(cls, client: "Client") -> List["Collection"]:
"""

query = text(
"""
f"""
select
relname as table_name,
atttypmod as embedding_dim
Expand All @@ -585,7 +585,7 @@ def _list_collections(cls, client: "Client") -> List["Collection"]:
join pg_attribute pa
on pc.oid = pa.attrelid
where
pc.relnamespace = 'vecs'::regnamespace
pc.relnamespace = '{client.schema}'::regnamespace
and pc.relkind = 'r'
and pa.attname = 'vec'
and not pc.relname ^@ '_'
Expand Down Expand Up @@ -636,13 +636,13 @@ def index(self) -> Optional[str]:

if self._index is None:
query = text(
"""
f"""
select
relname as table_name
from
pg_class pc
where
pc.relnamespace = 'vecs'::regnamespace
pc.relnamespace = '{self.client.schema}'::regnamespace
and relname ilike 'ix_vector%'
and pc.relkind = 'i'
"""
Expand Down Expand Up @@ -760,7 +760,7 @@ def create_index(
with sess.begin():
if self.index is not None:
if replace:
sess.execute(text(f'drop index vecs."{self.index}";'))
sess.execute(text(f'drop index "{self.client.schema}"."{self.index}";'))
self._index = None
else:
raise ArgError("replace is set to False but an index exists")
Expand All @@ -787,7 +787,7 @@ def create_index(
text(
f"""
create index ix_{ops}_ivfflat_nl{n_lists}_{unique_string}
on vecs."{self.table.name}"
on {self.client.schema}."{self.table.name}"
using ivfflat (vec {ops}) with (lists={n_lists})
"""
)
Expand All @@ -806,7 +806,7 @@ def create_index(
text(
f"""
create index ix_{ops}_hnsw_m{m}_efc{ef_construction}_{unique_string}
on vecs."{self.table.name}"
on {self.client.schema}."{self.table.name}"
using hnsw (vec {ops}) WITH (m={m}, ef_construction={ef_construction});
"""
)
Expand Down
Loading