Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,12 @@ jobs:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
run: |
poetry run test-cov
poetry run test-verbose

- name: Run tests
if: matrix.connection != 'plain' || matrix.redis-stack-version != 'latest'
run: |
SKIP_VECTORIZERS=True SKIP_RERANKERS=True poetry run test-cov
SKIP_VECTORIZERS=True SKIP_RERANKERS=True poetry run test-verbose

- name: Run notebooks
if: matrix.connection == 'plain' && matrix.redis-stack-version == 'latest'
Expand Down
6 changes: 3 additions & 3 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,17 @@ To run Testcontainers-based tests you need a local Docker installation such as:

Tests w/ vectorizers:
```bash
poetry run test-cov
poetry run test-verbose
```

Tests w/out vectorizers:
```bash
SKIP_VECTORIZERS=true poetry run test-cov
SKIP_VECTORIZERS=true poetry run test-verbose
```

Tests w/out rerankers:
```bash
SKIP_RERANKERS=true poetry run test-cov
SKIP_RERANKERS=true poetry run test-verbose
```

### Documentation
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ check-types:
lint: format check-types

test:
SKIP_RERANKERS=true SKIP_VECTORIZERS=true poetry run test-cov
SKIP_RERANKERS=true SKIP_VECTORIZERS=true poetry run test-verbose

test-all:
poetry run test-cov
poetry run test-verbose

check: lint test

Expand Down
49 changes: 40 additions & 9 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,55 @@


@pytest.fixture(scope="session", autouse=True)
def redis_container():
# Set the default Redis version if not already set
def redis_container(request):
"""
Create a unique Compose project for each xdist worker by setting
COMPOSE_PROJECT_NAME. That prevents collisions on container/volume names.
"""
# In xdist, the config has "workerid" in workerinput. For the main (non-xdist)
# process, 'workerid' is often 'master' or something similar.
worker_id = request.config.workerinput.get("workerid", "master")

# Set the Compose project name so containers do not clash across workers
os.environ["COMPOSE_PROJECT_NAME"] = f"redis_test_{worker_id}"
os.environ.setdefault("REDIS_VERSION", "edge")

compose = DockerCompose("tests", compose_file_name="docker-compose.yml", pull=True)
compose = DockerCompose(
context="tests",
compose_file_name="docker-compose.yml",
pull=True,
)

compose.start()

# If you mapped the container port 6379:6379 in docker-compose.yml,
# you might have collisions across workers. If you rely on ephemeral
# host ports, remove the `ports:` block in docker-compose.yml and do:
redis_host, redis_port = compose.get_service_host_and_port("redis", 6379)
redis_url = f"redis://{redis_host}:{redis_port}"
os.environ["REDIS_URL"] = redis_url
#redis_url = f"redis://{redis_host}:{redis_port}"
#os.environ["REDIS_URL"] = redis_url

yield compose

compose.stop()
# Optionally, clean up the COMPOSE_PROJECT_NAME you set:
os.environ.pop("COMPOSE_PROJECT_NAME", None)


@pytest.fixture(scope="session")
def redis_url():
return os.getenv("REDIS_URL", "redis://localhost:6379")
def redis_url(redis_container):
"""
Use the `DockerCompose` fixture to get host/port of the 'redis' service
on container port 6379 (mapped to an ephemeral port on the host).
"""
host, port = redis_container.get_service_host_and_port("redis", 6379)
return f"redis://{host}:{port}"

@pytest.fixture
async def async_client(redis_url):
"""
An async Redis client that uses the dynamic `redis_url`.
"""
client = await RedisConnectionFactory.get_async_redis_connection(redis_url)
yield client
try:
Expand All @@ -38,8 +66,11 @@ async def async_client(redis_url):
raise

@pytest.fixture
def client():
conn = RedisConnectionFactory.get_redis_connection(os.environ["REDIS_URL"])
def client(redis_url):
"""
A sync Redis client that uses the dynamic `redis_url`.
"""
conn = RedisConnectionFactory.get_redis_connection(redis_url)
yield conn
conn.close()

