Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Easy type hints in synapse.logging.opentracing #12894

Merged
merged 5 commits into from
May 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12894.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type annotations to `synapse.logging.opentracing`.
6 changes: 4 additions & 2 deletions synapse/config/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Set
from typing import Any, List, Set

from synapse.types import JsonDict
from synapse.util.check_dependencies import DependencyException, check_requirements
Expand Down Expand Up @@ -49,7 +49,9 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:

# The tracer is enabled so sanitize the config

self.opentracer_whitelist = opentracing_config.get("homeserver_whitelist", [])
self.opentracer_whitelist: List[str] = opentracing_config.get(
"homeserver_whitelist", []
)
if not isinstance(self.opentracer_whitelist, list):
raise ConfigError("Tracer homeserver_whitelist config is malformed")

Expand Down
114 changes: 65 additions & 49 deletions synapse/logging/opentracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,24 @@ def set_fates(clotho, lachesis, atropos, father="Zues", mother="Themis"):
import logging
import re
from functools import wraps
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Pattern, Type
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
Dict,
Generator,
Iterable,
List,
Optional,
Pattern,
Type,
TypeVar,
Union,
)

import attr
from typing_extensions import ParamSpec

from twisted.internet import defer
from twisted.web.http import Request
Expand Down Expand Up @@ -256,7 +271,7 @@ class _WrappedRustReporter(BaseReporter):
def set_process(self, *args, **kwargs):
return self._reporter.set_process(*args, **kwargs)

def report_span(self, span):
def report_span(self, span: "opentracing.Span") -> None:
try:
return self._reporter.report_span(span)
except Exception:
Expand Down Expand Up @@ -307,15 +322,19 @@ class SynapseBaggage:
Sentinel = object()


def only_if_tracing(func):
P = ParamSpec("P")
R = TypeVar("R")


def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
"""Executes the function only if we're tracing. Otherwise returns None."""

@wraps(func)
def _only_if_tracing_inner(*args, **kwargs):
def _only_if_tracing_inner(*args: P.args, **kwargs: P.kwargs) -> Optional[R]:
if opentracing:
return func(*args, **kwargs)
else:
return
return None

return _only_if_tracing_inner

Expand Down Expand Up @@ -356,17 +375,10 @@ def ensure_active_span_inner_2(*args, **kwargs):
return ensure_active_span_inner_1


@contextlib.contextmanager
def noop_context_manager(*args, **kwargs):
"""Does exactly what it says on the tin"""
# TODO: replace with contextlib.nullcontext once we drop support for Python 3.6
yield


# Setup


def init_tracer(hs: "HomeServer"):
def init_tracer(hs: "HomeServer") -> None:
"""Set the whitelists and initialise the JaegerClient tracer"""
global opentracing
if not hs.config.tracing.opentracer_enabled:
Expand Down Expand Up @@ -408,11 +420,11 @@ def init_tracer(hs: "HomeServer"):


@only_if_tracing
def set_homeserver_whitelist(homeserver_whitelist):
def set_homeserver_whitelist(homeserver_whitelist: Iterable[str]) -> None:
"""Sets the homeserver whitelist

Args:
homeserver_whitelist (Iterable[str]): regex of whitelisted homeservers
homeserver_whitelist: regexes specifying whitelisted homeservers
"""
global _homeserver_whitelist
if homeserver_whitelist:
Expand All @@ -423,15 +435,15 @@ def set_homeserver_whitelist(homeserver_whitelist):


@only_if_tracing
def whitelisted_homeserver(destination):
def whitelisted_homeserver(destination: str) -> bool:
"""Checks if a destination matches the whitelist

Args:
destination (str)
destination
"""

if _homeserver_whitelist:
return _homeserver_whitelist.match(destination)
return _homeserver_whitelist.match(destination) is not None
return False


