Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 64 additions & 38 deletions logfire/_internal/instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import inspect
import warnings
from collections.abc import Iterable, Sequence
from contextlib import AbstractContextManager, asynccontextmanager, contextmanager
from contextlib import AbstractContextManager, asynccontextmanager, contextmanager, nullcontext
from typing import TYPE_CHECKING, Any, Callable, TypeVar

from opentelemetry import trace
from opentelemetry.util import types as otel_types
from typing_extensions import LiteralString, ParamSpec
from typing_extensions import Concatenate, LiteralString, ParamSpec

from .constants import ATTRIBUTES_MESSAGE_TEMPLATE_KEY, ATTRIBUTES_TAGS_KEY
from .stack_info import get_filepath_attribute
Expand Down Expand Up @@ -50,7 +51,10 @@ def instrument(
extract_args: bool | Iterable[str],
record_return: bool,
allow_generator: bool,
new_context: bool,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
from logfire.propagate import attach_context

from .main import set_user_attributes_on_raw_span

def decorator(func: Callable[P, R]) -> Callable[P, R]:
Expand All @@ -60,6 +64,13 @@ def decorator(func: Callable[P, R]) -> Callable[P, R]:
stacklevel=2,
)

if new_context:
context_manager = attach_context({} if new_context is True else new_context())
link_to_current = True
else:
link_to_current = False
context_manager = nullcontext()

attributes = get_attributes(func, msg_template, tags)
open_span = get_open_span(logfire, attributes, span_name, extract_args, func)

Expand All @@ -68,44 +79,52 @@ def decorator(func: Callable[P, R]) -> Callable[P, R]:
warnings.warn(GENERATOR_WARNING_MESSAGE, stacklevel=2)

def wrapper(*func_args: P.args, **func_kwargs: P.kwargs): # type: ignore
with open_span(*func_args, **func_kwargs):
yield from func(*func_args, **func_kwargs)
prev_context = trace.get_current_span().get_span_context() if link_to_current else None
with context_manager:
with open_span(prev_context, *func_args, **func_kwargs):
yield from func(*func_args, **func_kwargs)
elif inspect.isasyncgenfunction(func):
if not allow_generator:
warnings.warn(GENERATOR_WARNING_MESSAGE, stacklevel=2)

async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs): # type: ignore
with open_span(*func_args, **func_kwargs):
# `yield from` is invalid syntax in an async function.
# This loop is not quite equivalent, because `yield from` also handles things like
# sending values to the subgenerator.
# Fixing this would at least mean porting https://peps.python.org/pep-0380/#formal-semantics
# which is quite messy, and it's not clear if that would be correct based on
# https://discuss.python.org/t/yield-from-in-async-functions/47050.
# So instead we have an extra warning in the docs about this.
async for x in func(*func_args, **func_kwargs):
yield x
prev_context = trace.get_current_span().get_span_context() if link_to_current else None
with context_manager:
with open_span(prev_context, *func_args, **func_kwargs):
# `yield from` is invalid syntax in an async function.
# This loop is not quite equivalent, because `yield from` also handles things like
# sending values to the subgenerator.
# Fixing this would at least mean porting https://peps.python.org/pep-0380/#formal-semantics
# which is quite messy, and it's not clear if that would be correct based on
# https://discuss.python.org/t/yield-from-in-async-functions/47050.
# So instead we have an extra warning in the docs about this.
async for x in func(*func_args, **func_kwargs):
yield x

elif inspect.iscoroutinefunction(func):

async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs) -> R: # type: ignore
with open_span(*func_args, **func_kwargs) as span:
result = await func(*func_args, **func_kwargs)
if record_return:
# open_span returns a FastLogfireSpan, so we can't use span.set_attribute for complex types.
# This isn't great because it has to parse the JSON schema.
# Not sure if making get_open_span return a LogfireSpan when record_return is True
# would be faster overall or if it would be worth the added complexity.
set_user_attributes_on_raw_span(span._span, {'return': result})
return result
prev_context = trace.get_current_span().get_span_context() if link_to_current else None
with context_manager:
with open_span(prev_context, *func_args, **func_kwargs) as span:
result = await func(*func_args, **func_kwargs)
if record_return:
# open_span returns a FastLogfireSpan, so we can't use span.set_attribute for complex types.
# This isn't great because it has to parse the JSON schema.
# Not sure if making get_open_span return a LogfireSpan when record_return is True
# would be faster overall or if it would be worth the added complexity.
set_user_attributes_on_raw_span(span._span, {'return': result})
return result
else:
# Same as the above, but without the async/await
def wrapper(*func_args: P.args, **func_kwargs: P.kwargs) -> R:
with open_span(*func_args, **func_kwargs) as span:
result = func(*func_args, **func_kwargs)
if record_return:
set_user_attributes_on_raw_span(span._span, {'return': result})
return result
prev_context = trace.get_current_span().get_span_context() if link_to_current else None
with context_manager:
with open_span(prev_context, *func_args, **func_kwargs) as span:
result = func(*func_args, **func_kwargs)
if record_return:
set_user_attributes_on_raw_span(span._span, {'return': result})
return result

