Skip to content

Commit 58d785b

Browse files
Worker (#83)
* base llm * ;) * base llm * llm * more ideas around queues * examples * cleanup * napper * querent workers * update the querent core to handle shutdown signals * cleanup * worker * handle signaling
1 parent 1350729 commit 58d785b

File tree

5 files changed

+67
-26
lines changed

5 files changed

+67
-26
lines changed

README.md

-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# Querent
55

66
**Querent: Unleash the Power of Data and Graph Neural Networks**
7-
87
*Unlock Insights, Scale Asynchronously, and Forge a Knowledge-Driven Future*
98

109
**Welcome to Querent!** We're not just another data framework; we're the future of knowledge discovery and insight generation. Querent is your agile and dynamic companion for collecting, processing, and harnessing data's transformative potential. Whether you're crafting knowledge graphs, training cutting-edge language models, or diving deep into data-driven insights, Querent has your back.

querent/llm/base_llm.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@ def __init__(
1212
output_queue: QuerentQueue[Any],
1313
num_workers: int = 1,
1414
):
15-
self.input_queue = input_queue # ingested tokens coming file by file
16-
self.output_queue = (
17-
output_queue # any type of output, need to think about various LLM outputs
18-
)
15+
self.input_queue = input_queue
16+
self.output_queue = output_queue
1917
self.num_workers = num_workers
18+
self.workers = []
2019

2120
@abstractmethod
2221
async def process_tokens(self, data: IngestedTokens) -> Any:
@@ -49,10 +48,16 @@ async def start_workers(self, num_workers: int = 1):
4948
return self.workers
5049

5150
async def stop_workers(self):
52-
# Signal the workers to stop by putting None into the input queue
53-
await self.input_queue.close()
54-
# Wait for the workers to finish processing
55-
await asyncio.gather(*self.workers)
56-
# Close the output queue
57-
# TODO this will change depending on many output queues we have
58-
await self.output_queue.close()
51+
try:
52+
# Signal the workers to stop by putting None into the input queue
53+
for _ in range(self.num_workers):
54+
await self.input_queue.put(None)
55+
# Wait for the workers to finish processing
56+
await asyncio.gather(*self.workers)
57+
except asyncio.CancelledError:
58+
pass
59+
except Exception as e:
60+
print(f"Stop workers error: {e}")
61+
finally:
62+
# Close the output queue
63+
await self.output_queue.close()

querent/llm/transformers/gpt2_llm_v1.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,39 @@
11
from transformers import GPT2LMHeadModel, GPT2Tokenizer
2+
from querent.common.types.ingested_tokens import IngestedTokens
23
from querent.common.types.querent_queue import QuerentQueue
34
from querent.llm.base_llm import BaseLLM
45

56

67
class GPT2LLM(BaseLLM):
78
def __init__(
89
self,
9-
input_queue: QuerentQueue[str],
10-
output_queue: QuerentQueue[str],
10+
input_queue: QuerentQueue[IngestedTokens],
11+
output_queue: QuerentQueue[IngestedTokens],
1112
model_name="gpt2",
13+
num_workers=1,
1214
):
13-
super().__init__(input_queue, output_queue)
15+
super().__init__(input_queue, output_queue, num_workers=num_workers)
1416
self.model_name = model_name
15-
self.model = GPT2LMHeadModel.from_pretrained(model_name)
16-
self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
1717

18-
async def process_tokens(self, data: str) -> str:
18+
async def process_tokens(self, data: IngestedTokens) -> str:
1919
try:
20-
input_text = data # Assuming data is a string
21-
input_ids = self.tokenizer.encode(input_text, return_tensors="pt")
22-
output = self.model.generate(
23-
input_ids, max_length=50, num_return_sequences=1, no_repeat_ngram_size=2
20+
# get the input text from the data which is a list of str
21+
input_text_list = data.data
22+
23+
# concatenate the input text into a single string
24+
input_text = " ".join(input_text_list)
25+
26+
model = GPT2LMHeadModel.from_pretrained(self.model_name)
27+
tokenizer = GPT2Tokenizer.from_pretrained(self.model_name)
28+
29+
input_ids = tokenizer.encode(input_text, return_tensors="pt")
30+
output = model.generate(
31+
input_ids,
32+
max_length=50,
33+
num_return_sequences=1,
34+
no_repeat_ngram_size=2,
2435
)
25-
generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
36+
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
2637
return generated_text
2738
except Exception as e:
2839
# Log the error and return an informative error message

querent/napper/querent.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import asyncio
22
import logging
3-
import signal
43
from typing import List
54
from querent.common.types.querent_queue import QuerentQueue
65
from querent.llm.base_llm import BaseLLM
76
from querent.napper.resource_manager import ResourceManager
87
from querent.napper.auto_scaler import AutoScaler
8+
from signaling import SignalHandler # Import the SignalHandler class from signaling.py
99

1010
# Set up logging
1111
logging.basicConfig(
@@ -20,26 +20,30 @@ def __init__(
2020
self,
2121
querenters: List[BaseLLM],
2222
num_workers: int,
23-
max_workers: int,
2423
resource_manager: ResourceManager,
2524
auto_scale_threshold: int = 10,
2625
):
2726
self.num_workers = num_workers
28-
self.max_workers = max_workers
2927
self.resource_manager = resource_manager
3028
self.querenters = querenters
3129
self.auto_scale_threshold = auto_scale_threshold
3230
self.auto_scaler = AutoScaler(
3331
self.resource_manager, querenters, threshold=self.auto_scale_threshold
3432
)
3533

34+
# Create an instance of SignalHandler and pass the Querent instance
35+
self.signal_handler = SignalHandler(self)
36+
3637
async def start(self):
3738
try:
3839
logger.info("Starting Querent")
3940

4041
# Start the auto-scaler
4142
asyncio.create_task(self.auto_scaler.run())
4243

44+
# Start handling signals
45+
asyncio.create_task(self.signal_handler.handle_signals())
46+
4347
except Exception as e:
4448
logger.error(f"An error occurred during Querent execution: {e}")
4549
await self.graceful_shutdown()

querent/napper/signaling.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import asyncio
2+
import signal
3+
4+
5+
class SignalHandler:
6+
def __init__(self, querent):
7+
self.querent = querent
8+
9+
async def handle_signals(self):
10+
for sig in [signal.SIGINT, signal.SIGTERM]:
11+
loop = asyncio.get_event_loop()
12+
loop.add_signal_handler(sig, self.handle_signal)
13+
14+
async def handle_signal(self):
15+
try:
16+
print("Received shutdown signal. Initiating graceful shutdown...")
17+
await self.querent.graceful_shutdown()
18+
except Exception as e:
19+
print(f"Error during graceful shutdown: {str(e)}")
20+
finally:
21+
print("Querent stopped")
22+
exit(0)

0 commit comments

Comments
 (0)