diff --git a/.gitignore b/.gitignore index ece1f9e97..c2979fbe7 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,5 @@ image.png scratch/ *-test.sh -*.hdf5 \ No newline at end of file +*.hdf5 +*.jsonl \ No newline at end of file diff --git a/integration/test_batch_v4.py b/integration/test_batch_v4.py index 4ec6a6de4..2c9726f2f 100644 --- a/integration/test_batch_v4.py +++ b/integration/test_batch_v4.py @@ -596,8 +596,8 @@ def batch_insert(batch: BatchClient) -> None: with concurrent.futures.ThreadPoolExecutor() as executor: with client.batch.dynamic() as batch: futures = [executor.submit(batch_insert, batch) for _ in range(nr_threads)] - for future in concurrent.futures.as_completed(futures): - future.result() + for future in concurrent.futures.as_completed(futures): + future.result() objs = client.collections.get(name).query.fetch_objects(limit=nr_objects * nr_threads).objects assert len(objs) == nr_objects * nr_threads @@ -687,9 +687,13 @@ def test_batching_error_logs( for obj in [{"name": i} for i in range(100)]: batch.add_object(properties=obj, collection=name) assert ( - "Failed to send 100 objects in a batch of 100. Please inspect client.batch.failed_objects or collection.batch.failed_objects for the failed objects." - in caplog.text - ) + ("Failed to send" in caplog.text) + and ("objects in a batch of" in caplog.text) + and ( + "Please inspect client.batch.failed_objects or collection.batch.failed_objects for the failed objects." + in caplog.text + ) + ) # number of objects sent per batch is not fixed for less than 100 objects def test_references_with_to_uuids(client_factory: ClientFactory) -> None: diff --git a/profiling/test_sphere.py b/profiling/test_sphere.py index d49b2a7d3..a0614a498 100644 --- a/profiling/test_sphere.py +++ b/profiling/test_sphere.py @@ -9,7 +9,7 @@ def test_sphere(collection_factory: CollectionFactory) -> None: - sphere_file = get_file_path("sphere.100k.jsonl") + sphere_file = get_file_path("sphere.1m.jsonl") collection = collection_factory( properties=[ @@ -26,7 +26,7 @@ def test_sphere(collection_factory: CollectionFactory) -> None: ) start = time.time() - import_objects = 50000 + import_objects = 1000000 with collection.batch.dynamic() as batch: with open(sphere_file) as jsonl_file: for i, jsonl in enumerate(jsonl_file): @@ -45,7 +45,7 @@ def test_sphere(collection_factory: CollectionFactory) -> None: vector=json_parsed["vector"], ) if i % 1000 == 0: - print(f"Imported {i} objects") + print(f"Imported {len(collection)} objects") assert len(collection.batch.failed_objects) == 0 assert len(collection) == import_objects print(f"Imported {import_objects} objects in {time.time() - start}") diff --git a/weaviate/collections/batch/base.py b/weaviate/collections/batch/base.py index 207c8926a..72c7b16de 100644 --- a/weaviate/collections/batch/base.py +++ b/weaviate/collections/batch/base.py @@ -1,3 +1,4 @@ +import asyncio import math import threading import time @@ -161,11 +162,12 @@ def __init__( batch_mode: _BatchMode, event_loop: _EventLoop, vectorizer_batching: bool, - objects_: Optional[ObjectsBatchRequest] = None, + objects: Optional[ObjectsBatchRequest] = None, references: Optional[ReferencesBatchRequest] = None, ) -> None: - self.__batch_objects = objects_ or ObjectsBatchRequest() + self.__batch_objects = objects or ObjectsBatchRequest() self.__batch_references = references or ReferencesBatchRequest() + self.__connection = connection self.__consistency_level: Optional[ConsistencyLevel] = consistency_level self.__vectorizer_batching = vectorizer_batching @@ -174,15 +176,12 @@ def __init__( self.__batch_rest = _BatchREST(connection, self.__consistency_level) # lookup table for objects that are currently being processed - is used to not send references from objects that have not been added yet - self.__uuid_lookup_lock = threading.Lock() self.__uuid_lookup: Set[str] = set() # we do not want that users can access the results directly as they are not thread-safe self.__results_for_wrapper_backup = results self.__results_for_wrapper = _BatchDataWrapper() - self.__results_lock = threading.Lock() - self.__cluster = _ClusterBatch(self.__connection) self.__batching_mode: _BatchMode = batch_mode @@ -221,7 +220,6 @@ def __init__( self.__recommended_num_refs: int = 50 self.__active_requests = 0 - self.__active_requests_lock = threading.Lock() # dynamic batching self.__time_last_scale_up: float = 0 @@ -233,9 +231,21 @@ def __init__( # do 62 secs to give us some buffer to the "per-minute" calculation self.__fix_rate_batching_base_time = 62 + self.__loop.run_until_complete(self.__make_asyncio_locks) + self.__bg_thread = self.__start_bg_threads() self.__bg_thread_exception: Optional[Exception] = None + async def __make_asyncio_locks(self) -> None: + """Create the locks in the context of the running event loop so that internal `asyncio.get_event_loop()` calls work.""" + self.__active_requests_lock = asyncio.Lock() + self.__uuid_lookup_lock = asyncio.Lock() + self.__results_lock = asyncio.Lock() + + async def __release_asyncio_lock(self, lock: asyncio.Lock) -> None: + """Release the lock in the context of the running event loop so that internal `asyncio.get_event_loop()` calls work.""" + return lock.release() + @property def number_errors(self) -> int: """Return the number of errors in the batch.""" @@ -292,16 +302,35 @@ def __batch_send(self) -> None: self.__time_stamp_last_request = time.time() self._batch_send = True - self.__active_requests_lock.acquire() + self.__loop.run_until_complete(self.__active_requests_lock.acquire) self.__active_requests += 1 - self.__active_requests_lock.release() + self.__loop.run_until_complete( + self.__release_asyncio_lock, self.__active_requests_lock + ) + + start = time.time() + while (len_o := len(self.__batch_objects)) < self.__recommended_num_objects and ( + len_r := len(self.__batch_references) + ) < self.__recommended_num_refs: + # wait for more objects to be added up to the recommended number + time.sleep(0.01) + if ( + self.__shut_background_thread_down is not None + and self.__shut_background_thread_down.is_set() + ): + # shutdown was requested, exit the loop + break + if time.time() - start >= 1 and ( + len_o == len(self.__batch_objects) or len_r == len(self.__batch_references) + ): + # no new objects were added in the last second, exit the loop + break objs = self.__batch_objects.pop_items(self.__recommended_num_objects) - self.__uuid_lookup_lock.acquire() refs = self.__batch_references.pop_items( - self.__recommended_num_refs, uuid_lookup=self.__uuid_lookup + self.__recommended_num_refs, + uuid_lookup=self.__uuid_lookup, ) - self.__uuid_lookup_lock.release() # do not block the thread - the results are written to a central (locked) list and we want to have multiple concurrent batch-requests self.__loop.schedule( self.__send_batch, @@ -349,6 +378,7 @@ def batch_send_wrapper() -> None: try: self.__batch_send() except Exception as e: + logger.error(e) self.__bg_thread_exception = e demonBatchSend = threading.Thread( @@ -357,6 +387,7 @@ def batch_send_wrapper() -> None: name="BgBatchScheduler", ) demonBatchSend.start() + return demonBatchSend def __dynamic_batching(self) -> None: @@ -459,10 +490,24 @@ async def __send_batch( response_obj = await self.__batch_grpc.objects( objects=objs, timeout=DEFAULT_REQUEST_TIMEOUT ) + if response_obj.has_errors: + logger.error( + { + "message": f"Failed to send {len(response_obj.errors)} in a batch of {len(objs)}", + "errors": {err.message for err in response_obj.errors.values()}, + } + ) except Exception as e: errors_obj = { - idx: ErrorObject(message=repr(e), object_=obj) for idx, obj in enumerate(objs) + idx: ErrorObject(message=repr(e), object_=BatchObject._from_internal(obj)) + for idx, obj in enumerate(objs) } + logger.error( + { + "message": f"Failed to send all objects in a batch of {len(objs)}", + "error": repr(e), + } + ) response_obj = BatchObjectReturn( _all_responses=list(errors_obj.values()), elapsed_seconds=time.time() - start, @@ -509,7 +554,9 @@ async def __send_batch( ) readd_objects = [ - err.object_ for i, err in response_obj.errors.items() if i in readded_objects + err.object_._to_internal() + for i, err in response_obj.errors.items() + if i in readded_objects ] readded_uuids = {obj.uuid for obj in readd_objects} @@ -541,8 +588,8 @@ async def __send_batch( ) else: # sleep a bit to recover from the rate limit in other cases - time.sleep(2**highest_retry_count) - self.__uuid_lookup_lock.acquire() + await asyncio.sleep(2**highest_retry_count) + await self.__uuid_lookup_lock.acquire() self.__uuid_lookup.difference_update( obj.uuid for obj in objs if obj.uuid not in readded_uuids ) @@ -561,7 +608,7 @@ async def __send_batch( "message": "There have been more than 30 failed object batches. Further errors will not be logged.", } ) - self.__results_lock.acquire() + await self.__results_lock.acquire() self.__results_for_wrapper.results.objs += response_obj self.__results_for_wrapper.failed_objects.extend(response_obj.errors.values()) self.__results_lock.release() @@ -573,7 +620,9 @@ async def __send_batch( response_ref = await self.__batch_rest.references(references=refs) except Exception as e: errors_ref = { - idx: ErrorReference(message=repr(e), reference=ref) + idx: ErrorReference( + message=repr(e), reference=BatchReference._from_internal(ref) + ) for idx, ref in enumerate(refs) } response_ref = BatchReferenceReturn( @@ -595,12 +644,12 @@ async def __send_batch( "message": "There have been more than 30 failed reference batches. Further errors will not be logged.", } ) - self.__results_lock.acquire() + await self.__results_lock.acquire() self.__results_for_wrapper.results.refs += response_ref self.__results_for_wrapper.failed_references.extend(response_ref.errors.values()) self.__results_lock.release() - self.__active_requests_lock.acquire() + await self.__active_requests_lock.acquire() self.__active_requests -= 1 self.__active_requests_lock.release() @@ -641,9 +690,7 @@ def _add_object( ) except ValidationError as e: raise WeaviateBatchValidationError(repr(e)) - self.__uuid_lookup_lock.acquire() self.__uuid_lookup.add(str(batch_object.uuid)) - self.__uuid_lookup_lock.release() self.__batch_objects.add(batch_object._to_internal()) # block if queue gets too long or weaviate is overloaded - reading files is faster them sending them so we do diff --git a/weaviate/collections/batch/grpc_batch_objects.py b/weaviate/collections/batch/grpc_batch_objects.py index cf4b554bd..907135bdc 100644 --- a/weaviate/collections/batch/grpc_batch_objects.py +++ b/weaviate/collections/batch/grpc_batch_objects.py @@ -10,6 +10,7 @@ from weaviate.collections.classes.batch import ( ErrorObject, _BatchObject, + BatchObject, BatchObjectReturn, ) from weaviate.collections.classes.config import ConsistencyLevel @@ -117,7 +118,9 @@ async def objects( for idx, weav_obj in enumerate(weaviate_objs): obj = objects[idx] if idx in errors: - error = ErrorObject(errors[idx], obj, original_uuid=obj.uuid) + error = ErrorObject( + errors[idx], BatchObject._from_internal(obj), original_uuid=obj.uuid + ) return_errors[obj.index] = error all_responses[idx] = error else: diff --git a/weaviate/collections/batch/rest.py b/weaviate/collections/batch/rest.py index 1f69efbd7..051bc0db7 100644 --- a/weaviate/collections/batch/rest.py +++ b/weaviate/collections/batch/rest.py @@ -2,6 +2,7 @@ from weaviate.collections.classes.batch import ( ErrorReference, + BatchReference, _BatchReference, BatchReferenceReturn, ) @@ -45,7 +46,7 @@ async def references(self, references: List[_BatchReference]) -> BatchReferenceR errors = { idx: ErrorReference( message=entry["result"]["errors"]["error"][0]["message"], - reference=references[idx], + reference=BatchReference._from_internal(references[idx]), ) for idx, entry in enumerate(payload) if entry["result"]["status"] == "FAILED" diff --git a/weaviate/collections/classes/batch.py b/weaviate/collections/classes/batch.py index cc5e361f2..8f91518b0 100644 --- a/weaviate/collections/classes/batch.py +++ b/weaviate/collections/classes/batch.py @@ -31,7 +31,7 @@ class _BatchReference: to: str tenant: Optional[str] from_uuid: str - to_uuid: Optional[str] = None + to_uuid: Union[str, None] class BatchObject(BaseModel): @@ -49,6 +49,7 @@ class BatchObject(BaseModel): vector: Optional[VECTORS] = Field(default=None) tenant: Optional[str] = Field(default=None) index: int + retry_count: int = 0 def __init__(self, **data: Any) -> None: v = data.get("vector") @@ -76,6 +77,19 @@ def _to_internal(self) -> _BatchObject: index=self.index, ) + @classmethod + def _from_internal(cls, obj: _BatchObject) -> "BatchObject": + return BatchObject( + collection=obj.collection, + vector=obj.vector, + uuid=uuid_package.UUID(obj.uuid), + properties=obj.properties, + tenant=obj.tenant, + references=obj.references, + index=obj.index, + retry_count=obj.retry_count, + ) + @field_validator("collection") def _validate_collection(cls, v: str) -> str: return _capitalize_first_letter(v) @@ -136,13 +150,32 @@ def _to_internal(self) -> _BatchReference: tenant=self.tenant, ) + @classmethod + def _from_internal(cls, ref: _BatchReference) -> "BatchReference": + from_ = ref.from_.split("weaviate://")[1].split("/") + to = ref.to.split("weaviate://")[1].split("/") + if len(to) == 2: + to_object_collection = to[1] + elif len(to) == 1: + to_object_collection = None + else: + raise ValueError(f"Invalid reference 'to' value in _BatchReference object {ref}") + return BatchReference( + from_object_collection=from_[1], + from_object_uuid=ref.from_uuid, + from_property_name=ref.from_[-1], + to_object_uuid=ref.to_uuid if ref.to_uuid is not None else uuid_package.UUID(to[-1]), + to_object_collection=to_object_collection, + tenant=ref.tenant, + ) + @dataclass class ErrorObject: """This class contains the error information for a single object in a batch operation.""" message: str - object_: _BatchObject + object_: BatchObject original_uuid: Optional[UUID] = None @@ -151,7 +184,7 @@ class ErrorReference: """This class contains the error information for a single reference in a batch operation.""" message: str - reference: _BatchReference + reference: BatchReference @dataclass