1
0

Make StateFilter frozen

This commit is contained in:
Olivier Wilkinson (reivilibre)
2021-08-04 15:01:17 +01:00
parent 15db8b7c7f
commit 0c1f2156c8
2 changed files with 43 additions and 27 deletions

View File

@@ -25,6 +25,7 @@ from typing import (
)
import attr
from frozendict import frozendict
from synapse.api.constants import EventTypes
from synapse.events import EventBase
@@ -40,7 +41,7 @@ logger = logging.getLogger(__name__)
T = TypeVar("T")
@attr.s(slots=True)
@attr.s(slots=True, frozen=True)
class StateFilter:
"""A filter used when querying for state.
@@ -53,14 +54,16 @@ class StateFilter:
appear in `types`.
"""
types = attr.ib(type=Dict[str, Optional[Set[str]]])
types = attr.ib(type=frozendict[str, Optional[Set[str]]])
include_others = attr.ib(default=False, type=bool)
def __attrs_post_init__(self):
# If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary
if self.include_others:
self.types = {k: v for k, v in self.types.items() if v is not None}
self.types = frozendict(
{k: v for k, v in self.types.items() if v is not None}
)
@staticmethod
def all() -> "StateFilter":
@@ -69,7 +72,7 @@ class StateFilter:
Returns:
The new state filter.
"""
return StateFilter(types={}, include_others=True)
return StateFilter(types=frozendict(), include_others=True)
@staticmethod
def none() -> "StateFilter":
@@ -78,7 +81,7 @@ class StateFilter:
Returns:
The new state filter.
"""
return StateFilter(types={}, include_others=False)
return StateFilter(types=frozendict(), include_others=False)
@staticmethod
def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
@@ -103,7 +106,7 @@ class StateFilter:
type_dict.setdefault(typ, set()).add(s) # type: ignore
return StateFilter(types=type_dict)
return StateFilter(types=frozendict(type_dict))
@staticmethod
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
@@ -116,7 +119,9 @@ class StateFilter:
Returns:
The new state filter
"""
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
return StateFilter(
types=frozendict({EventTypes.Member: set(members)}), include_others=True
)
def return_expanded(self) -> "StateFilter":
"""Creates a new StateFilter where type wild cards have been removed
@@ -173,7 +178,7 @@ class StateFilter:
# We want to return all non-members, but only particular
# memberships
return StateFilter(
types={EventTypes.Member: self.types[EventTypes.Member]},
types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
include_others=True,
)
@@ -324,14 +329,16 @@ class StateFilter:
if state_keys is None:
member_filter = StateFilter.all()
else:
member_filter = StateFilter({EventTypes.Member: state_keys})
member_filter = StateFilter(frozendict({EventTypes.Member: state_keys}))
elif self.include_others:
member_filter = StateFilter.all()
else:
member_filter = StateFilter.none()
non_member_filter = StateFilter(
types={k: v for k, v in self.types.items() if k != EventTypes.Member},
types=frozendict(
{k: v for k, v in self.types.items() if k != EventTypes.Member}
),
include_others=self.include_others,
)

View File

@@ -14,6 +14,8 @@
import logging
from frozendict import frozendict
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.storage.state import StateFilter
@@ -183,7 +185,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: {self.u_alice.to_string()}},
types=frozendict({EventTypes.Member: {self.u_alice.to_string()}}),
include_others=True,
),
)
@@ -203,7 +205,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
types=frozendict({EventTypes.Member: set()}), include_others=True
),
)
)
@@ -228,7 +230,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
types=frozendict({EventTypes.Member: set()}), include_others=True
),
)
@@ -245,7 +247,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
types=frozendict({EventTypes.Member: set()}), include_others=True
),
)
@@ -258,7 +260,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
types=frozendict({EventTypes.Member: None}), include_others=True
),
)
@@ -275,7 +277,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
types=frozendict({EventTypes.Member: None}), include_others=True
),
)
@@ -295,7 +297,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
types=frozendict({EventTypes.Member: {e5.state_key}}),
include_others=True,
),
)
@@ -312,7 +315,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
types=frozendict({EventTypes.Member: {e5.state_key}}),
include_others=True,
),
)
@@ -325,7 +329,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False
types=frozendict({EventTypes.Member: {e5.state_key}}),
include_others=False,
),
)
@@ -375,7 +380,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
types=frozendict({EventTypes.Member: set()}), include_others=True
),
)
@@ -387,7 +392,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
types=frozendict({EventTypes.Member: set()}), include_others=True
),
)
@@ -400,7 +405,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
types=frozendict({EventTypes.Member: None}), include_others=True
),
)
@@ -411,7 +416,7 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: None}, include_others=True
types=frozendict({EventTypes.Member: None}), include_others=True
),
)
@@ -430,7 +435,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
types=frozendict({EventTypes.Member: {e5.state_key}}),
include_others=True,
),
)
@@ -441,7 +447,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=True
types=frozendict({EventTypes.Member: {e5.state_key}}),
include_others=True,
),
)
@@ -454,7 +461,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False
types=frozendict({EventTypes.Member: {e5.state_key}}),
include_others=False,
),
)
@@ -465,7 +473,8 @@ class StateStoreTestCase(HomeserverTestCase):
self.state_datastore._state_group_members_cache,
group,
state_filter=StateFilter(
types={EventTypes.Member: {e5.state_key}}, include_others=False
types=frozendict({EventTypes.Member: {e5.state_key}}),
include_others=False,
),
)