diff --git a/docker-compose.yml b/docker-compose.yml index 76a60398f3..75292bbd03 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,9 +1,9 @@ --- x-client-libs-stack-image: &client-libs-stack-image - image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_STACK_IMAGE_TAG:-rs-7.4.0-v2}" + image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_STACK_IMAGE_TAG:-8.0-M06-pre}" x-client-libs-image: &client-libs-image - image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_IMAGE_TAG:-7.4.2}" + image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_IMAGE_TAG:-8.0-M06-pre}" services: diff --git a/redis/commands/redismodules.py b/redis/commands/redismodules.py index 7ba40dd845..6e253b1597 100644 --- a/redis/commands/redismodules.py +++ b/redis/commands/redismodules.py @@ -72,6 +72,14 @@ def tdigest(self): tdigest = TDigestBloom(client=self) return tdigest + def vset(self): + """Access the VectorSet commands namespace.""" + + from .vectorset import VectorSet + + vset = VectorSet(client=self) + return vset + class AsyncRedisModuleCommands(RedisModuleCommands): def ft(self, index_name="idx"): diff --git a/redis/commands/vectorset/__init__.py b/redis/commands/vectorset/__init__.py new file mode 100644 index 0000000000..d78580a73b --- /dev/null +++ b/redis/commands/vectorset/__init__.py @@ -0,0 +1,46 @@ +import json + +from redis._parsers.helpers import pairs_to_dict +from redis.commands.vectorset.utils import ( + parse_vemb_result, + parse_vlinks_result, + parse_vsim_result, +) + +from ..helpers import get_protocol_version +from .commands import ( + VEMB_CMD, + VGETATTR_CMD, + VINFO_CMD, + VLINKS_CMD, + VSIM_CMD, + VectorSetCommands, +) + + +class VectorSet(VectorSetCommands): + def __init__(self, client, **kwargs): + """Create a new VectorSet client.""" + # Set the module commands' callbacks + self._MODULE_CALLBACKS = { + VEMB_CMD: parse_vemb_result, + VGETATTR_CMD: lambda r: r and json.loads(r) or None, + } + + self._RESP2_MODULE_CALLBACKS = { + VINFO_CMD: lambda r: r and pairs_to_dict(r) or None, + VSIM_CMD: parse_vsim_result, + VLINKS_CMD: parse_vlinks_result, + } + self._RESP3_MODULE_CALLBACKS = {} + + self.client = client + self.execute_command = client.execute_command + + if get_protocol_version(self.client) in ["3", 3]: + self._MODULE_CALLBACKS.update(self._RESP3_MODULE_CALLBACKS) + else: + self._MODULE_CALLBACKS.update(self._RESP2_MODULE_CALLBACKS) + + for k, v in self._MODULE_CALLBACKS.items(): + self.client.set_response_callback(k, v) diff --git a/redis/commands/vectorset/commands.py b/redis/commands/vectorset/commands.py new file mode 100644 index 0000000000..c24bd200ce --- /dev/null +++ b/redis/commands/vectorset/commands.py @@ -0,0 +1,367 @@ +import json +from enum import Enum +from typing import Awaitable, Dict, List, Optional, Union + +from redis.client import NEVER_DECODE +from redis.commands.helpers import get_protocol_version +from redis.exceptions import DataError +from redis.typing import CommandsProtocol, EncodableT, KeyT, Number + +VADD_CMD = "VADD" +VSIM_CMD = "VSIM" +VREM_CMD = "VREM" +VDIM_CMD = "VDIM" +VCARD_CMD = "VCARD" +VEMB_CMD = "VEMB" +VLINKS_CMD = "VLINKS" +VINFO_CMD = "VINFO" +VSETATTR_CMD = "VSETATTR" +VGETATTR_CMD = "VGETATTR" +VRANDMEMBER_CMD = "VRANDMEMBER" + + +class QuantizationOptions(Enum): + """Quantization options for the VADD command.""" + + NOQUANT = "NOQUANT" + BIN = "BIN" + Q8 = "Q8" + + +class CallbacksOptions(Enum): + """Options that can be set for the commands callbacks""" + + RAW = "RAW" + WITHSCORES = "WITHSCORES" + ALLOW_DECODING = "ALLOW_DECODING" + RESP3 = "RESP3" + + +class VectorSetCommands(CommandsProtocol): + """Redis VectorSet commands""" + + def vadd( + self, + key: KeyT, + vector: Union[List[float], bytes], + element: str, + reduce_dim: Optional[int] = None, + cas: Optional[bool] = False, + quantization: Optional[QuantizationOptions] = None, + ef: Optional[Number] = None, + attributes: Optional[Union[dict, str]] = None, + numlinks: Optional[int] = None, + ) -> Union[Awaitable[int], int]: + """ + Add vector ``vector`` for element ``element`` to a vector set ``key``. + + ``reduce_dim`` sets the dimensions to reduce the vector to. + If not provided, the vector is not reduced. + + ``cas`` is a boolean flag that indicates whether to use CAS (check-and-set style) + when adding the vector. If not provided, CAS is not used. + + ``quantization`` sets the quantization type to use. + If not provided, int8 quantization is used. + The options are: + - NOQUANT: No quantization + - BIN: Binary quantization + - Q8: Signed 8-bit quantization + + ``ef`` sets the exploration factor to use. + If not provided, the default exploration factor is used. + + ``attributes`` is a dictionary or json string that contains the attributes to set for the vector. + If not provided, no attributes are set. + + ``numlinks`` sets the number of links to create for the vector. + If not provided, the default number of links is used. + + For more information see https://redis.io/commands/vadd + """ + if not vector or not element: + raise DataError("Both vector and element must be provided") + + pieces = [] + if reduce_dim: + pieces.extend(["REDUCE", reduce_dim]) + + values_pieces = [] + if isinstance(vector, bytes): + values_pieces.extend(["FP32", vector]) + else: + values_pieces.extend(["VALUES", len(vector)]) + values_pieces.extend(vector) + pieces.extend(values_pieces) + + pieces.append(element) + + if cas: + pieces.append("CAS") + + if quantization: + pieces.append(quantization.value) + + if ef: + pieces.extend(["EF", ef]) + + if attributes: + if isinstance(attributes, dict): + # transform attributes to json string + attributes_json = json.dumps(attributes) + else: + attributes_json = attributes + pieces.extend(["SETATTR", attributes_json]) + + if numlinks: + pieces.extend(["M", numlinks]) + + return self.execute_command(VADD_CMD, key, *pieces) + + def vsim( + self, + key: KeyT, + input: Union[List[float], bytes, str], + with_scores: Optional[bool] = False, + count: Optional[int] = None, + ef: Optional[Number] = None, + filter: Optional[str] = None, + filter_ef: Optional[str] = None, + truth: Optional[bool] = False, + no_thread: Optional[bool] = False, + ) -> Union[ + Awaitable[Optional[List[Union[List[EncodableT], Dict[EncodableT, Number]]]]], + Optional[List[Union[List[EncodableT], Dict[EncodableT, Number]]]], + ]: + """ + Compare a vector or element ``input`` with the other vectors in a vector set ``key``. + + ``with_scores`` sets if the results should be returned with the + similarity scores of the elements in the result. + + ``count`` sets the number of results to return. + + ``ef`` sets the exploration factor. + + ``filter`` sets filter that should be applied for the search. + + ``filter_ef`` sets the max filtering effort. + + ``truth`` when enabled forces the command to perform linear scan. + + ``no_thread`` when enabled forces the command to execute the search + on the data structure in the main thread. + + For more information see https://redis.io/commands/vsim + """ + + if not input: + raise DataError("'input' should be provided") + + pieces = [] + options = {} + + if isinstance(input, bytes): + pieces.extend(["FP32", input]) + elif isinstance(input, list): + pieces.extend(["VALUES", len(input)]) + pieces.extend(input) + else: + pieces.extend(["ELE", input]) + + if with_scores: + pieces.append("WITHSCORES") + options[CallbacksOptions.WITHSCORES.value] = True + + if count: + pieces.extend(["COUNT", count]) + + if ef: + pieces.extend(["EF", ef]) + + if filter: + pieces.extend(["FILTER", filter]) + + if filter_ef: + pieces.extend(["FILTER-EF", filter_ef]) + + if truth: + pieces.append("TRUTH") + + if no_thread: + pieces.append("NOTHREAD") + + return self.execute_command(VSIM_CMD, key, *pieces, **options) + + def vdim(self, key: KeyT) -> Union[Awaitable[int], int]: + """ + Get the dimension of a vector set. + + In the case of vectors that were populated using the `REDUCE` + option, for random projection, the vector set will report the size of + the projected (reduced) dimension. + + Raises `redis.exceptions.ResponseError` if the vector set doesn't exist. + + For more information see https://redis.io/commands/vdim + """ + return self.execute_command(VDIM_CMD, key) + + def vcard(self, key: KeyT) -> Union[Awaitable[int], int]: + """ + Get the cardinality(the number of elements) of a vector set with key ``key``. + + Raises `redis.exceptions.ResponseError` if the vector set doesn't exist. + + For more information see https://redis.io/commands/vcard + """ + return self.execute_command(VCARD_CMD, key) + + def vrem(self, key: KeyT, element: str) -> Union[Awaitable[int], int]: + """ + Remove an element from a vector set. + + For more information see https://redis.io/commands/vrem + """ + return self.execute_command(VREM_CMD, key, element) + + def vemb( + self, key: KeyT, element: str, raw: Optional[bool] = False + ) -> Union[ + Awaitable[Optional[Union[List[EncodableT], Dict[str, EncodableT]]]], + Optional[Union[List[EncodableT], Dict[str, EncodableT]]], + ]: + """ + Get the approximated vector of an element ``element`` from vector set ``key``. + + ``raw`` is a boolean flag that indicates whether to return the + interal representation used by the vector. + + + For more information see https://redis.io/commands/vembed + """ + options = {} + pieces = [] + pieces.extend([key, element]) + + if get_protocol_version(self.client) in ["3", 3]: + options[CallbacksOptions.RESP3.value] = True + + if raw: + pieces.append("RAW") + + options[NEVER_DECODE] = True + if ( + hasattr(self.client, "connection_pool") + and self.client.connection_pool.connection_kwargs["decode_responses"] + ) or ( + hasattr(self.client, "nodes_manager") + and self.client.nodes_manager.connection_kwargs["decode_responses"] + ): + # allow decoding in the postprocessing callback + # if the user set decode_responses=True + # in the connection pool + options[CallbacksOptions.ALLOW_DECODING.value] = True + + options[CallbacksOptions.RAW.value] = True + + return self.execute_command(VEMB_CMD, *pieces, **options) + + def vlinks( + self, key: KeyT, element: str, with_scores: Optional[bool] = False + ) -> Union[ + Awaitable[ + Optional[ + List[Union[List[Union[str, bytes]], Dict[Union[str, bytes], Number]]] + ] + ], + Optional[List[Union[List[Union[str, bytes]], Dict[Union[str, bytes], Number]]]], + ]: + """ + Returns the neighbors for each level the element ``element`` exists in the vector set ``key``. + + The result is a list of lists, where each list contains the neighbors for one level. + If the element does not exist, or if the vector set does not exist, None is returned. + + If the ``WITHSCORES`` option is provided, the result is a list of dicts, + where each dict contains the neighbors for one level, with the scores as values. + + For more information see https://redis.io/commands/vlinks + """ + options = {} + pieces = [] + pieces.extend([key, element]) + + if with_scores: + pieces.append("WITHSCORES") + options[CallbacksOptions.WITHSCORES.value] = True + + return self.execute_command(VLINKS_CMD, *pieces, **options) + + def vinfo(self, key: KeyT) -> Union[Awaitable[dict], dict]: + """ + Get information about a vector set. + + For more information see https://redis.io/commands/vinfo + """ + return self.execute_command(VINFO_CMD, key) + + def vsetattr( + self, key: KeyT, element: str, attributes: Optional[Union[dict, str]] = None + ) -> Union[Awaitable[int], int]: + """ + Associate or remove JSON attributes ``attributes`` of element ``element`` + for vector set ``key``. + + For more information see https://redis.io/commands/vsetattr + """ + if attributes is None: + attributes_json = "{}" + elif isinstance(attributes, dict): + # transform attributes to json string + attributes_json = json.dumps(attributes) + else: + attributes_json = attributes + + return self.execute_command(VSETATTR_CMD, key, element, attributes_json) + + def vgetattr( + self, key: KeyT, element: str + ) -> Union[Optional[Awaitable[dict]], Optional[dict]]: + """ + Retrieve the JSON attributes of an element ``elemet`` for vector set ``key``. + + If the element does not exist, or if the vector set does not exist, None is + returned. + + For more information see https://redis.io/commands/vgetattr + """ + return self.execute_command(VGETATTR_CMD, key, element) + + def vrandmember( + self, key: KeyT, count: Optional[int] = None + ) -> Union[ + Awaitable[Optional[Union[List[str], str]]], Optional[Union[List[str], str]] + ]: + """ + Returns random elements from a vector set ``key``. + + ``count`` is the number of elements to return. + If ``count`` is not provided, a single element is returned as a single string. + If ``count`` is positive(smaller than the number of elements + in the vector set), the command returns a list with up to ``count`` + distinct elements from the vector set + If ``count`` is negative, the command returns a list with ``count`` random elements, + potentially with duplicates. + If ``count`` is greater than the number of elements in the vector set, + only the entire set is returned as a list. + + If the vector set does not exist, ``None`` is returned. + + For more information see https://redis.io/commands/vrandmember + """ + pieces = [] + pieces.append(key) + if count is not None: + pieces.append(count) + return self.execute_command(VRANDMEMBER_CMD, *pieces) diff --git a/redis/commands/vectorset/utils.py b/redis/commands/vectorset/utils.py new file mode 100644 index 0000000000..ed6d194ae0 --- /dev/null +++ b/redis/commands/vectorset/utils.py @@ -0,0 +1,94 @@ +from redis._parsers.helpers import pairs_to_dict +from redis.commands.vectorset.commands import CallbacksOptions + + +def parse_vemb_result(response, **options): + """ + Handle VEMB result since the command can returning different result + structures depending on input options and on quantization type of the vector set. + + Parsing VEMB result into: + - List[Union[bytes, Union[int, float]]] + - Dict[str, Union[bytes, str, float]] + """ + if response is None: + return response + + if options.get(CallbacksOptions.RAW.value): + result = {} + result["quantization"] = ( + response[0].decode("utf-8") + if options.get(CallbacksOptions.ALLOW_DECODING.value) + else response[0] + ) + result["raw"] = response[1] + result["l2"] = float(response[2]) + if len(response) > 3: + result["range"] = float(response[3]) + return result + else: + if options.get(CallbacksOptions.RESP3.value): + return response + + result = [] + for i in range(len(response)): + try: + result.append(int(response[i])) + except ValueError: + # if the value is not an integer, it should be a float + result.append(float(response[i])) + + return result + + +def parse_vlinks_result(response, **options): + """ + Handle VLINKS result since the command can be returning different result + structures depending on input options. + Parsing VLINKS result into: + - List[List[str]] + - List[Dict[str, Number]] + """ + if response is None: + return response + + if options.get(CallbacksOptions.WITHSCORES.value): + result = [] + # Redis will return a list of list of strings. + # This list have to be transformed to list of dicts + for level_item in response: + level_data_dict = {} + for key, value in pairs_to_dict(level_item).items(): + value = float(value) + level_data_dict[key] = value + result.append(level_data_dict) + return result + else: + # return the list of elements for each level + # list of lists + return response + + +def parse_vsim_result(response, **options): + """ + Handle VSIM result since the command can be returning different result + structures depending on input options. + Parsing VSIM result into: + - List[List[str]] + - List[Dict[str, Number]] + """ + if response is None: + return response + + if options.get(CallbacksOptions.WITHSCORES.value): + # Redis will return a list of list of pairs. + # This list have to be transformed to dict + result_dict = {} + for key, value in pairs_to_dict(response).items(): + value = float(value) + result_dict[key] = value + return result_dict + else: + # return the list of elements for each level + # list of lists + return response diff --git a/tests/test_asyncio/test_vsets.py b/tests/test_asyncio/test_vsets.py new file mode 100644 index 0000000000..9abc899066 --- /dev/null +++ b/tests/test_asyncio/test_vsets.py @@ -0,0 +1,858 @@ +import json +import random +import numpy as np +import pytest +import pytest_asyncio +import redis +from redis.commands.vectorset.commands import QuantizationOptions + +from tests.conftest import ( + skip_if_server_version_lt, +) + + +@pytest_asyncio.fixture() +async def d_client(create_redis, redis_url): + return await create_redis(url=redis_url, decode_responses=True) + + +@pytest_asyncio.fixture() +async def client(create_redis, redis_url): + return await create_redis(url=redis_url, decode_responses=False) + + +@skip_if_server_version_lt("7.9.0") +async def test_add_elem_with_values(d_client): + float_array = [1, 4.32, 0.11] + resp = await d_client.vset().vadd("myset", float_array, "elem1") + assert resp == 1 + + emb = await d_client.vset().vemb("myset", "elem1") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + with pytest.raises(redis.DataError): + await d_client.vset().vadd("myset_invalid_data", None, "elem1") + + with pytest.raises(redis.DataError): + await d_client.vset().vadd("myset_invalid_data", [12, 45], None, reduce_dim=3) + + +@skip_if_server_version_lt("7.9.0") +async def test_add_elem_with_vector(d_client): + float_array = [1, 4.32, 0.11] + # Convert the list of floats to a byte array in fp32 format + byte_array = _to_fp32_blob_array(float_array) + resp = await d_client.vset().vadd("myset", byte_array, "elem1") + assert resp == 1 + + emb = await d_client.vset().vemb("myset", "elem1") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + +@skip_if_server_version_lt("7.9.0") +async def test_add_elem_reduced_dim(d_client): + float_array = [1, 4.32, 0.11, 0.5, 0.9] + resp = await d_client.vset().vadd("myset", float_array, "elem1", reduce_dim=3) + assert resp == 1 + + dim = await d_client.vset().vdim("myset") + assert dim == 3 + + +@skip_if_server_version_lt("7.9.0") +async def test_add_elem_cas(d_client): + float_array = [1, 4.32, 0.11, 0.5, 0.9] + resp = await d_client.vset().vadd( + "myset", vector=float_array, element="elem1", cas=True + ) + assert resp == 1 + + emb = await d_client.vset().vemb("myset", "elem1") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + +@skip_if_server_version_lt("7.9.0") +async def test_add_elem_no_quant(d_client): + float_array = [1, 4.32, 0.11, 0.5, 0.9] + resp = await d_client.vset().vadd( + "myset", + vector=float_array, + element="elem1", + quantization=QuantizationOptions.NOQUANT, + ) + assert resp == 1 + + emb = await d_client.vset().vemb("myset", "elem1") + assert _validate_quantization(float_array, emb, tolerance=0.0) + + +@skip_if_server_version_lt("7.9.0") +async def test_add_elem_bin_quant(d_client): + float_array = [1, 4.32, 0.0, 0.05, -2.9] + resp = await d_client.vset().vadd( + "myset", + vector=float_array, + element="elem1", + quantization=QuantizationOptions.BIN, + ) + assert resp == 1 + + emb = await d_client.vset().vemb("myset", "elem1") + expected_array = [1, 1, -1, 1, -1] + assert _validate_quantization(expected_array, emb, tolerance=0.0) + + +@skip_if_server_version_lt("7.9.0") +async def test_add_elem_q8_quant(d_client): + float_array = [1, 4.32, 10.0, -21, -2.9] + resp = await d_client.vset().vadd( + "myset", + vector=float_array, + element="elem1", + quantization=QuantizationOptions.BIN, + ) + assert resp == 1 + + emb = await d_client.vset().vemb("myset", "elem1") + expected_array = [1, 1, 1, -1, -1] + assert _validate_quantization(expected_array, emb, tolerance=0.0) + + +@skip_if_server_version_lt("7.9.0") +async def test_add_elem_ef(d_client): + await d_client.vset().vadd("myset", vector=[5, 55, 65, -20, 30], element="elem1") + await d_client.vset().vadd( + "myset", vector=[-40, -40.32, 10.0, -4, 2.9], element="elem2" + ) + + float_array = [1, 4.32, 10.0, -21, -2.9] + resp = await d_client.vset().vadd("myset", float_array, "elem3", ef=1) + assert resp == 1 + + emb = await d_client.vset().vemb("myset", "elem3") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + sim = await d_client.vset().vsim("myset", input="elem3", with_scores=True) + assert len(sim) == 3 + + +@skip_if_server_version_lt("7.9.0") +async def test_add_elem_with_attr(d_client): + float_array = [1, 4.32, 10.0, -21, -2.9] + attrs_dict = {"key1": "value1", "key2": "value2"} + resp = await d_client.vset().vadd( + "myset", + vector=float_array, + element="elem3", + attributes=attrs_dict, + ) + assert resp == 1 + + emb = await d_client.vset().vemb("myset", "elem3") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + attr_saved = await d_client.vset().vgetattr("myset", "elem3") + assert attr_saved == attrs_dict + + resp = await d_client.vset().vadd( + "myset", + vector=float_array, + element="elem4", + attributes={}, + ) + assert resp == 1 + + emb = await d_client.vset().vemb("myset", "elem4") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + attr_saved = await d_client.vset().vgetattr("myset", "elem4") + assert attr_saved is None + + resp = await d_client.vset().vadd( + "myset", + vector=float_array, + element="elem5", + attributes=json.dumps(attrs_dict), + ) + assert resp == 1 + + emb = await d_client.vset().vemb("myset", "elem5") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + attr_saved = await d_client.vset().vgetattr("myset", "elem5") + assert attr_saved == attrs_dict + + +@skip_if_server_version_lt("7.9.0") +async def test_add_elem_with_numlinks(d_client): + elements_count = 100 + vector_dim = 10 + for i in range(elements_count): + float_array = [random.randint(0, 10) for x in range(vector_dim)] + await d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=8, + ) + + float_array = [1, 4.32, 0.11, 0.5, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5] + resp = await d_client.vset().vadd("myset", float_array, "elem_numlinks", numlinks=8) + assert resp == 1 + + emb = await d_client.vset().vemb("myset", "elem_numlinks") + assert _validate_quantization(float_array, emb, tolerance=0.5) + + numlinks_all_layers = await d_client.vset().vlinks("myset", "elem_numlinks") + for neighbours_list_for_layer in numlinks_all_layers: + assert len(neighbours_list_for_layer) <= 8 + + +@skip_if_server_version_lt("7.9.0") +async def test_vsim_count(d_client): + elements_count = 30 + vector_dim = 800 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + await d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=64, + ) + + vsim = await d_client.vset().vsim("myset", input="elem1") + assert len(vsim) == 10 + assert isinstance(vsim, list) + assert isinstance(vsim[0], str) + + vsim = await d_client.vset().vsim("myset", input="elem1", count=5) + assert len(vsim) == 5 + assert isinstance(vsim, list) + assert isinstance(vsim[0], str) + + vsim = await d_client.vset().vsim("myset", input="elem1", count=50) + assert len(vsim) == 30 + assert isinstance(vsim, list) + assert isinstance(vsim[0], str) + + vsim = await d_client.vset().vsim("myset", input="elem1", count=15) + assert len(vsim) == 15 + assert isinstance(vsim, list) + assert isinstance(vsim[0], str) + + +@skip_if_server_version_lt("7.9.0") +async def test_vsim_with_scores(d_client): + elements_count = 20 + vector_dim = 50 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + await d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=64, + ) + + vsim = await d_client.vset().vsim("myset", input="elem1", with_scores=True) + assert len(vsim) == 10 + assert isinstance(vsim, dict) + assert isinstance(vsim["elem1"], float) + assert 0 <= vsim["elem1"] <= 1 + + +@skip_if_server_version_lt("7.9.0") +async def test_vsim_with_different_vector_input_types(d_client): + elements_count = 10 + vector_dim = 5 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + attributes = {"index": i, "elem_name": f"elem_{i}"} + await d_client.vset().vadd( + "myset", + float_array, + f"elem_{i}", + numlinks=4, + attributes=attributes, + ) + sim = await d_client.vset().vsim("myset", input="elem_1") + assert len(sim) == 10 + assert isinstance(sim, list) + + float_array = [1, 4.32, 0.0, 0.05, -2.9] + sim_to_float_array = await d_client.vset().vsim("myset", input=float_array) + assert len(sim_to_float_array) == 10 + assert isinstance(sim_to_float_array, list) + + fp32_vector = _to_fp32_blob_array(float_array) + sim_to_fp32_vector = await d_client.vset().vsim("myset", input=fp32_vector) + assert len(sim_to_fp32_vector) == 10 + assert isinstance(sim_to_fp32_vector, list) + assert sim_to_float_array == sim_to_fp32_vector + + with pytest.raises(redis.DataError): + await d_client.vset().vsim("myset", input=None) + + +@skip_if_server_version_lt("7.9.0") +async def test_vsim_unexisting(d_client): + float_array = [1, 4.32, 0.11, 0.5, 0.9] + await d_client.vset().vadd("myset", vector=float_array, element="elem1", cas=True) + + with pytest.raises(redis.ResponseError): + await d_client.vset().vsim("myset", input="elem_not_existing") + + sim = await d_client.vset().vsim("myset_not_existing", input="elem1") + assert sim == [] + + +@skip_if_server_version_lt("7.9.0") +async def test_vsim_with_filter(d_client): + elements_count = 30 + vector_dim = 800 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + attributes = {"index": i, "elem_name": f"elem_{i}"} + await d_client.vset().vadd( + "myset", + float_array, + f"elem_{i}", + numlinks=4, + attributes=attributes, + ) + sim = await d_client.vset().vsim("myset", input="elem_1", filter=".index > 10") + assert len(sim) == 10 + assert isinstance(sim, list) + for elem in sim: + assert int(elem.split("_")[1]) > 10 + + sim = await d_client.vset().vsim( + "myset", + input="elem_1", + filter=".index > 10 and .index < 15 and .elem_name in ['elem_12', 'elem_17']", + ) + assert len(sim) == 1 + assert isinstance(sim, list) + assert sim[0] == "elem_12" + + sim = await d_client.vset().vsim( + "myset", + input="elem_1", + filter=".index > 25 and .elem_name in ['elem_12', 'elem_17', 'elem_19']", + ef=100, + ) + assert len(sim) == 0 + assert isinstance(sim, list) + + sim = await d_client.vset().vsim( + "myset", + input="elem_1", + filter=".index > 28 and .elem_name in ['elem_12', 'elem_17', 'elem_29']", + filter_ef=1, + ) + assert len(sim) == 0 + assert isinstance(sim, list) + + sim = await d_client.vset().vsim( + "myset", + input="elem_1", + filter=".index > 28 and .elem_name in ['elem_12', 'elem_17', 'elem_29']", + filter_ef=20, + ) + assert len(sim) == 1 + assert isinstance(sim, list) + + +@skip_if_server_version_lt("7.9.0") +async def test_vsim_truth_no_thread_enabled(d_client): + elements_count = 5000 + vector_dim = 30 + for i in range(1, elements_count + 1): + float_array = [random.uniform(10 * i, 1000 * i) for x in range(vector_dim)] + await d_client.vset().vadd("myset", float_array, f"elem_{i}") + + await d_client.vset().vadd("myset", [-22 for _ in range(vector_dim)], "elem_man_2") + + sim_without_truth = await d_client.vset().vsim( + "myset", input="elem_man_2", with_scores=True + ) + sim_truth = await d_client.vset().vsim( + "myset", input="elem_man_2", with_scores=True, truth=True + ) + + assert len(sim_without_truth) == 10 + assert len(sim_truth) == 10 + + assert isinstance(sim_without_truth, dict) + assert isinstance(sim_truth, dict) + + results_scores = list( + zip( + [v for _, v in sim_truth.items()], [v for _, v in sim_without_truth.items()] + ) + ) + + found_better_match = False + for index, (score_with_truth, score_without_truth) in enumerate(results_scores): + if score_with_truth < score_without_truth: + assert False, ( + "Score with truth [{score_with_truth}] < score without truth [{score_without_truth}]" + ) + elif score_with_truth > score_without_truth: + found_better_match = True + + assert found_better_match + + sim_no_thread = await d_client.vset().vsim( + "myset", input="elem_man_2", with_scores=True, no_thread=True + ) + + assert len(sim_no_thread) == 10 + assert isinstance(sim_no_thread, dict) + + +@skip_if_server_version_lt("7.9.0") +async def test_vdim(d_client): + float_array = [1, 4.32, 0.11, 0.5, 0.9, 0.1, 0.2] + await d_client.vset().vadd("myset", float_array, "elem1") + + dim = await d_client.vset().vdim("myset") + assert dim == len(float_array) + + await d_client.vset().vadd("myset_reduced", float_array, "elem1", reduce_dim=4) + reduced_dim = await d_client.vset().vdim("myset_reduced") + assert reduced_dim == 4 + + with pytest.raises(redis.ResponseError): + await d_client.vset().vdim("myset_unexisting") + + +@skip_if_server_version_lt("7.9.0") +async def test_vcard(d_client): + n = 20 + for i in range(n): + float_array = [random.uniform(0, 10) for x in range(1, 8)] + await d_client.vset().vadd("myset", float_array, f"elem{i}") + + card = await d_client.vset().vcard("myset") + assert card == n + + with pytest.raises(redis.ResponseError): + await d_client.vset().vdim("myset_unexisting") + + +@skip_if_server_version_lt("7.9.0") +async def test_vrem(d_client): + n = 3 + for i in range(n): + float_array = [random.uniform(0, 10) for x in range(1, 8)] + await d_client.vset().vadd("myset", float_array, f"elem{i}") + + resp = await d_client.vset().vrem("myset", "elem2") + assert resp == 1 + + card = await d_client.vset().vcard("myset") + assert card == n - 1 + + resp = await d_client.vset().vrem("myset", "elem2") + assert resp == 0 + + card = await d_client.vset().vcard("myset") + assert card == n - 1 + + resp = await d_client.vset().vrem("myset_unexisting", "elem1") + assert resp == 0 + + +@skip_if_server_version_lt("7.9.0") +async def test_vemb_bin_quantization(d_client): + e = [1, 4.32, 0.0, 0.05, -2.9] + await d_client.vset().vadd( + "myset", + e, + "elem", + quantization=QuantizationOptions.BIN, + ) + emb_no_quant = await d_client.vset().vemb("myset", "elem") + assert emb_no_quant == [1, 1, -1, 1, -1] + + emb_no_quant_raw = await d_client.vset().vemb("myset", "elem", raw=True) + assert emb_no_quant_raw["quantization"] == "bin" + assert isinstance(emb_no_quant_raw["raw"], bytes) + assert isinstance(emb_no_quant_raw["l2"], float) + assert "range" not in emb_no_quant_raw + + +@skip_if_server_version_lt("7.9.0") +async def test_vemb_q8_quantization(d_client): + e = [1, 10.32, 0.0, 2.05, -12.5] + await d_client.vset().vadd("myset", e, "elem", quantization=QuantizationOptions.Q8) + + emb_q8_quant = await d_client.vset().vemb("myset", "elem") + assert _validate_quantization(e, emb_q8_quant, tolerance=0.1) + + emb_q8_quant_raw = await d_client.vset().vemb("myset", "elem", raw=True) + assert emb_q8_quant_raw["quantization"] == "int8" + assert isinstance(emb_q8_quant_raw["raw"], bytes) + assert isinstance(emb_q8_quant_raw["l2"], float) + assert isinstance(emb_q8_quant_raw["range"], float) + + +@skip_if_server_version_lt("7.9.0") +async def test_vemb_no_quantization(d_client): + e = [1, 10.32, 0.0, 2.05, -12.5] + await d_client.vset().vadd( + "myset", e, "elem", quantization=QuantizationOptions.NOQUANT + ) + + emb_no_quant = await d_client.vset().vemb("myset", "elem") + assert _validate_quantization(e, emb_no_quant, tolerance=0.1) + + emb_no_quant_raw = await d_client.vset().vemb("myset", "elem", raw=True) + assert emb_no_quant_raw["quantization"] == "f32" + assert isinstance(emb_no_quant_raw["raw"], bytes) + assert isinstance(emb_no_quant_raw["l2"], float) + assert "range" not in emb_no_quant_raw + + +@skip_if_server_version_lt("7.9.0") +async def test_vemb_default_quantization(d_client): + e = [1, 5.32, 0.0, 0.25, -5] + await d_client.vset().vadd("myset", vector=e, element="elem") + + emb_default_quant = await d_client.vset().vemb("myset", "elem") + assert _validate_quantization(e, emb_default_quant, tolerance=0.1) + + emb_default_quant_raw = await d_client.vset().vemb("myset", "elem", raw=True) + assert emb_default_quant_raw["quantization"] == "int8" + assert isinstance(emb_default_quant_raw["raw"], bytes) + assert isinstance(emb_default_quant_raw["l2"], float) + assert isinstance(emb_default_quant_raw["range"], float) + + +@skip_if_server_version_lt("7.9.0") +async def test_vemb_fp32_quantization(d_client): + float_array_fp32 = [1, 4.32, 0.11] + # Convert the list of floats to a byte array in fp32 format + byte_array = _to_fp32_blob_array(float_array_fp32) + await d_client.vset().vadd("myset", byte_array, "elem") + + emb_fp32_quant = await d_client.vset().vemb("myset", "elem") + assert _validate_quantization(float_array_fp32, emb_fp32_quant, tolerance=0.1) + + emb_fp32_quant_raw = await d_client.vset().vemb("myset", "elem", raw=True) + assert emb_fp32_quant_raw["quantization"] == "int8" + assert isinstance(emb_fp32_quant_raw["raw"], bytes) + assert isinstance(emb_fp32_quant_raw["l2"], float) + assert isinstance(emb_fp32_quant_raw["range"], float) + + +@skip_if_server_version_lt("7.9.0") +async def test_vemb_unexisting(d_client): + emb_not_existing = await d_client.vset().vemb("not_existing", "elem") + assert emb_not_existing is None + + e = [1, 5.32, 0.0, 0.25, -5] + await d_client.vset().vadd("myset", vector=e, element="elem") + emb_elem_not_existing = await d_client.vset().vemb("myset", "not_existing") + assert emb_elem_not_existing is None + + +@skip_if_server_version_lt("7.9.0") +async def test_vlinks(d_client): + elements_count = 100 + vector_dim = 800 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + await d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=8, + ) + + element_links_all_layers = await d_client.vset().vlinks("myset", "elem1") + assert len(element_links_all_layers) >= 1 + for neighbours_list_for_layer in element_links_all_layers: + assert isinstance(neighbours_list_for_layer, list) + for neighbour in neighbours_list_for_layer: + assert isinstance(neighbour, str) + + elem_links_all_layers_with_scores = await d_client.vset().vlinks( + "myset", "elem1", with_scores=True + ) + assert len(elem_links_all_layers_with_scores) >= 1 + for neighbours_dict_for_layer in elem_links_all_layers_with_scores: + assert isinstance(neighbours_dict_for_layer, dict) + for neighbour_key, score_value in neighbours_dict_for_layer.items(): + assert isinstance(neighbour_key, str) + assert isinstance(score_value, float) + + float_array = [0.75, 0.25, 0.5, 0.1, 0.9] + await d_client.vset().vadd("myset_one_elem_only", float_array, "elem1") + elem_no_neighbours_with_scores = await d_client.vset().vlinks( + "myset_one_elem_only", "elem1", with_scores=True + ) + assert len(elem_no_neighbours_with_scores) >= 1 + for neighbours_dict_for_layer in elem_no_neighbours_with_scores: + assert isinstance(neighbours_dict_for_layer, dict) + assert len(neighbours_dict_for_layer) == 0 + + elem_no_neighbours_no_scores = await d_client.vset().vlinks( + "myset_one_elem_only", "elem1" + ) + assert len(elem_no_neighbours_no_scores) >= 1 + for neighbours_list_for_layer in elem_no_neighbours_no_scores: + assert isinstance(neighbours_list_for_layer, list) + assert len(neighbours_list_for_layer) == 0 + + unexisting_element_links = await d_client.vset().vlinks("myset", "unexisting_elem") + assert unexisting_element_links is None + + unexisting_vset_links = await d_client.vset().vlinks("myset_unexisting", "elem1") + assert unexisting_vset_links is None + + unexisting_element_links = await d_client.vset().vlinks( + "myset", "unexisting_elem", with_scores=True + ) + assert unexisting_element_links is None + + unexisting_vset_links = await d_client.vset().vlinks( + "myset_unexisting", "elem1", with_scores=True + ) + assert unexisting_vset_links is None + + +@skip_if_server_version_lt("7.9.0") +async def test_vinfo(d_client): + elements_count = 100 + vector_dim = 800 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + await d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=8, + quantization=QuantizationOptions.BIN, + ) + + vset_info = await d_client.vset().vinfo("myset") + assert vset_info["quant-type"] == "bin" + assert vset_info["vector-dim"] == vector_dim + assert vset_info["size"] == elements_count + assert vset_info["max-level"] > 0 + assert vset_info["hnsw-max-node-uid"] == elements_count + + unexisting_vset_info = await d_client.vset().vinfo("myset_unexisting") + assert unexisting_vset_info is None + + +@skip_if_server_version_lt("7.9.0") +async def test_vset_vget_attributes(d_client): + float_array = [1, 4.32, 0.11] + attributes = {"key1": "value1", "key2": "value2"} + + # validate vgetattrs when no attributes are set with vadd + resp = await d_client.vset().vadd("myset", float_array, "elem1") + assert resp == 1 + + attrs = await d_client.vset().vgetattr("myset", "elem1") + assert attrs is None + + # validate vgetattrs when attributes are set with vadd + resp = await d_client.vset().vadd( + "myset_with_attrs", float_array, "elem1", attributes=attributes + ) + assert resp == 1 + + attrs = await d_client.vset().vgetattr("myset_with_attrs", "elem1") + assert attrs == attributes + + # Set attributes and get attributes + resp = await d_client.vset().vsetattr("myset", "elem1", attributes) + assert resp == 1 + attr_saved = await d_client.vset().vgetattr("myset", "elem1") + assert attr_saved == attributes + + # Set attributes to None + resp = await d_client.vset().vsetattr("myset", "elem1", None) + assert resp == 1 + attr_saved = await d_client.vset().vgetattr("myset", "elem1") + assert attr_saved is None + + # Set attributes to empty dict + resp = await d_client.vset().vsetattr("myset", "elem1", {}) + assert resp == 1 + attr_saved = await d_client.vset().vgetattr("myset", "elem1") + assert attr_saved is None + + # Set attributes provided as string + resp = await d_client.vset().vsetattr("myset", "elem1", json.dumps(attributes)) + assert resp == 1 + attr_saved = await d_client.vset().vgetattr("myset", "elem1") + assert attr_saved == attributes + + # Set attributes to unexisting element + resp = await d_client.vset().vsetattr("myset", "elem2", attributes) + assert resp == 0 + attr_saved = await d_client.vset().vgetattr("myset", "elem2") + assert attr_saved is None + + # Set attributes to unexisting vset + resp = await d_client.vset().vsetattr("myset_unexisting", "elem1", attributes) + assert resp == 0 + attr_saved = await d_client.vset().vgetattr("myset_unexisting", "elem1") + assert attr_saved is None + + +@skip_if_server_version_lt("7.9.0") +async def test_vrandmember(d_client): + elements = ["elem1", "elem2", "elem3"] + for elem in elements: + float_array = [random.uniform(0, 10) for x in range(1, 8)] + await d_client.vset().vadd("myset", float_array, element=elem) + + random_member = await d_client.vset().vrandmember("myset") + assert random_member in elements + + members_list = await d_client.vset().vrandmember("myset", count=2) + assert len(members_list) == 2 + assert all(member in elements for member in members_list) + + # Test with count greater than the number of elements + members_list = await d_client.vset().vrandmember("myset", count=10) + assert len(members_list) == len(elements) + assert all(member in elements for member in members_list) + + # Test with negative count + members_list = await d_client.vset().vrandmember("myset", count=-2) + assert len(members_list) == 2 + assert all(member in elements for member in members_list) + + # Test with count equal to the number of elements + members_list = await d_client.vset().vrandmember("myset", count=len(elements)) + assert len(members_list) == len(elements) + assert all(member in elements for member in members_list) + + # Test with count equal to 0 + members_list = await d_client.vset().vrandmember("myset", count=0) + assert members_list == [] + + # Test with count equal to 1 + members_list = await d_client.vset().vrandmember("myset", count=1) + assert len(members_list) == 1 + assert members_list[0] in elements + + # Test with count equal to -1 + members_list = await d_client.vset().vrandmember("myset", count=-1) + assert len(members_list) == 1 + assert members_list[0] in elements + + # Test with unexisting vset & without count + members_list = await d_client.vset().vrandmember("myset_unexisting") + assert members_list is None + + # Test with unexisting vset & count + members_list = await d_client.vset().vrandmember("myset_unexisting", count=5) + assert members_list == [] + + +@skip_if_server_version_lt("7.9.0") +async def test_vset_commands_without_decoding_responces(client): + # test vadd + elements = ["elem1", "elem2", "elem3"] + for elem in elements: + float_array = [random.uniform(0, 10) for x in range(0, 8)] + resp = await client.vset().vadd("myset", float_array, element=elem) + assert resp == 1 + + # test vemb + emb = await client.vset().vemb("myset", "elem1") + assert len(emb) == 8 + assert isinstance(emb, list) + assert all(isinstance(x, float) for x in emb) + + emb_raw = await client.vset().vemb("myset", "elem1", raw=True) + assert emb_raw["quantization"] == b"int8" + assert isinstance(emb_raw["raw"], bytes) + assert isinstance(emb_raw["l2"], float) + assert isinstance(emb_raw["range"], float) + + # test vsim + vsim = await client.vset().vsim("myset", input="elem1") + assert len(vsim) == 3 + assert isinstance(vsim, list) + assert isinstance(vsim[0], bytes) + + # test vsim with scores + vsim_with_scores = await client.vset().vsim( + "myset", input="elem1", with_scores=True + ) + assert len(vsim_with_scores) == 3 + assert isinstance(vsim_with_scores, dict) + assert isinstance(vsim_with_scores[b"elem1"], float) + + # test vlinks - no scores + element_links_all_layers = await client.vset().vlinks("myset", "elem1") + assert len(element_links_all_layers) >= 1 + for neighbours_list_for_layer in element_links_all_layers: + assert isinstance(neighbours_list_for_layer, list) + for neighbour in neighbours_list_for_layer: + assert isinstance(neighbour, bytes) + # test vlinks with scores + elem_links_all_layers_with_scores = await client.vset().vlinks( + "myset", "elem1", with_scores=True + ) + assert len(elem_links_all_layers_with_scores) >= 1 + for neighbours_dict_for_layer in elem_links_all_layers_with_scores: + assert isinstance(neighbours_dict_for_layer, dict) + for neighbour_key, score_value in neighbours_dict_for_layer.items(): + assert isinstance(neighbour_key, bytes) + assert isinstance(score_value, float) + + # test vinfo + vset_info = await client.vset().vinfo("myset") + assert vset_info[b"quant-type"] == b"int8" + assert vset_info[b"vector-dim"] == 8 + assert vset_info[b"size"] == len(elements) + assert vset_info[b"max-level"] >= 0 + assert vset_info[b"hnsw-max-node-uid"] == len(elements) + + # test vgetattr + attributes = {"key1": "value1", "key2": "value2"} + await client.vset().vsetattr("myset", "elem1", attributes) + attrs = await client.vset().vgetattr("myset", "elem1") + assert attrs == attributes + + # test vrandmember + random_member = await client.vset().vrandmember("myset") + assert isinstance(random_member, bytes) + assert random_member.decode("utf-8") in elements + + members_list = await client.vset().vrandmember("myset", count=2) + assert len(members_list) == 2 + assert all(member.decode("utf-8") in elements for member in members_list) + + +def _to_fp32_blob_array(float_array): + """ + Convert a list of floats to a byte array in fp32 format. + """ + # Convert the list of floats to a NumPy array with dtype np.float32 + arr = np.array(float_array, dtype=np.float32) + # Convert the NumPy array to a byte array + byte_array = arr.tobytes() + return byte_array + + +def _validate_quantization(original, quantized, tolerance=0.1): + original = np.array(original, dtype=np.float32) + quantized = np.array(quantized, dtype=np.float32) + + max_diff = np.max(np.abs(original - quantized)) + if max_diff > tolerance: + return False + else: + return True diff --git a/tests/test_vsets.py b/tests/test_vsets.py new file mode 100644 index 0000000000..ab4194657b --- /dev/null +++ b/tests/test_vsets.py @@ -0,0 +1,856 @@ +import json +import random +import numpy as np +import pytest +import redis +from redis.commands.vectorset.commands import QuantizationOptions + +from .conftest import ( + _get_client, + skip_if_server_version_lt, +) + + +@pytest.fixture +def d_client(request): + r = _get_client(redis.Redis, request, decode_responses=True) + + r.flushdb() + return r + + +@pytest.fixture +def client(request): + r = _get_client(redis.Redis, request, decode_responses=False) + + r.flushdb() + return r + + +@skip_if_server_version_lt("7.9.0") +def test_add_elem_with_values(d_client): + float_array = [1, 4.32, 0.11] + resp = d_client.vset().vadd("myset", float_array, "elem1") + assert resp == 1 + + emb = d_client.vset().vemb("myset", "elem1") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + with pytest.raises(redis.DataError): + d_client.vset().vadd("myset_invalid_data", None, "elem1") + + with pytest.raises(redis.DataError): + d_client.vset().vadd("myset_invalid_data", [12, 45], None, reduce_dim=3) + + +@skip_if_server_version_lt("7.9.0") +def test_add_elem_with_vector(d_client): + float_array = [1, 4.32, 0.11] + # Convert the list of floats to a byte array in fp32 format + byte_array = _to_fp32_blob_array(float_array) + resp = d_client.vset().vadd("myset", byte_array, "elem1") + assert resp == 1 + + emb = d_client.vset().vemb("myset", "elem1") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + +@skip_if_server_version_lt("7.9.0") +def test_add_elem_reduced_dim(d_client): + float_array = [1, 4.32, 0.11, 0.5, 0.9] + resp = d_client.vset().vadd("myset", float_array, "elem1", reduce_dim=3) + assert resp == 1 + + dim = d_client.vset().vdim("myset") + assert dim == 3 + + +@skip_if_server_version_lt("7.9.0") +def test_add_elem_cas(d_client): + float_array = [1, 4.32, 0.11, 0.5, 0.9] + resp = d_client.vset().vadd("myset", vector=float_array, element="elem1", cas=True) + assert resp == 1 + + emb = d_client.vset().vemb("myset", "elem1") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + +@skip_if_server_version_lt("7.9.0") +def test_add_elem_no_quant(d_client): + float_array = [1, 4.32, 0.11, 0.5, 0.9] + resp = d_client.vset().vadd( + "myset", + vector=float_array, + element="elem1", + quantization=QuantizationOptions.NOQUANT, + ) + assert resp == 1 + + emb = d_client.vset().vemb("myset", "elem1") + assert _validate_quantization(float_array, emb, tolerance=0.0) + + +@skip_if_server_version_lt("7.9.0") +def test_add_elem_bin_quant(d_client): + float_array = [1, 4.32, 0.0, 0.05, -2.9] + resp = d_client.vset().vadd( + "myset", + vector=float_array, + element="elem1", + quantization=QuantizationOptions.BIN, + ) + assert resp == 1 + + emb = d_client.vset().vemb("myset", "elem1") + expected_array = [1, 1, -1, 1, -1] + assert _validate_quantization(expected_array, emb, tolerance=0.0) + + +@skip_if_server_version_lt("7.9.0") +def test_add_elem_q8_quant(d_client): + float_array = [1, 4.32, 10.0, -21, -2.9] + resp = d_client.vset().vadd( + "myset", + vector=float_array, + element="elem1", + quantization=QuantizationOptions.BIN, + ) + assert resp == 1 + + emb = d_client.vset().vemb("myset", "elem1") + expected_array = [1, 1, 1, -1, -1] + assert _validate_quantization(expected_array, emb, tolerance=0.0) + + +@skip_if_server_version_lt("7.9.0") +def test_add_elem_ef(d_client): + d_client.vset().vadd("myset", vector=[5, 55, 65, -20, 30], element="elem1") + d_client.vset().vadd("myset", vector=[-40, -40.32, 10.0, -4, 2.9], element="elem2") + + float_array = [1, 4.32, 10.0, -21, -2.9] + resp = d_client.vset().vadd("myset", float_array, "elem3", ef=1) + assert resp == 1 + + emb = d_client.vset().vemb("myset", "elem3") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + sim = d_client.vset().vsim("myset", input="elem3", with_scores=True) + assert len(sim) == 3 + + +@skip_if_server_version_lt("7.9.0") +def test_add_elem_with_attr(d_client): + float_array = [1, 4.32, 10.0, -21, -2.9] + attrs_dict = {"key1": "value1", "key2": "value2"} + resp = d_client.vset().vadd( + "myset", + vector=float_array, + element="elem3", + attributes=attrs_dict, + ) + assert resp == 1 + + emb = d_client.vset().vemb("myset", "elem3") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + attr_saved = d_client.vset().vgetattr("myset", "elem3") + assert attr_saved == attrs_dict + + resp = d_client.vset().vadd( + "myset", + vector=float_array, + element="elem4", + attributes={}, + ) + assert resp == 1 + + emb = d_client.vset().vemb("myset", "elem4") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + attr_saved = d_client.vset().vgetattr("myset", "elem4") + assert attr_saved is None + + resp = d_client.vset().vadd( + "myset", + vector=float_array, + element="elem5", + attributes=json.dumps(attrs_dict), + ) + assert resp == 1 + + emb = d_client.vset().vemb("myset", "elem5") + assert _validate_quantization(float_array, emb, tolerance=0.1) + + attr_saved = d_client.vset().vgetattr("myset", "elem5") + assert attr_saved == attrs_dict + + +@skip_if_server_version_lt("7.9.0") +def test_add_elem_with_numlinks(d_client): + elements_count = 100 + vector_dim = 10 + for i in range(elements_count): + float_array = [random.randint(0, 10) for x in range(vector_dim)] + d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=8, + ) + + float_array = [1, 4.32, 0.11, 0.5, 0.9, 0.1, 0.2, 0.3, 0.4, 0.5] + resp = d_client.vset().vadd("myset", float_array, "elem_numlinks", numlinks=8) + assert resp == 1 + + emb = d_client.vset().vemb("myset", "elem_numlinks") + assert _validate_quantization(float_array, emb, tolerance=0.5) + + numlinks_all_layers = d_client.vset().vlinks("myset", "elem_numlinks") + for neighbours_list_for_layer in numlinks_all_layers: + assert len(neighbours_list_for_layer) <= 8 + + +@skip_if_server_version_lt("7.9.0") +def test_vsim_count(d_client): + elements_count = 30 + vector_dim = 800 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=64, + ) + + vsim = d_client.vset().vsim("myset", input="elem1") + assert len(vsim) == 10 + assert isinstance(vsim, list) + assert isinstance(vsim[0], str) + + vsim = d_client.vset().vsim("myset", input="elem1", count=5) + assert len(vsim) == 5 + assert isinstance(vsim, list) + assert isinstance(vsim[0], str) + + vsim = d_client.vset().vsim("myset", input="elem1", count=50) + assert len(vsim) == 30 + assert isinstance(vsim, list) + assert isinstance(vsim[0], str) + + vsim = d_client.vset().vsim("myset", input="elem1", count=15) + assert len(vsim) == 15 + assert isinstance(vsim, list) + assert isinstance(vsim[0], str) + + +@skip_if_server_version_lt("7.9.0") +def test_vsim_with_scores(d_client): + elements_count = 20 + vector_dim = 50 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=64, + ) + + vsim = d_client.vset().vsim("myset", input="elem1", with_scores=True) + assert len(vsim) == 10 + assert isinstance(vsim, dict) + assert isinstance(vsim["elem1"], float) + assert 0 <= vsim["elem1"] <= 1 + + +@skip_if_server_version_lt("7.9.0") +def test_vsim_with_different_vector_input_types(d_client): + elements_count = 10 + vector_dim = 5 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + attributes = {"index": i, "elem_name": f"elem_{i}"} + d_client.vset().vadd( + "myset", + float_array, + f"elem_{i}", + numlinks=4, + attributes=attributes, + ) + sim = d_client.vset().vsim("myset", input="elem_1") + assert len(sim) == 10 + assert isinstance(sim, list) + + float_array = [1, 4.32, 0.0, 0.05, -2.9] + sim_to_float_array = d_client.vset().vsim("myset", input=float_array) + assert len(sim_to_float_array) == 10 + assert isinstance(sim_to_float_array, list) + + fp32_vector = _to_fp32_blob_array(float_array) + sim_to_fp32_vector = d_client.vset().vsim("myset", input=fp32_vector) + assert len(sim_to_fp32_vector) == 10 + assert isinstance(sim_to_fp32_vector, list) + assert sim_to_float_array == sim_to_fp32_vector + + with pytest.raises(redis.DataError): + d_client.vset().vsim("myset", input=None) + + +@skip_if_server_version_lt("7.9.0") +def test_vsim_unexisting(d_client): + float_array = [1, 4.32, 0.11, 0.5, 0.9] + d_client.vset().vadd("myset", vector=float_array, element="elem1", cas=True) + + with pytest.raises(redis.ResponseError): + d_client.vset().vsim("myset", input="elem_not_existing") + + sim = d_client.vset().vsim("myset_not_existing", input="elem1") + assert sim == [] + + +@skip_if_server_version_lt("7.9.0") +def test_vsim_with_filter(d_client): + elements_count = 30 + vector_dim = 800 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + attributes = {"index": i, "elem_name": f"elem_{i}"} + d_client.vset().vadd( + "myset", + float_array, + f"elem_{i}", + numlinks=4, + attributes=attributes, + ) + sim = d_client.vset().vsim("myset", input="elem_1", filter=".index > 10") + assert len(sim) == 10 + assert isinstance(sim, list) + for elem in sim: + assert int(elem.split("_")[1]) > 10 + + sim = d_client.vset().vsim( + "myset", + input="elem_1", + filter=".index > 10 and .index < 15 and .elem_name in ['elem_12', 'elem_17']", + ) + assert len(sim) == 1 + assert isinstance(sim, list) + assert sim[0] == "elem_12" + + sim = d_client.vset().vsim( + "myset", + input="elem_1", + filter=".index > 25 and .elem_name in ['elem_12', 'elem_17', 'elem_19']", + ef=100, + ) + assert len(sim) == 0 + assert isinstance(sim, list) + + sim = d_client.vset().vsim( + "myset", + input="elem_1", + filter=".index > 28 and .elem_name in ['elem_12', 'elem_17', 'elem_29']", + filter_ef=1, + ) + assert len(sim) == 0 + assert isinstance(sim, list) + + sim = d_client.vset().vsim( + "myset", + input="elem_1", + filter=".index > 28 and .elem_name in ['elem_12', 'elem_17', 'elem_29']", + filter_ef=20, + ) + assert len(sim) == 1 + assert isinstance(sim, list) + + +@skip_if_server_version_lt("7.9.0") +def test_vsim_truth_no_thread_enabled(d_client): + elements_count = 5000 + vector_dim = 30 + for i in range(1, elements_count + 1): + float_array = [random.uniform(10 * i, 1000 * i) for x in range(vector_dim)] + d_client.vset().vadd("myset", float_array, f"elem_{i}") + + d_client.vset().vadd("myset", [-22 for _ in range(vector_dim)], "elem_man_2") + + sim_without_truth = d_client.vset().vsim( + "myset", input="elem_man_2", with_scores=True + ) + sim_truth = d_client.vset().vsim( + "myset", input="elem_man_2", with_scores=True, truth=True + ) + + assert len(sim_without_truth) == 10 + assert len(sim_truth) == 10 + + assert isinstance(sim_without_truth, dict) + assert isinstance(sim_truth, dict) + + results_scores = list( + zip( + [v for _, v in sim_truth.items()], [v for _, v in sim_without_truth.items()] + ) + ) + + found_better_match = False + for index, (score_with_truth, score_without_truth) in enumerate(results_scores): + if score_with_truth < score_without_truth: + assert False, ( + "Score with truth [{score_with_truth}] < score without truth [{score_without_truth}]" + ) + elif score_with_truth > score_without_truth: + found_better_match = True + + assert found_better_match + + sim_no_thread = d_client.vset().vsim( + "myset", input="elem_man_2", with_scores=True, no_thread=True + ) + + assert len(sim_no_thread) == 10 + assert isinstance(sim_no_thread, dict) + + +@skip_if_server_version_lt("7.9.0") +def test_vdim(d_client): + float_array = [1, 4.32, 0.11, 0.5, 0.9, 0.1, 0.2] + d_client.vset().vadd("myset", float_array, "elem1") + + dim = d_client.vset().vdim("myset") + assert dim == len(float_array) + + d_client.vset().vadd("myset_reduced", float_array, "elem1", reduce_dim=4) + reduced_dim = d_client.vset().vdim("myset_reduced") + assert reduced_dim == 4 + + with pytest.raises(redis.ResponseError): + d_client.vset().vdim("myset_unexisting") + + +@skip_if_server_version_lt("7.9.0") +def test_vcard(d_client): + n = 20 + for i in range(n): + float_array = [random.uniform(0, 10) for x in range(1, 8)] + d_client.vset().vadd("myset", float_array, f"elem{i}") + + card = d_client.vset().vcard("myset") + assert card == n + + with pytest.raises(redis.ResponseError): + d_client.vset().vdim("myset_unexisting") + + +@skip_if_server_version_lt("7.9.0") +def test_vrem(d_client): + n = 3 + for i in range(n): + float_array = [random.uniform(0, 10) for x in range(1, 8)] + d_client.vset().vadd("myset", float_array, f"elem{i}") + + resp = d_client.vset().vrem("myset", "elem2") + assert resp == 1 + + card = d_client.vset().vcard("myset") + assert card == n - 1 + + resp = d_client.vset().vrem("myset", "elem2") + assert resp == 0 + + card = d_client.vset().vcard("myset") + assert card == n - 1 + + resp = d_client.vset().vrem("myset_unexisting", "elem1") + assert resp == 0 + + +@skip_if_server_version_lt("7.9.0") +def test_vemb_bin_quantization(d_client): + e = [1, 4.32, 0.0, 0.05, -2.9] + d_client.vset().vadd( + "myset", + e, + "elem", + quantization=QuantizationOptions.BIN, + ) + emb_no_quant = d_client.vset().vemb("myset", "elem") + assert emb_no_quant == [1, 1, -1, 1, -1] + + emb_no_quant_raw = d_client.vset().vemb("myset", "elem", raw=True) + assert emb_no_quant_raw["quantization"] == "bin" + assert isinstance(emb_no_quant_raw["raw"], bytes) + assert isinstance(emb_no_quant_raw["l2"], float) + assert "range" not in emb_no_quant_raw + + +@skip_if_server_version_lt("7.9.0") +def test_vemb_q8_quantization(d_client): + e = [1, 10.32, 0.0, 2.05, -12.5] + d_client.vset().vadd("myset", e, "elem", quantization=QuantizationOptions.Q8) + + emb_q8_quant = d_client.vset().vemb("myset", "elem") + assert _validate_quantization(e, emb_q8_quant, tolerance=0.1) + + emb_q8_quant_raw = d_client.vset().vemb("myset", "elem", raw=True) + assert emb_q8_quant_raw["quantization"] == "int8" + assert isinstance(emb_q8_quant_raw["raw"], bytes) + assert isinstance(emb_q8_quant_raw["l2"], float) + assert isinstance(emb_q8_quant_raw["range"], float) + + +@skip_if_server_version_lt("7.9.0") +def test_vemb_no_quantization(d_client): + e = [1, 10.32, 0.0, 2.05, -12.5] + d_client.vset().vadd("myset", e, "elem", quantization=QuantizationOptions.NOQUANT) + + emb_no_quant = d_client.vset().vemb("myset", "elem") + assert _validate_quantization(e, emb_no_quant, tolerance=0.1) + + emb_no_quant_raw = d_client.vset().vemb("myset", "elem", raw=True) + assert emb_no_quant_raw["quantization"] == "f32" + assert isinstance(emb_no_quant_raw["raw"], bytes) + assert isinstance(emb_no_quant_raw["l2"], float) + assert "range" not in emb_no_quant_raw + + +@skip_if_server_version_lt("7.9.0") +def test_vemb_default_quantization(d_client): + e = [1, 5.32, 0.0, 0.25, -5] + d_client.vset().vadd("myset", vector=e, element="elem") + + emb_default_quant = d_client.vset().vemb("myset", "elem") + assert _validate_quantization(e, emb_default_quant, tolerance=0.1) + + emb_default_quant_raw = d_client.vset().vemb("myset", "elem", raw=True) + assert emb_default_quant_raw["quantization"] == "int8" + assert isinstance(emb_default_quant_raw["raw"], bytes) + assert isinstance(emb_default_quant_raw["l2"], float) + assert isinstance(emb_default_quant_raw["range"], float) + + +@skip_if_server_version_lt("7.9.0") +def test_vemb_fp32_quantization(d_client): + float_array_fp32 = [1, 4.32, 0.11] + # Convert the list of floats to a byte array in fp32 format + byte_array = _to_fp32_blob_array(float_array_fp32) + d_client.vset().vadd("myset", byte_array, "elem") + + emb_fp32_quant = d_client.vset().vemb("myset", "elem") + assert _validate_quantization(float_array_fp32, emb_fp32_quant, tolerance=0.1) + + emb_fp32_quant_raw = d_client.vset().vemb("myset", "elem", raw=True) + assert emb_fp32_quant_raw["quantization"] == "int8" + assert isinstance(emb_fp32_quant_raw["raw"], bytes) + assert isinstance(emb_fp32_quant_raw["l2"], float) + assert isinstance(emb_fp32_quant_raw["range"], float) + + +@skip_if_server_version_lt("7.9.0") +def test_vemb_unexisting(d_client): + emb_not_existing = d_client.vset().vemb("not_existing", "elem") + assert emb_not_existing is None + + e = [1, 5.32, 0.0, 0.25, -5] + d_client.vset().vadd("myset", vector=e, element="elem") + emb_elem_not_existing = d_client.vset().vemb("myset", "not_existing") + assert emb_elem_not_existing is None + + +@skip_if_server_version_lt("7.9.0") +def test_vlinks(d_client): + elements_count = 100 + vector_dim = 800 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=8, + ) + + element_links_all_layers = d_client.vset().vlinks("myset", "elem1") + assert len(element_links_all_layers) >= 1 + for neighbours_list_for_layer in element_links_all_layers: + assert isinstance(neighbours_list_for_layer, list) + for neighbour in neighbours_list_for_layer: + assert isinstance(neighbour, str) + + elem_links_all_layers_with_scores = d_client.vset().vlinks( + "myset", "elem1", with_scores=True + ) + assert len(elem_links_all_layers_with_scores) >= 1 + for neighbours_dict_for_layer in elem_links_all_layers_with_scores: + assert isinstance(neighbours_dict_for_layer, dict) + for neighbour_key, score_value in neighbours_dict_for_layer.items(): + assert isinstance(neighbour_key, str) + assert isinstance(score_value, float) + + float_array = [0.75, 0.25, 0.5, 0.1, 0.9] + d_client.vset().vadd("myset_one_elem_only", float_array, "elem1") + elem_no_neighbours_with_scores = d_client.vset().vlinks( + "myset_one_elem_only", "elem1", with_scores=True + ) + assert len(elem_no_neighbours_with_scores) >= 1 + for neighbours_dict_for_layer in elem_no_neighbours_with_scores: + assert isinstance(neighbours_dict_for_layer, dict) + assert len(neighbours_dict_for_layer) == 0 + + elem_no_neighbours_no_scores = d_client.vset().vlinks( + "myset_one_elem_only", "elem1" + ) + assert len(elem_no_neighbours_no_scores) >= 1 + for neighbours_list_for_layer in elem_no_neighbours_no_scores: + assert isinstance(neighbours_list_for_layer, list) + assert len(neighbours_list_for_layer) == 0 + + unexisting_element_links = d_client.vset().vlinks("myset", "unexisting_elem") + assert unexisting_element_links is None + + unexisting_vset_links = d_client.vset().vlinks("myset_unexisting", "elem1") + assert unexisting_vset_links is None + + unexisting_element_links = d_client.vset().vlinks( + "myset", "unexisting_elem", with_scores=True + ) + assert unexisting_element_links is None + + unexisting_vset_links = d_client.vset().vlinks( + "myset_unexisting", "elem1", with_scores=True + ) + assert unexisting_vset_links is None + + +@skip_if_server_version_lt("7.9.0") +def test_vinfo(d_client): + elements_count = 100 + vector_dim = 800 + for i in range(elements_count): + float_array = [random.uniform(0, 10) for x in range(vector_dim)] + d_client.vset().vadd( + "myset", + float_array, + f"elem{i}", + numlinks=8, + quantization=QuantizationOptions.BIN, + ) + + vset_info = d_client.vset().vinfo("myset") + assert vset_info["quant-type"] == "bin" + assert vset_info["vector-dim"] == vector_dim + assert vset_info["size"] == elements_count + assert vset_info["max-level"] > 0 + assert vset_info["hnsw-max-node-uid"] == elements_count + + unexisting_vset_info = d_client.vset().vinfo("myset_unexisting") + assert unexisting_vset_info is None + + +@skip_if_server_version_lt("7.9.0") +def test_vset_vget_attributes(d_client): + float_array = [1, 4.32, 0.11] + attributes = {"key1": "value1", "key2": "value2"} + + # validate vgetattrs when no attributes are set with vadd + resp = d_client.vset().vadd("myset", float_array, "elem1") + assert resp == 1 + + attrs = d_client.vset().vgetattr("myset", "elem1") + assert attrs is None + + # validate vgetattrs when attributes are set with vadd + resp = d_client.vset().vadd( + "myset_with_attrs", float_array, "elem1", attributes=attributes + ) + assert resp == 1 + + attrs = d_client.vset().vgetattr("myset_with_attrs", "elem1") + assert attrs == attributes + + # Set attributes and get attributes + resp = d_client.vset().vsetattr("myset", "elem1", attributes) + assert resp == 1 + attr_saved = d_client.vset().vgetattr("myset", "elem1") + assert attr_saved == attributes + + # Set attributes to None + resp = d_client.vset().vsetattr("myset", "elem1", None) + assert resp == 1 + attr_saved = d_client.vset().vgetattr("myset", "elem1") + assert attr_saved is None + + # Set attributes to empty dict + resp = d_client.vset().vsetattr("myset", "elem1", {}) + assert resp == 1 + attr_saved = d_client.vset().vgetattr("myset", "elem1") + assert attr_saved is None + + # Set attributes provided as string + resp = d_client.vset().vsetattr("myset", "elem1", json.dumps(attributes)) + assert resp == 1 + attr_saved = d_client.vset().vgetattr("myset", "elem1") + assert attr_saved == attributes + + # Set attributes to unexisting element + resp = d_client.vset().vsetattr("myset", "elem2", attributes) + assert resp == 0 + attr_saved = d_client.vset().vgetattr("myset", "elem2") + assert attr_saved is None + + # Set attributes to unexisting vset + resp = d_client.vset().vsetattr("myset_unexisting", "elem1", attributes) + assert resp == 0 + attr_saved = d_client.vset().vgetattr("myset_unexisting", "elem1") + assert attr_saved is None + + +@skip_if_server_version_lt("7.9.0") +def test_vrandmember(d_client): + elements = ["elem1", "elem2", "elem3"] + for elem in elements: + float_array = [random.uniform(0, 10) for x in range(1, 8)] + d_client.vset().vadd("myset", float_array, element=elem) + + random_member = d_client.vset().vrandmember("myset") + assert random_member in elements + + members_list = d_client.vset().vrandmember("myset", count=2) + assert len(members_list) == 2 + assert all(member in elements for member in members_list) + + # Test with count greater than the number of elements + members_list = d_client.vset().vrandmember("myset", count=10) + assert len(members_list) == len(elements) + assert all(member in elements for member in members_list) + + # Test with negative count + members_list = d_client.vset().vrandmember("myset", count=-2) + assert len(members_list) == 2 + assert all(member in elements for member in members_list) + + # Test with count equal to the number of elements + members_list = d_client.vset().vrandmember("myset", count=len(elements)) + assert len(members_list) == len(elements) + assert all(member in elements for member in members_list) + + # Test with count equal to 0 + members_list = d_client.vset().vrandmember("myset", count=0) + assert members_list == [] + + # Test with count equal to 1 + members_list = d_client.vset().vrandmember("myset", count=1) + assert len(members_list) == 1 + assert members_list[0] in elements + + # Test with count equal to -1 + members_list = d_client.vset().vrandmember("myset", count=-1) + assert len(members_list) == 1 + assert members_list[0] in elements + + # Test with unexisting vset & without count + members_list = d_client.vset().vrandmember("myset_unexisting") + assert members_list is None + + # Test with unexisting vset & count + members_list = d_client.vset().vrandmember("myset_unexisting", count=5) + assert members_list == [] + + +@skip_if_server_version_lt("7.9.0") +def test_vset_commands_without_decoding_responces(client): + # test vadd + elements = ["elem1", "elem2", "elem3"] + for elem in elements: + float_array = [random.uniform(0, 10) for x in range(0, 8)] + resp = client.vset().vadd("myset", float_array, element=elem) + assert resp == 1 + + # test vemb + emb = client.vset().vemb("myset", "elem1") + assert len(emb) == 8 + assert isinstance(emb, list) + assert all(isinstance(x, float) for x in emb) + + emb_raw = client.vset().vemb("myset", "elem1", raw=True) + assert emb_raw["quantization"] == b"int8" + assert isinstance(emb_raw["raw"], bytes) + assert isinstance(emb_raw["l2"], float) + assert isinstance(emb_raw["range"], float) + + # test vsim + vsim = client.vset().vsim("myset", input="elem1") + assert len(vsim) == 3 + assert isinstance(vsim, list) + assert isinstance(vsim[0], bytes) + + # test vsim with scores + vsim_with_scores = client.vset().vsim("myset", input="elem1", with_scores=True) + assert len(vsim_with_scores) == 3 + assert isinstance(vsim_with_scores, dict) + assert isinstance(vsim_with_scores[b"elem1"], float) + + # test vlinks - no scores + element_links_all_layers = client.vset().vlinks("myset", "elem1") + assert len(element_links_all_layers) >= 1 + for neighbours_list_for_layer in element_links_all_layers: + assert isinstance(neighbours_list_for_layer, list) + for neighbour in neighbours_list_for_layer: + assert isinstance(neighbour, bytes) + # test vlinks with scores + elem_links_all_layers_with_scores = client.vset().vlinks( + "myset", "elem1", with_scores=True + ) + assert len(elem_links_all_layers_with_scores) >= 1 + for neighbours_dict_for_layer in elem_links_all_layers_with_scores: + assert isinstance(neighbours_dict_for_layer, dict) + for neighbour_key, score_value in neighbours_dict_for_layer.items(): + assert isinstance(neighbour_key, bytes) + assert isinstance(score_value, float) + + # test vinfo + vset_info = client.vset().vinfo("myset") + assert vset_info[b"quant-type"] == b"int8" + assert vset_info[b"vector-dim"] == 8 + assert vset_info[b"size"] == len(elements) + assert vset_info[b"max-level"] >= 0 + assert vset_info[b"hnsw-max-node-uid"] == len(elements) + + # test vgetattr + attributes = {"key1": "value1", "key2": "value2"} + client.vset().vsetattr("myset", "elem1", attributes) + attrs = client.vset().vgetattr("myset", "elem1") + assert attrs == attributes + + # test vrandmember + random_member = client.vset().vrandmember("myset") + assert isinstance(random_member, bytes) + assert random_member.decode("utf-8") in elements + + members_list = client.vset().vrandmember("myset", count=2) + assert len(members_list) == 2 + assert all(member.decode("utf-8") in elements for member in members_list) + + +def _to_fp32_blob_array(float_array): + """ + Convert a list of floats to a byte array in fp32 format. + """ + # Convert the list of floats to a NumPy array with dtype np.float32 + arr = np.array(float_array, dtype=np.float32) + # Convert the NumPy array to a byte array + byte_array = arr.tobytes() + return byte_array + + +def _validate_quantization(original, quantized, tolerance=0.1): + original = np.array(original, dtype=np.float32) + quantized = np.array(quantized, dtype=np.float32) + + max_diff = np.max(np.abs(original - quantized)) + if max_diff > tolerance: + return False + else: + return True