Skip to content

Commit

Permalink
Merge pull request #1270 from weaviate/fix-deadlocks-in-batching
Browse files Browse the repository at this point in the history
Replace `threading.Lock` with `asyncio.Lock` when batching to avoid deadlocks
  • Loading branch information
tsmith023 authored Jan 10, 2025
2 parents 18838b1 + ce970b4 commit 676dba2
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 35 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ image.png
scratch/

*-test.sh
*.hdf5
*.hdf5
*.jsonl
14 changes: 9 additions & 5 deletions integration/test_batch_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,8 +596,8 @@ def batch_insert(batch: BatchClient) -> None:
with concurrent.futures.ThreadPoolExecutor() as executor:
with client.batch.dynamic() as batch:
futures = [executor.submit(batch_insert, batch) for _ in range(nr_threads)]
for future in concurrent.futures.as_completed(futures):
future.result()
for future in concurrent.futures.as_completed(futures):
future.result()
objs = client.collections.get(name).query.fetch_objects(limit=nr_objects * nr_threads).objects
assert len(objs) == nr_objects * nr_threads

Expand Down Expand Up @@ -687,9 +687,13 @@ def test_batching_error_logs(
for obj in [{"name": i} for i in range(100)]:
batch.add_object(properties=obj, collection=name)
assert (
"Failed to send 100 objects in a batch of 100. Please inspect client.batch.failed_objects or collection.batch.failed_objects for the failed objects."
in caplog.text
)
("Failed to send" in caplog.text)
and ("objects in a batch of" in caplog.text)
and (
"Please inspect client.batch.failed_objects or collection.batch.failed_objects for the failed objects."
in caplog.text
)
) # number of objects sent per batch is not fixed for less than 100 objects


def test_references_with_to_uuids(client_factory: ClientFactory) -> None:
Expand Down
6 changes: 3 additions & 3 deletions profiling/test_sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def test_sphere(collection_factory: CollectionFactory) -> None:
sphere_file = get_file_path("sphere.100k.jsonl")
sphere_file = get_file_path("sphere.1m.jsonl")

collection = collection_factory(
properties=[
Expand All @@ -26,7 +26,7 @@ def test_sphere(collection_factory: CollectionFactory) -> None:
)
start = time.time()

import_objects = 50000
import_objects = 1000000
with collection.batch.dynamic() as batch:
with open(sphere_file) as jsonl_file:
for i, jsonl in enumerate(jsonl_file):
Expand All @@ -45,7 +45,7 @@ def test_sphere(collection_factory: CollectionFactory) -> None:
vector=json_parsed["vector"],
)
if i % 1000 == 0:
print(f"Imported {i} objects")
print(f"Imported {len(collection)} objects")
assert len(collection.batch.failed_objects) == 0
assert len(collection) == import_objects
print(f"Imported {import_objects} objects in {time.time() - start}")
89 changes: 68 additions & 21 deletions weaviate/collections/batch/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import math
import threading
import time
Expand Down Expand Up @@ -161,11 +162,12 @@ def __init__(
batch_mode: _BatchMode,
event_loop: _EventLoop,
vectorizer_batching: bool,
objects_: Optional[ObjectsBatchRequest] = None,
objects: Optional[ObjectsBatchRequest] = None,
references: Optional[ReferencesBatchRequest] = None,
) -> None:
self.__batch_objects = objects_ or ObjectsBatchRequest()
self.__batch_objects = objects or ObjectsBatchRequest()
self.__batch_references = references or ReferencesBatchRequest()