wrapper = functools.wraps(func)(wrapper) # type: ignore
return wrapper
Expand All @@ -119,28 +138,32 @@ def get_open_span(
span_name: str | None,
extract_args: bool | Iterable[str],
func: Callable[P, R],
) -> Callable[P, AbstractContextManager[Any]]:
) -> Callable[Concatenate[trace.SpanContext | None, P], AbstractContextManager[Any]]:
final_span_name: str = span_name or attributes[ATTRIBUTES_MESSAGE_TEMPLATE_KEY] # type: ignore

# This is the fast case for when there are no arguments to extract
def open_span(*_: P.args, **__: P.kwargs): # type: ignore
return logfire._fast_span(final_span_name, attributes) # type: ignore
def open_span(span_context: trace.SpanContext | None, *_: P.args, **__: P.kwargs): # type: ignore
span = logfire._fast_span(final_span_name, attributes) # type: ignore
if span_context is not None:
span._span.add_link(span_context) # pyright: ignore[reportPrivateUsage]
return span

if extract_args is True:
sig = inspect.signature(func)
if sig.parameters: # only extract args if there are any

def open_span(*func_args: P.args, **func_kwargs: P.kwargs):
def open_span(span_context: trace.SpanContext | None, *func_args: P.args, **func_kwargs: P.kwargs):
bound = sig.bind(*func_args, **func_kwargs)
bound.apply_defaults()
args_dict = bound.arguments
return logfire._instrument_span_with_args( # type: ignore
span = logfire._instrument_span_with_args( # type: ignore
final_span_name, attributes, args_dict
)
if span_context is not None:
span._span.add_link(span_context) # pyright: ignore[reportPrivateUsage]
return span

return open_span

if extract_args: # i.e. extract_args should be an iterable of argument names
elif extract_args: # i.e. extract_args should be an iterable of argument names
sig = inspect.signature(func)

if isinstance(extract_args, str):
Expand All @@ -157,17 +180,20 @@ def open_span(*func_args: P.args, **func_kwargs: P.kwargs):

if extract_args_final: # check that there are still arguments to extract

def open_span(*func_args: P.args, **func_kwargs: P.kwargs):
def open_span(span_context: trace.SpanContext | None, *func_args: P.args, **func_kwargs: P.kwargs):
bound = sig.bind(*func_args, **func_kwargs)
bound.apply_defaults()
args_dict = bound.arguments

# This line is the only difference from the extract_args=True case
args_dict = {k: args_dict[k] for k in extract_args_final}

return logfire._instrument_span_with_args( # type: ignore
span = logfire._instrument_span_with_args( # type: ignore
final_span_name, attributes, args_dict
)
if span_context is not None:
span._span.add_link(span_context) # pyright: ignore[reportPrivateUsage]
return span

return open_span

Expand Down
6 changes: 5 additions & 1 deletion logfire/_internal/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ def instrument(
extract_args: bool | Iterable[str] = True,
record_return: bool = False,
allow_generator: bool = False,
new_context: bool = False,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""Decorator for instrumenting a function as a span.

Expand All @@ -603,6 +604,7 @@ def my_function(a: int):
Ignored for generators.
allow_generator: Set to `True` to prevent a warning when instrumenting a generator function.
Read https://logfire.pydantic.dev/docs/guides/advanced/generators/#using-logfireinstrument first.
new_context: Set to `True` to clear context before starting instrumentation, and link back to the previous span.
"""

@overload
Expand All @@ -629,6 +631,7 @@ def instrument( # type: ignore[reportInconsistentOverload]
extract_args: bool | Iterable[str] = True,
record_return: bool = False,
allow_generator: bool = False,
new_context: bool = False,
) -> Callable[[Callable[P, R]], Callable[P, R]] | Callable[P, R]:
"""Decorator for instrumenting a function as a span.

Expand All @@ -652,11 +655,12 @@ def my_function(a: int):
Ignored for generators.
allow_generator: Set to `True` to prevent a warning when instrumenting a generator function.
Read https://logfire.pydantic.dev/docs/guides/advanced/generators/#using-logfireinstrument first.
new_context: Set to `True` to clear context before starting instrumentation, and link back to the previous span.
"""
if callable(msg_template):
return self.instrument()(msg_template)
return instrument(
self, tuple(self._tags), msg_template, span_name, extract_args, record_return, allow_generator
self, tuple(self._tags), msg_template, span_name, extract_args, record_return, allow_generator, new_context
)

def log(
Expand Down
Loading
Loading