Merge commit '3f58fc848' into anoa/dinsic_release_1_31_0
This commit is contained in:
1
changelog.d/9150.feature
Normal file
1
changelog.d/9150.feature
Normal file
@@ -0,0 +1 @@
|
||||
New API /_synapse/admin/rooms/{roomId}/context/{eventId}.
|
||||
1
changelog.d/9299.misc
Normal file
1
changelog.d/9299.misc
Normal file
@@ -0,0 +1 @@
|
||||
Update the `Cursor` type hints to better match PEP 249.
|
||||
1
changelog.d/9321.bugfix
Normal file
1
changelog.d/9321.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Assert a maximum length for the `client_secret` parameter for spec compliance.
|
||||
1
changelog.d/9333.bugfix
Normal file
1
changelog.d/9333.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix additional errors when previewing URLs: "AttributeError 'NoneType' object has no attribute 'xpath'" and "ValueError: Unicode strings with encoding declaration are not supported. Please use bytes input or XML fragments without declaration.".
|
||||
@@ -10,6 +10,7 @@
|
||||
* [Undoing room shutdowns](#undoing-room-shutdowns)
|
||||
- [Make Room Admin API](#make-room-admin-api)
|
||||
- [Forward Extremities Admin API](#forward-extremities-admin-api)
|
||||
- [Event Context API](#event-context-api)
|
||||
|
||||
# List Room API
|
||||
|
||||
@@ -594,3 +595,121 @@ that were deleted.
|
||||
"deleted": 1
|
||||
}
|
||||
```
|
||||
|
||||
# Event Context API
|
||||
|
||||
This API lets a client find the context of an event. This is designed primarily to investigate abuse reports.
|
||||
|
||||
```
|
||||
GET /_synapse/admin/v1/rooms/<room_id>/context/<event_id>
|
||||
```
|
||||
|
||||
This API mimmicks [GET /_matrix/client/r0/rooms/{roomId}/context/{eventId}](https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-rooms-roomid-context-eventid). Please refer to the link for all details on parameters and reseponse.
|
||||
|
||||
Example response:
|
||||
|
||||
```json
|
||||
{
|
||||
"end": "t29-57_2_0_2",
|
||||
"events_after": [
|
||||
{
|
||||
"content": {
|
||||
"body": "This is an example text message",
|
||||
"msgtype": "m.text",
|
||||
"format": "org.matrix.custom.html",
|
||||
"formatted_body": "<b>This is an example text message</b>"
|
||||
},
|
||||
"type": "m.room.message",
|
||||
"event_id": "$143273582443PhrSn:example.org",
|
||||
"room_id": "!636q39766251:example.com",
|
||||
"sender": "@example:example.org",
|
||||
"origin_server_ts": 1432735824653,
|
||||
"unsigned": {
|
||||
"age": 1234
|
||||
}
|
||||
}
|
||||
],
|
||||
"event": {
|
||||
"content": {
|
||||
"body": "filename.jpg",
|
||||
"info": {
|
||||
"h": 398,
|
||||
"w": 394,
|
||||
"mimetype": "image/jpeg",
|
||||
"size": 31037
|
||||
},
|
||||
"url": "mxc://example.org/JWEIFJgwEIhweiWJE",
|
||||
"msgtype": "m.image"
|
||||
},
|
||||
"type": "m.room.message",
|
||||
"event_id": "$f3h4d129462ha:example.com",
|
||||
"room_id": "!636q39766251:example.com",
|
||||
"sender": "@example:example.org",
|
||||
"origin_server_ts": 1432735824653,
|
||||
"unsigned": {
|
||||
"age": 1234
|
||||
}
|
||||
},
|
||||
"events_before": [
|
||||
{
|
||||
"content": {
|
||||
"body": "something-important.doc",
|
||||
"filename": "something-important.doc",
|
||||
"info": {
|
||||
"mimetype": "application/msword",
|
||||
"size": 46144
|
||||
},
|
||||
"msgtype": "m.file",
|
||||
"url": "mxc://example.org/FHyPlCeYUSFFxlgbQYZmoEoe"
|
||||
},
|
||||
"type": "m.room.message",
|
||||
"event_id": "$143273582443PhrSn:example.org",
|
||||
"room_id": "!636q39766251:example.com",
|
||||
"sender": "@example:example.org",
|
||||
"origin_server_ts": 1432735824653,
|
||||
"unsigned": {
|
||||
"age": 1234
|
||||
}
|
||||
}
|
||||
],
|
||||
"start": "t27-54_2_0_2",
|
||||
"state": [
|
||||
{
|
||||
"content": {
|
||||
"creator": "@example:example.org",
|
||||
"room_version": "1",
|
||||
"m.federate": true,
|
||||
"predecessor": {
|
||||
"event_id": "$something:example.org",
|
||||
"room_id": "!oldroom:example.org"
|
||||
}
|
||||
},
|
||||
"type": "m.room.create",
|
||||
"event_id": "$143273582443PhrSn:example.org",
|
||||
"room_id": "!636q39766251:example.com",
|
||||
"sender": "@example:example.org",
|
||||
"origin_server_ts": 1432735824653,
|
||||
"unsigned": {
|
||||
"age": 1234
|
||||
},
|
||||
"state_key": ""
|
||||
},
|
||||
{
|
||||
"content": {
|
||||
"membership": "join",
|
||||
"avatar_url": "mxc://example.org/SEsfnsuifSDFSSEF",
|
||||
"displayname": "Alice Margatroid"
|
||||
},
|
||||
"type": "m.room.member",
|
||||
"event_id": "$143273582443PhrSn:example.org",
|
||||
"room_id": "!636q39766251:example.com",
|
||||
"sender": "@example:example.org",
|
||||
"origin_server_ts": 1432735824653,
|
||||
"unsigned": {
|
||||
"age": 1234
|
||||
},
|
||||
"state_key": "@alice:example.org"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
import logging
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
|
||||
from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
|
||||
from synapse.util.async_helpers import concurrently_execute
|
||||
|
||||
@@ -32,6 +33,11 @@ logger = logging.getLogger(__name__)
|
||||
# TODO: Flairs
|
||||
|
||||
|
||||
# Note that the maximum lengths are somewhat arbitrary.
|
||||
MAX_SHORT_DESC_LEN = 1000
|
||||
MAX_LONG_DESC_LEN = 10000
|
||||
|
||||
|
||||
class GroupsServerWorkerHandler:
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
@@ -508,11 +514,26 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
||||
)
|
||||
|
||||
profile = {}
|
||||
for keyname in ("name", "avatar_url", "short_description", "long_description"):
|
||||
for keyname, max_length in (
|
||||
("name", MAX_DISPLAYNAME_LEN),
|
||||
("avatar_url", MAX_AVATAR_URL_LEN),
|
||||
("short_description", MAX_SHORT_DESC_LEN),
|
||||
("long_description", MAX_LONG_DESC_LEN),
|
||||
):
|
||||
if keyname in content:
|
||||
value = content[keyname]
|
||||
if not isinstance(value, str):
|
||||
raise SynapseError(400, "%r value is not a string" % (keyname,))
|
||||
raise SynapseError(
|
||||
400,
|
||||
"%r value is not a string" % (keyname,),
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
if len(value) > max_length:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Invalid %s parameter" % (keyname,),
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
profile[keyname] = value
|
||||
|
||||
await self.store.update_group_profile(group_id, profile)
|
||||
|
||||
@@ -38,6 +38,7 @@ from synapse.api.filtering import Filter
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import copy_power_levels_contents
|
||||
from synapse.rest.admin._base import assert_user_is_admin
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
@@ -1025,41 +1026,51 @@ class RoomCreationHandler(BaseHandler):
|
||||
class RoomContextHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.storage = hs.get_storage()
|
||||
self.state_store = self.storage.state
|
||||
|
||||
async def get_event_context(
|
||||
self,
|
||||
user: UserID,
|
||||
requester: Requester,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
limit: int,
|
||||
event_filter: Optional[Filter],
|
||||
use_admin_priviledge: bool = False,
|
||||
) -> Optional[JsonDict]:
|
||||
"""Retrieves events, pagination tokens and state around a given event
|
||||
in a room.
|
||||
|
||||
Args:
|
||||
user
|
||||
requester
|
||||
room_id
|
||||
event_id
|
||||
limit: The maximum number of events to return in total
|
||||
(excluding state).
|
||||
event_filter: the filter to apply to the events returned
|
||||
(excluding the target event_id)
|
||||
|
||||
use_admin_priviledge: if `True`, return all events, regardless
|
||||
of whether `user` has access to them. To be used **ONLY**
|
||||
from the admin API.
|
||||
Returns:
|
||||
dict, or None if the event isn't found
|
||||
"""
|
||||
user = requester.user
|
||||
if use_admin_priviledge:
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
before_limit = math.floor(limit / 2.0)
|
||||
after_limit = limit - before_limit
|
||||
|
||||
users = await self.store.get_users_in_room(room_id)
|
||||
is_peeking = user.to_string() not in users
|
||||
|
||||
def filter_evts(events):
|
||||
return filter_events_for_client(
|
||||
async def filter_evts(events):
|
||||
if use_admin_priviledge:
|
||||
return events
|
||||
return await filter_events_for_client(
|
||||
self.storage, user.to_string(), events, is_peeking=is_peeking
|
||||
)
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ from synapse.rest.admin.rooms import (
|
||||
JoinRoomAliasServlet,
|
||||
ListRoomRestServlet,
|
||||
MakeRoomAdminRestServlet,
|
||||
RoomEventContextServlet,
|
||||
RoomMembersRestServlet,
|
||||
RoomRestServlet,
|
||||
RoomStateRestServlet,
|
||||
@@ -238,6 +239,7 @@ def register_servlets(hs, http_server):
|
||||
MakeRoomAdminRestServlet(hs).register(http_server)
|
||||
ShadowBanRestServlet(hs).register(http_server)
|
||||
ForwardExtremitiesRestServlet(hs).register(http_server)
|
||||
RoomEventContextServlet(hs).register(http_server)
|
||||
|
||||
|
||||
def register_servlets_for_client_rest_resource(hs, http_server):
|
||||
|
||||
@@ -14,10 +14,12 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
from urllib import parse as urlparse
|
||||
|
||||
from synapse.api.constants import EventTypes, JoinRules, Membership
|
||||
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
|
||||
from synapse.api.filtering import Filter
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
@@ -34,6 +36,7 @@ from synapse.rest.admin._base import (
|
||||
)
|
||||
from synapse.storage.databases.main.room import RoomSortOrder
|
||||
from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester
|
||||
from synapse.util import json_decoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@@ -604,3 +607,65 @@ class ForwardExtremitiesRestServlet(RestServlet):
|
||||
|
||||
extremities = await self.store.get_forward_extremities_for_room(room_id)
|
||||
return 200, {"count": len(extremities), "results": extremities}
|
||||
|
||||
|
||||
class RoomEventContextServlet(RestServlet):
|
||||
"""
|
||||
Provide the context for an event.
|
||||
This API is designed to be used when system administrators wish to look at
|
||||
an abuse report and understand what happened during and immediately prior
|
||||
to this event.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super().__init__()
|
||||
self.clock = hs.get_clock()
|
||||
self.room_context_handler = hs.get_room_context_handler()
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_GET(self, request, room_id, event_id):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=False)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
limit = parse_integer(request, "limit", default=10)
|
||||
|
||||
# picking the API shape for symmetry with /messages
|
||||
filter_str = parse_string(request, b"filter", encoding="utf-8")
|
||||
if filter_str:
|
||||
filter_json = urlparse.unquote(filter_str)
|
||||
event_filter = Filter(
|
||||
json_decoder.decode(filter_json)
|
||||
) # type: Optional[Filter]
|
||||
else:
|
||||
event_filter = None
|
||||
|
||||
results = await self.room_context_handler.get_event_context(
|
||||
requester,
|
||||
room_id,
|
||||
event_id,
|
||||
limit,
|
||||
event_filter,
|
||||
use_admin_priviledge=True,
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
results["events_before"] = await self._event_serializer.serialize_events(
|
||||
results["events_before"], time_now
|
||||
)
|
||||
results["event"] = await self._event_serializer.serialize_event(
|
||||
results["event"], time_now
|
||||
)
|
||||
results["events_after"] = await self._event_serializer.serialize_events(
|
||||
results["events_after"], time_now
|
||||
)
|
||||
results["state"] = await self._event_serializer.serialize_events(
|
||||
results["state"], time_now
|
||||
)
|
||||
|
||||
return 200, results
|
||||
|
||||
@@ -648,7 +648,7 @@ class RoomEventContextServlet(RestServlet):
|
||||
event_filter = None
|
||||
|
||||
results = await self.room_context_handler.get_event_context(
|
||||
requester.user, room_id, event_id, limit, event_filter
|
||||
requester, room_id, event_id, limit, event_filter
|
||||
)
|
||||
|
||||
if not results:
|
||||
|
||||
@@ -16,13 +16,24 @@
|
||||
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from twisted.web.http import Request
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.types import GroupID
|
||||
from synapse.handlers.groups_local import GroupsLocalHandler
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
parse_json_object_from_request,
|
||||
)
|
||||
from synapse.types import GroupID, JsonDict
|
||||
|
||||
from ._base import client_patterns
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -33,7 +44,7 @@ def _validate_group_id(f):
|
||||
"""
|
||||
|
||||
@wraps(f)
|
||||
def wrapper(self, request, group_id, *args, **kwargs):
|
||||
def wrapper(self, request: Request, group_id: str, *args, **kwargs):
|
||||
if not GroupID.is_valid(group_id):
|
||||
raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
|
||||
|
||||
@@ -48,14 +59,14 @@ class GroupServlet(RestServlet):
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request, group_id):
|
||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -66,11 +77,15 @@ class GroupServlet(RestServlet):
|
||||
return 200, group_description
|
||||
|
||||
@_validate_group_id
|
||||
async def on_POST(self, request, group_id):
|
||||
async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(
|
||||
content, ("name", "avatar_url", "short_description", "long_description")
|
||||
)
|
||||
assert isinstance(self.groups_handler, GroupsLocalHandler)
|
||||
await self.groups_handler.update_group_profile(
|
||||
group_id, requester_user_id, content
|
||||
)
|
||||
@@ -84,14 +99,14 @@ class GroupSummaryServlet(RestServlet):
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request, group_id):
|
||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -116,18 +131,21 @@ class GroupSummaryRoomsCatServlet(RestServlet):
|
||||
"/rooms/(?P<room_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id, category_id, room_id):
|
||||
async def on_PUT(
|
||||
self, request: Request, group_id: str, category_id: str, room_id: str
|
||||
):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
assert isinstance(self.groups_handler, GroupsLocalHandler)
|
||||
resp = await self.groups_handler.update_group_summary_room(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
@@ -139,10 +157,13 @@ class GroupSummaryRoomsCatServlet(RestServlet):
|
||||
return 200, resp
|
||||
|
||||
@_validate_group_id
|
||||
async def on_DELETE(self, request, group_id, category_id, room_id):
|
||||
async def on_DELETE(
|
||||
self, request: Request, group_id: str, category_id: str, room_id: str
|
||||
):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
assert isinstance(self.groups_handler, GroupsLocalHandler)
|
||||
resp = await self.groups_handler.delete_group_summary_room(
|
||||
group_id, requester_user_id, room_id=room_id, category_id=category_id
|
||||
)
|
||||
@@ -158,14 +179,16 @@ class GroupCategoryServlet(RestServlet):
|
||||
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request, group_id, category_id):
|
||||
async def on_GET(
|
||||
self, request: Request, group_id: str, category_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -176,11 +199,14 @@ class GroupCategoryServlet(RestServlet):
|
||||
return 200, category
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id, category_id):
|
||||
async def on_PUT(
|
||||
self, request: Request, group_id: str, category_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
assert isinstance(self.groups_handler, GroupsLocalHandler)
|
||||
resp = await self.groups_handler.update_group_category(
|
||||
group_id, requester_user_id, category_id=category_id, content=content
|
||||
)
|
||||
@@ -188,10 +214,13 @@ class GroupCategoryServlet(RestServlet):
|
||||
return 200, resp
|
||||
|
||||
@_validate_group_id
|
||||
async def on_DELETE(self, request, group_id, category_id):
|
||||
async def on_DELETE(
|
||||
self, request: Request, group_id: str, category_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
assert isinstance(self.groups_handler, GroupsLocalHandler)
|
||||
resp = await self.groups_handler.delete_group_category(
|
||||
group_id, requester_user_id, category_id=category_id
|
||||
)
|
||||
@@ -205,14 +234,14 @@ class GroupCategoriesServlet(RestServlet):
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request, group_id):
|
||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -229,14 +258,16 @@ class GroupRoleServlet(RestServlet):
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request, group_id, role_id):
|
||||
async def on_GET(
|
||||
self, request: Request, group_id: str, role_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -247,11 +278,14 @@ class GroupRoleServlet(RestServlet):
|
||||
return 200, category
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id, role_id):
|
||||
async def on_PUT(
|
||||
self, request: Request, group_id: str, role_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
assert isinstance(self.groups_handler, GroupsLocalHandler)
|
||||
resp = await self.groups_handler.update_group_role(
|
||||
group_id, requester_user_id, role_id=role_id, content=content
|
||||
)
|
||||
@@ -259,10 +293,13 @@ class GroupRoleServlet(RestServlet):
|
||||
return 200, resp
|
||||
|
||||
@_validate_group_id
|
||||
async def on_DELETE(self, request, group_id, role_id):
|
||||
async def on_DELETE(
|
||||
self, request: Request, group_id: str, role_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
assert isinstance(self.groups_handler, GroupsLocalHandler)
|
||||
resp = await self.groups_handler.delete_group_role(
|
||||
group_id, requester_user_id, role_id=role_id
|
||||
)
|
||||
@@ -276,14 +313,14 @@ class GroupRolesServlet(RestServlet):
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request, group_id):
|
||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -308,18 +345,21 @@ class GroupSummaryUsersRoleServlet(RestServlet):
|
||||
"/users/(?P<user_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id, role_id, user_id):
|
||||
async def on_PUT(
|
||||
self, request: Request, group_id: str, role_id: str, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
assert isinstance(self.groups_handler, GroupsLocalHandler)
|
||||
resp = await self.groups_handler.update_group_summary_user(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
@@ -331,10 +371,13 @@ class GroupSummaryUsersRoleServlet(RestServlet):
|
||||
return 200, resp
|
||||
|
||||
@_validate_group_id
|
||||
async def on_DELETE(self, request, group_id, role_id, user_id):
|
||||
async def on_DELETE(
|
||||
self, request: Request, group_id: str, role_id: str, user_id: str
|
||||
):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
assert isinstance(self.groups_handler, GroupsLocalHandler)
|
||||
resp = await self.groups_handler.delete_group_summary_user(
|
||||
group_id, requester_user_id, user_id=user_id, role_id=role_id
|
||||
)
|
||||
@@ -348,14 +391,14 @@ class GroupRoomServlet(RestServlet):
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request, group_id):
|
||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -372,14 +415,14 @@ class GroupUsersServlet(RestServlet):
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request, group_id):
|
||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -396,14 +439,14 @@ class GroupInvitedUsersServlet(RestServlet):
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request, group_id):
|
||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -420,18 +463,19 @@ class GroupSettingJoinPolicyServlet(RestServlet):
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id):
|
||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
assert isinstance(self.groups_handler, GroupsLocalHandler)
|
||||
result = await self.groups_handler.set_group_join_policy(
|
||||
group_id, requester_user_id, content
|
||||
)
|
||||
@@ -445,14 +489,14 @@ class GroupCreateServlet(RestServlet):
|
||||
|
||||
PATTERNS = client_patterns("/create_group$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
self.server_name = hs.hostname
|
||||
|
||||
async def on_POST(self, request):
|
||||
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -461,6 +505,7 @@ class GroupCreateServlet(RestServlet):
|
||||
localpart = content.pop("localpart")
|
||||
group_id = GroupID(localpart, self.server_name).to_string()
|
||||
|
||||
assert isinstance(self.groups_handler, GroupsLocalHandler)
|
||||
result = await self.groups_handler.create_group(
|
||||
group_id, requester_user_id, content
|
||||
)
|
||||
@@ -476,18 +521,21 @@ class GroupAdminRoomsServlet(RestServlet):
|
||||
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id, room_id):
|
||||
async def on_PUT(
|
||||
self, request: Request, group_id: str, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
assert isinstance(self.groups_handler, GroupsLocalHandler)
|
||||
result = await self.groups_handler.add_room_to_group(
|
||||
group_id, requester_user_id, room_id, content
|
||||
)
|
||||
@@ -495,10 +543,13 @@ class GroupAdminRoomsServlet(RestServlet):
|
||||
return 200, result
|
||||
|
||||
@_validate_group_id
|
||||
async def on_DELETE(self, request, group_id, room_id):
|
||||
async def on_DELETE(
|
||||
self, request: Request, group_id: str, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
assert isinstance(self.groups_handler, GroupsLocalHandler)
|
||||
result = await self.groups_handler.remove_room_from_group(
|
||||
group_id, requester_user_id, room_id
|
||||
)
|
||||
@@ -515,18 +566,21 @@ class GroupAdminRoomsConfigServlet(RestServlet):
|
||||
"/config/(?P<config_key>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id, room_id, config_key):
|
||||
async def on_PUT(
|
||||
self, request: Request, group_id: str, room_id: str, config_key: str
|
||||
):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
assert isinstance(self.groups_handler, GroupsLocalHandler)
|
||||
result = await self.groups_handler.update_room_in_group(
|
||||
group_id, requester_user_id, room_id, config_key, content
|
||||
)
|
||||
@@ -542,7 +596,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
|
||||
"/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
@@ -551,12 +605,13 @@ class GroupAdminUsersInviteServlet(RestServlet):
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id, user_id):
|
||||
async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
config = content.get("config", {})
|
||||
assert isinstance(self.groups_handler, GroupsLocalHandler)
|
||||
result = await self.groups_handler.invite(
|
||||
group_id, user_id, requester_user_id, config
|
||||
)
|
||||
@@ -572,18 +627,19 @@ class GroupAdminUsersKickServlet(RestServlet):
|
||||
"/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id, user_id):
|
||||
async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
assert isinstance(self.groups_handler, GroupsLocalHandler)
|
||||
result = await self.groups_handler.remove_user_from_group(
|
||||
group_id, user_id, requester_user_id, content
|
||||
)
|
||||
@@ -597,18 +653,19 @@ class GroupSelfLeaveServlet(RestServlet):
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id):
|
||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
assert isinstance(self.groups_handler, GroupsLocalHandler)
|
||||
result = await self.groups_handler.remove_user_from_group(
|
||||
group_id, requester_user_id, requester_user_id, content
|
||||
)
|
||||
@@ -622,18 +679,19 @@ class GroupSelfJoinServlet(RestServlet):
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id):
|
||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
assert isinstance(self.groups_handler, GroupsLocalHandler)
|
||||
result = await self.groups_handler.join_group(
|
||||
group_id, requester_user_id, content
|
||||
)
|
||||
@@ -647,18 +705,19 @@ class GroupSelfAcceptInviteServlet(RestServlet):
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id):
|
||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
assert isinstance(self.groups_handler, GroupsLocalHandler)
|
||||
result = await self.groups_handler.accept_invite(
|
||||
group_id, requester_user_id, content
|
||||
)
|
||||
@@ -672,14 +731,14 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id):
|
||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -696,14 +755,14 @@ class PublicisedGroupsForUserServlet(RestServlet):
|
||||
|
||||
PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
async def on_GET(self, request, user_id):
|
||||
async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
|
||||
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
result = await self.groups_handler.get_publicised_groups_for_user(user_id)
|
||||
@@ -717,14 +776,14 @@ class PublicisedGroupsForUsersServlet(RestServlet):
|
||||
|
||||
PATTERNS = client_patterns("/publicised_groups$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
async def on_POST(self, request):
|
||||
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
|
||||
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
@@ -741,13 +800,13 @@ class GroupsForUserServlet(RestServlet):
|
||||
|
||||
PATTERNS = client_patterns("/joined_groups$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
async def on_GET(self, request):
|
||||
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -756,7 +815,7 @@ class GroupsForUserServlet(RestServlet):
|
||||
return 200, result
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
def register_servlets(hs: "HomeServer", http_server):
|
||||
GroupServlet(hs).register(http_server)
|
||||
GroupSummaryServlet(hs).register(http_server)
|
||||
GroupInvitedUsersServlet(hs).register(http_server)
|
||||
|
||||
@@ -195,6 +195,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
|
||||
body, ["client_secret", "country", "phone_number", "send_attempt"]
|
||||
)
|
||||
client_secret = body["client_secret"]
|
||||
assert_valid_client_secret(client_secret)
|
||||
country = body["country"]
|
||||
phone_number = body["phone_number"]
|
||||
send_attempt = body["send_attempt"]
|
||||
@@ -297,6 +298,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
|
||||
|
||||
sid = parse_string(request, "sid", required=True)
|
||||
client_secret = parse_string(request, "client_secret", required=True)
|
||||
assert_valid_client_secret(client_secret)
|
||||
token = parse_string(request, "token", required=True)
|
||||
|
||||
# Attempt to validate a 3PID session
|
||||
|
||||
@@ -58,7 +58,10 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
|
||||
_charset_match = re.compile(br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9-]+)"?', flags=re.I)
|
||||
_xml_encoding_match = re.compile(
|
||||
br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9-]+)"', flags=re.I
|
||||
)
|
||||
_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
|
||||
|
||||
OG_TAG_NAME_MAXLEN = 50
|
||||
@@ -299,24 +302,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||
with open(media_info["filename"], "rb") as file:
|
||||
body = file.read()
|
||||
|
||||
encoding = None
|
||||
|
||||
# Let's try and figure out if it has an encoding set in a meta tag.
|
||||
# Limit it to the first 1kb, since it ought to be in the meta tags
|
||||
# at the top.
|
||||
match = _charset_match.search(body[:1000])
|
||||
|
||||
# If we find a match, it should take precedence over the
|
||||
# Content-Type header, so set it here.
|
||||
if match:
|
||||
encoding = match.group(1).decode("ascii")
|
||||
|
||||
# If we don't find a match, we'll look at the HTTP Content-Type, and
|
||||
# if that doesn't exist, we'll fall back to UTF-8.
|
||||
if not encoding:
|
||||
content_match = _content_type_match.match(media_info["media_type"])
|
||||
encoding = content_match.group(1) if content_match else "utf-8"
|
||||
|
||||
encoding = get_html_media_encoding(body, media_info["media_type"])
|
||||
og = decode_and_calc_og(body, media_info["uri"], encoding)
|
||||
|
||||
# pre-cache the image for posterity
|
||||
@@ -688,6 +674,48 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||
logger.debug("No media removed from url cache")
|
||||
|
||||
|
||||
def get_html_media_encoding(body: bytes, content_type: str) -> str:
|
||||
"""
|
||||
Get the encoding of the body based on the (presumably) HTML body or media_type.
|
||||
|
||||
The precedence used for finding a character encoding is:
|
||||
|
||||
1. meta tag with a charset declared.
|
||||
2. The XML document's character encoding attribute.
|
||||
3. The Content-Type header.
|
||||
4. Fallback to UTF-8.
|
||||
|
||||
Args:
|
||||
body: The HTML document, as bytes.
|
||||
content_type: The Content-Type header.
|
||||
|
||||
Returns:
|
||||
The character encoding of the body, as a string.
|
||||
"""
|
||||
# Limit searches to the first 1kb, since it ought to be at the top.
|
||||
body_start = body[:1024]
|
||||
|
||||
# Let's try and figure out if it has an encoding set in a meta tag.
|
||||
match = _charset_match.search(body_start)
|
||||
if match:
|
||||
return match.group(1).decode("ascii")
|
||||
|
||||
# TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
|
||||
|
||||
# If we didn't find a match, see if it an XML document with an encoding.
|
||||
match = _xml_encoding_match.match(body_start)
|
||||
if match:
|
||||
return match.group(1).decode("ascii")
|
||||
|
||||
# If we don't find a match, we'll look at the HTTP Content-Type, and
|
||||
# if that doesn't exist, we'll fall back to UTF-8.
|
||||
content_match = _content_type_match.match(content_type)
|
||||
if content_match:
|
||||
return content_match.group(1)
|
||||
|
||||
return "utf-8"
|
||||
|
||||
|
||||
def decode_and_calc_og(
|
||||
body: bytes, media_uri: str, request_encoding: Optional[str] = None
|
||||
) -> Dict[str, Optional[str]]:
|
||||
@@ -724,6 +752,11 @@ def decode_and_calc_og(
|
||||
def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]:
|
||||
# Attempt to parse the body. If this fails, log and return no metadata.
|
||||
tree = etree.fromstring(body_attempt, parser)
|
||||
|
||||
# The data was successfully parsed, but no tree was found.
|
||||
if tree is None:
|
||||
return {}
|
||||
|
||||
return _calc_og(tree, media_uri)
|
||||
|
||||
# Attempt to parse the body. If this fails, log and return no metadata.
|
||||
|
||||
@@ -24,7 +24,17 @@
|
||||
import abc
|
||||
import functools
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import twisted.internet.base
|
||||
import twisted.internet.tcp
|
||||
@@ -582,7 +592,9 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||
return UserDirectoryHandler(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_groups_local_handler(self):
|
||||
def get_groups_local_handler(
|
||||
self,
|
||||
) -> Union[GroupsLocalWorkerHandler, GroupsLocalHandler]:
|
||||
if self.config.worker_app:
|
||||
return GroupsLocalWorkerHandler(self)
|
||||
else:
|
||||
|
||||
@@ -158,8 +158,8 @@ class LoggingDatabaseConnection:
|
||||
def commit(self) -> None:
|
||||
self.conn.commit()
|
||||
|
||||
def rollback(self, *args, **kwargs) -> None:
|
||||
self.conn.rollback(*args, **kwargs)
|
||||
def rollback(self) -> None:
|
||||
self.conn.rollback()
|
||||
|
||||
def __enter__(self) -> "Connection":
|
||||
self.conn.__enter__()
|
||||
@@ -244,12 +244,15 @@ class LoggingTransaction:
|
||||
assert self.exception_callbacks is not None
|
||||
self.exception_callbacks.append((callback, args, kwargs))
|
||||
|
||||
def fetchone(self) -> Optional[Tuple]:
|
||||
return self.txn.fetchone()
|
||||
|
||||
def fetchmany(self, size: Optional[int] = None) -> List[Tuple]:
|
||||
return self.txn.fetchmany(size=size)
|
||||
|
||||
def fetchall(self) -> List[Tuple]:
|
||||
return self.txn.fetchall()
|
||||
|
||||
def fetchone(self) -> Tuple:
|
||||
return self.txn.fetchone()
|
||||
|
||||
def __iter__(self) -> Iterator[Tuple]:
|
||||
return self.txn.__iter__()
|
||||
|
||||
@@ -754,6 +757,7 @@ class DatabasePool:
|
||||
Returns:
|
||||
A list of dicts where the key is the column header.
|
||||
"""
|
||||
assert cursor.description is not None, "cursor.description was None"
|
||||
col_headers = [intern(str(column[0])) for column in cursor.description]
|
||||
results = [dict(zip(col_headers, row)) for row in cursor]
|
||||
return results
|
||||
|
||||
@@ -619,9 +619,9 @@ def _get_or_create_schema_state(
|
||||
|
||||
txn.execute("SELECT version, upgraded FROM schema_version")
|
||||
row = txn.fetchone()
|
||||
current_version = int(row[0]) if row else None
|
||||
|
||||
if current_version:
|
||||
if row is not None:
|
||||
current_version = int(row[0])
|
||||
txn.execute(
|
||||
"SELECT file FROM applied_schema_deltas WHERE version >= ?",
|
||||
(current_version,),
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Iterable, Iterator, List, Optional, Tuple
|
||||
from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
|
||||
|
||||
from typing_extensions import Protocol
|
||||
|
||||
@@ -20,23 +20,44 @@ from typing_extensions import Protocol
|
||||
Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
|
||||
"""
|
||||
|
||||
_Parameters = Union[Sequence[Any], Mapping[str, Any]]
|
||||
|
||||
|
||||
class Cursor(Protocol):
|
||||
def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
|
||||
def execute(self, sql: str, parameters: _Parameters = ...) -> Any:
|
||||
...
|
||||
|
||||
def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
|
||||
def executemany(self, sql: str, parameters: Sequence[_Parameters]) -> Any:
|
||||
...
|
||||
|
||||
def fetchone(self) -> Optional[Tuple]:
|
||||
...
|
||||
|
||||
def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]:
|
||||
...
|
||||
|
||||
def fetchall(self) -> List[Tuple]:
|
||||
...
|
||||
|
||||
def fetchone(self) -> Tuple:
|
||||
...
|
||||
|
||||
@property
|
||||
def description(self) -> Any:
|
||||
return None
|
||||
def description(
|
||||
self,
|
||||
) -> Optional[
|
||||
Sequence[
|
||||
# Note that this is an approximate typing based on sqlite3 and other
|
||||
# drivers, and may not be entirely accurate.
|
||||
Tuple[
|
||||
str,
|
||||
Optional[Any],
|
||||
Optional[int],
|
||||
Optional[int],
|
||||
Optional[int],
|
||||
Optional[int],
|
||||
Optional[int],
|
||||
]
|
||||
]
|
||||
]:
|
||||
...
|
||||
|
||||
@property
|
||||
def rowcount(self) -> int:
|
||||
@@ -59,7 +80,7 @@ class Connection(Protocol):
|
||||
def commit(self) -> None:
|
||||
...
|
||||
|
||||
def rollback(self, *args, **kwargs) -> None:
|
||||
def rollback(self) -> None:
|
||||
...
|
||||
|
||||
def __enter__(self) -> "Connection":
|
||||
|
||||
@@ -106,7 +106,9 @@ class PostgresSequenceGenerator(SequenceGenerator):
|
||||
|
||||
def get_next_id_txn(self, txn: Cursor) -> int:
|
||||
txn.execute("SELECT nextval(?)", (self._sequence_name,))
|
||||
return txn.fetchone()[0]
|
||||
fetch_res = txn.fetchone()
|
||||
assert fetch_res is not None
|
||||
return fetch_res[0]
|
||||
|
||||
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
|
||||
txn.execute(
|
||||
@@ -147,7 +149,9 @@ class PostgresSequenceGenerator(SequenceGenerator):
|
||||
txn.execute(
|
||||
"SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
|
||||
)
|
||||
last_value, is_called = txn.fetchone()
|
||||
fetch_res = txn.fetchone()
|
||||
assert fetch_res is not None
|
||||
last_value, is_called = fetch_res
|
||||
|
||||
# If we have an associated stream check the stream_positions table.
|
||||
max_in_stream_positions = None
|
||||
|
||||
@@ -25,7 +25,7 @@ from synapse.api.errors import Codes, SynapseError
|
||||
_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
|
||||
|
||||
# https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
|
||||
client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
|
||||
CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
|
||||
|
||||
# https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris,
|
||||
# together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically
|
||||
@@ -42,28 +42,31 @@ MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
|
||||
rand = random.SystemRandom()
|
||||
|
||||
|
||||
def random_string(length):
|
||||
def random_string(length: int) -> str:
|
||||
return "".join(rand.choice(string.ascii_letters) for _ in range(length))
|
||||
|
||||
|
||||
def random_string_with_symbols(length):
|
||||
def random_string_with_symbols(length: int) -> str:
|
||||
return "".join(rand.choice(_string_with_symbols) for _ in range(length))
|
||||
|
||||
|
||||
def is_ascii(s):
|
||||
if isinstance(s, bytes):
|
||||
try:
|
||||
s.decode("ascii").encode("ascii")
|
||||
except UnicodeDecodeError:
|
||||
return False
|
||||
except UnicodeEncodeError:
|
||||
return False
|
||||
return True
|
||||
def is_ascii(s: bytes) -> bool:
|
||||
try:
|
||||
s.decode("ascii").encode("ascii")
|
||||
except UnicodeDecodeError:
|
||||
return False
|
||||
except UnicodeEncodeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def assert_valid_client_secret(client_secret):
|
||||
"""Validate that a given string matches the client_secret regex defined by the spec"""
|
||||
if client_secret_regex.match(client_secret) is None:
|
||||
def assert_valid_client_secret(client_secret: str) -> None:
|
||||
"""Validate that a given string matches the client_secret defined by the spec"""
|
||||
if (
|
||||
len(client_secret) <= 0
|
||||
or len(client_secret) > 255
|
||||
or CLIENT_SECRET_REGEX.match(client_secret) is None
|
||||
):
|
||||
raise SynapseError(
|
||||
400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM
|
||||
)
|
||||
|
||||
@@ -80,6 +80,7 @@ async def filter_events_for_client(
|
||||
events = [e for e in events if not e.internal_metadata.is_soft_failed()]
|
||||
|
||||
types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
|
||||
|
||||
event_id_to_state = await storage.state.get_state_for_events(
|
||||
frozenset(e.event_id for e in events),
|
||||
state_filter=StateFilter.from_types(types),
|
||||
|
||||
@@ -1445,6 +1445,90 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
|
||||
|
||||
def test_context_as_non_admin(self):
|
||||
"""
|
||||
Test that, without being admin, one cannot use the context admin API
|
||||
"""
|
||||
# Create a room.
|
||||
user_id = self.register_user("test", "test")
|
||||
user_tok = self.login("test", "test")
|
||||
|
||||
self.register_user("test_2", "test")
|
||||
user_tok_2 = self.login("test_2", "test")
|
||||
|
||||
room_id = self.helper.create_room_as(user_id, tok=user_tok)
|
||||
|
||||
# Populate the room with events.
|
||||
events = []
|
||||
for i in range(30):
|
||||
events.append(
|
||||
self.helper.send_event(
|
||||
room_id, "com.example.test", content={"index": i}, tok=user_tok
|
||||
)
|
||||
)
|
||||
|
||||
# Now attempt to find the context using the admin API without being admin.
|
||||
midway = (len(events) - 1) // 2
|
||||
for tok in [user_tok, user_tok_2]:
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_synapse/admin/v1/rooms/%s/context/%s"
|
||||
% (room_id, events[midway]["event_id"]),
|
||||
access_token=tok,
|
||||
)
|
||||
self.assertEquals(
|
||||
403, int(channel.result["code"]), msg=channel.result["body"]
|
||||
)
|
||||
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||
|
||||
def test_context_as_admin(self):
|
||||
"""
|
||||
Test that, as admin, we can find the context of an event without having joined the room.
|
||||
"""
|
||||
|
||||
# Create a room. We're not part of it.
|
||||
user_id = self.register_user("test", "test")
|
||||
user_tok = self.login("test", "test")
|
||||
room_id = self.helper.create_room_as(user_id, tok=user_tok)
|
||||
|
||||
# Populate the room with events.
|
||||
events = []
|
||||
for i in range(30):
|
||||
events.append(
|
||||
self.helper.send_event(
|
||||
room_id, "com.example.test", content={"index": i}, tok=user_tok
|
||||
)
|
||||
)
|
||||
|
||||
# Now let's fetch the context for this room.
|
||||
midway = (len(events) - 1) // 2
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_synapse/admin/v1/rooms/%s/context/%s"
|
||||
% (room_id, events[midway]["event_id"]),
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEquals(
|
||||
channel.json_body["event"]["event_id"], events[midway]["event_id"]
|
||||
)
|
||||
|
||||
for i, found_event in enumerate(channel.json_body["events_before"]):
|
||||
for j, posted_event in enumerate(events):
|
||||
if found_event["event_id"] == posted_event["event_id"]:
|
||||
self.assertTrue(j < midway)
|
||||
break
|
||||
else:
|
||||
self.fail("Event %s from events_before not found" % j)
|
||||
|
||||
for i, found_event in enumerate(channel.json_body["events_after"]):
|
||||
for j, posted_event in enumerate(events):
|
||||
if found_event["event_id"] == posted_event["event_id"]:
|
||||
self.assertTrue(j > midway)
|
||||
break
|
||||
else:
|
||||
self.fail("Event %s from events_after not found" % j)
|
||||
|
||||
|
||||
class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
from synapse.rest.media.v1.preview_url_resource import (
|
||||
decode_and_calc_og,
|
||||
get_html_media_encoding,
|
||||
summarize_paragraphs,
|
||||
)
|
||||
|
||||
@@ -26,7 +27,7 @@ except ImportError:
|
||||
lxml = None
|
||||
|
||||
|
||||
class PreviewTestCase(unittest.TestCase):
|
||||
class SummarizeTestCase(unittest.TestCase):
|
||||
if not lxml:
|
||||
skip = "url preview feature requires lxml"
|
||||
|
||||
@@ -144,12 +145,12 @@ class PreviewTestCase(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
class PreviewUrlTestCase(unittest.TestCase):
|
||||
class CalcOgTestCase(unittest.TestCase):
|
||||
if not lxml:
|
||||
skip = "url preview feature requires lxml"
|
||||
|
||||
def test_simple(self):
|
||||
html = """
|
||||
html = b"""
|
||||
<html>
|
||||
<head><title>Foo</title></head>
|
||||
<body>
|
||||
@@ -163,7 +164,7 @@ class PreviewUrlTestCase(unittest.TestCase):
|
||||
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||
|
||||
def test_comment(self):
|
||||
html = """
|
||||
html = b"""
|
||||
<html>
|
||||
<head><title>Foo</title></head>
|
||||
<body>
|
||||
@@ -178,7 +179,7 @@ class PreviewUrlTestCase(unittest.TestCase):
|
||||
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||
|
||||
def test_comment2(self):
|
||||
html = """
|
||||
html = b"""
|
||||
<html>
|
||||
<head><title>Foo</title></head>
|
||||
<body>
|
||||
@@ -202,7 +203,7 @@ class PreviewUrlTestCase(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_script(self):
|
||||
html = """
|
||||
html = b"""
|
||||
<html>
|
||||
<head><title>Foo</title></head>
|
||||
<body>
|
||||
@@ -217,7 +218,7 @@ class PreviewUrlTestCase(unittest.TestCase):
|
||||
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||
|
||||
def test_missing_title(self):
|
||||
html = """
|
||||
html = b"""
|
||||
<html>
|
||||
<body>
|
||||
Some text.
|
||||
@@ -230,7 +231,7 @@ class PreviewUrlTestCase(unittest.TestCase):
|
||||
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
|
||||
|
||||
def test_h1_as_title(self):
|
||||
html = """
|
||||
html = b"""
|
||||
<html>
|
||||
<meta property="og:description" content="Some text."/>
|
||||
<body>
|
||||
@@ -244,7 +245,7 @@ class PreviewUrlTestCase(unittest.TestCase):
|
||||
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
|
||||
|
||||
def test_missing_title_and_broken_h1(self):
|
||||
html = """
|
||||
html = b"""
|
||||
<html>
|
||||
<body>
|
||||
<h1><a href="foo"/></h1>
|
||||
@@ -258,13 +259,20 @@ class PreviewUrlTestCase(unittest.TestCase):
|
||||
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
|
||||
|
||||
def test_empty(self):
|
||||
html = ""
|
||||
"""Test a body with no data in it."""
|
||||
html = b""
|
||||
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||
self.assertEqual(og, {})
|
||||
|
||||
def test_no_tree(self):
|
||||
"""A valid body with no tree in it."""
|
||||
html = b"\x00"
|
||||
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||
self.assertEqual(og, {})
|
||||
|
||||
def test_invalid_encoding(self):
|
||||
"""An invalid character encoding should be ignored and treated as UTF-8, if possible."""
|
||||
html = """
|
||||
html = b"""
|
||||
<html>
|
||||
<head><title>Foo</title></head>
|
||||
<body>
|
||||
@@ -290,3 +298,76 @@ class PreviewUrlTestCase(unittest.TestCase):
|
||||
"""
|
||||
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||
self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
|
||||
|
||||
|
||||
class MediaEncodingTestCase(unittest.TestCase):
|
||||
def test_meta_charset(self):
|
||||
"""A character encoding is found via the meta tag."""
|
||||
encoding = get_html_media_encoding(
|
||||
b"""
|
||||
<html>
|
||||
<head><meta charset="ascii">
|
||||
</head>
|
||||
</html>
|
||||
""",
|
||||
"text/html",
|
||||
)
|
||||
self.assertEqual(encoding, "ascii")
|
||||
|
||||
# A less well-formed version.
|
||||
encoding = get_html_media_encoding(
|
||||
b"""
|
||||
<html>
|
||||
<head>< meta charset = ascii>
|
||||
</head>
|
||||
</html>
|
||||
""",
|
||||
"text/html",
|
||||
)
|
||||
self.assertEqual(encoding, "ascii")
|
||||
|
||||
def test_xml_encoding(self):
|
||||
"""A character encoding is found via the meta tag."""
|
||||
encoding = get_html_media_encoding(
|
||||
b"""
|
||||
<?xml version="1.0" encoding="ascii"?>
|
||||
<html>
|
||||
</html>
|
||||
""",
|
||||
"text/html",
|
||||
)
|
||||
self.assertEqual(encoding, "ascii")
|
||||
|
||||
def test_meta_xml_encoding(self):
|
||||
"""Meta tags take precedence over XML encoding."""
|
||||
encoding = get_html_media_encoding(
|
||||
b"""
|
||||
<?xml version="1.0" encoding="ascii"?>
|
||||
<html>
|
||||
<head><meta charset="UTF-16">
|
||||
</head>
|
||||
</html>
|
||||
""",
|
||||
"text/html",
|
||||
)
|
||||
self.assertEqual(encoding, "UTF-16")
|
||||
|
||||
def test_content_type(self):
|
||||
"""A character encoding is found via the Content-Type header."""
|
||||
# Test a few variations of the header.
|
||||
headers = (
|
||||
'text/html; charset="ascii";',
|
||||
"text/html;charset=ascii;",
|
||||
'text/html; charset="ascii"',
|
||||
"text/html; charset=ascii",
|
||||
'text/html; charset="ascii;',
|
||||
'text/html; charset=ascii";',
|
||||
)
|
||||
for header in headers:
|
||||
encoding = get_html_media_encoding(b"", header)
|
||||
self.assertEqual(encoding, "ascii")
|
||||
|
||||
def test_fallback(self):
|
||||
"""A character encoding cannot be found in the body or header."""
|
||||
encoding = get_html_media_encoding(b"", "text/html")
|
||||
self.assertEqual(encoding, "utf-8")
|
||||
|
||||
Reference in New Issue
Block a user