diff --git a/synapse/logging/tracing.py b/synapse/logging/tracing.py index 109ae185e1..552f5f8504 100644 --- a/synapse/logging/tracing.py +++ b/synapse/logging/tracing.py @@ -166,6 +166,7 @@ from functools import wraps from typing import ( TYPE_CHECKING, Any, + Awaitable, Callable, ContextManager, Dict, @@ -789,67 +790,108 @@ def extract_text_map( # Tracing decorators +def create_decorator( + func: Callable[P, R], + # TODO: What is the correct type for these `Any`? `P.args, P.kwargs` isn't allowed here + wrapping_logic: Callable[[Callable[P, R], Any, Any], ContextManager[None]], +) -> Callable[P, R]: + """ + Creates a decorator that is able to handle sync functions, async functions + (coroutines), and inlineDeferred from Twisted. + + Example usage: + ```py + # Decorator to time the functiona and log it out + def duration(func: Callable[P, R]) -> Callable[P, R]: + @contextlib.contextmanager + def _wrapping_logic(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs): + start_ts = time.time() + yield + end_ts = time.time() + duration = end_ts - start_ts + logger.info("%s took %s seconds", func.__name__, duration) + + return create_decorator(func, _wrapping_logic) + ``` + + Args: + func: The function to be decorated + wrapping_logic: The business logic of your custom decorator. + This should be a ContextManager so you are able to run your logic + before/after the function as desired. + """ + + @wraps(func) + async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + if inspect.iscoroutinefunction(func): + with wrapping_logic(func, *args, **kwargs): + return await func(*args, **kwargs) + else: + # The other case here handles both sync functions and those + # decorated with inlineDeferred. + scope = wrapping_logic(func, *args, **kwargs) + scope.__enter__() + + try: + result = func(*args, **kwargs) + if isinstance(result, defer.Deferred): + + def call_back(result: R) -> R: + scope.__exit__(None, None, None) + return result + + def err_back(result: R) -> R: + scope.__exit__(None, None, None) + return result + + result.addCallbacks(call_back, err_back) + + else: + if inspect.isawaitable(result): + logger.error( + "@trace may not have wrapped %s correctly! " + "The function is not async but returned a %s.", + func.__qualname__, + type(result).__name__, + ) + + scope.__exit__(None, None, None) + + return result + + except Exception as e: + scope.__exit__(type(e), None, e.__traceback__) + raise + + return _wrapper # type: ignore[return-value] + + def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]: """ Decorator to trace a function with a custom opname. See the module's doc string for usage examples. - """ - def decorator(func: Callable[P, R]) -> Callable[P, R]: + @contextlib.contextmanager + def _wrapping_logic(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs): if opentelemetry is None: - return func # type: ignore[unreachable] + return None - if inspect.iscoroutinefunction(func): + scope = start_active_span(opname) + scope.__enter__() + try: + yield + except Exception as e: + scope.__exit__(type(e), None, e.__traceback__) + raise + finally: + scope.__exit__(None, None, None) - @wraps(func) - async def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R: - with start_active_span(opname): - return await func(*args, **kwargs) # type: ignore[misc] + def _decorator(func: Callable[P, R]): + return create_decorator(func, _wrapping_logic) - else: - # The other case here handles both sync functions and those - # decorated with inlineDeferred. - @wraps(func) - def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R: - scope = start_active_span(opname) - scope.__enter__() - - try: - result = func(*args, **kwargs) - if isinstance(result, defer.Deferred): - - def call_back(result: R) -> R: - scope.__exit__(None, None, None) - return result - - def err_back(result: R) -> R: - scope.__exit__(None, None, None) - return result - - result.addCallbacks(call_back, err_back) - - else: - if inspect.isawaitable(result): - logger.error( - "@trace may not have wrapped %s correctly! " - "The function is not async but returned a %s.", - func.__qualname__, - type(result).__name__, - ) - - scope.__exit__(None, None, None) - - return result - - except Exception as e: - scope.__exit__(type(e), None, e.__traceback__) - raise - - return _trace_inner # type: ignore[return-value] - - return decorator + return _decorator def trace(func: Callable[P, R]) -> Callable[P, R]: @@ -866,22 +908,22 @@ def trace(func: Callable[P, R]) -> Callable[P, R]: def tag_args(func: Callable[P, R]) -> Callable[P, R]: """ - Tags all of the args to the active span. + Decorator to tag all of the args to the active span. """ if not opentelemetry: return func - @wraps(func) - def _tag_args_inner(*args: P.args, **kwargs: P.kwargs) -> R: + @contextlib.contextmanager + def _wrapping_logic(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs): argspec = inspect.getfullargspec(func) for i, arg in enumerate(argspec.args[1:]): set_attribute("ARG_" + arg, str(args[i + 1])) # type: ignore[index] set_attribute("args", str(args[len(argspec.args) :])) # type: ignore[index] set_attribute("kwargs", str(kwargs)) - return func(*args, **kwargs) + yield - return _tag_args_inner + return create_decorator(func, _wrapping_logic) @contextlib.contextmanager diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index eec55b6478..22e72a31de 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -33,6 +33,7 @@ from synapse.api.constants import MAX_DEPTH, EventTypes from synapse.api.errors import StoreError from synapse.api.room_versions import EventFormatVersions, RoomVersion from synapse.events import EventBase, make_event_from_dict +from synapse.logging.tracing import tag_args, trace from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -709,6 +710,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # Return all events where not all sets can reach them. return {eid for eid, n in event_to_missing_sets.items() if n} + @trace + @tag_args async def get_oldest_event_ids_with_depth_in_room( self, room_id: str ) -> List[Tuple[str, int]]: