Skip to content

Commit e8025c7

Browse files
Test Worker (#83) (#84)
* Worker (#83) * base llm * ;) * base llm * llm * more ideas around queues * examples * cleanup * napper * querent workers * update the querent core to handle shutdown signals * cleanup * worker * handle signaling * cleanup * handle signaling * some more editions * querent working test
1 parent 58d785b commit e8025c7

File tree

8 files changed

+141
-110
lines changed

8 files changed

+141
-110
lines changed

querent/llm/base_llm.py

+23-6
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from abc import ABC, abstractmethod
22
import asyncio
3-
from typing import Any
3+
import json
4+
from typing import Any, Type
45
from querent.common.types.ingested_tokens import IngestedTokens
56
from querent.common.types.querent_queue import QuerentQueue
67

78

89
class BaseLLM(ABC):
910
def __init__(
1011
self,
11-
input_queue: QuerentQueue[IngestedTokens],
12-
output_queue: QuerentQueue[Any],
12+
input_queue: QuerentQueue,
13+
output_queue: QuerentQueue,
1314
num_workers: int = 1,
1415
):
1516
self.input_queue = input_queue
@@ -33,9 +34,24 @@ async def worker(self):
3334
while True:
3435
data = await self.input_queue.get()
3536
if data is None:
36-
# Sentinel value to stop the worker
3737
break
38-
result = await self.process_tokens(data)
38+
if isinstance(data, IngestedTokens):
39+
result = await self.process_tokens(data)
40+
elif isinstance(data, str):
41+
ingested_token_from_str = IngestedTokens(
42+
file="", data=[data], error=None
43+
)
44+
result = await self.process_tokens(ingested_token_from_str)
45+
elif isinstance(data, [list, tuple]):
46+
tokens_from_list = json.dumps(data)
47+
ingested_token_from_list = IngestedTokens(
48+
file="", data=[tokens_from_list], error=None
49+
)
50+
result = await self.process_tokens(ingested_token_from_list)
51+
else:
52+
raise Exception(
53+
f"Invalid data type {type(data)} for {self.__class__.__name__}"
54+
)
3955
await self.output_queue.put(result)
4056
self.input_queue.task_done()
4157
except asyncio.CancelledError:
@@ -53,11 +69,12 @@ async def stop_workers(self):
5369
for _ in range(self.num_workers):
5470
await self.input_queue.put(None)
5571
# Wait for the workers to finish processing
56-
await asyncio.gather(*self.workers)
72+
await asyncio.gather(*self.workers) # Await the workers here
5773
except asyncio.CancelledError:
5874
pass
5975
except Exception as e:
6076
print(f"Stop workers error: {e}")
6177
finally:
6278
# Close the output queue
79+
await self.input_queue.close()
6380
await self.output_queue.close()

querent/llm/transformers/gpt2_llm_v1.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
class GPT2LLM(BaseLLM):
88
def __init__(
99
self,
10-
input_queue: QuerentQueue[IngestedTokens],
11-
output_queue: QuerentQueue[IngestedTokens],
10+
input_queue: QuerentQueue,
11+
output_queue: QuerentQueue,
1212
model_name="gpt2",
1313
num_workers=1,
1414
):

querent/napper/auto_scaler.py

+37-17
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import logging
23
from typing import List
34

@@ -7,37 +8,45 @@
78

89
class AutoScaler:
910
def __init__(
10-
self, resource_manager: ResourceManager, querenters: List[BaseLLM], threshold=10
11+
self,
12+
resource_manager: ResourceManager,
13+
querenters: List[BaseLLM],
14+
threshold: int = 10,
1115
):
1216
self.resource_manager = resource_manager
1317
self.querenters = querenters
1418
self.threshold = threshold
1519
self.logger = logging.getLogger("AutoScaler")
20+
self.querent_termination_event = resource_manager.querent_termination_event
21+
self.worker_tasks: List[asyncio.Task] = [] # Store the worker tasks
1622

17-
async def scale_querenters(self, total_requested_workers):
23+
async def scale_querenters(self, total_requested_workers: int):
1824
current_total_workers = sum(
1925
querenter.num_workers for querenter in self.querenters
2026
)
2127

22-
if total_requested_workers > current_total_workers:
28+
if total_requested_workers <= current_total_workers:
2329
# Scale up querenter workers
30+
self.worker_tasks = []
2431
for querenter in self.querenters:
2532
num_workers_to_scale = querenter.num_workers
26-
await querenter.start_workers(num_workers_to_scale)
27-
28-
elif total_requested_workers < current_total_workers:
29-
# Scale down querenter workers
30-
for querenter in self.querenters:
31-
num_workers_to_scale = querenter.num_workers
32-
await querenter.stop_workers(num_workers_to_scale)
33-
34-
self.logger.info(
35-
f"Scaled querenter workers to {total_requested_workers} workers in total"
36-
)
33+
workers = await querenter.start_workers(num_workers_to_scale)
34+
# Create tasks for the workers and store them
35+
worker_tasks = [asyncio.create_task(worker) for worker in workers]
36+
self.worker_tasks.extend(
37+
worker_tasks
38+
) # Extend the list of worker tasks
39+
self.logger.info(
40+
f"Started {len(worker_tasks)} workers for {querenter.__class__.__name__}"
41+
)
42+
else:
43+
raise Exception("Total requested workers exceed the current total workers.")
3744

38-
async def run(self):
45+
async def start(self):
3946
try:
40-
while True:
47+
while (
48+
not self.querent_termination_event.is_set()
49+
): # Check termination_event
4150
# Calculate the total requested workers for all querenters
4251
total_requested_workers = sum(
4352
querenter.num_workers for querenter in self.querenters
@@ -53,11 +62,22 @@ async def run(self):
5362
"Total requested workers exceed the maximum allowed workers."
5463
)
5564

56-
# Scale the number of querenter workers
65+
# Scale querenter workers
5766
await self.scale_querenters(total_requested_workers)
5867

68+
# Wait for a while before checking again (adjust this as needed)
69+
await asyncio.sleep(1)
70+
71+
# Check if all worker tasks have completed
72+
if all(task.done() for task in self.worker_tasks):
73+
self.querent_termination_event.set() # Set termination event
74+
75+
except asyncio.CancelledError:
76+
pass
5977
except Exception as e:
6078
self.logger.error(f"An error occurred during AutoScaler execution: {e}")
79+
finally:
80+
self.logger.info("AutoScaler stopped")
6181

6282
async def stop(self):
6383
self.logger.info("Stopping AutoScaler")

querent/napper/querent.py

+25-22
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import asyncio
22
import logging
3-
from typing import List
4-
from querent.common.types.querent_queue import QuerentQueue
3+
import signal
4+
from typing import List, Awaitable
55
from querent.llm.base_llm import BaseLLM
66
from querent.napper.resource_manager import ResourceManager
77
from querent.napper.auto_scaler import AutoScaler
8-
from signaling import SignalHandler # Import the SignalHandler class from signaling.py
98

109
# Set up logging
1110
logging.basicConfig(
@@ -31,27 +30,27 @@ def __init__(
3130
self.resource_manager, querenters, threshold=self.auto_scale_threshold
3231
)
3332

34-
# Create an instance of SignalHandler and pass the Querent instance
35-
self.signal_handler = SignalHandler(self)
33+
# Create an event to handle termination requests
34+
self.querent_termination_event = resource_manager.querent_termination_event
3635

3736
async def start(self):
3837
try:
3938
logger.info("Starting Querent")
4039

4140
# Start the auto-scaler
42-
asyncio.create_task(self.auto_scaler.run())
41+
auto_scale_task = asyncio.create_task(self.auto_scaler.start())
4342

4443
# Start handling signals
45-
asyncio.create_task(self.signal_handler.handle_signals())
44+
self.setup_signal_handlers()
45+
46+
# Start the tasks above and wait for them to finish
47+
await asyncio.gather(auto_scale_task, self.wait_for_termination())
4648

4749
except Exception as e:
4850
logger.error(f"An error occurred during Querent execution: {e}")
4951
await self.graceful_shutdown()
5052
finally:
51-
# Stop the workers
52-
await asyncio.gather(
53-
*(querenter.stop_workers() for querenter in self.querenters)
54-
)
53+
await self.graceful_shutdown()
5554
logger.info("Querent stopped")
5655

5756
async def graceful_shutdown(self):
@@ -60,17 +59,21 @@ async def graceful_shutdown(self):
6059
# Stop the auto-scaler and querenters gracefully
6160
await self.auto_scaler.stop()
6261

63-
# Stop the workers
64-
await asyncio.gather(
65-
*(querenter.stop_workers() for querenter in self.querenters)
66-
)
67-
6862
logger.info("Querent stopped gracefully")
6963

70-
async def handle_shutdown(self):
64+
def setup_signal_handlers(self):
65+
for sig in [signal.SIGINT, signal.SIGTERM]:
66+
loop = asyncio.get_event_loop()
67+
loop.add_signal_handler(sig, self.handle_signal)
68+
69+
def handle_signal(self):
7170
try:
72-
# Wait for a KeyboardInterrupt (Ctrl+C) or SIGTERM to initiate graceful shutdown
73-
await asyncio.Event().wait()
74-
except (KeyboardInterrupt, SystemExit):
75-
logger.info("Received shutdown signal (Ctrl+C or SIGTERM)")
76-
await self.graceful_shutdown()
71+
print("Received shutdown signal. Initiating graceful shutdown...")
72+
shutdown_task = asyncio.create_task(self.graceful_shutdown())
73+
asyncio.run(shutdown_task)
74+
except Exception as e:
75+
print(f"Error during graceful shutdown: {str(e)}")
76+
77+
async def wait_for_termination(self) -> Awaitable[None]:
78+
# Wait for the termination event to be set, indicating graceful shutdown
79+
await self.querent_termination_event.wait()

querent/napper/resource_manager.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import asyncio
12
import logging
23

34

45
class ResourceManager:
56
def __init__(self, max_allowed_workers=100):
67
self.max_allowed_workers = max_allowed_workers
78
self.min_allowed_workers = 1
9+
self.querent_termination_event = asyncio.Event()
810
self.logger = logging.getLogger("ResourceManager")
911

1012
async def get_max_allowed_workers(self):

querent/napper/signaling.py

-22
This file was deleted.

tests/llm_tests/gpt2_llm_v1_test.py

-41
This file was deleted.

tests/llm_tests/mock_llm_test.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import asyncio
2+
import pytest
3+
from querent.common.types.querent_queue import QuerentQueue
4+
from querent.llm.base_llm import BaseLLM
5+
from querent.napper.querent import Querent
6+
from querent.napper.resource_manager import ResourceManager
7+
8+
input_data = ["Data 1", "Data 2", "Data 3"]
9+
input_queue = QuerentQueue()
10+
output_queue = QuerentQueue()
11+
resource_manager = ResourceManager()
12+
13+
14+
# Define a simple mock LLM class for testing
15+
class MockLLM(BaseLLM):
16+
async def process_tokens(self, data):
17+
return f"Processed: {data}"
18+
19+
def validate(self):
20+
return True
21+
22+
23+
@pytest.mark.asyncio
24+
async def test_querent_with_base_llm():
25+
# Put some input data into the input queue
26+
input_data = ["Data 1", "Data 2", "Data 3", None]
27+
for data in input_data:
28+
await input_queue.put(data)
29+
# Wait for the tasks to finish processing (implicitly handled by Querent)
30+
num_llms = 1
31+
llms = [MockLLM(input_queue, output_queue) for _ in range(num_llms)]
32+
33+
# Create a Querent instance
34+
querent = Querent(llms, num_workers=num_llms, resource_manager=resource_manager)
35+
36+
# Start the querent
37+
38+
await querent.start()
39+
40+
# Check the output queue for results and store them in a list
41+
results = []
42+
async for result in output_queue:
43+
results.append(result)
44+
output_queue.task_done()
45+
46+
# Assert that the results match the expected output
47+
expected_output = [
48+
"Processed: Data: ['Data 1']",
49+
"Processed: Data: ['Data 2']",
50+
"Processed: Data: ['Data 3]",
51+
]
52+
assert len(results) == len(expected_output)

0 commit comments

Comments
 (0)