1+ import pika
2+ from contextlib import contextmanager
3+ from .tracer import RequestIdContext , trace_id_ctx
4+
5+ @contextmanager
6+ def rabbitmq_trace_context (channel , properties ):
7+ request_id = properties .headers .get ("X-Request-ID" ) if properties .headers else None
8+ with RequestIdContext (request_id ):
9+ yield
10+ if properties .headers is None :
11+ properties .headers = {}
12+ properties .headers ["X-Request-ID" ] = trace_id_ctx .get ()
13+
14+ class RabbitMQMiddleware :
15+ def __init__ (self , connection_parameters ):
16+ self .connection = pika .BlockingConnection (connection_parameters )
17+ self .channel = self .connection .channel ()
18+
19+ def publish (self , exchange , routing_key , body , properties = None ):
20+ if properties is None :
21+ properties = pika .BasicProperties ()
22+ with rabbitmq_trace_context (self .channel , properties ):
23+ self .channel .basic_publish (exchange = exchange , routing_key = routing_key , body = body , properties = properties )
24+
25+ def consume (self , queue , on_message_callback , auto_ack = False ):
26+ def callback (ch , method , properties , body ):
27+ with rabbitmq_trace_context (ch , properties ):
28+ on_message_callback (ch , method , properties , body )
29+ self .channel .basic_consume (queue = queue , on_message_callback = callback , auto_ack = auto_ack )
30+ self .channel .start_consuming ()
0 commit comments