Skip to content

Commit

Permalink
Merge pull request #1040 from weaviate/dynamic_index_dirk
Browse files Browse the repository at this point in the history
Dynamic index dirk
  • Loading branch information
dirkkul authored May 7, 2024
2 parents 302c887 + 6dbf504 commit bc6c080
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 52 deletions.
33 changes: 32 additions & 1 deletion integration/test_collection_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
_CollectionConfig,
_CollectionConfigSimple,
_PQConfig,
_VectorIndexConfigDynamic,
_VectorIndexConfigFlat,
_VectorIndexConfigHNSW,
Configure,
Reconfigure,
Expand Down Expand Up @@ -463,6 +465,7 @@ def test_update_flat(collection_factory: CollectionFactory) -> None:
config = collection.config.get()
assert config.vector_index_type == VectorIndexType.FLAT
assert config.vector_index_config is not None
assert isinstance(config.vector_index_config, _VectorIndexConfigFlat)
assert config.vector_index_config.vector_cache_max_objects == 5
assert isinstance(config.vector_index_config.quantizer, _BQConfig)
assert config.vector_index_config.quantizer.rescore_limit == 10
Expand All @@ -476,6 +479,7 @@ def test_update_flat(collection_factory: CollectionFactory) -> None:
config = collection.config.get()
assert config.vector_index_type == VectorIndexType.FLAT
assert config.vector_index_config is not None
assert isinstance(config.vector_index_config, _VectorIndexConfigFlat)
assert config.vector_index_config.vector_cache_max_objects == 10
assert isinstance(config.vector_index_config.quantizer, _BQConfig)
assert config.vector_index_config.quantizer.rescore_limit == 20
Expand Down Expand Up @@ -570,6 +574,7 @@ def test_config_vector_index_flat_and_quantizer_bq(collection_factory: Collectio
conf = collection.config.get()
assert conf.vector_index_type == VectorIndexType.FLAT
assert conf.vector_index_config is not None
assert isinstance(conf.vector_index_config, _VectorIndexConfigFlat)
assert conf.vector_index_config.vector_cache_max_objects == 234
assert isinstance(conf.vector_index_config.quantizer, _BQConfig)
assert conf.vector_index_config.quantizer.rescore_limit == 456
Expand All @@ -587,8 +592,8 @@ def test_config_vector_index_hnsw_and_quantizer_pq(collection_factory: Collectio
conf = collection.config.get()
assert conf.vector_index_type == VectorIndexType.HNSW
assert conf.vector_index_config is not None
assert conf.vector_index_config.vector_cache_max_objects == 234
assert isinstance(conf.vector_index_config, _VectorIndexConfigHNSW)
assert conf.vector_index_config.vector_cache_max_objects == 234
assert conf.vector_index_config.ef_construction == 789
assert isinstance(conf.vector_index_config.quantizer, _PQConfig)
assert conf.vector_index_config.quantizer.segments == 456
Expand Down Expand Up @@ -742,3 +747,29 @@ def test_config_skip_vector_index(collection_factory: CollectionFactory) -> None
assert config.vector_index_config.quantizer is None
assert config.vector_index_config.skip is True
assert config.vector_index_config.vector_cache_max_objects == 1000000000000


def test_dynamic_collection(collection_factory: CollectionFactory) -> None:
collection_dummy = collection_factory("dummy")
if collection_dummy._connection._weaviate_version.is_lower_than(1, 25, 0):
pytest.skip("Dynamic index is not supported in Weaviate versions lower than 1.25.0")

collection = collection_factory(
vector_index_config=Configure.VectorIndex.dynamic(
distance_metric=VectorDistances.COSINE,
threshold=1000,
hnsw=Configure.VectorIndex.hnsw(cleanup_interval_seconds=123, flat_search_cutoff=1234),
flat=Configure.VectorIndex.flat(vector_cache_max_objects=7643),
),
ports=(8090, 50061),
)

config = collection.config.get()
assert isinstance(config.vector_index_config, _VectorIndexConfigDynamic)
assert config.vector_index_config.distance_metric == VectorDistances.COSINE
assert config.vector_index_config.threshold == 1000
assert isinstance(config.vector_index_config.hnsw, _VectorIndexConfigHNSW)
assert config.vector_index_config.hnsw.cleanup_interval_seconds == 123
assert config.vector_index_config.hnsw.flat_search_cutoff == 1234
assert isinstance(config.vector_index_config.flat, _VectorIndexConfigFlat)
assert config.vector_index_config.flat.vector_cache_max_objects == 7643
49 changes: 47 additions & 2 deletions weaviate/collections/classes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_VectorIndexConfigFlatCreate,
_VectorIndexConfigHNSWUpdate,
_VectorIndexConfigFlatUpdate,
_VectorIndexConfigDynamicCreate,
_VectorIndexConfigSkipCreate,
_VectorIndexConfigUpdate,
VectorIndexType as VectorIndexTypeAlias,
Expand Down Expand Up @@ -1119,6 +1120,21 @@ def vector_index_type() -> str:
VectorIndexConfigFlat = _VectorIndexConfigFlat


@dataclass
class _VectorIndexConfigDynamic(_VectorIndexConfig):
distance_metric: VectorDistances
hnsw: Optional[VectorIndexConfigHNSW]
flat: Optional[VectorIndexConfigFlat]
threshold: Optional[int]

@staticmethod
def vector_index_type() -> str:
return VectorIndexType.DYNAMIC.value


VectorIndexConfigDynamic = _VectorIndexConfigDynamic


@dataclass
class _GenerativeConfig(_ConfigBase):
generative: GenerativeSearches
Expand Down Expand Up @@ -1162,7 +1178,9 @@ def to_dict(self) -> Dict[str, Any]:
@dataclass
class _NamedVectorConfig(_ConfigBase):
vectorizer: _NamedVectorizerConfig
vector_index_config: Union[VectorIndexConfigHNSW, VectorIndexConfigFlat]
vector_index_config: Union[
VectorIndexConfigHNSW, VectorIndexConfigFlat, VectorIndexConfigDynamic
]

def to_dict(self) -> Dict:
ret_dict = super().to_dict()
Expand All @@ -1185,7 +1203,9 @@ class _CollectionConfig(_ConfigBase):
replication_config: ReplicationConfig
reranker_config: Optional[RerankerConfig]
sharding_config: Optional[ShardingConfig]
vector_index_config: Union[VectorIndexConfigHNSW, VectorIndexConfigFlat, None]
vector_index_config: Union[
VectorIndexConfigHNSW, VectorIndexConfigFlat, VectorIndexConfigDynamic, None
]
vector_index_type: Optional[VectorIndexType]
vectorizer_config: Optional[VectorizerConfig]
vectorizer: Optional[Vectorizers]
Expand Down Expand Up @@ -1622,6 +1642,31 @@ def flat(
quantizer=quantizer,
)

@staticmethod
def dynamic(
distance_metric: Optional[VectorDistances] = None,
threshold: Optional[int] = None,
hnsw: Optional[_VectorIndexConfigHNSWCreate] = None,
flat: Optional[_VectorIndexConfigFlatCreate] = None,
vector_cache_max_objects: Optional[int] = None,
quantizer: Optional[_BQConfigCreate] = None,
) -> _VectorIndexConfigDynamicCreate:
"""Create a `_VectorIndexConfigDynamicCreate` object to be used when defining the DYNAMIC vector index configuration of Weaviate.
Use this method when defining the `vector_index_config` argument in `collections.create()`.
Arguments:
See [the docs](https://weaviate.io/developers/weaviate/configuration/indexes#how-to-configure-hnsw) for a more detailed view!
""" # noqa: D417 (missing argument descriptions in the docstring)
return _VectorIndexConfigDynamicCreate(
distance=distance_metric,
threshold=threshold,
hnsw=hnsw,
flat=flat,
vectorCacheMaxObjects=vector_cache_max_objects,
quantizer=quantizer,
)


class Configure:
"""Use this factory class to generate the correct object for use when using the `collections.create()` method. E.g., `.multi_tenancy()` will return a `MultiTenancyConfigCreate` object to be used in the `multi_tenancy_config` argument.
Expand Down
117 changes: 70 additions & 47 deletions weaviate/collections/classes/config_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
_NamedVectorizerConfig,
_PQConfig,
_VectorIndexConfigFlat,
_VectorIndexConfigDynamic,
_InvertedIndexConfig,
_BM25Config,
_StopwordsConfig,
Expand Down Expand Up @@ -98,56 +99,76 @@ def __get_vector_index_type(schema: Dict[str, Any]) -> Optional[VectorIndexType]
return None


def __get_vector_index_config(
schema: Dict[str, Any]
) -> Union[_VectorIndexConfigHNSW, _VectorIndexConfigFlat, None]:
if "vectorIndexConfig" not in schema:
return None
def __get_quantizer_config(config: Dict[str, Any]) -> Optional[Union[_PQConfig, _BQConfig]]:
quantizer: Optional[Union[_PQConfig, _BQConfig]] = None
if "bq" in schema["vectorIndexConfig"] and schema["vectorIndexConfig"]["bq"]["enabled"]:
if "bq" in config and config["bq"]["enabled"]:
# values are not present for bq+hnsw
quantizer = _BQConfig(
cache=schema["vectorIndexConfig"]["bq"].get("cache"),
rescore_limit=schema["vectorIndexConfig"]["bq"].get("rescoreLimit"),
cache=config["bq"].get("cache"),
rescore_limit=config["bq"].get("rescoreLimit"),
)
elif "pq" in schema["vectorIndexConfig"] and schema["vectorIndexConfig"]["pq"].get("enabled"):
elif "pq" in config and config["pq"].get("enabled"):
quantizer = _PQConfig(
internal_bit_compression=schema["vectorIndexConfig"]["pq"].get("bitCompression"),
segments=schema["vectorIndexConfig"]["pq"].get("segments"),
centroids=schema["vectorIndexConfig"]["pq"].get("centroids"),
training_limit=schema["vectorIndexConfig"]["pq"].get("trainingLimit"),
internal_bit_compression=config["pq"].get("bitCompression"),
segments=config["pq"].get("segments"),
centroids=config["pq"].get("centroids"),
training_limit=config["pq"].get("trainingLimit"),
encoder=_PQEncoderConfig(
type_=PQEncoderType(
schema["vectorIndexConfig"]["pq"].get("encoder", {}).get("type")
),
type_=PQEncoderType(config["pq"].get("encoder", {}).get("type")),
distribution=PQEncoderDistribution(
schema["vectorIndexConfig"]["pq"].get("encoder", {}).get("distribution")
config["pq"].get("encoder", {}).get("distribution")
),
),
)
return quantizer


def __get_hnsw_config(config: Dict[str, Any]) -> _VectorIndexConfigHNSW:
quantizer = __get_quantizer_config(config)
return _VectorIndexConfigHNSW(
cleanup_interval_seconds=config["cleanupIntervalSeconds"],
distance_metric=VectorDistances(config.get("distance")),
dynamic_ef_min=config["dynamicEfMin"],
dynamic_ef_max=config["dynamicEfMax"],
dynamic_ef_factor=config["dynamicEfFactor"],
ef=config["ef"],
ef_construction=config["efConstruction"],
flat_search_cutoff=config["flatSearchCutoff"],
max_connections=config["maxConnections"],
quantizer=quantizer,
skip=config["skip"],
vector_cache_max_objects=config["vectorCacheMaxObjects"],
)


def __get_flat_config(config: Dict[str, Any]) -> _VectorIndexConfigFlat:
quantizer = __get_quantizer_config(config)
return _VectorIndexConfigFlat(
distance_metric=VectorDistances(config["distance"]),
quantizer=quantizer,
vector_cache_max_objects=config["vectorCacheMaxObjects"],
)


def __get_vector_index_config(
schema: Dict[str, Any]
) -> Union[_VectorIndexConfigHNSW, _VectorIndexConfigFlat, _VectorIndexConfigDynamic, None]:
if "vectorIndexConfig" not in schema:
return None
if schema["vectorIndexType"] == "hnsw":
return _VectorIndexConfigHNSW(
cleanup_interval_seconds=schema["vectorIndexConfig"].get("cleanupIntervalSeconds"),
distance_metric=VectorDistances(schema["vectorIndexConfig"].get("distance")),
dynamic_ef_min=schema["vectorIndexConfig"]["dynamicEfMin"],
dynamic_ef_max=schema["vectorIndexConfig"]["dynamicEfMax"],
dynamic_ef_factor=schema["vectorIndexConfig"]["dynamicEfFactor"],
ef=schema["vectorIndexConfig"]["ef"],
ef_construction=schema["vectorIndexConfig"]["efConstruction"],
flat_search_cutoff=schema["vectorIndexConfig"]["flatSearchCutoff"],
max_connections=schema["vectorIndexConfig"]["maxConnections"],
quantizer=quantizer,
skip=schema["vectorIndexConfig"]["skip"],
vector_cache_max_objects=schema["vectorIndexConfig"]["vectorCacheMaxObjects"],
)
else:
assert schema["vectorIndexType"] == "flat"
return _VectorIndexConfigFlat(
return __get_hnsw_config(schema["vectorIndexConfig"])
elif schema["vectorIndexType"] == "flat":
return __get_flat_config(schema["vectorIndexConfig"])
elif schema["vectorIndexType"] == "dynamic":
return _VectorIndexConfigDynamic(
distance_metric=VectorDistances(schema["vectorIndexConfig"]["distance"]),
quantizer=quantizer,
vector_cache_max_objects=schema["vectorIndexConfig"]["vectorCacheMaxObjects"],
threshold=schema["vectorIndexConfig"].get("threshold"),
quantizer=None,
hnsw=__get_hnsw_config(schema["vectorIndexConfig"]["hnsw"]),
flat=__get_flat_config(schema["vectorIndexConfig"]["flat"]),
)
else:
return None


def __get_vector_config(
Expand Down Expand Up @@ -230,17 +251,19 @@ def _collection_config_from_json(schema: Dict[str, Any]) -> _CollectionConfig:
references=_references_from_config(schema) if schema.get("properties") is not None else [],
replication_config=_ReplicationConfig(factor=schema["replicationConfig"]["factor"]),
reranker_config=__get_rerank_config(schema),
sharding_config=None
if schema.get("multiTenancyConfig", {}).get("enabled", False)
else _ShardingConfig(
virtual_per_physical=schema["shardingConfig"]["virtualPerPhysical"],
desired_count=schema["shardingConfig"]["desiredCount"],
actual_count=schema["shardingConfig"]["actualCount"],
desired_virtual_count=schema["shardingConfig"]["desiredVirtualCount"],
actual_virtual_count=schema["shardingConfig"]["actualVirtualCount"],
key=schema["shardingConfig"]["key"],
strategy=schema["shardingConfig"]["strategy"],
function=schema["shardingConfig"]["function"],
sharding_config=(
None
if schema.get("multiTenancyConfig", {}).get("enabled", False)
else _ShardingConfig(
virtual_per_physical=schema["shardingConfig"]["virtualPerPhysical"],
desired_count=schema["shardingConfig"]["desiredCount"],
actual_count=schema["shardingConfig"]["actualCount"],
desired_virtual_count=schema["shardingConfig"]["desiredVirtualCount"],
actual_virtual_count=schema["shardingConfig"]["actualVirtualCount"],
key=schema["shardingConfig"]["key"],
strategy=schema["shardingConfig"]["strategy"],
function=schema["shardingConfig"]["function"],
)
),
vector_index_config=__get_vector_index_config(schema),
vector_index_type=__get_vector_index_type(schema),
Expand Down
7 changes: 6 additions & 1 deletion weaviate/collections/classes/config_named_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
_VectorIndexConfigCreate,
_VectorIndexConfigHNSWUpdate,
_VectorIndexConfigFlatUpdate,
_VectorIndexConfigDynamicUpdate,
_VectorIndexConfigUpdate,
VectorIndexType,
)
Expand Down Expand Up @@ -858,7 +859,11 @@ class _NamedVectorsUpdate:
def update(
name: str,
*,
vector_index_config: Union[_VectorIndexConfigHNSWUpdate, _VectorIndexConfigFlatUpdate],
vector_index_config: Union[
_VectorIndexConfigHNSWUpdate,
_VectorIndexConfigFlatUpdate,
_VectorIndexConfigDynamicUpdate,
],
) -> _NamedVectorConfigUpdate:
"""Update the vector index configuration of a named vector.
Expand Down
17 changes: 17 additions & 0 deletions weaviate/collections/classes/config_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class VectorIndexType(str, Enum):

HNSW = "hnsw"
FLAT = "flat"
DYNAMIC = "dynamic"


class _VectorIndexConfigCreate(_ConfigCreateModel):
Expand Down Expand Up @@ -101,3 +102,19 @@ class _VectorIndexConfigFlatUpdate(_VectorIndexConfigUpdate):
@staticmethod
def vector_index_type() -> VectorIndexType:
return VectorIndexType.FLAT


class _VectorIndexConfigDynamicCreate(_VectorIndexConfigCreate):
threshold: Optional[int]
hnsw: Optional[_VectorIndexConfigHNSWCreate]
flat: Optional[_VectorIndexConfigFlatCreate]

@staticmethod
def vector_index_type() -> VectorIndexType:
return VectorIndexType.DYNAMIC


class _VectorIndexConfigDynamicUpdate(_VectorIndexConfigUpdate):
@staticmethod
def vector_index_type() -> VectorIndexType:
return VectorIndexType.DYNAMIC
9 changes: 8 additions & 1 deletion weaviate/collections/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

from weaviate.connect.v4 import _ExpectedStatusCodes

from weaviate.collections.classes.config_vector_index import _VectorIndexConfigDynamicUpdate


class _ConfigBase:
def __init__(self, connection: ConnectionV4, name: str, tenant: Optional[str]) -> None:
Expand Down Expand Up @@ -88,12 +90,17 @@ def update(
inverted_index_config: Optional[_InvertedIndexConfigUpdate] = None,
replication_config: Optional[_ReplicationConfigUpdate] = None,
vector_index_config: Optional[
Union[_VectorIndexConfigHNSWUpdate, _VectorIndexConfigFlatUpdate]
Union[
_VectorIndexConfigHNSWUpdate,
_VectorIndexConfigFlatUpdate,
_VectorIndexConfigDynamicUpdate,
]
] = None,
vectorizer_config: Optional[
Union[
_VectorIndexConfigHNSWUpdate,
_VectorIndexConfigFlatUpdate,
_VectorIndexConfigDynamicUpdate,
List[_NamedVectorConfigUpdate],
]
] = None,
Expand Down

0 comments on commit bc6c080

Please sign in to comment.