Expand Down
1,168 changes: 558 additions & 610 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ cohere = { version = ">=4.44", optional = true }
mistralai = { version = ">=1.0.0", optional = true }
boto3 = { version = ">=1.34.0", optional = true }
voyageai = { version = ">=0.2.2", optional = true }
pytest-xdist = {extras = ["psutil"], version = "^3.6.1"}

[tool.poetry.extras]
openai = ["openai"]
Expand All @@ -50,7 +51,6 @@ black = ">=20.8b1"
isort = ">=5.6.4"
pylint = "3.1.0"
pytest = "8.1.1"
pytest-cov = "5.0.0"
pytest-asyncio = "0.23.6"
mypy = "1.9.0"
types-redis = "*"
Expand Down Expand Up @@ -81,7 +81,6 @@ check-lint = "scripts:check_lint"
check-mypy = "scripts:check_mypy"
test = "scripts:test"
test-verbose = "scripts:test_verbose"
test-cov = "scripts:test_cov"
cov = "scripts:cov"
test-notebooks = "scripts:test_notebooks"
build-docs = "scripts:build_docs"
Expand Down
1 change: 1 addition & 0 deletions redisvl/schema/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def as_redis_field(self) -> RedisField:

class FlatVectorField(BaseField):
"Vector field with a FLAT index (brute force nearest neighbors search)"

type: str = Field(default="vector", const=True)
attrs: FlatVectorFieldAttributes

Expand Down
6 changes: 4 additions & 2 deletions scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ def check_mypy():


def test():
subprocess.run(["python", "-m", "pytest", "--log-level=CRITICAL"], check=True)
subprocess.run(["python", "-m", "pytest", "-n", "6", "--log-level=CRITICAL"], check=True)


def test_verbose():
subprocess.run(
["python", "-m", "pytest", "-vv", "-s", "--log-level=CRITICAL"], check=True
["python", "-m", "pytest", "-n", "6", "-vv", "-s", "--log-level=CRITICAL"], check=True
)


Expand All @@ -44,6 +44,8 @@ def test_cov():
"python",
"-m",
"pytest",
"-n",
"6",
"-vv",
"--cov=./redisvl",
"--cov-report=xml",
Expand Down
4 changes: 0 additions & 4 deletions tests/integration/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@
EXPECTED_LIB_NAME = f"redis-py(redisvl_v{__version__})"


def test_get_address_from_env(redis_url):
assert get_address_from_env() == redis_url


