Skip to content

Commit f4b89e6

Browse files
committed
Add more batching tests, export wrapping types from more logical locations
1 parent 975ac03 commit f4b89e6

File tree

6 files changed

+34
-8
lines changed

6 files changed

+34
-8
lines changed

integration/test_batch_v4.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import concurrent.futures
12
import uuid
23
from dataclasses import dataclass
34
from typing import Generator, List, Optional, Protocol, Tuple, Callable
@@ -6,6 +7,7 @@
67
from _pytest.fixtures import SubRequest
78

89
import weaviate
10+
from weaviate import BatchClient, ClientBatchingContextManager
911
from integration.conftest import _sanitize_collection_name
1012
from weaviate.collections.classes.batch import Shard
1113
from weaviate.collections.classes.config import (
@@ -20,7 +22,6 @@
2022
ReferenceToMulti,
2123
)
2224
from weaviate.collections.classes.tenants import Tenant
23-
from weaviate.outputs.batch import ClientBatchingContextManager
2425
from weaviate.types import UUID, VECTORS
2526

2627
UUID1 = uuid.UUID("806827e0-2b31-43ca-9269-24fa95a221f9")
@@ -575,6 +576,29 @@ def test_add_one_object_and_a_self_reference(
575576
assert obj.references["test"].objects[0].uuid == uuid
576577

577578

579+
def test_multi_threaded_batching(
580+
client_factory: ClientFactory,
581+
) -> None:
582+
client, name = client_factory()
583+
nr_objects = 1000
584+
nr_threads = 10
585+
586+
def batch_insert(batch: BatchClient) -> None:
587+
for i in range(nr_objects):
588+
batch.add_object(
589+
collection=name,
590+
properties={"name": "test" + str(i)},
591+
)
592+
593+
with concurrent.futures.ThreadPoolExecutor() as executor:
594+
with client.batch.dynamic() as batch:
595+
futures = [executor.submit(batch_insert, batch) for _ in range(nr_threads)]
596+
for future in concurrent.futures.as_completed(futures):
597+
future.result()
598+
objs = client.collections.get(name).query.fetch_objects(limit=nr_objects * nr_threads).objects
599+
assert len(objs) == nr_objects * nr_threads
600+
601+
578602
def test_error_reset(client_factory: ClientFactory) -> None:
579603
client, name = client_factory()
580604
with client.batch.dynamic() as batch:

weaviate/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
__version__ = "unknown version"
1313

1414
from .client import Client, WeaviateAsyncClient, WeaviateClient
15+
from .collections.batch.client import BatchClient, ClientBatchingContextManager
1516
from .connect.helpers import (
1617
connect_to_custom,
1718
connect_to_embedded,
@@ -45,6 +46,8 @@
4546
from .warnings import _Warnings
4647

4748
__all__ = [
49+
"BatchClient",
50+
"ClientBatchingContextManager",
4851
"Client",
4952
"WeaviateClient",
5053
"WeaviateAsyncClient",

weaviate/collections/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from weaviate.collections.batch.collection import BatchCollection, CollectionBatchingContextManager
12
from weaviate.collections.collection import Collection, CollectionAsync
23

3-
__all__ = ["Collection", "CollectionAsync"]
4+
__all__ = ["BatchCollection", "Collection", "CollectionAsync", "CollectionBatchingContextManager"]

weaviate/collections/batch/client.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def add_reference(
115115
)
116116

117117

118-
ClientBatchingContextManager = _ContextManagerWrapper[_BatchClient]
118+
BatchClient = _BatchClient
119+
ClientBatchingContextManager = _ContextManagerWrapper[BatchClient]
119120

120121

121122
class _BatchClientWrapper(_BatchWrapper):

weaviate/collections/batch/collection.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ def add_reference(
114114
)
115115

116116

117-
CollectionBatchingContextManager = _ContextManagerWrapper[_BatchCollection[Properties]]
117+
BatchCollection = _BatchCollection[Properties]
118+
CollectionBatchingContextManager = _ContextManagerWrapper[BatchCollection[Properties]]
118119

119120

120121
class _BatchCollectionWrapper(Generic[Properties], _BatchWrapper):

weaviate/outputs/batch.py

-4
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,11 @@
55
ErrorObject,
66
ErrorReference,
77
)
8-
from weaviate.collections.batch.client import ClientBatchingContextManager
9-
from weaviate.collections.batch.collection import CollectionBatchingContextManager
108

119
__all__ = [
1210
"BatchObjectReturn",
1311
"BatchReferenceReturn",
1412
"BatchResult",
1513
"ErrorObject",
1614
"ErrorReference",
17-
"ClientBatchingContextManager",
18-
"CollectionBatchingContextManager",
1915
]

0 commit comments

Comments
 (0)