1
0

Always try to invite the user when creating a server notice room

This commit is contained in:
Brendan Abolivier
2022-05-11 11:29:48 +01:00
parent 84facf769e
commit 99496fcf3d
2 changed files with 47 additions and 7 deletions
@@ -66,7 +66,6 @@ class ServerNoticesManager:
txn_id: The transaction ID.
"""
room_id = await self.get_or_create_notice_room_for_user(user_id)
await self.maybe_invite_user_to_room(user_id, room_id)
assert self.server_notices_mxid is not None
requester = create_requester(
@@ -90,13 +89,33 @@ class ServerNoticesManager:
)
return event
@cached()
async def get_or_create_notice_room_for_user(self, user_id: str) -> str:
"""Get the room for notices for a given user
If we have not yet created a notice room for this user, create it, but don't
invite the user to it.
Also checks if the user needs to be invited into the room, and invites them if
necessary.
Args:
user_id: complete user id for the user we want a room for
Returns:
room id of notice room.
"""
room_id = await self._get_or_create_notice_room_for_user(user_id)
await self.maybe_invite_user_to_room(user_id, room_id)
return room_id
@cached()
async def _get_or_create_notice_room_for_user(self, user_id: str) -> str:
"""Get the room for notices for a given user
If we have not yet created a notice room for this user, create it, but don't
invite the user to it.
Args:
user_id: complete user id for the user we want a room for
+26 -5
View File
@@ -214,7 +214,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
self.assertEqual(messages[0]["sender"], "@notices:test")
# invalidate cache of server notices room_ids
self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
self.server_notices_manager._get_or_create_notice_room_for_user.invalidate_all()
# send second message
channel = self.make_request(
@@ -289,7 +289,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
# invalidate cache of server notices room_ids
# if server tries to send to a cached room_id the user gets the message
# in old room
self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
self.server_notices_manager._get_or_create_notice_room_for_user.invalidate_all()
# send second message
channel = self.make_request(
@@ -376,7 +376,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
# invalidate cache of server notices room_ids
# if server tries to send to a cached room_id it gives an error
self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
self.server_notices_manager._get_or_create_notice_room_for_user.invalidate_all()
# send second message
channel = self.make_request(
@@ -432,7 +432,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
self.server_notices_manager._config.servernotices.server_notices_mxid_display_name = (
new_display_name
)
self.server_notices_manager.get_or_create_notice_room_for_user.cache.invalidate_all()
self.server_notices_manager._get_or_create_notice_room_for_user.invalidate_all()
self.make_request(
"POST",
@@ -478,7 +478,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
self.server_notices_manager._config.servernotices.server_notices_mxid_avatar_url = (
new_avatar_url
)
self.server_notices_manager.get_or_create_notice_room_for_user.cache.invalidate_all()
self.server_notices_manager._get_or_create_notice_room_for_user.invalidate_all()
self.make_request(
"POST",
@@ -501,6 +501,27 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(notice_user_state["avatar_url"], new_avatar_url)
@override_config({"server_notices": {"system_mxid_localpart": "notices"}})
def test_always_invite(self) -> None:
"""Tests that calling get_or_create_notice_room_for_user always end up inviting
the user if necessary.
"""
room_id = self.get_success(
self.server_notices_manager.get_or_create_notice_room_for_user(
self.other_user,
)
)
invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
invited = False
for invited_room in invited_rooms:
if invited_room.room_id == room_id:
invited = True
break
self.assertTrue(invited)
def _check_invite_and_join_status(
self, user_id: str, expected_invites: int, expected_memberships: int
) -> List[RoomsForUser]: