Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wrapper for Celery().send_task to support behavior as Task.apply_async #2377

Merged
merged 18 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions sentry_sdk/integrations/celery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

try:
from celery import VERSION as CELERY_VERSION # type: ignore
from celery.app.task import Task # type: ignore
from celery.app.trace import task_has_custom
from celery.exceptions import ( # type: ignore
Ignore,
Expand Down Expand Up @@ -83,6 +84,7 @@ def setup_once():

_patch_build_tracer()
_patch_task_apply_async()
_patch_celery_send_task()
_patch_worker_exit()
_patch_producer_publish()

Expand Down Expand Up @@ -243,7 +245,7 @@ def __exit__(self, exc_type, exc_value, traceback):
return None


def _wrap_apply_async(f):
def _wrap_task_run(f):
# type: (F) -> F
@wraps(f)
@ensure_integration_enabled(CeleryIntegration, f)
Expand All @@ -260,14 +262,19 @@ def apply_async(*args, **kwargs):
if not propagate_traces:
return f(*args, **kwargs)

task = args[0]
if isinstance(args[0], Task):
task_name = args[0].name # type: str
elif len(args) > 1 and isinstance(args[1], str):
task_name = args[1]
else:
task_name = "<unknown Celery task>"

task_started_from_beat = sentry_sdk.get_isolation_scope()._name == "celery-beat"

span_mgr = (
sentry_sdk.start_span(
op=OP.QUEUE_SUBMIT_CELERY,
description=task.name,
description=task_name,
origin=CeleryIntegration.origin,
)
if not task_started_from_beat
Expand Down Expand Up @@ -437,9 +444,14 @@ def sentry_build_tracer(name, task, *args, **kwargs):

def _patch_task_apply_async():
# type: () -> None
from celery.app.task import Task # type: ignore
Task.apply_async = _wrap_task_run(Task.apply_async)


def _patch_celery_send_task():
# type: () -> None
from celery import Celery

Task.apply_async = _wrap_apply_async(Task.apply_async)
Celery.send_task = _wrap_task_run(Celery.send_task)


def _patch_worker_exit():
Expand Down
52 changes: 50 additions & 2 deletions tests/integrations/celery/test_celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sentry_sdk import start_transaction, get_current_span
from sentry_sdk.integrations.celery import (
CeleryIntegration,
_wrap_apply_async,
_wrap_task_run,
)
from sentry_sdk.integrations.celery.beat import _get_headers
from tests.conftest import ApproxDict
Expand Down Expand Up @@ -568,7 +568,7 @@ def dummy_function(*args, **kwargs):
assert "sentry-trace" in headers
assert "baggage" in headers

wrapped = _wrap_apply_async(dummy_function)
wrapped = _wrap_task_run(dummy_function)
wrapped(mock.MagicMock(), (), headers={})


Expand Down Expand Up @@ -783,3 +783,51 @@ def task(): ...
assert span["origin"] == "auto.queue.celery"

monkeypatch.setattr(kombu.messaging.Producer, "_publish", old_publish)


@pytest.mark.forked
@mock.patch("celery.Celery.send_task")
def test_send_task_wrapped(
patched_send_task,
sentry_init,
capture_events,
reset_integrations,
):
sentry_init(integrations=[CeleryIntegration()], enable_tracing=True)
celery = Celery(__name__, broker="redis://example.com") # noqa: E231

events = capture_events()

with sentry_sdk.start_transaction(name="custom_transaction"):
celery.send_task("very_creative_task_name", args=(1, 2), kwargs={"foo": "bar"})

(call,) = patched_send_task.call_args_list # We should have exactly one call
(args, kwargs) = call

assert args == (celery, "very_creative_task_name")
assert kwargs["args"] == (1, 2)
assert kwargs["kwargs"] == {"foo": "bar"}
assert set(kwargs["headers"].keys()) == {
"sentry-task-enqueued-time",
"sentry-trace",
"baggage",
"headers",
}
assert set(kwargs["headers"]["headers"].keys()) == {
"sentry-trace",
"baggage",
"sentry-task-enqueued-time",
}
assert (
kwargs["headers"]["sentry-trace"]
== kwargs["headers"]["headers"]["sentry-trace"]
)

(event,) = events # We should have exactly one event (the transaction)
assert event["type"] == "transaction"
assert event["transaction"] == "custom_transaction"

(span,) = event["spans"] # We should have exactly one span
assert span["description"] == "very_creative_task_name"
assert span["op"] == "queue.submit.celery"
assert span["trace_id"] == kwargs["headers"]["sentry-trace"].split("-")[0]
3 changes: 2 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,9 @@ deps =
celery-v5.4: Celery~=5.4.0
celery-latest: Celery

{py3.7}-celery: importlib-metadata<5.0
{py3.6,py3.7,py3.8,py3.9,py3.10,py3.11,py3.12}-celery: newrelic
celery: pytest<7
{py3.7}-celery: importlib-metadata<5.0

# Chalice
chalice-v1.16: chalice~=1.16.0
Expand Down
Loading