Skip to content

Commit 053d8ef

Browse files
authored
feat: Polling for dataset registration and query execution to avoid timeouts (#178)
This switches dataset registration, query execution, and checkpoint restoration to busy waiting instead of blocking calls. Fixes #92.
1 parent c985ad5 commit 053d8ef

16 files changed

+327
-140
lines changed

examples/client_local_example.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,14 @@ def setup_local_client(directory: Path):
9898
client.register_metadata_parser("TEST_PARSER", TestMetadataParser)
9999

100100
# Registering the dataset with the client.
101-
client.register_dataset(
101+
if not client.register_dataset(
102102
"local_integrationtest_dataset",
103103
directory / "testd.jsonl",
104104
JSONLDataset,
105105
parsing_func,
106106
"TEST_PARSER",
107-
)
107+
):
108+
raise RuntimeError("Error while registering dataset!")
108109

109110
return client
110111

@@ -120,6 +121,7 @@ def run_query(client: MixteraClient, chunk_size: int):
120121
mixture = ArbitraryMixture(chunk_size=chunk_size)
121122
qea = QueryExecutionArgs(mixture=mixture)
122123
client.execute_query(query, qea)
124+
client.wait_for_execution(job_id)
123125

124126
rsa = ResultStreamingArgs(job_id=job_id)
125127
result_samples = list(client.stream_results(rsa))

examples/client_server_example.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def run_query(client: MixteraClient, chunk_size: int, tunnel: bool):
105105
mixture = ArbitraryMixture(chunk_size=chunk_size)
106106
qea = QueryExecutionArgs(mixture=mixture)
107107
client.execute_query(query, qea)
108+
client.wait_for_execution(job_id)
108109

109110
rsa = ResultStreamingArgs(job_id=job_id)
110111
result_samples = list(client.stream_results(rsa))
@@ -147,13 +148,14 @@ def main(server_host: str, server_port: int):
147148
client.register_metadata_parser("TEST_PARSER", TestMetadataParser)
148149

149150
# Registering the dataset with the client.
150-
client.register_dataset(
151+
if not client.register_dataset(
151152
"server_integrationtest_dataset",
152153
server_dir / "testd.jsonl",
153154
JSONLDataset,
154155
parsing_func,
155156
"TEST_PARSER",
156-
)
157+
):
158+
raise RuntimeError("Error while registering dataset.")
157159

158160
# Run queries on server
159161
chunk_size = 42

integrationtests/checkpointing/test_local_checkpointing.py

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def run_test_arbitrarymixture(client: MixteraClient):
3333
num_workers=0,
3434
)
3535
client.execute_query(query, query_execution_args)
36+
client.wait_for_execution(job_id)
3637
result_streaming_args = ResultStreamingArgs(job_id=job_id)
3738
logger.info("Executed query.")
3839
# Get one chunk for each worker
@@ -64,6 +65,7 @@ def run_test_arbitrarymixture(client: MixteraClient):
6465
logger.info(f"Got all chunks.")
6566

6667
client.restore_checkpoint(job_id, checkpoint_id)
68+
client.wait_for_execution(job_id)
6769

6870
logger.info("Restored checkpoint.")
6971

integrationtests/checkpointing/test_server_checkpointing.py

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def run_test_arbitrarymixture_server(client: ServerStub, dp_groups, nodes_per_gr
3434
num_workers=num_workers,
3535
)
3636
client.execute_query(query, query_execution_args)
37+
client.wait_for_execution(job_id)
3738
logger.info(
3839
f"Executed query for job {job_id} with dp_groups={dp_groups}, nodes_per_group={nodes_per_group}, num_workers={num_workers}"
3940
)
@@ -110,6 +111,7 @@ def run_test_arbitrarymixture_server(client: ServerStub, dp_groups, nodes_per_gr
110111

111112
# Restore from checkpoint
112113
client.restore_checkpoint(job_id, checkpoint_id)
114+
client.wait_for_execution(job_id)
113115
logger.info("Restored from checkpoint.")
114116

115117
# Obtain chunks after restoring from checkpoint

integrationtests/local/test_local.py

+11
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def test_filter_javascript(
4444
)
4545
query = Query.for_job(result_streaming_args.job_id).select(("language", "==", "JavaScript"))
4646
client.execute_query(query, query_exec_args)
47+
client.wait_for_execution(result_streaming_args.job_id)
4748
result_samples = []
4849
for sample in client.stream_results(result_streaming_args):
4950
result_samples.append(sample)
@@ -65,6 +66,7 @@ def test_filter_html(
6566
)
6667
query = Query.for_job(result_streaming_args.job_id).select(("language", "==", "HTML"))
6768
client.execute_query(query, query_exec_args)
69+
client.wait_for_execution(result_streaming_args.job_id)
6870
result_samples = []
6971

7072
for sample in client.stream_results(result_streaming_args):
@@ -91,6 +93,7 @@ def test_filter_both(
9193
.select(("language", "==", "JavaScript"))
9294
)
9395
client.execute_query(query, query_exec_args)
96+
client.wait_for_execution(result_streaming_args.job_id)
9497
result_samples = []
9598

9699
for sample in client.stream_results(result_streaming_args):
@@ -113,6 +116,7 @@ def test_filter_license(
113116
)
114117
query = Query.for_job(result_streaming_args.job_id).select(("license", "==", "CC"))
115118
client.execute_query(query, query_exec_args)
119+
client.wait_for_execution(result_streaming_args.job_id)
116120
result_samples = []
117121

118122
for sample in client.stream_results(result_streaming_args):
@@ -135,6 +139,7 @@ def test_filter_unknown_license(
135139
)
136140
query = Query.for_job(result_streaming_args.job_id).select(("license", "==", "All rights reserved."))
137141
client.execute_query(query, query_exec_args)
142+
client.wait_for_execution(result_streaming_args.job_id)
138143
assert len(list(client.stream_results(result_streaming_args))) == 0, "Got results back for expected empty results."
139144

140145

@@ -150,6 +155,7 @@ def test_filter_license_and_html(
150155
Query.for_job(result_streaming_args.job_id).select(("language", "==", "HTML")).select(("license", "==", "CC"))
151156
)
152157
client.execute_query(query, query_exec_args)
158+
client.wait_for_execution(result_streaming_args.job_id)
153159
result_samples = []
154160

155161
for sample in client.stream_results(result_streaming_args):
@@ -186,6 +192,7 @@ def test_reproducibility(
186192
)
187193
query_exec_args.mixture = mixture
188194
client.execute_query(query, query_exec_args)
195+
client.wait_for_execution(result_streaming_args.job_id)
189196
result_samples = []
190197

191198
for sample in client.stream_results(result_streaming_args):
@@ -216,6 +223,8 @@ def test_mixture_schedule(client: MixteraClient):
216223
query_execution_args = QueryExecutionArgs(mixture=mixture_schedule)
217224
result_streaming_args = ResultStreamingArgs(job_id)
218225
assert client.execute_query(query, query_execution_args)
226+
assert client.wait_for_execution(job_id)
227+
219228
logger.info(f"Executed query for job {job_id} for mixture schedule.")
220229

221230
result_samples = []
@@ -269,6 +278,8 @@ def test_dynamic_mixture(client: MixteraClient):
269278
result_streaming_args = ResultStreamingArgs(job_id)
270279

271280
assert client.execute_query(query, query_execution_args)
281+
assert client.wait_for_execution(job_id)
282+
272283
logger.info(f"Executed query for job {job_id} for dynamic mixture.")
273284

274285
result_iter = client.stream_results(result_streaming_args)

integrationtests/server/test_server.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def test_filter_javascript(
5050
)
5151
query = Query.for_job(result_streaming_args.job_id).select(("language", "==", "JavaScript"))
5252
assert client.execute_query(query, query_exec_args)
53+
assert client.wait_for_execution(result_streaming_args.job_id)
5354
result_samples = []
5455

5556
for sample in client.stream_results(result_streaming_args):
@@ -73,6 +74,7 @@ def test_filter_html(
7374
)
7475
query = Query.for_job(result_streaming_args.job_id).select(("language", "==", "HTML"))
7576
assert client.execute_query(query, query_exec_args)
77+
assert client.wait_for_execution(result_streaming_args.job_id)
7678
result_samples = []
7779

7880
for sample in client.stream_results(result_streaming_args):
@@ -100,6 +102,7 @@ def test_filter_both(
100102
.select(("language", "==", "JavaScript"))
101103
)
102104
assert client.execute_query(query, query_exec_args)
105+
assert client.wait_for_execution(result_streaming_args.job_id)
103106
result_samples = []
104107

105108
for sample in client.stream_results(result_streaming_args):
@@ -148,6 +151,7 @@ def test_filter_unknown_license(
148151
)
149152
query = Query.for_job(result_streaming_args.job_id).select(("license", "==", "All rights reserved."))
150153
assert client.execute_query(query, query_exec_args)
154+
assert client.wait_for_execution(result_streaming_args.job_id)
151155
assert len(list(client.stream_results(result_streaming_args))) == 0, "Got results back for expected empty results."
152156

153157

@@ -164,6 +168,7 @@ def test_filter_license_and_html(
164168
Query.for_job(result_streaming_args.job_id).select(("language", "==", "HTML")).select(("license", "==", "CC"))
165169
)
166170
assert client.execute_query(query, query_exec_args)
171+
assert client.wait_for_execution(result_streaming_args.job_id)
167172
result_samples = []
168173

169174
for sample in client.stream_results(result_streaming_args):
@@ -190,15 +195,16 @@ def test_reproducibility(
190195
f"6_{query_exec_args.mixture.chunk_size}_{query_exec_args.dp_groups}"
191196
+ f"_{query_exec_args.nodes_per_group}_{query_exec_args.num_workers}_{result_streaming_args.chunk_reading_degree_of_parallelism}"
192197
+ f"_{result_streaming_args.chunk_reading_window_size}_{result_streaming_args.chunk_reading_mixture_type}"
193-
+ f"_reproducibility_{i}"
198+
+ f"_{result_streaming_args.tunnel_via_server}_reproducibility_{i}"
194199
)
195200
query = (
196201
Query.for_job(result_streaming_args.job_id)
197202
.select(("language", "==", "HTML"))
198203
.select(("language", "==", "JavaScript"))
199204
)
200205
query_exec_args.mixture = mixture
201-
client.execute_query(query, query_exec_args)
206+
assert client.execute_query(query, query_exec_args)
207+
assert client.wait_for_execution(result_streaming_args.job_id)
202208
result_samples = []
203209

204210
for sample in client.stream_results(result_streaming_args):
@@ -229,6 +235,7 @@ def test_mixture_schedule(client: ServerStub):
229235
query_execution_args = QueryExecutionArgs(mixture=mixture_schedule)
230236
result_streaming_args = ResultStreamingArgs(job_id)
231237
assert client.execute_query(query, query_execution_args)
238+
assert client.wait_for_execution(job_id)
232239
logger.info(f"Executed query for job {job_id} for mixture schedule.")
233240

234241
result_samples = []
@@ -282,6 +289,7 @@ def test_dynamic_mixture(client: MixteraClient):
282289
result_streaming_args = ResultStreamingArgs(job_id)
283290

284291
assert client.execute_query(query, query_execution_args)
292+
assert client.wait_for_execution(job_id)
285293
logger.info(f"Executed query for job {job_id} for dynamic mixture.")
286294

287295
result_iter = client.stream_results(result_streaming_args)

mixtera/core/client/local/local_stub.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def register_dataset(
4949
) -> bool:
5050
if isinstance(loc, Path):
5151
loc = str(loc)
52-
5352
return self._mdc.register_dataset(identifier, loc, dtype, parsing_func, metadata_parser_identifier)
5453

5554
def register_metadata_parser(
@@ -95,6 +94,10 @@ def execute_query(self, query: Query, args: QueryExecutionArgs) -> bool:
9594
query, args.mixture, args.dp_groups, args.nodes_per_group, args.num_workers, cache_path
9695
)
9796

97+
def wait_for_execution(self, job_id: str) -> bool:
98+
logger.info(f"Waiting for execution of {job_id}.")
99+
return wait_for_key_in_dict(self._training_query_map, job_id, 30)
100+
98101
def is_remote(self) -> bool:
99102
return False
100103

mixtera/core/client/mixtera_client.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,20 @@ def execute_query(self, query: Query, args: QueryExecutionArgs) -> bool:
250250

251251
raise NotImplementedError()
252252

253+
@abstractmethod
254+
def wait_for_execution(self, job_id: str) -> bool:
255+
"""
256+
Waits until the query has finished executing.
257+
258+
Args:
259+
job_id (str): The job id of the query
260+
261+
Returns:
262+
bool indicating success
263+
"""
264+
265+
raise NotImplementedError()
266+
253267
def stream_results(self, args: ResultStreamingArgs) -> Generator[tuple[int, int, Sample], None, None]:
254268
"""
255269
Given a job ID, returns the QueryResult object from which the result chunks can be obtained.
@@ -265,7 +279,7 @@ def stream_results(self, args: ResultStreamingArgs) -> Generator[tuple[int, int,
265279
with self.current_mixture_id_val.get_lock():
266280
new_id = max(result_chunk.mixture_id, self.current_mixture_id_val.get_obj().value)
267281
self.current_mixture_id_val.get_obj().value = new_id
268-
logger.debug(f"Set current mixture ID to {new_id}")
282+
# logger.debug(f"Set current mixture ID to {new_id}")
269283

270284
result_chunk.configure_result_streaming(
271285
client=self,
@@ -275,7 +289,7 @@ def stream_results(self, args: ResultStreamingArgs) -> Generator[tuple[int, int,
275289

276290
with self.current_mixture_id_val.get_lock():
277291
self.current_mixture_id_val.get_obj().value = -1
278-
logger.debug("Reset current mixture ID to -1.")
292+
# logger.debug("Reset current mixture ID to -1.")
279293

280294
@abstractmethod
281295
def _stream_result_chunks(

mixtera/core/client/server/server_stub.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import time
12
from pathlib import Path
23
from typing import Any, Callable, Generator, Type
34

@@ -69,7 +70,7 @@ def execute_query(self, query: Query, args: QueryExecutionArgs) -> bool:
6970
logger.error("Could not register query at server!")
7071
return False
7172

72-
logger.info(f"Registered query for job {query.job_id} at server!")
73+
logger.info(f"Started query registration for job {query.job_id} at server!")
7374

7475
return True
7576

@@ -132,3 +133,21 @@ def checkpoint_completed(self, job_id: str, chkpnt_id: str, on_disk: bool) -> bo
132133

133134
def restore_checkpoint(self, job_id: str, chkpnt_id: str) -> None:
134135
return self.server_connection.restore_checkpoint(job_id, chkpnt_id)
136+
137+
def wait_for_execution(self, job_id: str) -> bool:
138+
logger.info("Waiting for query execution at server to finish.")
139+
status = self.server_connection.check_query_exec_status(job_id)
140+
141+
timeout_minutes = 30
142+
curr_time = 0
143+
while status == 0 and curr_time <= timeout_minutes * 60:
144+
time.sleep(1)
145+
status = self.server_connection.check_query_exec_status(job_id)
146+
curr_time += 1
147+
148+
if status != 1:
149+
logger.error(f"Query execution failed with status {status}.")
150+
return False
151+
152+
logger.info("Query execution finished.")
153+
return True

0 commit comments

Comments
 (0)