diff --git a/tests/test_collections_sharded.py b/tests/test_collections_sharded.py index 90dc678..462d09f 100644 --- a/tests/test_collections_sharded.py +++ b/tests/test_collections_sharded.py @@ -1,4 +1,6 @@ # pylint: disable=missing-docstring,redefined-outer-name +import threading +import time from datetime import datetime import pytest @@ -85,3 +87,28 @@ def test_create_collection_with_collation_with_shard_key_index_prefix( ) t.compare_all_sharded() + +@pytest.mark.parametrize("phase", [Runner.Phase.APPLY, Runner.Phase.CLONE]) +def test_shard_key_update_duplicate_key_error(t: Testing, phase: Runner.Phase): + """ + Test to reproduce pcsm duplicate key error when handling shard key updates + """ + db_name = "test_db" + collection_name = "test_collection" + coll = t.source[db_name][collection_name] + t.source.admin.command("shardCollection", f"{db_name}.{collection_name}", key={"key_id": 1}) + coll.insert_one({"key_id": 0, "name": "item_0", "value": "value_0"}) + def perform_shard_key_updates(): + num_updates = 20 + for i in range(1, num_updates + 1): + key_id = 100 + i + new_key_id = 5000 + i + coll.insert_one({"key_id": key_id, "name": f"test_doc_{i}", "value": f"value_{key_id}"}) + coll.update_one({"key_id": key_id}, {"$set": {"key_id": new_key_id, "shard_key_updated": True}}) + time.sleep(0.05) + update_thread = threading.Thread(target=perform_shard_key_updates) + update_thread.start() + with t.run(phase): + update_thread.join() + time.sleep(5) + t.compare_all_sharded()