|
| 1 | +import concurrent.futures |
1 | 2 | import uuid
|
2 | 3 | from dataclasses import dataclass
|
3 | 4 | from typing import Generator, List, Optional, Protocol, Tuple, Callable
|
|
6 | 7 | from _pytest.fixtures import SubRequest
|
7 | 8 |
|
8 | 9 | import weaviate
|
| 10 | +from weaviate import BatchClient, ClientBatchingContextManager |
9 | 11 | from integration.conftest import _sanitize_collection_name
|
10 | 12 | from weaviate.collections.classes.batch import Shard
|
11 | 13 | from weaviate.collections.classes.config import (
|
|
20 | 22 | ReferenceToMulti,
|
21 | 23 | )
|
22 | 24 | from weaviate.collections.classes.tenants import Tenant
|
23 |
| -from weaviate.outputs.batch import ClientBatchingContextManager |
24 | 25 | from weaviate.types import UUID, VECTORS
|
25 | 26 |
|
26 | 27 | UUID1 = uuid.UUID("806827e0-2b31-43ca-9269-24fa95a221f9")
|
@@ -575,6 +576,29 @@ def test_add_one_object_and_a_self_reference(
|
575 | 576 | assert obj.references["test"].objects[0].uuid == uuid
|
576 | 577 |
|
577 | 578 |
|
| 579 | +def test_multi_threaded_batching( |
| 580 | + client_factory: ClientFactory, |
| 581 | +) -> None: |
| 582 | + client, name = client_factory() |
| 583 | + nr_objects = 1000 |
| 584 | + nr_threads = 10 |
| 585 | + |
| 586 | + def batch_insert(batch: BatchClient) -> None: |
| 587 | + for i in range(nr_objects): |
| 588 | + batch.add_object( |
| 589 | + collection=name, |
| 590 | + properties={"name": "test" + str(i)}, |
| 591 | + ) |
| 592 | + |
| 593 | + with concurrent.futures.ThreadPoolExecutor() as executor: |
| 594 | + with client.batch.dynamic() as batch: |
| 595 | + futures = [executor.submit(batch_insert, batch) for _ in range(nr_threads)] |
| 596 | + for future in concurrent.futures.as_completed(futures): |
| 597 | + future.result() |
| 598 | + objs = client.collections.get(name).query.fetch_objects(limit=nr_objects * nr_threads).objects |
| 599 | + assert len(objs) == nr_objects * nr_threads |
| 600 | + |
| 601 | + |
578 | 602 | def test_error_reset(client_factory: ClientFactory) -> None:
|
579 | 603 | client, name = client_factory()
|
580 | 604 | with client.batch.dynamic() as batch:
|
|
0 commit comments