self.__connection = connection
self.__consistency_level: Optional[ConsistencyLevel] = consistency_level
self.__vectorizer_batching = vectorizer_batching
Expand All @@ -174,15 +176,12 @@ def __init__(
self.__batch_rest = _BatchREST(connection, self.__consistency_level)

# lookup table for objects that are currently being processed - is used to not send references from objects that have not been added yet
self.__uuid_lookup_lock = threading.Lock()
self.__uuid_lookup: Set[str] = set()

# we do not want that users can access the results directly as they are not thread-safe
self.__results_for_wrapper_backup = results
self.__results_for_wrapper = _BatchDataWrapper()

self.__results_lock = threading.Lock()

self.__cluster = _ClusterBatch(self.__connection)

self.__batching_mode: _BatchMode = batch_mode
Expand Down Expand Up @@ -221,7 +220,6 @@ def __init__(
self.__recommended_num_refs: int = 50

self.__active_requests = 0
self.__active_requests_lock = threading.Lock()

# dynamic batching
self.__time_last_scale_up: float = 0
Expand All @@ -233,9 +231,21 @@ def __init__(
# do 62 secs to give us some buffer to the "per-minute" calculation
self.__fix_rate_batching_base_time = 62

self.__loop.run_until_complete(self.__make_asyncio_locks)

self.__bg_thread = self.__start_bg_threads()
self.__bg_thread_exception: Optional[Exception] = None

async def __make_asyncio_locks(self) -> None:
"""Create the locks in the context of the running event loop so that internal `asyncio.get_event_loop()` calls work."""
self.__active_requests_lock = asyncio.Lock()
self.__uuid_lookup_lock = asyncio.Lock()
self.__results_lock = asyncio.Lock()

async def __release_asyncio_lock(self, lock: asyncio.Lock) -> None:
"""Release the lock in the context of the running event loop so that internal `asyncio.get_event_loop()` calls work."""
return lock.release()

@property
def number_errors(self) -> int:
"""Return the number of errors in the batch."""
Expand Down Expand Up @@ -292,16 +302,35 @@ def __batch_send(self) -> None:
self.__time_stamp_last_request = time.time()

self._batch_send = True
self.__active_requests_lock.acquire()
self.__loop.run_until_complete(self.__active_requests_lock.acquire)
self.__active_requests += 1
self.__active_requests_lock.release()
self.__loop.run_until_complete(
self.__release_asyncio_lock, self.__active_requests_lock
)

start = time.time()
while (len_o := len(self.__batch_objects)) < self.__recommended_num_objects and (
len_r := len(self.__batch_references)
) < self.__recommended_num_refs:
# wait for more objects to be added up to the recommended number
time.sleep(0.01)
if (
self.__shut_background_thread_down is not None
and self.__shut_background_thread_down.is_set()
):
# shutdown was requested, exit the loop
break
if time.time() - start >= 1 and (
len_o == len(self.__batch_objects) or len_r == len(self.__batch_references)
):
# no new objects were added in the last second, exit the loop
break

objs = self.__batch_objects.pop_items(self.__recommended_num_objects)
self.__uuid_lookup_lock.acquire()
refs = self.__batch_references.pop_items(
self.__recommended_num_refs, uuid_lookup=self.__uuid_lookup
self.__recommended_num_refs,
uuid_lookup=self.__uuid_lookup,
)
self.__uuid_lookup_lock.release()
# do not block the thread - the results are written to a central (locked) list and we want to have multiple concurrent batch-requests
self.__loop.schedule(
self.__send_batch,
Expand Down Expand Up @@ -349,6 +378,7 @@ def batch_send_wrapper() -> None:
try:
self.__batch_send()
except Exception as e:
logger.error(e)
self.__bg_thread_exception = e

demonBatchSend = threading.Thread(
Expand All @@ -357,6 +387,7 @@ def batch_send_wrapper() -> None:
name="BgBatchScheduler",
)
demonBatchSend.start()

return demonBatchSend

def __dynamic_batching(self) -> None:
Expand Down Expand Up @@ -459,10 +490,24 @@ async def __send_batch(
response_obj = await self.__batch_grpc.objects(
objects=objs, timeout=DEFAULT_REQUEST_TIMEOUT
)
if response_obj.has_errors:
logger.error(
{
"message": f"Failed to send {len(response_obj.errors)} in a batch of {len(objs)}",
"errors": {err.message for err in response_obj.errors.values()},
}
)
except Exception as e:
errors_obj = {
idx: ErrorObject(message=repr(e), object_=obj) for idx, obj in enumerate(objs)
idx: ErrorObject(message=repr(e), object_=BatchObject._from_internal(obj))
for idx, obj in enumerate(objs)
}
logger.error(
{
"message": f"Failed to send all objects in a batch of {len(objs)}",
"error": repr(e),
}
)
response_obj = BatchObjectReturn(
_all_responses=list(errors_obj.values()),
elapsed_seconds=time.time() - start,
Expand Down Expand Up @@ -509,7 +554,9 @@ async def __send_batch(
)

readd_objects = [
err.object_ for i, err in response_obj.errors.items() if i in readded_objects
err.object_._to_internal()
for i, err in response_obj.errors.items()
if i in readded_objects
]
readded_uuids = {obj.uuid for obj in readd_objects}

Expand Down Expand Up @@ -541,8 +588,8 @@ async def __send_batch(
)
else:
# sleep a bit to recover from the rate limit in other cases
time.sleep(2**highest_retry_count)
self.__uuid_lookup_lock.acquire()
await asyncio.sleep(2**highest_retry_count)
await self.__uuid_lookup_lock.acquire()
self.__uuid_lookup.difference_update(
obj.uuid for obj in objs if obj.uuid not in readded_uuids
)
Expand All @@ -561,7 +608,7 @@ async def __send_batch(
"message": "There have been more than 30 failed object batches. Further errors will not be logged.",
}
)
self.__results_lock.acquire()
await self.__results_lock.acquire()
self.__results_for_wrapper.results.objs += response_obj
self.__results_for_wrapper.failed_objects.extend(response_obj.errors.values())
self.__results_lock.release()
Expand All @@ -573,7 +620,9 @@ async def __send_batch(
response_ref = await self.__batch_rest.references(references=refs)
except Exception as e:
errors_ref = {
idx: ErrorReference(message=repr(e), reference=ref)
idx: ErrorReference(
message=repr(e), reference=BatchReference._from_internal(ref)
)
for idx, ref in enumerate(refs)
}
response_ref = BatchReferenceReturn(
Expand All @@ -595,12 +644,12 @@ async def __send_batch(
"message": "There have been more than 30 failed reference batches. Further errors will not be logged.",
}
)
self.__results_lock.acquire()
await self.__results_lock.acquire()
self.__results_for_wrapper.results.refs += response_ref
self.__results_for_wrapper.failed_references.extend(response_ref.errors.values())
self.__results_lock.release()

self.__active_requests_lock.acquire()
await self.__active_requests_lock.acquire()
self.__active_requests -= 1
self.__active_requests_lock.release()

Expand Down Expand Up @@ -641,9 +690,7 @@ def _add_object(
)
except ValidationError as e:
raise WeaviateBatchValidationError(repr(e))
self.__uuid_lookup_lock.acquire()
self.__uuid_lookup.add(str(batch_object.uuid))
self.__uuid_lookup_lock.release()
self.__batch_objects.add(batch_object._to_internal())

# block if queue gets too long or weaviate is overloaded - reading files is faster them sending them so we do
Expand Down
5 changes: 4 additions & 1 deletion weaviate/collections/batch/grpc_batch_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from weaviate.collections.classes.batch import (
ErrorObject,
_BatchObject,
BatchObject,
BatchObjectReturn,
)
from weaviate.collections.classes.config import ConsistencyLevel
Expand Down Expand Up @@ -117,7 +118,9 @@ async def objects(
for idx, weav_obj in enumerate(weaviate_objs):
obj = objects[idx]
if idx in errors:
error = ErrorObject(errors[idx], obj, original_uuid=obj.uuid)
error = ErrorObject(
errors[idx], BatchObject._from_internal(obj), original_uuid=obj.uuid
)
return_errors[obj.index] = error
all_responses[idx] = error
else:
Expand Down
3 changes: 2 additions & 1 deletion weaviate/collections/batch/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from weaviate.collections.classes.batch import (
ErrorReference,
BatchReference,
_BatchReference,
BatchReferenceReturn,
)
Expand Down Expand Up @@ -45,7 +46,7 @@ async def references(self, references: List[_BatchReference]) -> BatchReferenceR
errors = {
idx: ErrorReference(
message=entry["result"]["errors"]["error"][0]["message"],
reference=references[idx],
reference=BatchReference._from_internal(references[idx]),
)
for idx, entry in enumerate(payload)
if entry["result"]["status"] == "FAILED"
Expand Down
39 changes: 36 additions & 3 deletions weaviate/collections/classes/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class _BatchReference:
to: str
tenant: Optional[str]
from_uuid: str
to_uuid: Optional[str] = None
to_uuid: Union[str, None]


class BatchObject(BaseModel):
Expand All @@ -49,6 +49,7 @@ class BatchObject(BaseModel):
vector: Optional[VECTORS] = Field(default=None)
tenant: Optional[str] = Field(default=None)
index: int
retry_count: int = 0

def __init__(self, **data: Any) -> None:
v = data.get("vector")
Expand Down Expand Up @@ -76,6 +77,19 @@ def _to_internal(self) -> _BatchObject:
index=self.index,
)

@classmethod
def _from_internal(cls, obj: _BatchObject) -> "BatchObject":
return BatchObject(
collection=obj.collection,
vector=obj.vector,
uuid=uuid_package.UUID(obj.uuid),
properties=obj.properties,
tenant=obj.tenant,
references=obj.references,
index=obj.index,
retry_count=obj.retry_count,
)

@field_validator("collection")
def _validate_collection(cls, v: str) -> str:
return _capitalize_first_letter(v)
Expand Down Expand Up @@ -136,13 +150,32 @@ def _to_internal(self) -> _BatchReference:
tenant=self.tenant,
)

@classmethod
def _from_internal(cls, ref: _BatchReference) -> "BatchReference":
from_ = ref.from_.split("weaviate://")[1].split("/")
to = ref.to.split("weaviate://")[1].split("/")
if len(to) == 2:
to_object_collection = to[1]
elif len(to) == 1:
to_object_collection = None
else:
raise ValueError(f"Invalid reference 'to' value in _BatchReference object {ref}")
return BatchReference(
from_object_collection=from_[1],
from_object_uuid=ref.from_uuid,
from_property_name=ref.from_[-1],
to_object_uuid=ref.to_uuid if ref.to_uuid is not None else uuid_package.UUID(to[-1]),
to_object_collection=to_object_collection,
tenant=ref.tenant,
)


@dataclass
class ErrorObject:
"""This class contains the error information for a single object in a batch operation."""

message: str
object_: _BatchObject
object_: BatchObject
original_uuid: Optional[UUID] = None


Expand All @@ -151,7 +184,7 @@ class ErrorReference:
"""This class contains the error information for a single reference in a batch operation."""

message: str
reference: _BatchReference
reference: BatchReference


@dataclass
Expand Down

0 comments on commit 676dba2

Please sign in to comment.