Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix @trace not wrapping some state methods that return coroutines correctly #15647

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions synapse/storage/controllers/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,9 @@ async def get_state_groups(

@trace
@tag_args
def _get_state_groups_from_groups(
async def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter
) -> Awaitable[Dict[int, StateMap[str]]]:
MadLittleMods marked this conversation as resolved.
Show resolved Hide resolved
Comment on lines 175 to -180
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we somehow type this to catch it in CI before we see it in the logs?

  • def trace(func: Callable[P, R]) -> Callable[P, R]:
    """
    Decorator to trace a function.
    Sets the operation name to that of the function's name.
    See the module's doc string for usage examples.
    """
    return trace_with_opname(func.__name__)(func)
  • def _custom_sync_async_decorator(
    func: Callable[P, R],
    wrapping_logic: Callable[Concatenate[Callable[P, R], P], ContextManager[None]],
    ) -> Callable[P, R]:

Copy link
Contributor

@squahtx squahtx May 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The @trace implementation could probably be fixed up to handle this case.

if inspect.iscoroutine(result):
    def await_coroutine():
        try:
            await result
        finally:
            scope.__exit__(None, None, None)
    # the original method returned a coroutine,
    # so we create another coroutine wrapping it, that calls __exit__.
    return await_coroutine()

(wholly untested)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@squahtx Good shout! It looks plausible and I can introduce it in another PR ⏩

I know this PR wouldn't be necessary if that worked but this is the quick and easy fix and it aligns with the rest of what we're doing in synapse/storage/controllers/state.py (no more Awaitable)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tackling this in #15650

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I'm in favour of landing this PR.

) -> Dict[int, StateMap[str]]:
"""Returns the state groups for a given set of groups, filtering on
types of state events.

Expand All @@ -190,7 +190,9 @@ def _get_state_groups_from_groups(
Dict of state group to state map.
"""

return self.stores.state._get_state_groups_from_groups(groups, state_filter)
return await self.stores.state._get_state_groups_from_groups(
groups, state_filter
)

@trace
@tag_args
Expand Down Expand Up @@ -349,9 +351,9 @@ async def get_state_ids_for_event(

@trace
@tag_args
def get_state_for_groups(
async def get_state_for_groups(
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Awaitable[Dict[int, MutableStateMap[str]]]:
) -> Dict[int, MutableStateMap[str]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key

Expand All @@ -363,7 +365,7 @@ def get_state_for_groups(
Returns:
Dict of state group to state map.
"""
return self.stores.state._get_state_for_groups(
return await self.stores.state._get_state_for_groups(
groups, state_filter or StateFilter.all()
)

Expand Down