Expand All @@ -457,11 +469,11 @@ def start_active_span(
Args:
See opentracing.tracer
Returns:
scope (Scope) or noop_context_manager
scope (Scope) or contextlib.nullcontext
"""

if opentracing is None:
return noop_context_manager() # type: ignore[unreachable]
return contextlib.nullcontext() # type: ignore[unreachable]

if tracer is None:
# use the global tracer by default
Expand Down Expand Up @@ -505,7 +517,7 @@ def start_active_span_follows_from(
tracer: override the opentracing tracer. By default the global tracer is used.
"""
if opentracing is None:
return noop_context_manager() # type: ignore[unreachable]
return contextlib.nullcontext() # type: ignore[unreachable]

references = [opentracing.follows_from(context) for context in contexts]
scope = start_active_span(
Expand All @@ -525,27 +537,27 @@ def start_active_span_follows_from(


def start_active_span_from_edu(
edu_content,
operation_name,
references: Optional[list] = None,
tags=None,
start_time=None,
ignore_active_span=False,
finish_on_close=True,
):
edu_content: Dict[str, Any],
operation_name: str,
references: Optional[List["opentracing.Reference"]] = None,
tags: Optional[Dict] = None,
start_time: Optional[float] = None,
ignore_active_span: bool = False,
finish_on_close: bool = True,
) -> "opentracing.Scope":
"""
Extracts a span context from an edu and uses it to start a new active span

Args:
edu_content (dict): and edu_content with a `context` field whose value is
edu_content: an edu_content with a `context` field whose value is
canonical json for a dict which contains opentracing information.

For the other args see opentracing.tracer
"""
references = references or []

if opentracing is None:
return noop_context_manager() # type: ignore[unreachable]
return contextlib.nullcontext() # type: ignore[unreachable]

carrier = json_decoder.decode(edu_content.get("context", "{}")).get(
"opentracing", {}
Expand Down Expand Up @@ -578,27 +590,27 @@ def start_active_span_from_edu(

# Opentracing setters for tags, logs, etc
@only_if_tracing
def active_span():
def active_span() -> Optional["opentracing.Span"]:
"""Get the currently active span, if any"""
return opentracing.tracer.active_span


@ensure_active_span("set a tag")
def set_tag(key, value):
def set_tag(key: str, value: Union[str, bool, int, float]) -> None:
"""Sets a tag on the active span"""
assert opentracing.tracer.active_span is not None
opentracing.tracer.active_span.set_tag(key, value)


@ensure_active_span("log")
def log_kv(key_values, timestamp=None):
def log_kv(key_values: Dict[str, Any], timestamp: Optional[float] = None) -> None:
"""Log to the active span"""
assert opentracing.tracer.active_span is not None
opentracing.tracer.active_span.log_kv(key_values, timestamp)


@ensure_active_span("set the traces operation name")
def set_operation_name(operation_name):
def set_operation_name(operation_name: str) -> None:
"""Sets the operation name of the active span"""
assert opentracing.tracer.active_span is not None
opentracing.tracer.active_span.set_operation_name(operation_name)
Expand All @@ -624,7 +636,9 @@ def force_tracing(span=Sentinel) -> None:
span.set_baggage_item(SynapseBaggage.FORCE_TRACING, "1")


def is_context_forced_tracing(span_context) -> bool:
def is_context_forced_tracing(
span_context: Optional["opentracing.SpanContext"],
) -> bool:
"""Check if sampling has been force for the given span context."""
if span_context is None:
return False
Expand Down Expand Up @@ -696,13 +710,13 @@ def inject_response_headers(response_headers: Headers) -> None:


@ensure_active_span("get the active span context as a dict", ret={})
def get_active_span_text_map(destination=None):
def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]:
"""
Gets a span context as a dict. This can be used instead of manually
injecting a span into an empty carrier.

Args:
destination (str): the name of the remote server.
destination: the name of the remote server.

Returns:
dict: the active span's context if opentracing is enabled, otherwise empty.
Expand All @@ -721,7 +735,7 @@ def get_active_span_text_map(destination=None):


@ensure_active_span("get the span context as a string.", ret={})
def active_span_context_as_string():
def active_span_context_as_string() -> str:
"""
Returns:
The active span context encoded as a string.
Expand Down Expand Up @@ -750,21 +764,21 @@ def span_context_from_request(request: Request) -> "Optional[opentracing.SpanCon


@only_if_tracing
def span_context_from_string(carrier):
def span_context_from_string(carrier: str) -> Optional["opentracing.SpanContext"]:
"""
Returns:
The active span context decoded from a string.
"""
carrier = json_decoder.decode(carrier)
return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
payload: Dict[str, str] = json_decoder.decode(carrier)
return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, payload)


@only_if_tracing
def extract_text_map(carrier):
def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanContext"]:
"""
Wrapper method for opentracing's tracer.extract for TEXT_MAP.
Args:
carrier (dict): a dict possibly containing a span context.
carrier: a dict possibly containing a span context.

Returns:
The active span context extracted from carrier.
Expand Down Expand Up @@ -843,7 +857,7 @@ def err_back(result):
return decorator


def tag_args(func):
def tag_args(func: Callable[P, R]) -> Callable[P, R]:
"""
Tags all of the args to the active span.
"""
Expand All @@ -852,19 +866,21 @@ def tag_args(func):
return func

@wraps(func)
def _tag_args_inner(*args, **kwargs):
def _tag_args_inner(*args: P.args, **kwargs: P.kwargs) -> R:
argspec = inspect.getfullargspec(func)
for i, arg in enumerate(argspec.args[1:]):
set_tag("ARG_" + arg, args[i])
set_tag("args", args[len(argspec.args) :])
set_tag("ARG_" + arg, args[i]) # type: ignore[index]
set_tag("args", args[len(argspec.args) :]) # type: ignore[index]
set_tag("kwargs", kwargs)
return func(*args, **kwargs)

return _tag_args_inner


@contextlib.contextmanager
def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
def trace_servlet(
request: "SynapseRequest", extract_context: bool = False
) -> Generator[None, None, None]:
"""Returns a context manager which traces a request. It starts a span
with some servlet specific tags such as the request metrics name and
request information.
Expand Down
9 changes: 3 additions & 6 deletions synapse/metrics/background_process_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import logging
import threading
from contextlib import nullcontext
from functools import wraps
from types import TracebackType
from typing import (
Expand Down Expand Up @@ -41,11 +42,7 @@
LoggingContext,
PreserveLoggingContext,
)
from synapse.logging.opentracing import (
SynapseTags,
noop_context_manager,
start_active_span,
)
from synapse.logging.opentracing import SynapseTags, start_active_span
from synapse.metrics._types import Collector

if TYPE_CHECKING:
Expand Down Expand Up @@ -238,7 +235,7 @@ async def run() -> Optional[R]:
f"bgproc.{desc}", tags={SynapseTags.REQUEST_ID: str(context)}
)
else:
ctx = noop_context_manager()
ctx = nullcontext()
with ctx:
return await func(*args, **kwargs)
except Exception:
Expand Down