def test_unpack_redis_modules():
module_list = [
{
Expand Down
24 changes: 17 additions & 7 deletions tests/integration/test_llmcache.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import os
import warnings
from collections import namedtuple
from time import sleep, time
import warnings

import pytest
from pydantic.v1 import ValidationError
Expand Down Expand Up @@ -894,33 +894,43 @@ def test_bad_dtype_connecting_to_existing_cache(redis_url):
)


def test_vectorizer_dtype_mismatch():
def test_vectorizer_dtype_mismatch(redis_url):
with pytest.raises(ValueError):
SemanticCache(
name="test_dtype_mismatch",
dtype="float32",
vectorizer=HFTextVectorizer(dtype="float16"),
redis_url=redis_url,
overwrite=True,
)


def test_invalid_vectorizer():
def test_invalid_vectorizer(redis_url):
with pytest.raises(TypeError):
SemanticCache(
name="test_invalid_vectorizer",
vectorizer="invalid_vectorizer", # type: ignore
redis_url=redis_url,
overwrite=True,
)


def test_passes_through_dtype_to_default_vectorizer():
def test_passes_through_dtype_to_default_vectorizer(redis_url):
# The default is float32, so we should see float64 if we pass it in.
cache = SemanticCache(
name="test_pass_through_dtype", dtype="float64", overwrite=True
name="test_pass_through_dtype",
dtype="float64",
redis_url=redis_url,
overwrite=True,
)
assert cache._vectorizer.dtype == "float64"


def test_deprecated_dtype_argument():
def test_deprecated_dtype_argument(redis_url):
with pytest.warns(DeprecationWarning):
SemanticCache(name="test_deprecated_dtype", dtype="float32", overwrite=True)
SemanticCache(
name="test_deprecated_dtype",
dtype="float32",
redis_url=redis_url,
overwrite=True,
)
4 changes: 2 additions & 2 deletions tests/integration/test_search_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def filter_query():


@pytest.fixture
def index(sample_data):
def index(sample_data, redis_url):
fields_spec = [
{"name": "credit_score", "type": "tag"},
{"name": "user", "type": "tag"},
Expand Down Expand Up @@ -47,7 +47,7 @@ def index(sample_data):
index = SearchIndex.from_dict(json_schema)

# connect to local redis instance
index.connect(os.environ["REDIS_URL"])
index.connect(redis_url=redis_url)

# create the index (no data yet)
index.create(overwrite=True)
Expand Down
12 changes: 8 additions & 4 deletions tests/integration/test_semantic_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,42 +315,46 @@ def test_bad_dtype_connecting_to_exiting_router(redis_url, routes):
)


def test_vectorizer_dtype_mismatch(routes):
def test_vectorizer_dtype_mismatch(routes, redis_url):
with pytest.raises(ValueError):
SemanticRouter(
name="test_dtype_mismatch",
routes=routes,
dtype="float32",
vectorizer=HFTextVectorizer(dtype="float16"),
redis_url=redis_url,
overwrite=True,
)


def test_invalid_vectorizer(routes):
def test_invalid_vectorizer(routes, redis_url):
with pytest.raises(TypeError):
SemanticRouter(
name="test_invalid_vectorizer",
vectorizer="invalid_vectorizer", # type: ignore
redis_url=redis_url,
overwrite=True,
)


def test_passes_through_dtype_to_default_vectorizer(routes):
def test_passes_through_dtype_to_default_vectorizer(routes, redis_url):
# The default is float32, so we should see float64 if we pass it in.
router = SemanticRouter(
name="test_pass_through_dtype",
routes=routes,
dtype="float64",
redis_url=redis_url,
overwrite=True,
)
assert router.vectorizer.dtype == "float64"


def test_deprecated_dtype_argument(routes):
def test_deprecated_dtype_argument(routes, redis_url):
with pytest.warns(DeprecationWarning):
SemanticRouter(
name="test_deprecated_dtype",
routes=routes,
dtype="float32",
redis_url=redis_url,
overwrite=True,
)
19 changes: 13 additions & 6 deletions tests/integration/test_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,33 +589,40 @@ def test_bad_dtype_connecting_to_exiting_session(redis_url):
)


def test_vectorizer_dtype_mismatch():
def test_vectorizer_dtype_mismatch(redis_url):
with pytest.raises(ValueError):
SemanticSessionManager(
name="test_dtype_mismatch",
dtype="float32",
vectorizer=HFTextVectorizer(dtype="float16"),
redis_url=redis_url,
overwrite=True,
)


def test_invalid_vectorizer():
def test_invalid_vectorizer(redis_url):
with pytest.raises(TypeError):
SemanticSessionManager(
name="test_invalid_vectorizer",
vectorizer="invalid_vectorizer", # type: ignore
redis_url=redis_url,
overwrite=True,
)


def test_passes_through_dtype_to_default_vectorizer():
def test_passes_through_dtype_to_default_vectorizer(redis_url):
# The default is float32, so we should see float64 if we pass it in.
cache = SemanticSessionManager(
name="test_pass_through_dtype", dtype="float64", overwrite=True
name="test_pass_through_dtype",
dtype="float64",
redis_url=redis_url,
overwrite=True,
)
assert cache._vectorizer.dtype == "float64"


def test_deprecated_dtype_argument():
def test_deprecated_dtype_argument(redis_url):
with pytest.warns(DeprecationWarning):
SemanticSessionManager(name="float64 session", dtype="float64", overwrite=True)
SemanticSessionManager(
name="float64 session", dtype="float64", redis_url=redis_url, overwrite=True
)
Loading