|
23 | 23 | RawTextMessageMapper,
|
24 | 24 | Status,
|
25 | 25 | WebsocketAdapterManager,
|
| 26 | + WebsocketStatus, |
26 | 27 | )
|
27 | 28 |
|
28 | 29 | class EchoWebsocketHandler(tornado.websocket.WebSocketHandler):
|
@@ -263,34 +264,52 @@ def g():
|
263 | 264 | csp.run(g, starttime=datetime.now(pytz.UTC), endtime=timedelta(seconds=1), realtime=True)
|
264 | 265 |
|
265 | 266 | def test_dynamic_multiple_subscribers(self):
|
266 |
| - with tornado_server(): |
| 267 | + @csp.node |
| 268 | + def send_on_status(status: ts[Status], uri: str, val: str) -> ts[str]: |
| 269 | + if csp.ticked(status): |
| 270 | + if uri in status.msg and status.status_code == WebsocketStatus.ACTIVE.value: |
| 271 | + return val |
267 | 272 |
|
268 |
| - @csp.graph |
269 |
| - def g(): |
270 |
| - ws = WebsocketAdapterManager(dynamic=True) |
271 |
| - conn_request1 = csp.const( |
272 |
| - ConnectionRequest(uri="ws://localhost:8000/", on_connect_payload="hey world from 8000") |
273 |
| - ) |
274 |
| - conn_request2 = csp.const( |
275 |
| - ConnectionRequest(uri="ws://localhost:8001/", on_connect_payload="hey world from 8001") |
276 |
| - ) |
277 |
| - recv = ws.subscribe(str, RawTextMessageMapper(), connection_request=conn_request1) |
278 |
| - recv2 = ws.subscribe(str, RawTextMessageMapper(), connection_request=conn_request2) |
279 |
| - |
280 |
| - csp.add_graph_output("recv", recv) |
281 |
| - csp.add_graph_output("recv2", recv2) |
282 |
| - |
283 |
| - merged = csp.flatten([recv, recv2]) |
284 |
| - stop = csp.filter(csp.count(merged) == 2, merged) |
285 |
| - csp.stop_engine(stop) |
286 |
| - |
287 |
| - msgs = csp.run(g, starttime=datetime.now(pytz.UTC), endtime=timedelta(seconds=5), realtime=True) |
288 |
| - assert len(msgs["recv"]) == 1 |
289 |
| - assert msgs["recv"][0][1].msg == "hey world from 8000" |
290 |
| - assert msgs["recv"][0][1].uri == "ws://localhost:8000/" |
291 |
| - assert len(msgs["recv2"]) == 1 |
292 |
| - assert msgs["recv2"][0][1].msg == "hey world from 8001" |
293 |
| - assert msgs["recv2"][0][1].uri == "ws://localhost:8001/" |
| 273 | + with tornado_server(): |
| 274 | + # We do this to only spawn the tornado server once for both options |
| 275 | + for use_on_connect_payload in [True, False]: |
| 276 | + |
| 277 | + @csp.graph |
| 278 | + def g(): |
| 279 | + ws = WebsocketAdapterManager(dynamic=True) |
| 280 | + if use_on_connect_payload: |
| 281 | + conn_request1 = csp.const( |
| 282 | + ConnectionRequest(uri="ws://localhost:8000/", on_connect_payload="hey world from 8000") |
| 283 | + ) |
| 284 | + conn_request2 = csp.const( |
| 285 | + ConnectionRequest(uri="ws://localhost:8001/", on_connect_payload="hey world from 8001") |
| 286 | + ) |
| 287 | + else: |
| 288 | + conn_request1 = csp.const(ConnectionRequest(uri="ws://localhost:8000/")) |
| 289 | + conn_request2 = csp.const(ConnectionRequest(uri="ws://localhost:8001/")) |
| 290 | + status = ws.status() |
| 291 | + to_send = send_on_status(status, "ws://localhost:8000/", "hey world from 8000") |
| 292 | + to_send2 = send_on_status(status, "ws://localhost:8001/", "hey world from 8001") |
| 293 | + ws.send(to_send, connection_request=conn_request1) |
| 294 | + ws.send(to_send2, connection_request=conn_request2) |
| 295 | + |
| 296 | + recv = ws.subscribe(str, RawTextMessageMapper(), connection_request=conn_request1) |
| 297 | + recv2 = ws.subscribe(str, RawTextMessageMapper(), connection_request=conn_request2) |
| 298 | + |
| 299 | + csp.add_graph_output("recv", recv) |
| 300 | + csp.add_graph_output("recv2", recv2) |
| 301 | + |
| 302 | + merged = csp.flatten([recv, recv2]) |
| 303 | + stop = csp.filter(csp.count(merged) == 2, merged) |
| 304 | + csp.stop_engine(stop) |
| 305 | + |
| 306 | + msgs = csp.run(g, starttime=datetime.now(pytz.UTC), endtime=timedelta(seconds=5), realtime=True) |
| 307 | + assert len(msgs["recv"]) == 1 |
| 308 | + assert msgs["recv"][0][1].msg == "hey world from 8000" |
| 309 | + assert msgs["recv"][0][1].uri == "ws://localhost:8000/" |
| 310 | + assert len(msgs["recv2"]) == 1 |
| 311 | + assert msgs["recv2"][0][1].msg == "hey world from 8001" |
| 312 | + assert msgs["recv2"][0][1].uri == "ws://localhost:8001/" |
294 | 313 |
|
295 | 314 | @pytest.mark.parametrize("dynamic", [False, True])
|
296 | 315 | def test_send_recv_burst_json(self, dynamic):
|
|
0 commit comments