1
0

Leave the room if the client calls /forget

This commit is contained in:
Andrew Morgan
2022-08-10 19:20:31 +01:00
parent 51c01d450a
commit 36a6a93df7
4 changed files with 15 additions and 8 deletions

View File

@@ -1812,7 +1812,7 @@ class RoomShutdownHandler:
stream_id,
)
await self.room_member_handler.forget(target_requester.user, room_id)
await self.room_member_handler.forget(target_requester, room_id)
# Join users to new room
if new_room_user_id:

View File

@@ -261,7 +261,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
async def forget(self, user: UserID, room_id: str) -> None:
async def forget(self, requester: Requester, room_id: str) -> None:
raise NotImplementedError()
async def ratelimit_multiple_invites(
@@ -1909,19 +1909,25 @@ class RoomMemberMasterHandler(RoomMemberHandler):
"""Implements RoomMemberHandler._user_left_room"""
user_left_room(self.distributor, target, room_id)
async def forget(self, user: UserID, room_id: str) -> None:
user_id = user.to_string()
async def forget(self, requester: Requester, room_id: str) -> None:
user_id = requester.user.to_string()
member = await self._storage_controllers.state.get_current_state_event(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
membership = member.membership if member else None
if membership is not None and membership not in [
if membership and membership not in [
Membership.LEAVE,
Membership.BAN,
]:
raise SynapseError(400, "User %s in room %s" % (user_id, room_id))
# Have the user leave the room.
await self.update_membership(
requester=requester,
target=requester.user,
room_id=room_id,
action=Membership.LEAVE,
)
if membership:
await self.store.forget(user_id, room_id)

View File

@@ -137,5 +137,5 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
user_id=target.to_string(), room_id=room_id, change="left"
)
async def forget(self, target: UserID, room_id: str) -> None:
async def forget(self, target: Requester, room_id: str) -> None:
raise RuntimeError("Cannot forget rooms on workers.")

View File

@@ -799,6 +799,7 @@ class RoomForgetRestServlet(TransactionRestServlet):
super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
def register(self, http_server: HttpServer) -> None:
PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
@@ -809,7 +810,7 @@ class RoomForgetRestServlet(TransactionRestServlet):
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)
await self.room_member_handler.forget(user=requester.user, room_id=room_id)
await self.room_member_handler.forget(requester=requester, room_id=room_id)
return 200, {}