Rename storage classes (#12913)
This commit is contained in:
@@ -29,7 +29,7 @@ class TestEventContext(unittest.HomeserverTestCase):
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.store = hs.get_datastores().main
|
||||
self.storage = hs.get_storage()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
self.user_id = self.register_user("u1", "pass")
|
||||
self.user_tok = self.login("u1", "pass")
|
||||
@@ -87,7 +87,7 @@ class TestEventContext(unittest.HomeserverTestCase):
|
||||
def _check_serialize_deserialize(self, event, context):
|
||||
serialized = self.get_success(context.serialize(event, self.store))
|
||||
|
||||
d_context = EventContext.deserialize(self.storage, serialized)
|
||||
d_context = EventContext.deserialize(self._storage_controllers, serialized)
|
||||
|
||||
self.assertEqual(context.state_group, d_context.state_group)
|
||||
self.assertEqual(context.rejected, d_context.rejected)
|
||||
|
||||
@@ -50,7 +50,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
|
||||
hs = self.setup_test_homeserver(federation_http_client=None)
|
||||
self.handler = hs.get_federation_handler()
|
||||
self.store = hs.get_datastores().main
|
||||
self.state_storage = hs.get_storage().state
|
||||
self.state_storage_controller = hs.get_storage_controllers().state
|
||||
self._event_auth_handler = hs.get_event_auth_handler()
|
||||
return hs
|
||||
|
||||
@@ -338,7 +338,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
|
||||
# mapping from (type, state_key) -> state_event_id
|
||||
assert most_recent_prev_event_id is not None
|
||||
prev_state_map = self.get_success(
|
||||
self.state_storage.get_state_ids_for_event(most_recent_prev_event_id)
|
||||
self.state_storage_controller.get_state_ids_for_event(
|
||||
most_recent_prev_event_id
|
||||
)
|
||||
)
|
||||
# List of state event ID's
|
||||
prev_state_ids = list(prev_state_map.values())
|
||||
|
||||
@@ -70,7 +70,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
|
||||
) -> None:
|
||||
OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
|
||||
main_store = self.hs.get_datastores().main
|
||||
state_storage = self.hs.get_storage().state
|
||||
state_storage_controller = self.hs.get_storage_controllers().state
|
||||
|
||||
# create the room
|
||||
user_id = self.register_user("kermit", "test")
|
||||
@@ -146,10 +146,11 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
|
||||
)
|
||||
if prev_exists_as_outlier:
|
||||
prev_event.internal_metadata.outlier = True
|
||||
persistence = self.hs.get_storage().persistence
|
||||
persistence = self.hs.get_storage_controllers().persistence
|
||||
self.get_success(
|
||||
persistence.persist_event(
|
||||
prev_event, EventContext.for_outlier(self.hs.get_storage())
|
||||
prev_event,
|
||||
EventContext.for_outlier(self.hs.get_storage_controllers()),
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -216,7 +217,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
|
||||
|
||||
# check that the state at that event is as expected
|
||||
state = self.get_success(
|
||||
state_storage.get_state_ids_for_event(pulled_event.event_id)
|
||||
state_storage_controller.get_state_ids_for_event(pulled_event.event_id)
|
||||
)
|
||||
expected_state = {
|
||||
(e.type, e.state_key): e.event_id for e in state_at_prev_event
|
||||
|
||||
@@ -37,7 +37,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.handler = self.hs.get_event_creation_handler()
|
||||
self.persist_event_storage = self.hs.get_storage().persistence
|
||||
self._persist_event_storage_controller = (
|
||||
self.hs.get_storage_controllers().persistence
|
||||
)
|
||||
|
||||
self.user_id = self.register_user("tester", "foobar")
|
||||
self.access_token = self.login("tester", "foobar")
|
||||
@@ -65,7 +67,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
self.get_success(
|
||||
self.persist_event_storage.persist_event(memberEvent, memberEventContext)
|
||||
self._persist_event_storage_controller.persist_event(
|
||||
memberEvent, memberEventContext
|
||||
)
|
||||
)
|
||||
|
||||
return memberEvent, memberEventContext
|
||||
@@ -129,7 +133,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||
self.assertNotEqual(event1.event_id, event3.event_id)
|
||||
|
||||
ret_event3, event_pos3, _ = self.get_success(
|
||||
self.persist_event_storage.persist_event(event3, context)
|
||||
self._persist_event_storage_controller.persist_event(event3, context)
|
||||
)
|
||||
|
||||
# Assert that the returned values match those from the initial event
|
||||
@@ -143,7 +147,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||
self.assertNotEqual(event1.event_id, event3.event_id)
|
||||
|
||||
events, _ = self.get_success(
|
||||
self.persist_event_storage.persist_events([(event3, context)])
|
||||
self._persist_event_storage_controller.persist_events([(event3, context)])
|
||||
)
|
||||
ret_event4 = events[0]
|
||||
|
||||
@@ -166,7 +170,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||
self.assertNotEqual(event1.event_id, event2.event_id)
|
||||
|
||||
events, _ = self.get_success(
|
||||
self.persist_event_storage.persist_events(
|
||||
self._persist_event_storage_controller.persist_events(
|
||||
[(event1, context1), (event2, context2)]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -954,7 +954,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_storage().persistence.persist_event(event, context)
|
||||
self.hs.get_storage_controllers().persistence.persist_event(event, context)
|
||||
)
|
||||
|
||||
def test_local_user_leaving_room_remains_in_user_directory(self) -> None:
|
||||
|
||||
@@ -32,7 +32,7 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase):
|
||||
|
||||
self.master_store = hs.get_datastores().main
|
||||
self.slaved_store = self.worker_hs.get_datastores().main
|
||||
self.storage = hs.get_storage()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
def replicate(self):
|
||||
"""Tell the master side of replication that something has happened, and then
|
||||
|
||||
@@ -262,7 +262,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
)
|
||||
msg, msgctx = self.build_event()
|
||||
self.get_success(
|
||||
self.storage.persistence.persist_events([(j2, j2ctx), (msg, msgctx)])
|
||||
self._storage_controllers.persistence.persist_events(
|
||||
[(j2, j2ctx), (msg, msgctx)]
|
||||
)
|
||||
)
|
||||
self.replicate()
|
||||
|
||||
@@ -323,12 +325,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
|
||||
if backfill:
|
||||
self.get_success(
|
||||
self.storage.persistence.persist_events(
|
||||
self._storage_controllers.persistence.persist_events(
|
||||
[(event, context)], backfilled=True
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.get_success(self.storage.persistence.persist_event(event, context))
|
||||
self.get_success(
|
||||
self._storage_controllers.persistence.persist_event(event, context)
|
||||
)
|
||||
|
||||
return event
|
||||
|
||||
|
||||
@@ -31,7 +31,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
super().prepare(reactor, clock, homeserver)
|
||||
self.room_creator = homeserver.get_room_creation_handler()
|
||||
self.persist_event_storage = self.hs.get_storage().persistence
|
||||
self.persist_event_storage_controller = (
|
||||
self.hs.get_storage_controllers().persistence
|
||||
)
|
||||
|
||||
# Create a test user
|
||||
self.ourUser = UserID.from_string(OUR_USER_ID)
|
||||
@@ -61,7 +63,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
|
||||
)
|
||||
)
|
||||
self.get_success(
|
||||
self.persist_event_storage.persist_event(memberEvent, memberEventContext)
|
||||
self.persist_event_storage_controller.persist_event(
|
||||
memberEvent, memberEventContext
|
||||
)
|
||||
)
|
||||
|
||||
# Join the second user to the second room
|
||||
@@ -76,7 +80,9 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
|
||||
)
|
||||
)
|
||||
self.get_success(
|
||||
self.persist_event_storage.persist_event(memberEvent, memberEventContext)
|
||||
self.persist_event_storage_controller.persist_event(
|
||||
memberEvent, memberEventContext
|
||||
)
|
||||
)
|
||||
|
||||
def test_return_empty_with_no_data(self):
|
||||
|
||||
@@ -2579,7 +2579,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
|
||||
other_user_tok = self.login("user", "pass")
|
||||
event_builder_factory = self.hs.get_event_builder_factory()
|
||||
event_creation_handler = self.hs.get_event_creation_handler()
|
||||
storage = self.hs.get_storage()
|
||||
storage_controllers = self.hs.get_storage_controllers()
|
||||
|
||||
# Create two rooms, one with a local user only and one with both a local
|
||||
# and remote user.
|
||||
@@ -2604,7 +2604,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
|
||||
event_creation_handler.create_new_client_event(builder)
|
||||
)
|
||||
|
||||
self.get_success(storage.persistence.persist_event(event, context))
|
||||
self.get_success(storage_controllers.persistence.persist_event(event, context))
|
||||
|
||||
# Now get rooms
|
||||
url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
|
||||
|
||||
@@ -130,7 +130,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||
We do this by setting a very long time between purge jobs.
|
||||
"""
|
||||
store = self.hs.get_datastores().main
|
||||
storage = self.hs.get_storage()
|
||||
storage_controllers = self.hs.get_storage_controllers()
|
||||
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
|
||||
|
||||
# Send a first event, which should be filtered out at the end of the test.
|
||||
@@ -155,7 +155,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(2, len(events), "events retrieved from database")
|
||||
filtered_events = self.get_success(
|
||||
filter_events_for_client(storage, self.user_id, events)
|
||||
filter_events_for_client(storage_controllers, self.user_id, events)
|
||||
)
|
||||
|
||||
# We should only get one event back.
|
||||
|
||||
@@ -88,7 +88,7 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.clock = clock
|
||||
self.storage = hs.get_storage()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
self.virtual_user_id, _ = self.register_appservice_user(
|
||||
"as_user_potato", self.appservice.token
|
||||
@@ -168,7 +168,9 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# Fetch the state_groups
|
||||
state_group_map = self.get_success(
|
||||
self.storage.state.get_state_groups_ids(room_id, historical_event_ids)
|
||||
self._storage_controllers.state.get_state_groups_ids(
|
||||
room_id, historical_event_ids
|
||||
)
|
||||
)
|
||||
|
||||
# We expect all of the historical events to be using the same state_group
|
||||
|
||||
@@ -393,7 +393,8 @@ class EventChainStoreTestCase(HomeserverTestCase):
|
||||
# We need to persist the events to the events and state_events
|
||||
# tables.
|
||||
persist_events_store._store_event_txn(
|
||||
txn, [(e, EventContext(self.hs.get_storage())) for e in events]
|
||||
txn,
|
||||
[(e, EventContext(self.hs.get_storage_controllers())) for e in events],
|
||||
)
|
||||
|
||||
# Actually call the function that calculates the auth chain stuff.
|
||||
|
||||
@@ -31,7 +31,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
self.state = self.hs.get_state_handler()
|
||||
self.persistence = self.hs.get_storage().persistence
|
||||
self._persistence = self.hs.get_storage_controllers().persistence
|
||||
self.store = self.hs.get_datastores().main
|
||||
|
||||
self.register_user("user", "pass")
|
||||
@@ -71,7 +71,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
context = self.get_success(
|
||||
self.state.compute_event_context(event, state_ids_before_event=state)
|
||||
)
|
||||
self.get_success(self.persistence.persist_event(event, context))
|
||||
self.get_success(self._persistence.persist_event(event, context))
|
||||
|
||||
def assert_extremities(self, expected_extremities):
|
||||
"""Assert the current extremities for the room"""
|
||||
@@ -148,7 +148,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
self.get_success(self.persistence.persist_event(remote_event_2, context))
|
||||
self.get_success(self._persistence.persist_event(remote_event_2, context))
|
||||
|
||||
# Check that we haven't dropped the old extremity.
|
||||
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
|
||||
@@ -353,7 +353,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
self.state = self.hs.get_state_handler()
|
||||
self.persistence = self.hs.get_storage().persistence
|
||||
self._persistence = self.hs.get_storage_controllers().persistence
|
||||
self.store = self.hs.get_datastores().main
|
||||
|
||||
def test_remote_user_rooms_cache_invalidated(self):
|
||||
@@ -390,7 +390,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
|
||||
)
|
||||
|
||||
context = self.get_success(self.state.compute_event_context(remote_event_1))
|
||||
self.get_success(self.persistence.persist_event(remote_event_1, context))
|
||||
self.get_success(self._persistence.persist_event(remote_event_1, context))
|
||||
|
||||
# Call `get_rooms_for_user` to add the remote user to the cache
|
||||
rooms = self.get_success(self.store.get_rooms_for_user(remote_user))
|
||||
@@ -437,7 +437,7 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
|
||||
)
|
||||
|
||||
context = self.get_success(self.state.compute_event_context(remote_event_1))
|
||||
self.get_success(self.persistence.persist_event(remote_event_1, context))
|
||||
self.get_success(self._persistence.persist_event(remote_event_1, context))
|
||||
|
||||
# Call `get_users_in_room` to add the remote user to the cache
|
||||
users = self.get_success(self.store.get_users_in_room(room_id))
|
||||
|
||||
@@ -31,7 +31,7 @@ class PurgeTests(HomeserverTestCase):
|
||||
self.room_id = self.helper.create_room_as(self.user_id)
|
||||
|
||||
self.store = hs.get_datastores().main
|
||||
self.storage = self.hs.get_storage()
|
||||
self._storage_controllers = self.hs.get_storage_controllers()
|
||||
|
||||
def test_purge_history(self):
|
||||
"""
|
||||
@@ -51,7 +51,9 @@ class PurgeTests(HomeserverTestCase):
|
||||
|
||||
# Purge everything before this topological token
|
||||
self.get_success(
|
||||
self.storage.purge_events.purge_history(self.room_id, token_str, True)
|
||||
self._storage_controllers.purge_events.purge_history(
|
||||
self.room_id, token_str, True
|
||||
)
|
||||
)
|
||||
|
||||
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
|
||||
@@ -79,7 +81,9 @@ class PurgeTests(HomeserverTestCase):
|
||||
|
||||
# Purge everything before this topological token
|
||||
f = self.get_failure(
|
||||
self.storage.purge_events.purge_history(self.room_id, event, True),
|
||||
self._storage_controllers.purge_events.purge_history(
|
||||
self.room_id, event, True
|
||||
),
|
||||
SynapseError,
|
||||
)
|
||||
self.assertIn("greater than forward", f.value.args[0])
|
||||
@@ -105,7 +109,9 @@ class PurgeTests(HomeserverTestCase):
|
||||
self.assertIsNotNone(create_event)
|
||||
|
||||
# Purge everything before this topological token
|
||||
self.get_success(self.storage.purge_events.purge_room(self.room_id))
|
||||
self.get_success(
|
||||
self._storage_controllers.purge_events.purge_room(self.room_id)
|
||||
)
|
||||
|
||||
# The events aren't found.
|
||||
self.store._invalidate_get_event_cache(create_event.event_id)
|
||||
|
||||
@@ -31,7 +31,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.store = hs.get_datastores().main
|
||||
self.storage = hs.get_storage()
|
||||
self._storage = hs.get_storage_controllers()
|
||||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
|
||||
@@ -71,7 +71,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||
self.event_creation_handler.create_new_client_event(builder)
|
||||
)
|
||||
|
||||
self.get_success(self.storage.persistence.persist_event(event, context))
|
||||
self.get_success(self._storage.persistence.persist_event(event, context))
|
||||
|
||||
return event
|
||||
|
||||
@@ -93,7 +93,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||
self.event_creation_handler.create_new_client_event(builder)
|
||||
)
|
||||
|
||||
self.get_success(self.storage.persistence.persist_event(event, context))
|
||||
self.get_success(self._storage.persistence.persist_event(event, context))
|
||||
|
||||
return event
|
||||
|
||||
@@ -114,7 +114,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||
self.event_creation_handler.create_new_client_event(builder)
|
||||
)
|
||||
|
||||
self.get_success(self.storage.persistence.persist_event(event, context))
|
||||
self.get_success(self._storage.persistence.persist_event(event, context))
|
||||
|
||||
return event
|
||||
|
||||
@@ -268,7 +268,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
self.get_success(self.storage.persistence.persist_event(event_1, context_1))
|
||||
self.get_success(self._storage.persistence.persist_event(event_1, context_1))
|
||||
|
||||
event_2, context_2 = self.get_success(
|
||||
self.event_creation_handler.create_new_client_event(
|
||||
@@ -287,7 +287,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
)
|
||||
self.get_success(self.storage.persistence.persist_event(event_2, context_2))
|
||||
self.get_success(self._storage.persistence.persist_event(event_2, context_2))
|
||||
|
||||
# fetch one of the redactions
|
||||
fetched = self.get_success(self.store.get_event(redaction_event_id1))
|
||||
@@ -411,7 +411,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.get_success(
|
||||
self.storage.persistence.persist_event(redaction_event, context)
|
||||
self._storage.persistence.persist_event(redaction_event, context)
|
||||
)
|
||||
|
||||
# Now lets jump to the future where we have censored the redaction event
|
||||
|
||||
@@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
|
||||
# Room events need the full datastore, for persist_event() and
|
||||
# get_room_state()
|
||||
self.store = hs.get_datastores().main
|
||||
self.storage = hs.get_storage()
|
||||
self._storage = hs.get_storage_controllers()
|
||||
self.event_factory = hs.get_event_factory()
|
||||
|
||||
self.room = RoomID.from_string("!abcde:test")
|
||||
@@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
|
||||
|
||||
def inject_room_event(self, **kwargs):
|
||||
self.get_success(
|
||||
self.storage.persistence.persist_event(
|
||||
self._storage.persistence.persist_event(
|
||||
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -99,7 +99,9 @@ class EventSearchInsertionTest(HomeserverTestCase):
|
||||
prev_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
|
||||
prev_event = self.get_success(store.get_event(prev_event_ids[0]))
|
||||
prev_state_map = self.get_success(
|
||||
self.hs.get_storage().state.get_state_ids_for_event(prev_event_ids[0])
|
||||
self.hs.get_storage_controllers().state.get_state_ids_for_event(
|
||||
prev_event_ids[0]
|
||||
)
|
||||
)
|
||||
|
||||
event_dict = {
|
||||
|
||||
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
|
||||
class StateStoreTestCase(HomeserverTestCase):
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.store = hs.get_datastores().main
|
||||
self.storage = hs.get_storage()
|
||||
self.storage = hs.get_storage_controllers()
|
||||
self.state_datastore = self.storage.state.stores.state
|
||||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
|
||||
@@ -179,12 +179,12 @@ class Graph:
|
||||
class StateTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.dummy_store = _DummyStore()
|
||||
storage = Mock(main=self.dummy_store, state=self.dummy_store)
|
||||
storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store)
|
||||
hs = Mock(
|
||||
spec_set=[
|
||||
"config",
|
||||
"get_datastores",
|
||||
"get_storage",
|
||||
"get_storage_controllers",
|
||||
"get_auth",
|
||||
"get_state_handler",
|
||||
"get_clock",
|
||||
@@ -199,7 +199,7 @@ class StateTestCase(unittest.TestCase):
|
||||
hs.get_clock.return_value = MockClock()
|
||||
hs.get_auth.return_value = Auth(hs)
|
||||
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
|
||||
hs.get_storage.return_value = storage
|
||||
hs.get_storage_controllers.return_value = storage_controllers
|
||||
|
||||
self.state = StateHandler(hs)
|
||||
self.event_id = 0
|
||||
|
||||
@@ -70,7 +70,7 @@ async def inject_event(
|
||||
"""
|
||||
event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
|
||||
|
||||
persistence = hs.get_storage().persistence
|
||||
persistence = hs.get_storage_controllers().persistence
|
||||
assert persistence is not None
|
||||
|
||||
await persistence.persist_event(event, context)
|
||||
|
||||
@@ -34,7 +34,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||
super(FilterEventsForServerTestCase, self).setUp()
|
||||
self.event_creation_handler = self.hs.get_event_creation_handler()
|
||||
self.event_builder_factory = self.hs.get_event_builder_factory()
|
||||
self.storage = self.hs.get_storage()
|
||||
self._storage_controllers = self.hs.get_storage_controllers()
|
||||
|
||||
self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
|
||||
|
||||
@@ -60,7 +60,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||
events_to_filter.append(evt)
|
||||
|
||||
filtered = self.get_success(
|
||||
filter_events_for_server(self.storage, "test_server", events_to_filter)
|
||||
filter_events_for_server(
|
||||
self._storage_controllers, "test_server", events_to_filter
|
||||
)
|
||||
)
|
||||
|
||||
# the result should be 5 redacted events, and 5 unredacted events.
|
||||
@@ -80,7 +82,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||
outlier = self._inject_outlier()
|
||||
self.assertEqual(
|
||||
self.get_success(
|
||||
filter_events_for_server(self.storage, "remote_hs", [outlier])
|
||||
filter_events_for_server(
|
||||
self._storage_controllers, "remote_hs", [outlier]
|
||||
)
|
||||
),
|
||||
[outlier],
|
||||
)
|
||||
@@ -89,7 +93,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||
evt = self._inject_message("@unerased:local_hs")
|
||||
|
||||
filtered = self.get_success(
|
||||
filter_events_for_server(self.storage, "remote_hs", [outlier, evt])
|
||||
filter_events_for_server(
|
||||
self._storage_controllers, "remote_hs", [outlier, evt]
|
||||
)
|
||||
)
|
||||
self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
|
||||
self.assertEqual(filtered[0], outlier)
|
||||
@@ -99,7 +105,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||
# ... but other servers should only be able to see the outlier (the other should
|
||||
# be redacted)
|
||||
filtered = self.get_success(
|
||||
filter_events_for_server(self.storage, "other_server", [outlier, evt])
|
||||
filter_events_for_server(
|
||||
self._storage_controllers, "other_server", [outlier, evt]
|
||||
)
|
||||
)
|
||||
self.assertEqual(filtered[0], outlier)
|
||||
self.assertEqual(filtered[1].event_id, evt.event_id)
|
||||
@@ -132,7 +140,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# ... and the filtering happens.
|
||||
filtered = self.get_success(
|
||||
filter_events_for_server(self.storage, "test_server", events_to_filter)
|
||||
filter_events_for_server(
|
||||
self._storage_controllers, "test_server", events_to_filter
|
||||
)
|
||||
)
|
||||
|
||||
for i in range(0, len(events_to_filter)):
|
||||
@@ -168,7 +178,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||
event, context = self.get_success(
|
||||
self.event_creation_handler.create_new_client_event(builder)
|
||||
)
|
||||
self.get_success(self.storage.persistence.persist_event(event, context))
|
||||
self.get_success(
|
||||
self._storage_controllers.persistence.persist_event(event, context)
|
||||
)
|
||||
return event
|
||||
|
||||
def _inject_room_member(
|
||||
@@ -194,7 +206,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||
self.event_creation_handler.create_new_client_event(builder)
|
||||
)
|
||||
|
||||
self.get_success(self.storage.persistence.persist_event(event, context))
|
||||
self.get_success(
|
||||
self._storage_controllers.persistence.persist_event(event, context)
|
||||
)
|
||||
return event
|
||||
|
||||
def _inject_message(
|
||||
@@ -216,7 +230,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||
self.event_creation_handler.create_new_client_event(builder)
|
||||
)
|
||||
|
||||
self.get_success(self.storage.persistence.persist_event(event, context))
|
||||
self.get_success(
|
||||
self._storage_controllers.persistence.persist_event(event, context)
|
||||
)
|
||||
return event
|
||||
|
||||
def _inject_outlier(self) -> EventBase:
|
||||
@@ -234,8 +250,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||
event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
|
||||
event.internal_metadata.outlier = True
|
||||
self.get_success(
|
||||
self.storage.persistence.persist_event(
|
||||
event, EventContext.for_outlier(self.storage)
|
||||
self._storage_controllers.persistence.persist_event(
|
||||
event, EventContext.for_outlier(self._storage_controllers)
|
||||
)
|
||||
)
|
||||
return event
|
||||
@@ -293,7 +309,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
|
||||
self.assertEqual(
|
||||
self.get_success(
|
||||
filter_events_for_client(
|
||||
self.hs.get_storage(), "@user:test", [invite_event, reject_event]
|
||||
self.hs.get_storage_controllers(),
|
||||
"@user:test",
|
||||
[invite_event, reject_event],
|
||||
)
|
||||
),
|
||||
[invite_event, reject_event],
|
||||
@@ -303,7 +321,9 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
|
||||
self.assertEqual(
|
||||
self.get_success(
|
||||
filter_events_for_client(
|
||||
self.hs.get_storage(), "@other:test", [invite_event, reject_event]
|
||||
self.hs.get_storage_controllers(),
|
||||
"@other:test",
|
||||
[invite_event, reject_event],
|
||||
)
|
||||
),
|
||||
[],
|
||||
|
||||
@@ -264,7 +264,7 @@ class MockClock:
|
||||
async def create_room(hs, room_id: str, creator_id: str):
|
||||
"""Creates and persist a creation event for the given room"""
|
||||
|
||||
persistence_store = hs.get_storage().persistence
|
||||
persistence_store = hs.get_storage_controllers().persistence
|
||||
store = hs.get_datastores().main
|
||||
event_builder_factory = hs.get_event_builder_factory()
|
||||
event_creation_handler = hs.get_event_creation_handler()
|
||||
|
||||
Reference in New Issue
Block a user