|
9 | 9 | import csp
|
10 | 10 | from csp import ts
|
11 | 11 |
|
12 |
| -# csp.set_print_full_exception_stack(True) |
13 |
| - |
14 | 12 | if os.environ.get("CSP_TEST_WEBSOCKET"):
|
15 | 13 | import tornado.ioloop
|
16 | 14 | import tornado.web
|
|
23 | 21 | RawTextMessageMapper,
|
24 | 22 | Status,
|
25 | 23 | WebsocketAdapterManager,
|
| 24 | + WebsocketStatus, |
26 | 25 | )
|
27 | 26 |
|
28 | 27 | class EchoWebsocketHandler(tornado.websocket.WebSocketHandler):
|
29 | 28 | def on_message(self, msg):
|
30 | 29 | return self.write_message(msg)
|
31 | 30 |
|
32 |
| - |
33 |
| -@contextmanager |
34 |
| -def create_tornado_server(port: int): |
35 |
| - """Base context manager for creating a Tornado server in a thread""" |
36 |
| - ready_event = threading.Event() |
37 |
| - io_loop = None |
38 |
| - app = None |
39 |
| - io_thread = None |
40 |
| - |
41 |
| - def run_io_loop(): |
42 |
| - nonlocal io_loop, app |
43 |
| - io_loop = tornado.ioloop.IOLoop() |
44 |
| - io_loop.make_current() |
45 |
| - app = tornado.web.Application([(r"/", EchoWebsocketHandler)]) |
46 |
| - app.listen(port) |
47 |
| - ready_event.set() |
48 |
| - io_loop.start() |
49 |
| - |
50 |
| - io_thread = threading.Thread(target=run_io_loop) |
51 |
| - io_thread.start() |
52 |
| - ready_event.wait() |
53 |
| - |
54 |
| - try: |
55 |
| - yield io_loop, app, io_thread |
56 |
| - finally: |
57 |
| - io_loop.add_callback(io_loop.stop) |
58 |
| - if io_thread: |
59 |
| - io_thread.join(timeout=5) |
60 |
| - if io_thread.is_alive(): |
61 |
| - raise RuntimeError("IOLoop failed to stop") |
62 |
| - |
63 |
| - |
64 |
| -@contextmanager |
65 |
| -def tornado_server(port: int = 8001): |
66 |
| - """Simplified context manager that uses the base implementation""" |
67 |
| - with create_tornado_server(port) as (_io_loop, _app, _io_thread): |
68 |
| - yield |
| 31 | + @contextmanager |
| 32 | + def create_tornado_server(port: int): |
| 33 | + """Base context manager for creating a Tornado server in a thread""" |
| 34 | + ready_event = threading.Event() |
| 35 | + io_loop = None |
| 36 | + app = None |
| 37 | + io_thread = None |
| 38 | + |
| 39 | + def run_io_loop(): |
| 40 | + nonlocal io_loop, app |
| 41 | + io_loop = tornado.ioloop.IOLoop() |
| 42 | + io_loop.make_current() |
| 43 | + app = tornado.web.Application([(r"/", EchoWebsocketHandler)]) |
| 44 | + app.listen(port) |
| 45 | + ready_event.set() |
| 46 | + io_loop.start() |
| 47 | + |
| 48 | + io_thread = threading.Thread(target=run_io_loop) |
| 49 | + io_thread.start() |
| 50 | + ready_event.wait() |
| 51 | + |
| 52 | + try: |
| 53 | + yield io_loop, app, io_thread |
| 54 | + finally: |
| 55 | + io_loop.add_callback(io_loop.stop) |
| 56 | + if io_thread: |
| 57 | + io_thread.join(timeout=5) |
| 58 | + if io_thread.is_alive(): |
| 59 | + raise RuntimeError("IOLoop failed to stop") |
| 60 | + |
| 61 | + @contextmanager |
| 62 | + def tornado_server(port: int = 8001): |
| 63 | + """Simplified context manager that uses the base implementation""" |
| 64 | + with create_tornado_server(port) as (_io_loop, _app, _io_thread): |
| 65 | + yield |
69 | 66 |
|
70 | 67 |
|
| 68 | +@pytest.mark.skipif(os.environ.get("CSP_TEST_WEBSOCKET") is None, reason="'CSP_TEST_WEBSOCKET' env variable is not set") |
71 | 69 | class TestWebsocket:
|
72 | 70 | @pytest.fixture(scope="class", autouse=True)
|
73 | 71 | def setup_tornado(self, request):
|
@@ -263,34 +261,52 @@ def g():
|
263 | 261 | csp.run(g, starttime=datetime.now(pytz.UTC), endtime=timedelta(seconds=1), realtime=True)
|
264 | 262 |
|
265 | 263 | def test_dynamic_multiple_subscribers(self):
|
266 |
| - with tornado_server(): |
| 264 | + @csp.node |
| 265 | + def send_on_status(status: ts[Status], uri: str, val: str) -> ts[str]: |
| 266 | + if csp.ticked(status): |
| 267 | + if uri in status.msg and status.status_code == WebsocketStatus.ACTIVE.value: |
| 268 | + return val |
267 | 269 |
|
268 |
| - @csp.graph |
269 |
| - def g(): |
270 |
| - ws = WebsocketAdapterManager(dynamic=True) |
271 |
| - conn_request1 = csp.const( |
272 |
| - ConnectionRequest(uri="ws://localhost:8000/", on_connect_payload="hey world from 8000") |
273 |
| - ) |
274 |
| - conn_request2 = csp.const( |
275 |
| - ConnectionRequest(uri="ws://localhost:8001/", on_connect_payload="hey world from 8001") |
276 |
| - ) |
277 |
| - recv = ws.subscribe(str, RawTextMessageMapper(), connection_request=conn_request1) |
278 |
| - recv2 = ws.subscribe(str, RawTextMessageMapper(), connection_request=conn_request2) |
279 |
| - |
280 |
| - csp.add_graph_output("recv", recv) |
281 |
| - csp.add_graph_output("recv2", recv2) |
282 |
| - |
283 |
| - merged = csp.flatten([recv, recv2]) |
284 |
| - stop = csp.filter(csp.count(merged) == 2, merged) |
285 |
| - csp.stop_engine(stop) |
286 |
| - |
287 |
| - msgs = csp.run(g, starttime=datetime.now(pytz.UTC), endtime=timedelta(seconds=5), realtime=True) |
288 |
| - assert len(msgs["recv"]) == 1 |
289 |
| - assert msgs["recv"][0][1].msg == "hey world from 8000" |
290 |
| - assert msgs["recv"][0][1].uri == "ws://localhost:8000/" |
291 |
| - assert len(msgs["recv2"]) == 1 |
292 |
| - assert msgs["recv2"][0][1].msg == "hey world from 8001" |
293 |
| - assert msgs["recv2"][0][1].uri == "ws://localhost:8001/" |
| 270 | + with tornado_server(): |
| 271 | + # We do this to only spawn the tornado server once for both options |
| 272 | + for use_on_connect_payload in [True, False]: |
| 273 | + |
| 274 | + @csp.graph |
| 275 | + def g(): |
| 276 | + ws = WebsocketAdapterManager(dynamic=True) |
| 277 | + if use_on_connect_payload: |
| 278 | + conn_request1 = csp.const( |
| 279 | + ConnectionRequest(uri="ws://localhost:8000/", on_connect_payload="hey world from 8000") |
| 280 | + ) |
| 281 | + conn_request2 = csp.const( |
| 282 | + ConnectionRequest(uri="ws://localhost:8001/", on_connect_payload="hey world from 8001") |
| 283 | + ) |
| 284 | + else: |
| 285 | + conn_request1 = csp.const(ConnectionRequest(uri="ws://localhost:8000/")) |
| 286 | + conn_request2 = csp.const(ConnectionRequest(uri="ws://localhost:8001/")) |
| 287 | + status = ws.status() |
| 288 | + to_send = send_on_status(status, "ws://localhost:8000/", "hey world from 8000") |
| 289 | + to_send2 = send_on_status(status, "ws://localhost:8001/", "hey world from 8001") |
| 290 | + ws.send(to_send, connection_request=conn_request1) |
| 291 | + ws.send(to_send2, connection_request=conn_request2) |
| 292 | + |
| 293 | + recv = ws.subscribe(str, RawTextMessageMapper(), connection_request=conn_request1) |
| 294 | + recv2 = ws.subscribe(str, RawTextMessageMapper(), connection_request=conn_request2) |
| 295 | + |
| 296 | + csp.add_graph_output("recv", recv) |
| 297 | + csp.add_graph_output("recv2", recv2) |
| 298 | + |
| 299 | + merged = csp.flatten([recv, recv2]) |
| 300 | + stop = csp.filter(csp.count(merged) == 2, merged) |
| 301 | + csp.stop_engine(stop) |
| 302 | + |
| 303 | + msgs = csp.run(g, starttime=datetime.now(pytz.UTC), endtime=timedelta(seconds=5), realtime=True) |
| 304 | + assert len(msgs["recv"]) == 1 |
| 305 | + assert msgs["recv"][0][1].msg == "hey world from 8000" |
| 306 | + assert msgs["recv"][0][1].uri == "ws://localhost:8000/" |
| 307 | + assert len(msgs["recv2"]) == 1 |
| 308 | + assert msgs["recv2"][0][1].msg == "hey world from 8001" |
| 309 | + assert msgs["recv2"][0][1].uri == "ws://localhost:8001/" |
294 | 310 |
|
295 | 311 | @pytest.mark.parametrize("dynamic", [False, True])
|
296 | 312 | def test_send_recv_burst_json(self, dynamic):
|
|
0 commit comments