Skip to content

Commit

Permalink
Ensure log_event of non-msgpack serializable object do not kill serve…
Browse files Browse the repository at this point in the history
…rs (#7472)

Co-authored-by: Thomas Grainger <[email protected]>
  • Loading branch information
fjetter and graingert authored May 11, 2023
1 parent 2d80271 commit b68d71d
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 2 deletions.
5 changes: 5 additions & 0 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from dask.widgets import get_template

from distributed.core import ErrorMessage
from distributed.protocol.serialize import _is_dumpable
from distributed.utils import wait_for

try:
Expand Down Expand Up @@ -4412,6 +4413,10 @@ def log_event(self, topic: str | Collection[str], msg: Any):
>>> from time import time
>>> client.log_event("current-time", time())
"""
if not _is_dumpable(msg):
raise TypeError(
f"Message must be msgpack serializable. Got {type(msg)=} instead."
)
return self.sync(self.scheduler.log_event, topic=topic, msg=msg)

def get_events(self, topic: str | None = None):
Expand Down
19 changes: 19 additions & 0 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from distributed.process import AsyncProcess
from distributed.proctitle import enable_proctitle_on_children
from distributed.protocol import pickle
from distributed.protocol.serialize import _is_dumpable
from distributed.security import Security
from distributed.utils import (
get_ip,
Expand Down Expand Up @@ -610,6 +611,24 @@ async def _log_event(self, topic, msg):
)

def log_event(self, topic, msg):
"""Log an event under a given topic
Parameters
----------
topic : str, list[str]
Name of the topic under which to log an event. To log the same
event under multiple topics, pass a list of topic names.
msg
Event message to log. Note this must be msgpack serializable.
See also
--------
Client.log_event
"""
if not _is_dumpable(msg):
raise TypeError(
f"Message must be msgpack serializable. Got {type(msg)=} instead."
)
self._ongoing_background_tasks.call_soon(self._log_event, topic, msg)


Expand Down
21 changes: 21 additions & 0 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,27 @@ def _is_msgpack_serializable(v):
)


def _is_dumpable(v):
typ = type(v)
return (
v is None
or typ is str
or typ is bool
or typ is bytes
or typ is int
or typ is float
or typ is Pickled
or typ is Serialize
or typ is Serialized
or typ is ToPickle
or isinstance(v, dict)
and all(map(_is_dumpable, v.values()))
and all(type(x) is str for x in v.keys())
or isinstance(v, (list, tuple))
and all(map(_is_dumpable, v))
)


class ObjectDictSerializer:
def __init__(self, serializer):
self.serializer = serializer
Expand Down
14 changes: 14 additions & 0 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7818,6 +7818,20 @@ async def get_worker_logs(self, n=None, workers=None, nanny=False):
return results

def log_event(self, topic: str | Collection[str], msg: Any) -> None:
"""Log an event under a given topic
Parameters
----------
topic : str, list[str]
Name of the topic under which to log an event. To log the same
event under multiple topics, pass a list of topic names.
msg
Event message to log. Note this must be msgpack serializable.
See also
--------
Client.log_event
"""
event = (time(), msg)
if not isinstance(topic, str):
for t in topic:
Expand Down
28 changes: 27 additions & 1 deletion distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
from distributed.cluster_dump import load_cluster_dump
from distributed.comm import CommClosedError
from distributed.compatibility import LINUX, WINDOWS
from distributed.core import Status
from distributed.core import Status, error_message
from distributed.diagnostics.plugin import WorkerPlugin
from distributed.metrics import time
from distributed.scheduler import CollectTaskMetaDataPlugin, KilledWorker, Scheduler
Expand Down Expand Up @@ -7550,6 +7550,32 @@ def no_category():
await c.submit(no_category)


@gen_cluster(client=True, nthreads=[])
async def test_log_event_msgpack(c, s, a, b):
await c.log_event("test-topic", "foo")
with pytest.raises(TypeError, match="msgpack"):

class C:
pass

await c.log_event("test-topic", C())
await c.log_event("test-topic", "bar")
await c.log_event("test-topic", error_message(Exception()))

# assertion reversed for mock.ANY.__eq__(Serialized())
assert [
"foo",
"bar",
{
"status": "error",
"exception": mock.ANY,
"traceback": mock.ANY,
"exception_text": "Exception()",
"traceback_text": "",
},
] == [msg[1] for msg in s.get_events("test-topic")]


@gen_cluster(client=True)
async def test_log_event_warn_dask_warns(c, s, a, b):
from dask.distributed import warn
Expand Down
32 changes: 31 additions & 1 deletion distributed/tests/test_nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from distributed import Nanny, Scheduler, Worker, profile, rpc, wait, worker
from distributed.compatibility import LINUX, WINDOWS
from distributed.core import CommClosedError, Status
from distributed.core import CommClosedError, Status, error_message
from distributed.diagnostics import SchedulerPlugin
from distributed.diagnostics.plugin import WorkerPlugin
from distributed.metrics import time
Expand Down Expand Up @@ -751,3 +751,33 @@ async def test_worker_inherits_temp_config(c, s):
async with Nanny(s.address):
out = await c.submit(lambda: dask.config.get("test123"))
assert out == 123


@gen_cluster(client=True, nthreads=[])
async def test_log_event(c, s):
async with Nanny(s.address) as n:
n.log_event("test-topic1", "foo")

class C:
pass

with pytest.raises(TypeError, match="msgpack"):
n.log_event("test-topic2", C())
n.log_event("test-topic3", "bar")
n.log_event("test-topic4", error_message(Exception()))

# Worker unaffected
assert await c.submit(lambda x: x + 1, 1) == 2

assert [msg[1] for msg in s.get_events("test-topic1")] == ["foo"]
assert [msg[1] for msg in s.get_events("test-topic3")] == ["bar"]
# assertion reversed for mock.ANY.__eq__(Serialized())
assert [
{
"status": "error",
"exception": mock.ANY,
"traceback": mock.ANY,
"exception_text": "Exception()",
"traceback_text": "",
},
] == [msg[1] for msg in s.get_events("test-topic4")]
33 changes: 33 additions & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,39 @@ async def test_dont_overlap_communications_to_same_worker(c, s, a, b):
assert l1["stop"] < l2["start"]


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_log_event(c, s, a):
def log_event(msg):
w = get_worker()
w.log_event("test-topic", msg)

await c.submit(log_event, "foo")

class C:
pass

with pytest.raises(TypeError, match="msgpack"):
await c.submit(log_event, C())

# Worker still works
await c.submit(log_event, "bar")
await c.submit(log_event, error_message(Exception()))

# assertion reversed for mock.ANY.__eq__(Serialized())
assert [
"foo",
"bar",
{
"status": "error",
"exception": mock.ANY,
"traceback": mock.ANY,
"exception_text": "Exception()",
"traceback_text": "",
"worker": a.address,
},
] == [msg[1] for msg in s.get_events("test-topic")]


@gen_cluster(client=True)
async def test_log_exception_on_failed_task(c, s, a, b):
with captured_logger("distributed.worker") as logger:
Expand Down
19 changes: 19 additions & 0 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
from distributed.protocol import pickle, to_serialize
from distributed.protocol.serialize import _is_dumpable
from distributed.pubsub import PubSubWorkerExtension
from distributed.security import Security
from distributed.shuffle import ShuffleWorkerExtension
Expand Down Expand Up @@ -957,6 +958,24 @@ def logs(self):
return self._deque_handler.deque

def log_event(self, topic: str | Collection[str], msg: Any) -> None:
"""Log an event under a given topic
Parameters
----------
topic : str, list[str]
Name of the topic under which to log an event. To log the same
event under multiple topics, pass a list of topic names.
msg
Event message to log. Note this must be msgpack serializable.
See also
--------
Client.log_event
"""
if not _is_dumpable(msg):
raise TypeError(
f"Message must be msgpack serializable. Got {type(msg)=} instead."
)
full_msg = {
"op": "log-event",
"topic": topic,
Expand Down

0 comments on commit b68d71d

Please sign in to comment.