Draw: readd --extras, arbitrary resolutions
This commit is contained in:
@@ -3,7 +3,7 @@ import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pprint import pformat
|
||||
from typing import Awaitable, Callable, Collection, Optional, Tuple, cast
|
||||
from typing import Awaitable, Callable, Collection, Dict, List, Optional, Tuple, cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import dictdiffer
|
||||
@@ -73,13 +73,15 @@ def node(
|
||||
event: EventBase, suffix: Optional[str] = None, **kwargs: object
|
||||
) -> pydot.Node:
|
||||
if "label" not in kwargs:
|
||||
label = f"{event.event_id}\n{event.sender}: {(event.type,event.state_key)}"
|
||||
label = (
|
||||
f"{event.event_id}\n{event.sender}: {(event.type,event.get_state_key())}"
|
||||
)
|
||||
if event.type == "m.room.member":
|
||||
label += f" ({event.membership.upper()})"
|
||||
if suffix:
|
||||
label += f"\n{suffix}"
|
||||
kwargs["label"] = label
|
||||
type_to_shape = {} # {"m.room.member": "oval"}
|
||||
type_to_shape: Dict[str, str] = {} # {"m.room.member": "oval"}
|
||||
if event.type in type_to_shape:
|
||||
kwargs.setdefault("shape", type_to_shape[event.type])
|
||||
|
||||
@@ -97,9 +99,10 @@ def edge(source: EventBase, target: EventBase, **kwargs: object) -> pydot.Edge:
|
||||
|
||||
async def dump_mainlines(
|
||||
hs: MockHomeserver,
|
||||
starting_event: EventBase,
|
||||
resolve_point: Optional[EventBase],
|
||||
events: Collection[EventBase],
|
||||
extras: Collection[str],
|
||||
watch_func: Optional[Callable[[EventBase], Awaitable[str]]] = None,
|
||||
extras: Collection[EventBase] = (),
|
||||
) -> None:
|
||||
"""Visualise the auth DAG above a given `starting_event`.
|
||||
|
||||
@@ -123,21 +126,29 @@ async def dump_mainlines(
|
||||
suffix = await watch_func(event) if watch_func else None
|
||||
return node(event, suffix, **kwargs)
|
||||
|
||||
graph.add_node(await new_node(starting_event, fillcolor="#6699cc"))
|
||||
seen = {starting_event.event_id}
|
||||
seen = set()
|
||||
todo: List[EventBase] = []
|
||||
|
||||
todo = []
|
||||
for extra in extras:
|
||||
graph.add_node(await new_node(extra, fillcolor="#cc9966"))
|
||||
seen.add(extra.event_id)
|
||||
todo.append(extra)
|
||||
if resolve_point:
|
||||
graph.add_node(await new_node(resolve_point, fillcolor="#6699cc"))
|
||||
seen.add(resolve_point.event_id)
|
||||
|
||||
for pid in starting_event.prev_event_ids():
|
||||
parent = await hs.get_datastores().main.get_event(pid)
|
||||
for parent in events:
|
||||
graph.add_node(await new_node(parent, fillcolor="#6699cc"))
|
||||
seen.add(pid)
|
||||
graph.add_edge(edge(starting_event, parent, style="dashed"))
|
||||
seen.add(parent.event_id)
|
||||
todo.append(parent)
|
||||
if resolve_point:
|
||||
graph.add_edge(edge(resolve_point, parent, style="dashed"))
|
||||
|
||||
if extras:
|
||||
logger.debug(extras)
|
||||
extra_events = await hs.get_datastores().main.get_events(extras)
|
||||
logger.debug(extra_events)
|
||||
for extra_event in extra_events.values():
|
||||
if extra_event.event_id in seen:
|
||||
continue
|
||||
graph.add_node(await new_node(extra_event, fillcolor="#6699ee"))
|
||||
todo.append(extra_event)
|
||||
|
||||
async def fetch_auth_events(event: EventBase) -> StateMap[EventBase]:
|
||||
return {
|
||||
@@ -155,6 +166,8 @@ async def dump_mainlines(
|
||||
(("m.room.power_levels", ""), "solid"),
|
||||
(("m.room.join_rules", ""), "solid"),
|
||||
(("m.room.member", event.sender), "dotted"),
|
||||
# TODO: handle that state_key might be missing
|
||||
# (("m.room.member", event.state_key), "solid"),
|
||||
]:
|
||||
auth_event = auth_events.get(key)
|
||||
if auth_event:
|
||||
@@ -189,13 +202,30 @@ parser.add_argument(
|
||||
"config_file", help="Synapse config file", type=argparse.FileType("r")
|
||||
)
|
||||
parser.add_argument("--verbose", "-v", help="Log verbosely", action="store_true")
|
||||
parser.add_argument("-d", "--draw", help="Render auth DAG", action="store_true")
|
||||
parser.add_argument(
|
||||
"--debug", "-d", help="Enter debugger after state is resolved", action="store_true"
|
||||
"event_ids",
|
||||
help="""\
|
||||
The event ID(s) to be resolved.\
|
||||
|
||||
If a single event is given, resolve across all of its parents to compute the state
|
||||
before the given event. If multiple events are given, resolve across them directly.
|
||||
""",
|
||||
nargs="+",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--extra",
|
||||
dest="extras",
|
||||
help=(
|
||||
"An extra event to include in the auth DAG when using the `--draw` flag. "
|
||||
"Can be provided multiple times."
|
||||
),
|
||||
action="append",
|
||||
)
|
||||
parser.add_argument("event_id", help="The event ID to be resolved")
|
||||
parser.add_argument(
|
||||
"--watch",
|
||||
help="Track a piece of state in the auth DAG",
|
||||
help="Track a piece of state in the auth DAG when using the `--draw` flag.",
|
||||
default=None,
|
||||
nargs=2,
|
||||
metavar=("TYPE", "STATE_KEY"),
|
||||
@@ -213,19 +243,22 @@ async def debug_specific_stateres(
|
||||
- the recomputed and stored state, written to stdout, and
|
||||
- their difference, written to stdout.
|
||||
"""
|
||||
# Fetch the event in question.
|
||||
event = await hs.get_datastores().main.get_event(args.event_id)
|
||||
assert event is not None
|
||||
logger.info(
|
||||
"event %s has %d parents, %s",
|
||||
event.event_id,
|
||||
len(event.prev_event_ids()),
|
||||
event.prev_event_ids(),
|
||||
)
|
||||
DEBUG_AT_EVENT = len(args.event_ids) == 1
|
||||
|
||||
if DEBUG_AT_EVENT:
|
||||
resolve_point = await hs.get_datastores().main.get_event(args.event_ids[0])
|
||||
prev_event_ids = resolve_point.prev_event_ids()
|
||||
else:
|
||||
resolve_point = None
|
||||
prev_event_ids = args.event_ids
|
||||
|
||||
parent_events = (await hs.get_datastores().main.get_events(prev_event_ids)).values()
|
||||
sample_event = next(iter(parent_events))
|
||||
|
||||
logger.info("Resolving across %d parents, %s", len(prev_event_ids), prev_event_ids)
|
||||
state_after_parents = [
|
||||
await hs.get_storage_controllers().state.get_state_ids_for_event(prev_event_id)
|
||||
for prev_event_id in event.prev_event_ids()
|
||||
for prev_event_id in prev_event_ids
|
||||
]
|
||||
|
||||
if args.watch is not None:
|
||||
@@ -236,8 +269,10 @@ async def debug_specific_stateres(
|
||||
|
||||
async def watch_func(event: EventBase) -> str:
|
||||
try:
|
||||
result = await hs.get_storage_controllers().state.get_state_ids_for_event(
|
||||
event.event_id, filter
|
||||
result = (
|
||||
await hs.get_storage_controllers().state.get_state_ids_for_event(
|
||||
event.event_id, filter
|
||||
)
|
||||
)
|
||||
except RuntimeError:
|
||||
return f"\n{key_pair}: <Event unavailable :(>"
|
||||
@@ -247,37 +282,31 @@ async def debug_specific_stateres(
|
||||
else:
|
||||
watch_func = None
|
||||
|
||||
await dump_mainlines(hs, event, watch_func)
|
||||
if args.draw:
|
||||
await dump_mainlines(hs, resolve_point, parent_events, args.extras, watch_func)
|
||||
|
||||
result = await hs.get_state_resolution_handler().resolve_events_with_store(
|
||||
event.room_id,
|
||||
event.room_version.identifier,
|
||||
sample_event.room_id,
|
||||
sample_event.room_version.identifier,
|
||||
state_after_parents,
|
||||
event_map=None,
|
||||
state_res_store=StateResolutionStore(hs.get_datastores().main),
|
||||
)
|
||||
|
||||
logger.info("State resolved at %s:", event.event_id)
|
||||
logger.info("State resolved:")
|
||||
logger.info(pformat(result))
|
||||
|
||||
logger.info("Stored state at %s:", event.event_id)
|
||||
stored_state = await hs.get_storage_controllers().state.get_state_ids_for_event(
|
||||
event.event_id
|
||||
)
|
||||
logger.info(pformat(stored_state))
|
||||
|
||||
# TODO make this a like-for-like comparison.
|
||||
logger.info("Diff from stored (after event) to resolved (before event):")
|
||||
for change in dictdiffer.diff(stored_state, result):
|
||||
logger.info(pformat(change))
|
||||
|
||||
if args.debug:
|
||||
print(
|
||||
f"see `state_after_parents[i]` for 0 <= i < {len(state_after_parents)}"
|
||||
" and `result`",
|
||||
file=sys.stderr,
|
||||
if DEBUG_AT_EVENT:
|
||||
logger.info("Stored state at %s:", sample_event.event_id)
|
||||
stored_state = await hs.get_storage_controllers().state.get_state_ids_for_event(
|
||||
sample_event.event_id
|
||||
)
|
||||
breakpoint()
|
||||
logger.info(pformat(stored_state))
|
||||
|
||||
# TODO make this a like-for-like comparison.
|
||||
logger.info("Diff from stored (after event) to resolved (before event):")
|
||||
for change in dictdiffer.diff(stored_state, result):
|
||||
logger.info(pformat(change))
|
||||
|
||||
|
||||
# Entrypoint.
|
||||
@@ -288,7 +317,7 @@ if __name__ == "__main__":
|
||||
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||
stream=sys.stdout,
|
||||
)
|
||||
# Suppress logs weren't not interested in.
|
||||
# Suppress logs we aren't interested in.
|
||||
logging.getLogger("synapse.util").setLevel(logging.ERROR)
|
||||
logging.getLogger("synapse.storage").setLevel(logging.ERROR)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user