diff --git a/dspy/utils/parallelizer.py b/dspy/utils/parallelizer.py index 5b5544d3d3..1d76368e61 100644 --- a/dspy/utils/parallelizer.py +++ b/dspy/utils/parallelizer.py @@ -156,6 +156,7 @@ def all_done(): index, outcome = f.result() except Exception: pass + else: if outcome != job_cancelled and results[index] is None: # Check if this is an exception diff --git a/tests/utils/test_parallelizer.py b/tests/utils/test_parallelizer.py index 128614ffc8..82989c9d31 100644 --- a/tests/utils/test_parallelizer.py +++ b/tests/utils/test_parallelizer.py @@ -1,8 +1,10 @@ import time import pytest +import threading from dspy.utils.parallelizer import ParallelExecutor +from dspy.dsp.utils.settings import thread_local_overrides def test_worker_threads_independence(): @@ -83,3 +85,77 @@ def task(item): assert str(executor.exceptions_map[2]) == "test error for 3" assert isinstance(executor.exceptions_map[4], RuntimeError) assert str(executor.exceptions_map[4]) == "test error for 5" + + +def test_thread_local_overrides_with_usage_tracker(): + + class MockUsageTracker: + def __init__(self): + self.tracked_items = [] + + def track(self, value): + self.tracked_items.append(value) + + parent_thread_usage_tracker = MockUsageTracker() + parent_thread_overrides = {"usage_tracker": parent_thread_usage_tracker, "some_setting": "parent_value"} + + override_token = thread_local_overrides.set(parent_thread_overrides) + + try: + worker_thread_ids = set() + worker_thread_ids_lock = threading.Lock() + + # Track all usage tracker instances seen (may be same instance reused across tasks in same thread) + all_usage_tracker_instances = [] + usage_tracker_instances_lock = threading.Lock() + + def task(item): + + current_thread_id = threading.get_ident() + + with worker_thread_ids_lock: + worker_thread_ids.add(current_thread_id) + + current_thread_overrides = thread_local_overrides.get() + + # Verify overrides were copied to worker thread + assert current_thread_overrides.get("some_setting") == "parent_value" + + worker_thread_usage_tracker = current_thread_overrides.get("usage_tracker") + + assert worker_thread_usage_tracker is not None + assert isinstance(worker_thread_usage_tracker, MockUsageTracker) + + # Collect all tracker instances (same thread will get same instance) + with usage_tracker_instances_lock: + if worker_thread_usage_tracker not in all_usage_tracker_instances: + all_usage_tracker_instances.append(worker_thread_usage_tracker) + + worker_thread_usage_tracker.track(item) + + return item * 2 + + input_data = [1, 2, 3, 4, 5] + executor = ParallelExecutor(num_threads=3) + results = executor.execute(task, input_data) + + assert results == [2, 4, 6, 8, 10] + + # Verify that worker threads got their own deep copied usage trackers + # Even if only one thread was used, it should have a different instance than parent + assert len(all_usage_tracker_instances) >= 1, "At least one worker usage tracker should exist" + + for worker_usage_tracker in all_usage_tracker_instances: + assert worker_usage_tracker is not parent_thread_usage_tracker, ( + "Worker thread usage tracker should be deep copy, not same instance as parent" + ) + + assert len(parent_thread_usage_tracker.tracked_items) == 0, ( + "Parent usage tracker should not be modified by worker threads" + ) + + total_tracked_items_count = sum(len(tracker.tracked_items) for tracker in all_usage_tracker_instances) + assert total_tracked_items_count == len(input_data), "All items should be tracked across worker threads" + + finally: + thread_local_overrides.reset(override_token)