From 13da1dca0a2c1687fd8da2207dbf32c6c4baebf7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Sat, 22 Mar 2025 18:49:47 +0000 Subject: [PATCH] Join policy server PoC Adds the ability to enforce users to go through a web flow before being able to join a room (unless invited). Configured by specifying a "policy server" in the join rule event content: ``` { "join_rule": "invite", "re.jki.join_policy_server": "localhost:8865" } ``` The server will then return a 403 when a client tries to join, including a URL that the client can redirect the user to, which eventually returns a token (very much like an OAuth2 flow). This token then can be included when calling `/join` again and the join will be successful. --- policy_server.py | 114 ++++++++++++++++++++++++ synapse/event_auth.py | 12 +++ synapse/federation/federation_client.py | 37 ++++++++ synapse/federation/transport/client.py | 40 +++++++++ synapse/handlers/room_member.py | 48 +++++++++- synapse/rest/client/room.py | 7 ++ 6 files changed, 257 insertions(+), 1 deletion(-) create mode 100644 policy_server.py diff --git a/policy_server.py b/policy_server.py new file mode 100644 index 0000000000..db4b3894d1 --- /dev/null +++ b/policy_server.py @@ -0,0 +1,114 @@ +import secrets +import ssl +from dataclasses import dataclass, field + +from aiohttp import web +from signedjson.key import decode_signing_key_base64 +from signedjson.types import SigningKey + +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.crypto.event_signing import compute_event_signature + +routes = web.RouteTableDef() + +JOIN_FLOW_PAGE = """ + + +Accept policy and join room + + +""" + + +SIGNING_KEY = decode_signing_key_base64( + "ed25519", "p_afG2", "E+EmxfcqLYjlS20I5ZzjoYeN7oR9Qt/zitPGomU0hmA" +) + + +@dataclass +class PolicyServer: + server_name: str + signing_key: SigningKey + base_url: str + token_store: dict[str, str] = field(default_factory=dict) + + +@routes.get("/") +async def hello(request): + return web.Response(text="Hello, world") + + +@routes.post("/_matrix/federation/unstable/re.jki.join_policy/request_join") +async def request_join(request: web.Request) -> web.Response: + policy_server: PolicyServer = request.app["policy_server"] + return web.json_response({"url": policy_server.base_url + "/join_flow"}) + + +@routes.post("/_matrix/federation/unstable/re.jki.join_policy/sign_join") +async def sign_join(request: web.Request) -> web.Response: + policy_server: PolicyServer = request.app["policy_server"] + + json_body = await request.json() + if json_body["token"] not in policy_server.token_store: + return web.json_response({}, status=403) + + room_version_id = json_body["room_version"] + event_json = json_body["event"] + + room_version = KNOWN_ROOM_VERSIONS[room_version_id] + + signatures = compute_event_signature( + room_version=room_version, + event_dict=event_json, + signature_name=policy_server.server_name, + signing_key=policy_server.signing_key, + ) + + return web.json_response({"signatures": signatures[policy_server.server_name]}) + + +@routes.get("/join_flow") +async def join_flow(request: web.Request) -> web.Response: + redirect_url = request.query["redirect_url"] + return web.Response( + text=JOIN_FLOW_PAGE.format(redirect_url=redirect_url), content_type="text/html" + ) + + +@routes.get("/accept") +async def accept(request: web.Request) -> web.Response: + policy_server: PolicyServer = request.app["policy_server"] + + redirect_url = request.query["redirect_url"] + + token = secrets.token_hex(16) + policy_server.token_store[token] = "user_id" + + # TODO: Use less dodgy URL creation + if "?" in redirect_url: + redirect_url += f"&token={token}" + else: + redirect_url += f"?token={token}" + + return web.Response( + text="Done!", + status=307, + headers={"location": redirect_url}, + ) + + +context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) +context.load_cert_chain( + certfile="/home/erikj/git/synapse/demo/8080/localhost:8080.tls.crt", + keyfile="/home/erikj/git/synapse/demo/8080/localhost:8080.tls.key", +) + + +app = web.Application() +app["policy_server"] = PolicyServer( + server_name="localhost:8865", + signing_key=SIGNING_KEY, + base_url="https://localhost:8865", +) +app.add_routes(routes) +web.run_app(app, port=8865, ssl_context=context) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 5ecf493f98..aff77f38d4 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -552,11 +552,23 @@ def _is_membership_change_allowed( key = (EventTypes.JoinRules, "") join_rule_event = auth_events.get(key) + join_policy_server: Optional[str] = None if join_rule_event: join_rule = join_rule_event.content.get("join_rule", JoinRules.INVITE) + join_policy_server = join_rule_event.content.get("re.jki.join_policy_server") else: join_rule = JoinRules.INVITE + if ( + join_policy_server + and membership == Membership.JOIN + and not (caller_in_room or caller_invited) + ): + logger.info("Checking sigs") + if not event.signatures.get(join_policy_server): + raise AuthError(403, "Not signed by join policy server") + caller_invited = True + user_level = get_user_power_level(event.user_id, auth_events) target_level = get_user_power_level(target_user_id, auth_events) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 7d80ff6998..811caa49af 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1960,6 +1960,43 @@ class FederationClient(FederationBase): ip_address=ip_address, ) + async def join_policy_server_get_url( + self, policy_server: str, room_id: str, room_version: RoomVersion, user_id: str + ) -> Optional[str]: + result = await self.transport_layer.join_policy_server_get_url( + policy_server=policy_server, + room_id=room_id, + room_version=room_version, + user_id=user_id, + ) + + url = result.get("url") + if isinstance(url, str): + return url + return None + + async def join_policy_server_sign_join( + self, + policy_server: str, + room_id: str, + user_id: str, + token: str, + room_version: RoomVersion, + event: EventBase, + ) -> None: + result = await self.transport_layer.join_policy_server_sign_join( + policy_server=policy_server, + room_id=room_id, + user_id=user_id, + token=token, + room_version=room_version, + event=event, + ) + + signatures = result.get("signatures") + if signatures: + event.signatures[policy_server] = signatures + @attr.s(frozen=True, slots=True, auto_attribs=True) class TimestampToEventResponse: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 206e91ed14..68f68aee9c 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -894,6 +894,46 @@ class TransportLayerClient: ip_address=ip_address, ) + async def join_policy_server_get_url( + self, policy_server: str, room_id: str, room_version: RoomVersion, user_id: str + ) -> JsonDict: + path = _create_path( + FEDERATION_UNSTABLE_PREFIX, "/re.jki.join_policy/request_join" + ) + return await self.client.post_json( + policy_server, + path, + data={ + "room_id": room_id, + "room_version": room_version.identifier, + "user_id": user_id, + }, + ignore_backoff=True, + ) + + async def join_policy_server_sign_join( + self, + policy_server: str, + room_id: str, + user_id: str, + token: str, + room_version: RoomVersion, + event: EventBase, + ) -> JsonDict: + path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/re.jki.join_policy/sign_join") + return await self.client.post_json( + policy_server, + path, + data={ + "room_id": room_id, + "user_id": user_id, + "token": token, + "room_version": room_version.identifier, + "event": event.get_pdu_json(), + }, + ignore_backoff=True, + ) + def _create_path(federation_prefix: str, path: str, *args: str) -> str: """ diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 70cbbc352b..475b0c6808 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -398,6 +398,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): require_consent: bool = True, outlier: bool = False, origin_server_ts: Optional[int] = None, + join_policy_token: Optional[str] = None, ) -> Tuple[str, int]: """ Internal membership update function to get an existing event or create @@ -491,9 +492,49 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ) context = await unpersisted_context.persist(event) prev_state_ids = await context.get_prev_state_ids( - StateFilter.from_types([(EventTypes.Member, user_id)]) + StateFilter.from_types( + [(EventTypes.Member, user_id), (EventTypes.JoinRules, "")] + ) ) + if membership == Membership.JOIN: + join_rule_id = prev_state_ids.get((EventTypes.JoinRules, "")) + if join_rule_id is not None: + join_rule_event = await self.store.get_event( + join_rule_id, allow_none=True + ) + if join_rule_event: + join_policy_server = join_rule_event.content.get( + "re.jki.join_policy_server" + ) + if isinstance(join_policy_server, str): + if join_policy_token is None: + policy_url = await self.federation_handler.federation_client.join_policy_server_get_url( + policy_server=join_policy_server, + room_id=room_id, + room_version=event.room_version, + user_id=target.to_string(), + ) + + if policy_url is not None: + raise SynapseError( + 403, + "Cannot join room", + errcode="RE_JKI_JOIN_POLICY_URL", + additional_fields={ + "re.jki.join_policy_url": policy_url + }, + ) + else: + await self.federation_handler.federation_client.join_policy_server_sign_join( + policy_server=join_policy_server, + room_id=room_id, + room_version=event.room_version, + user_id=target.to_string(), + token=join_policy_token, + event=event, + ) + prev_member_event_id = prev_state_ids.get( (EventTypes.Member, user_id), None ) @@ -584,6 +625,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): state_event_ids: Optional[List[str]] = None, depth: Optional[int] = None, origin_server_ts: Optional[int] = None, + join_policy_token: Optional[str] = None, ) -> Tuple[str, int]: """Update a user's membership in a room. @@ -681,6 +723,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): state_event_ids=state_event_ids, depth=depth, origin_server_ts=origin_server_ts, + join_policy_token=join_policy_token, ) return result @@ -704,6 +747,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): state_event_ids: Optional[List[str]] = None, depth: Optional[int] = None, origin_server_ts: Optional[int] = None, + join_policy_token: Optional[str] = None, ) -> Tuple[str, int]: """Helper for update_membership. @@ -929,6 +973,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): require_consent=require_consent, outlier=outlier, origin_server_ts=origin_server_ts, + join_policy_token=join_policy_token, ) latest_event_ids = await self.store.get_prev_events_for_room(room_id) @@ -1188,6 +1233,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): require_consent=require_consent, outlier=outlier, origin_server_ts=origin_server_ts, + join_policy_token=join_policy_token, ) async def check_for_any_membership_in_room( diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 03e7bc0a24..9eb5304b24 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -528,6 +528,12 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): remote_room_hosts, ) + join_policy_token = parse_string( + request, "re.jki.join_policy_token", required=False + ) + + logger.info("re.jki.join_policy_token: %s", join_policy_token) + await self.room_member_handler.update_membership( requester=requester, target=requester.user, @@ -537,6 +543,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): remote_room_hosts=remote_room_hosts, content=content, third_party_signed=content.get("third_party_signed", None), + join_policy_token=join_policy_token, ) return 200, {"room_id": room_id}