1
0

Merge commit 'f14428b25' into anoa/dinsic_release_1_31_0

This commit is contained in:
Andrew Morgan
2021-04-16 15:02:53 +01:00
45 changed files with 521 additions and 249 deletions
+4 -3
View File
@@ -7,6 +7,8 @@ jobs:
- checkout
- docker_prepare
- run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD
# for release builds, we want to get the amd64 image out asap, so first
# we do an amd64-only build, before following up with a multiarch build.
- docker_build:
tag: -t matrixdotorg/synapse:${CIRCLE_TAG}
platforms: linux/amd64
@@ -21,9 +23,8 @@ jobs:
- checkout
- docker_prepare
- run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD
- docker_build:
tag: -t matrixdotorg/synapse:latest
platforms: linux/amd64
# for `latest`, we don't want the arm images to disappear, so don't update the tag
# until all of the platforms are built.
- docker_build:
tag: -t matrixdotorg/synapse:latest
platforms: linux/amd64,linux/arm/v7,linux/arm64
+1
View File
@@ -0,0 +1 @@
Add number of local devices to Room Details Admin API. Contributed by @dklimpel.
+1
View File
@@ -0,0 +1 @@
Spam-checkers may now define their methods as `async`.
+1
View File
@@ -0,0 +1 @@
Add type hints to push module.
+1
View File
@@ -0,0 +1 @@
Don't publish `latest` docker image until all archs are built.
+1
View File
@@ -0,0 +1 @@
Improve structured logging tests.
+1
View File
@@ -0,0 +1 @@
Fix occasional deadlock when handling SIGHUP.
+1
View File
@@ -0,0 +1 @@
Fix login API to not ratelimit application services that have ratelimiting disabled.
+1
View File
@@ -0,0 +1 @@
Fix bug where we ratelimited auto joining of rooms on registration (using `auto_join_rooms` config).
+13 -11
View File
@@ -87,7 +87,7 @@ GET /_synapse/admin/v1/rooms
Response:
```
```jsonc
{
"rooms": [
{
@@ -139,7 +139,7 @@ GET /_synapse/admin/v1/rooms?search_term=TWIM
Response:
```
```json
{
"rooms": [
{
@@ -174,7 +174,7 @@ GET /_synapse/admin/v1/rooms?order_by=size
Response:
```
```jsonc
{
"rooms": [
{
@@ -230,14 +230,14 @@ GET /_synapse/admin/v1/rooms?order_by=size&from=100
Response:
```
```jsonc
{
"rooms": [
{
"room_id": "!mscvqgqpHYjBGDxNym:matrix.org",
"name": "Music Theory",
"canonical_alias": "#musictheory:matrix.org",
"joined_members": 127
"joined_members": 127,
"joined_local_members": 2,
"version": "1",
"creator": "@foo:matrix.org",
@@ -254,7 +254,7 @@ Response:
"room_id": "!twcBhHVdZlQWuuxBhN:termina.org.uk",
"name": "weechat-matrix",
"canonical_alias": "#weechat-matrix:termina.org.uk",
"joined_members": 137
"joined_members": 137,
"joined_local_members": 20,
"version": "4",
"creator": "@foo:termina.org.uk",
@@ -289,6 +289,7 @@ The following fields are possible in the JSON response body:
* `canonical_alias` - The canonical (main) alias address of the room.
* `joined_members` - How many users are currently in the room.
* `joined_local_members` - How many local users are currently in the room.
* `joined_local_devices` - How many local devices are currently in the room.
* `version` - The version of the room as a string.
* `creator` - The `user_id` of the room creator.
* `encryption` - Algorithm of end-to-end encryption of messages. Is `null` if encryption is not active.
@@ -311,15 +312,16 @@ GET /_synapse/admin/v1/rooms/<room_id>
Response:
```
```json
{
"room_id": "!mscvqgqpHYjBGDxNym:matrix.org",
"name": "Music Theory",
"avatar": "mxc://matrix.org/AQDaVFlbkQoErdOgqWRgiGSV",
"topic": "Theory, Composition, Notation, Analysis",
"canonical_alias": "#musictheory:matrix.org",
"joined_members": 127
"joined_members": 127,
"joined_local_members": 2,
"joined_local_devices": 2,
"version": "1",
"creator": "@foo:matrix.org",
"encryption": null,
@@ -353,13 +355,13 @@ GET /_synapse/admin/v1/rooms/<room_id>/members
Response:
```
```json
{
"members": [
"@foo:matrix.org",
"@bar:matrix.org",
"@foobar:matrix.org
],
"@foobar:matrix.org"
],
"total": 3
}
```
+13 -6
View File
@@ -22,6 +22,8 @@ well as some specific methods:
* `user_may_create_room`
* `user_may_create_room_alias`
* `user_may_publish_room`
* `check_username_for_spam`
* `check_registration_for_spam`
The details of the each of these methods (as well as their inputs and outputs)
are documented in the `synapse.events.spamcheck.SpamChecker` class.
@@ -32,28 +34,33 @@ call back into the homeserver internals.
### Example
```python
from synapse.spam_checker_api import RegistrationBehaviour
class ExampleSpamChecker:
def __init__(self, config, api):
self.config = config
self.api = api
def check_event_for_spam(self, foo):
async def check_event_for_spam(self, foo):
return False # allow all events
def user_may_invite(self, inviter_userid, invitee_userid, room_id):
async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
return True # allow all invites
def user_may_create_room(self, userid):
async def user_may_create_room(self, userid):
return True # allow all room creations
def user_may_create_room_alias(self, userid, room_alias):
async def user_may_create_room_alias(self, userid, room_alias):
return True # allow all room aliases
def user_may_publish_room(self, userid, room_id):
async def user_may_publish_room(self, userid, room_id):
return True # allow publishing of all rooms
def check_username_for_spam(self, user_profile):
async def check_username_for_spam(self, user_profile):
return False # allow all usernames
async def check_registration_for_spam(self, email_threepid, username, request_info):
return RegistrationBehaviour.ALLOW # allow all registrations
```
## Configuration
+1 -6
View File
@@ -56,12 +56,7 @@ files =
synapse/metrics,
synapse/module_api,
synapse/notifier.py,
synapse/push/emailpusher.py,
synapse/push/httppusher.py,
synapse/push/mailer.py,
synapse/push/pusher.py,
synapse/push/pusherpool.py,
synapse/push/push_rule_evaluator.py,
synapse/push,
synapse/replication,
synapse/rest,
synapse/server.py,
+2
View File
@@ -31,6 +31,8 @@ class SynapsePlugin(Plugin):
) -> Optional[Callable[[MethodSigContext], CallableType]]:
if fullname.startswith(
"synapse.util.caches.descriptors._CachedFunction.__call__"
) or fullname.startswith(
"synapse.util.caches.descriptors._LruCachedFunction.__call__"
):
return cached_function_method_signature
return None
+3 -1
View File
@@ -31,7 +31,9 @@ from synapse.api.errors import (
MissingClientTokenError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import StateMap, UserID
@@ -479,7 +481,7 @@ class Auth:
now = self.hs.get_clock().time_msec()
return now < expiry
def get_appservice_by_req(self, request):
def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
token = self.get_access_token_from_request(request)
service = self.store.get_app_service_by_token(token)
if not service:
+5 -1
View File
@@ -245,6 +245,8 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
# Set up the SIGHUP machinery.
if hasattr(signal, "SIGHUP"):
reactor = hs.get_reactor()
@wrap_as_background_process("sighup")
def handle_sighup(*args, **kwargs):
# Tell systemd our state, if we're using it. This will silently fail if
@@ -260,7 +262,9 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
# is so that we're in a sane state, e.g. flushing the logs may fail
# if the sighup happens in the middle of writing a log entry.
def run_sighup(*args, **kwargs):
hs.get_clock().call_later(0, handle_sighup, *args, **kwargs)
# `callFromThread` should be "signal safe" as well as thread
# safe.
reactor.callFromThread(handle_sighup, *args, **kwargs)
signal.signal(signal.SIGHUP, run_sighup)
+45 -27
View File
@@ -15,10 +15,11 @@
# limitations under the License.
import inspect
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import Collection
from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
import synapse.events
@@ -39,7 +40,9 @@ class SpamChecker:
else:
self.spam_checkers.append(module(config=config))
def check_event_for_spam(self, event: "synapse.events.EventBase") -> bool:
async def check_event_for_spam(
self, event: "synapse.events.EventBase"
) -> Union[bool, str]:
"""Checks if a given event is considered "spammy" by this server.
If the server considers an event spammy, then it will be rejected if
@@ -50,15 +53,16 @@ class SpamChecker:
event: the event to be checked
Returns:
True if the event is spammy.
True or a string if the event is spammy. If a string is returned it
will be used as the error message returned to the user.
"""
for spam_checker in self.spam_checkers:
if spam_checker.check_event_for_spam(event):
if await maybe_awaitable(spam_checker.check_event_for_spam(event)):
return True
return False
def user_may_invite(
async def user_may_invite(
self,
inviter_userid: str,
invitee_userid: str,
@@ -91,21 +95,22 @@ class SpamChecker:
"""
for spam_checker in self.spam_checkers:
if (
spam_checker.user_may_invite(
inviter_userid,
invitee_userid,
third_party_invite,
room_id,
new_room,
published_room,
)
is False
await maybe_awaitable(
spam_checker.user_may_invite(
inviter_userid,
invitee_userid,
third_party_invite,
room_id,
new_room,
published_room,
)
) is False
):
return False
return True
def user_may_create_room(
async def user_may_create_room(
self,
userid: str,
invite_list: List[str],
@@ -130,16 +135,17 @@ class SpamChecker:
"""
for spam_checker in self.spam_checkers:
if (
spam_checker.user_may_create_room(
userid, invite_list, third_party_invite_list, cloning
)
is False
await maybe_awaitable(
spam_checker.user_may_create_room(
userid, invite_list, third_party_invite_list, cloning
)
) is False
):
return False
return True
def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
async def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
"""Checks if a given user may create a room alias
If this method returns false, the association request will be rejected.
@@ -152,12 +158,17 @@ class SpamChecker:
True if the user may create a room alias, otherwise False
"""
for spam_checker in self.spam_checkers:
if spam_checker.user_may_create_room_alias(userid, room_alias) is False:
if (
await maybe_awaitable(
spam_checker.user_may_create_room_alias(userid, room_alias)
)
is False
):
return False
return True
def user_may_publish_room(self, userid: str, room_id: str) -> bool:
async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
"""Checks if a given user may publish a room to the directory
If this method returns false, the publish request will be rejected.
@@ -170,7 +181,12 @@ class SpamChecker:
True if the user may publish the room, otherwise False
"""
for spam_checker in self.spam_checkers:
if spam_checker.user_may_publish_room(userid, room_id) is False:
if (
await maybe_awaitable(
spam_checker.user_may_publish_room(userid, room_id)
)
is False
):
return False
return True
@@ -194,7 +210,7 @@ class SpamChecker:
return True
def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
"""Checks if a user ID or display name are considered "spammy" by this server.
If the server considers a username spammy, then it will not be included in
@@ -216,12 +232,12 @@ class SpamChecker:
if checker:
# Make a copy of the user profile object to ensure the spam checker
# cannot modify it.
if checker(user_profile.copy()):
if await maybe_awaitable(checker(user_profile.copy())):
return True
return False
def check_registration_for_spam(
async def check_registration_for_spam(
self,
email_threepid: Optional[dict],
username: Optional[str],
@@ -244,7 +260,9 @@ class SpamChecker:
# spam checker
checker = getattr(spam_checker, "check_registration_for_spam", None)
if checker:
behaviour = checker(email_threepid, username, request_info)
behaviour = await maybe_awaitable(
checker(email_threepid, username, request_info)
)
assert isinstance(behaviour, RegistrationBehaviour)
if behaviour != RegistrationBehaviour.ALLOW:
return behaviour
+6 -1
View File
@@ -78,6 +78,7 @@ class FederationBase:
ctx = current_context()
@defer.inlineCallbacks
def callback(_, pdu: EventBase):
with PreserveLoggingContext(ctx):
if not check_event_content_hash(pdu):
@@ -105,7 +106,11 @@ class FederationBase:
)
return redacted_event
if self.spam_checker.check_event_for_spam(pdu):
result = yield defer.ensureDeferred(
self.spam_checker.check_event_for_spam(pdu)
)
if result:
logger.warning(
"Event contains spam, redacting %s: %s",
pdu.event_id,
+8 -7
View File
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
import time
import unicodedata
@@ -22,6 +21,7 @@ import urllib.parse
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Iterable,
@@ -58,6 +58,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
@@ -851,7 +852,7 @@ class AuthHandler(BaseHandler):
async def validate_login(
self, login_submission: Dict[str, Any], ratelimit: bool = False,
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate auth types which don't
@@ -994,7 +995,7 @@ class AuthHandler(BaseHandler):
async def _validate_userid_login(
self, username: str, login_submission: Dict[str, Any],
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Helper for validate_login
Handles login, once we've mapped 3pids onto userids
@@ -1072,7 +1073,7 @@ class AuthHandler(BaseHandler):
async def check_password_provider_3pid(
self, medium: str, address: str, password: str
) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]:
) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Check if a password provider is able to validate a thirdparty login
Args:
@@ -1628,6 +1629,6 @@ class PasswordProvider:
# This might return an awaitable, if it does block the log out
# until it completes.
result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
if inspect.isawaitable(result):
await result
await maybe_awaitable(
g(user_id=user_id, device_id=device_id, access_token=access_token,)
)
+4 -2
View File
@@ -133,7 +133,9 @@ class DirectoryHandler(BaseHandler):
403, "You must be in the room to create an alias for it"
)
if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
if not await self.spam_checker.user_may_create_room_alias(
user_id, room_alias
):
raise AuthError(403, "This user is not permitted to create this alias")
if not self.config.is_alias_creation_allowed(
@@ -409,7 +411,7 @@ class DirectoryHandler(BaseHandler):
"""
user_id = requester.user.to_string()
if not self.spam_checker.user_may_publish_room(user_id, room_id):
if not await self.spam_checker.user_may_publish_room(user_id, room_id):
raise AuthError(
403, "This user is not permitted to publish rooms to the room list"
)
+1 -1
View File
@@ -1666,7 +1666,7 @@ class FederationHandler(BaseHandler):
is_published = await self.store.is_room_published(event.room_id)
if not self.spam_checker.user_may_invite(
if not await self.spam_checker.user_may_invite(
event.sender,
event.state_key,
None,
+1 -1
View File
@@ -746,7 +746,7 @@ class EventCreationHandler:
event.sender,
)
spam_error = self.spam_checker.check_event_for_spam(event)
spam_error = await self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, str):
spam_error = "Spam is not permitted here"
+2 -5
View File
@@ -18,7 +18,6 @@ from typing import List, Tuple
from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler
from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@@ -98,10 +97,8 @@ class ReceiptsHandler(BaseHandler):
self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
# Note that the min here shouldn't be relied upon to be accurate.
await maybe_awaitable(
self.hs.get_pusherpool().on_new_receipts(
min_batch_id, max_batch_id, affected_room_ids
)
await self.hs.get_pusherpool().on_new_receipts(
min_batch_id, max_batch_id, affected_room_ids
)
return True
+1 -1
View File
@@ -202,7 +202,7 @@ class RegistrationHandler(BaseHandler):
"""
self.check_registration_ratelimit(address)
result = self.spam_checker.check_registration_for_spam(
result = await self.spam_checker.check_registration_for_spam(
threepid, localpart, user_agent_ips or [],
)
+6 -3
View File
@@ -368,7 +368,7 @@ class RoomCreationHandler(BaseHandler):
else:
is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin and not self.spam_checker.user_may_create_room(
if not is_requester_admin and not await self.spam_checker.user_may_create_room(
user_id, invite_list=[], third_party_invite_list=[], cloning=True
):
raise SynapseError(403, "You are not permitted to create rooms")
@@ -452,6 +452,7 @@ class RoomCreationHandler(BaseHandler):
invite_list=[],
initial_state=initial_state,
creation_content=creation_content,
ratelimit=False,
)
# Transfer membership events
@@ -623,7 +624,7 @@ class RoomCreationHandler(BaseHandler):
invite_list = config.get("invite", [])
invite_3pid_list = config.get("invite_3pid", [])
if not is_requester_admin and not self.spam_checker.user_may_create_room(
if not is_requester_admin and not await self.spam_checker.user_may_create_room(
user_id,
invite_list=invite_list,
third_party_invite_list=invite_3pid_list,
@@ -753,6 +754,7 @@ class RoomCreationHandler(BaseHandler):
room_alias=room_alias,
power_level_content_override=power_level_content_override,
creator_join_profile=creator_join_profile,
ratelimit=ratelimit,
)
if "name" in config:
@@ -858,6 +860,7 @@ class RoomCreationHandler(BaseHandler):
room_alias: Optional[RoomAlias] = None,
power_level_content_override: Optional[JsonDict] = None,
creator_join_profile: Optional[JsonDict] = None,
ratelimit: bool = True,
) -> int:
"""Sends the initial events into a new room.
@@ -904,7 +907,7 @@ class RoomCreationHandler(BaseHandler):
creator.user,
room_id,
"join",
ratelimit=False,
ratelimit=ratelimit,
content=creator_join_profile,
new_room=True,
)
+15 -12
View File
@@ -246,7 +246,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Only rate-limit if the user actually joined the room, otherwise we'll end
# up blocking profile updates.
if newly_joined:
if newly_joined and ratelimit:
time_now_s = self.clock.time()
(
allowed,
@@ -456,7 +456,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
is_published = await self.store.is_room_published(room_id)
if not self.spam_checker.user_may_invite(
if not await self.spam_checker.user_may_invite(
requester.user.to_string(),
target.to_string(),
third_party_invite=None,
@@ -560,17 +560,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise SynapseError(403, "Not allowed to join this room")
if not is_host_in_room:
time_now_s = self.clock.time()
(
allowed,
time_allowed,
) = self._join_rate_limiter_remote.can_requester_do_action(requester,)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now_s))
if ratelimit:
time_now_s = self.clock.time()
(
allowed,
time_allowed,
) = self._join_rate_limiter_remote.can_requester_do_action(
requester,
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now_s))
)
inviter = await self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
@@ -947,7 +950,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
is_published = await self.store.is_room_published(room_id)
if not self.spam_checker.user_may_invite(
if not await self.spam_checker.user_may_invite(
requester.user.to_string(),
invitee,
third_party_invite={"medium": medium, "address": address},
+5 -5
View File
@@ -81,11 +81,11 @@ class UserDirectoryHandler(StateDeltasHandler):
results = await self.store.search_user_dir(user_id, search_term, limit)
# Remove any spammy users from the results.
results["results"] = [
user
for user in results["results"]
if not self.spam_checker.check_username_for_spam(user)
]
non_spammy_users = []
for user in results["results"]:
if not await self.spam_checker.check_username_for_spam(user):
non_spammy_users.append(user)
results["results"] = non_spammy_users
return results
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
import threading
from functools import wraps
@@ -25,6 +24,7 @@ from twisted.internet import defer
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.logging.opentracing import noop_context_manager, start_active_span
from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
import resource
@@ -206,12 +206,7 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
if bg_start_span:
ctx = start_active_span(desc, tags={"request_id": context.request})
with ctx:
result = func(*args, **kwargs)
if inspect.isawaitable(result):
result = await result
return result
return await maybe_awaitable(func(*args, **kwargs))
except Exception:
logger.exception(
"Background process '%s' threw an exception", desc,
+10 -5
View File
@@ -14,19 +14,22 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
from synapse.util.metrics import Measure
from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
class ActionGenerator:
def __init__(self, hs):
self.hs = hs
def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.bulk_evaluator = BulkPushRuleEvaluator(hs)
# really we want to get all user ids and all profile tags too,
# since we want the actions for each profile tag for every user and
@@ -35,6 +38,8 @@ class ActionGenerator:
# event stream, so we just run the rules for a client with no profile
# tag (ie. we just need all the users).
async def handle_push_actions_for_event(self, event, context):
async def handle_push_actions_for_event(
self, event: EventBase, context: EventContext
) -> None:
with Measure(self.clock, "action_for_event_by_user"):
await self.bulk_evaluator.action_for_event_by_user(event, context)
+18 -5
View File
@@ -15,16 +15,19 @@
# limitations under the License.
import copy
from typing import Any, Dict, List
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
def list_with_base_rules(rawrules, use_new_defaults=False):
def list_with_base_rules(
rawrules: List[Dict[str, Any]], use_new_defaults: bool = False
) -> List[Dict[str, Any]]:
"""Combine the list of rules set by the user with the default push rules
Args:
rawrules(list): The rules the user has modified or set.
use_new_defaults(bool): Whether to use the new experimental default rules when
rawrules: The rules the user has modified or set.
use_new_defaults: Whether to use the new experimental default rules when
appending or prepending default rules.
Returns:
@@ -94,7 +97,11 @@ def list_with_base_rules(rawrules, use_new_defaults=False):
return ruleslist
def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
def make_base_append_rules(
kind: str,
modified_base_rules: Dict[str, Dict[str, Any]],
use_new_defaults: bool = False,
) -> List[Dict[str, Any]]:
rules = []
if kind == "override":
@@ -116,6 +123,7 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
rules = copy.deepcopy(rules)
for r in rules:
# Only modify the actions, keep the conditions the same.
assert isinstance(r["rule_id"], str)
modified = modified_base_rules.get(r["rule_id"])
if modified:
r["actions"] = modified["actions"]
@@ -123,7 +131,11 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
return rules
def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False):
def make_base_prepend_rules(
kind: str,
modified_base_rules: Dict[str, Dict[str, Any]],
use_new_defaults: bool = False,
) -> List[Dict[str, Any]]:
rules = []
if kind == "override":
@@ -133,6 +145,7 @@ def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False):
rules = copy.deepcopy(rules)
for r in rules:
# Only modify the actions, keep the conditions the same.
assert isinstance(r["rule_id"], str)
modified = modified_base_rules.get(r["rule_id"])
if modified:
r["actions"] = modified["actions"]
+61 -37
View File
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
import attr
from prometheus_client import Counter
@@ -25,18 +26,18 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import register_cache
from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.descriptors import lru_cache
from synapse.util.caches.lrucache import LruCache
from .push_rule_evaluator import PushRuleEvaluatorForEvent
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
rules_by_room = {}
push_rules_invalidation_counter = Counter(
"synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", ""
)
@@ -101,7 +102,7 @@ class BulkPushRuleEvaluator:
room at once.
"""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
@@ -113,7 +114,9 @@ class BulkPushRuleEvaluator:
resizable=False,
)
async def _get_rules_for_event(self, event, context):
async def _get_rules_for_event(
self, event: EventBase, context: EventContext
) -> Dict[str, List[Dict[str, Any]]]:
"""This gets the rules for all users in the room at the time of the event,
as well as the push rules for the invitee if the event is an invite.
@@ -140,11 +143,8 @@ class BulkPushRuleEvaluator:
return rules_by_user
@lru_cache()
def _get_rules_for_room(self, room_id):
def _get_rules_for_room(self, room_id: str) -> "RulesForRoom":
"""Get the current RulesForRoom object for the given room id
Returns:
RulesForRoom
"""
# It's important that RulesForRoom gets added to self._get_rules_for_room.cache
# before any lookup methods get called on it as otherwise there may be
@@ -156,20 +156,21 @@ class BulkPushRuleEvaluator:
self.room_push_rule_cache_metrics,
)
async def _get_power_levels_and_sender_level(self, event, context):
async def _get_power_levels_and_sender_level(
self, event: EventBase, context: EventContext
) -> Tuple[dict, int]:
prev_state_ids = await context.get_prev_state_ids()
pl_event_id = prev_state_ids.get(POWER_KEY)
if pl_event_id:
# fastpath: if there's a power level event, that's all we need, and
# not having a power level event is an extreme edge case
pl_event = await self.store.get_event(pl_event_id)
auth_events = {POWER_KEY: pl_event}
auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
else:
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=False
)
auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
auth_events_dict = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()}
sender_level = get_user_power_level(event.sender, auth_events)
@@ -177,7 +178,9 @@ class BulkPushRuleEvaluator:
return pl_event.content if pl_event else {}, sender_level
async def action_for_event_by_user(self, event, context) -> None:
async def action_for_event_by_user(
self, event: EventBase, context: EventContext
) -> None:
"""Given an event and context, evaluate the push rules, check if the message
should increment the unread count, and insert the results into the
event_push_actions_staging table.
@@ -185,7 +188,7 @@ class BulkPushRuleEvaluator:
count_as_unread = _should_count_as_unread(event, context)
rules_by_user = await self._get_rules_for_event(event, context)
actions_by_user = {}
actions_by_user = {} # type: Dict[str, List[Union[dict, str]]]
room_members = await self.store.get_joined_users_from_context(event, context)
@@ -198,7 +201,7 @@ class BulkPushRuleEvaluator:
event, len(room_members), sender_power_level, power_levels
)
condition_cache = {}
condition_cache = {} # type: Dict[str, bool]
for uid, rules in rules_by_user.items():
if event.sender == uid:
@@ -249,7 +252,13 @@ class BulkPushRuleEvaluator:
)
def _condition_checker(evaluator, conditions, uid, display_name, cache):
def _condition_checker(
evaluator: PushRuleEvaluatorForEvent,
conditions: List[dict],
uid: str,
display_name: str,
cache: Dict[str, bool],
) -> bool:
for cond in conditions:
_id = cond.get("_id", None)
if _id:
@@ -277,15 +286,19 @@ class RulesForRoom:
"""
def __init__(
self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics
self,
hs: "HomeServer",
room_id: str,
rules_for_room_cache: LruCache,
room_push_rule_cache_metrics: CacheMetric,
):
"""
Args:
hs (HomeServer)
room_id (str)
hs: The HomeServer object.
room_id: The room ID.
rules_for_room_cache: The cache object that caches these
RoomsForUser objects.
room_push_rule_cache_metrics (CacheMetric)
room_push_rule_cache_metrics: The metrics object
"""
self.room_id = room_id
self.is_mine_id = hs.is_mine_id
@@ -294,8 +307,10 @@ class RulesForRoom:
self.linearizer = Linearizer(name="rules_for_room")
self.member_map = {} # event_id -> (user_id, state)
self.rules_by_user = {} # user_id -> rules
# event_id -> (user_id, state)
self.member_map = {} # type: Dict[str, Tuple[str, str]]
# user_id -> rules
self.rules_by_user = {} # type: Dict[str, List[Dict[str, dict]]]
# The last state group we updated the caches for. If the state_group of
# a new event comes along, we know that we can just return the cached
@@ -315,7 +330,7 @@ class RulesForRoom:
# calculate push for)
# These never need to be invalidated as we will never set up push for
# them.
self.uninteresting_user_set = set()
self.uninteresting_user_set = set() # type: Set[str]
# We need to be clever on the invalidating caches callbacks, as
# otherwise the invalidation callback holds a reference to the object,
@@ -325,7 +340,9 @@ class RulesForRoom:
# to self around in the callback.
self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
async def get_rules(self, event, context):
async def get_rules(
self, event: EventBase, context: EventContext
) -> Dict[str, List[Dict[str, dict]]]:
"""Given an event context return the rules for all users who are
currently in the room.
"""
@@ -356,6 +373,8 @@ class RulesForRoom:
else:
current_state_ids = await context.get_current_state_ids()
push_rules_delta_state_cache_metric.inc_misses()
# Ensure the state IDs exist.
assert current_state_ids is not None
push_rules_state_size_counter.inc(len(current_state_ids))
@@ -420,18 +439,23 @@ class RulesForRoom:
return ret_rules_by_user
async def _update_rules_with_member_event_ids(
self, ret_rules_by_user, member_event_ids, state_group, event
):
self,
ret_rules_by_user: Dict[str, list],
member_event_ids: Dict[str, str],
state_group: Optional[int],
event: EventBase,
) -> None:
"""Update the partially filled rules_by_user dict by fetching rules for
any newly joined users in the `member_event_ids` list.
Args:
ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
ret_rules_by_user: Partially filled dict of push rules. Gets
updated with any new rules.
member_event_ids (dict): Dict of user id to event id for membership events
member_event_ids: Dict of user id to event id for membership events
that have happened since the last time we filled rules_by_user
state_group: The state group we are currently computing push rules
for. Used when updating the cache.
event: The event we are currently computing push rules for.
"""
sequence = self.sequence
@@ -449,19 +473,19 @@ class RulesForRoom:
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values())
user_ids = {
joined_user_ids = {
user_id
for user_id, membership in members.values()
if membership == Membership.JOIN
}
logger.debug("Joined: %r", user_ids)
logger.debug("Joined: %r", joined_user_ids)
# Previously we only considered users with pushers or read receipts in that
# room. We can't do this anymore because we use push actions to calculate unread
# counts, which don't rely on the user having pushers or sent a read receipt into
# the room. Therefore we just need to filter for local users here.
user_ids = list(filter(self.is_mine_id, user_ids))
user_ids = list(filter(self.is_mine_id, joined_user_ids))
rules_by_user = await self.store.bulk_get_push_rules(
user_ids, on_invalidate=self.invalidate_all_cb
@@ -473,7 +497,7 @@ class RulesForRoom:
self.update_cache(sequence, members, ret_rules_by_user, state_group)
def invalidate_all(self):
def invalidate_all(self) -> None:
# Note: Don't hand this function directly to an invalidation callback
# as it keeps a reference to self and will stop this instance from being
# GC'd if it gets dropped from the rules_to_user cache. Instead use
@@ -485,7 +509,7 @@ class RulesForRoom:
self.rules_by_user = {}
push_rules_invalidation_counter.inc()
def update_cache(self, sequence, members, rules_by_user, state_group):
def update_cache(self, sequence, members, rules_by_user, state_group) -> None:
if sequence == self.sequence:
self.member_map.update(members)
self.rules_by_user = rules_by_user
@@ -506,7 +530,7 @@ class _Invalidation:
cache = attr.ib(type=LruCache)
room_id = attr.ib(type=str)
def __call__(self):
def __call__(self) -> None:
rules = self.cache.get(self.room_id, None, update_metrics=False)
if rules:
rules.invalidate_all()
+15 -8
View File
@@ -14,24 +14,27 @@
# limitations under the License.
import copy
from typing import Any, Dict, List, Optional
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
from synapse.types import UserID
def format_push_rules_for_user(user, ruleslist):
def format_push_rules_for_user(user: UserID, ruleslist) -> Dict[str, Dict[str, list]]:
"""Converts a list of rawrules and a enabled map into nested dictionaries
to match the Matrix client-server format for push rules"""
# We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(ruleslist)
rules = {"global": {}, "device": {}}
rules = {
"global": {},
"device": {},
} # type: Dict[str, Dict[str, List[Dict[str, Any]]]]
rules["global"] = _add_empty_priority_class_arrays(rules["global"])
for r in ruleslist:
rulearray = None
template_name = _priority_class_to_template_name(r["priority_class"])
# Remove internal stuff.
@@ -57,13 +60,13 @@ def format_push_rules_for_user(user, ruleslist):
return rules
def _add_empty_priority_class_arrays(d):
def _add_empty_priority_class_arrays(d: Dict[str, list]) -> Dict[str, list]:
for pc in PRIORITY_CLASS_MAP.keys():
d[pc] = []
return d
def _rule_to_template(rule):
def _rule_to_template(rule: Dict[str, Any]) -> Optional[Dict[str, Any]]:
unscoped_rule_id = None
if "rule_id" in rule:
unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"])
@@ -82,6 +85,10 @@ def _rule_to_template(rule):
return None
templaterule = {"actions": rule["actions"]}
templaterule["pattern"] = thecond["pattern"]
else:
# This should not be reached unless this function is not kept in sync
# with PRIORITY_CLASS_INVERSE_MAP.
raise ValueError("Unexpected template_name: %s" % (template_name,))
if unscoped_rule_id:
templaterule["rule_id"] = unscoped_rule_id
@@ -90,9 +97,9 @@ def _rule_to_template(rule):
return templaterule
def _rule_id_from_namespaced(in_rule_id):
def _rule_id_from_namespaced(in_rule_id: str) -> str:
return in_rule_id.split("/")[-1]
def _priority_class_to_template_name(pc):
def _priority_class_to_template_name(pc: int) -> str:
return PRIORITY_CLASS_INVERSE_MAP[pc]
+29 -19
View File
@@ -15,8 +15,14 @@
import logging
import re
from typing import TYPE_CHECKING, Dict, Iterable, Optional
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import StateMap
if TYPE_CHECKING:
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@@ -28,25 +34,29 @@ ALL_ALONE = "Empty Room"
async def calculate_room_name(
store,
room_state_ids,
user_id,
fallback_to_members=True,
fallback_to_single_member=True,
):
store: "DataStore",
room_state_ids: StateMap[str],
user_id: str,
fallback_to_members: bool = True,
fallback_to_single_member: bool = True,
) -> Optional[str]:
"""
Works out a user-facing name for the given room as per Matrix
spec recommendations.
Does not yet support internationalisation.
Args:
room_state: Dictionary of the room's state
store: The data store to query.
room_state_ids: Dictionary of the room's state IDs.
user_id: The ID of the user to whom the room name is being presented
fallback_to_members: If False, return None instead of generating a name
based on the room's members if the room has no
title or aliases.
fallback_to_single_member: If False, return None instead of generating a
name based on the user who invited this user to the room if the room
has no title or aliases.
Returns:
(string or None) A human readable name for the room.
A human readable name for the room, if possible.
"""
# does it have a name?
if (EventTypes.Name, "") in room_state_ids:
@@ -97,7 +107,7 @@ async def calculate_room_name(
name_from_member_event(inviter_member_event),
)
else:
return
return None
else:
return "Room Invite"
@@ -150,19 +160,19 @@ async def calculate_room_name(
else:
return ALL_ALONE
elif len(other_members) == 1 and not fallback_to_single_member:
return
else:
return descriptor_from_member_events(other_members)
return None
return descriptor_from_member_events(other_members)
def descriptor_from_member_events(member_events):
def descriptor_from_member_events(member_events: Iterable[EventBase]) -> str:
"""Get a description of the room based on the member events.
Args:
member_events (Iterable[FrozenEvent])
member_events: The events of a room.
Returns:
str
The room description
"""
member_events = list(member_events)
@@ -183,7 +193,7 @@ def descriptor_from_member_events(member_events):
)
def name_from_member_event(member_event):
def name_from_member_event(member_event: EventBase) -> str:
if (
member_event.content
and "displayname" in member_event.content
@@ -193,12 +203,12 @@ def name_from_member_event(member_event):
return member_event.state_key
def _state_as_two_level_dict(state):
ret = {}
def _state_as_two_level_dict(state: StateMap[str]) -> Dict[str, Dict[str, str]]:
ret = {} # type: Dict[str, Dict[str, str]]
for k, v in state.items():
ret.setdefault(k[0], {})[k[1]] = v
return ret
def _looks_like_an_alias(string):
def _looks_like_an_alias(string: str) -> bool:
return ALIAS_RE.match(string) is not None
+22 -6
View File
@@ -30,22 +30,30 @@ IS_GLOB = re.compile(r"[\?\*\[\]]")
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
def _room_member_count(ev, condition, room_member_count):
def _room_member_count(
ev: EventBase, condition: Dict[str, Any], room_member_count: int
) -> bool:
return _test_ineq_condition(condition, room_member_count)
def _sender_notification_permission(ev, condition, sender_power_level, power_levels):
def _sender_notification_permission(
ev: EventBase,
condition: Dict[str, Any],
sender_power_level: int,
power_levels: Dict[str, Union[int, Dict[str, int]]],
) -> bool:
notif_level_key = condition.get("key")
if notif_level_key is None:
return False
notif_levels = power_levels.get("notifications", {})
assert isinstance(notif_levels, dict)
room_notif_level = notif_levels.get(notif_level_key, 50)
return sender_power_level >= room_notif_level
def _test_ineq_condition(condition, number):
def _test_ineq_condition(condition: Dict[str, Any], number: int) -> bool:
if "is" not in condition:
return False
m = INEQUALITY_EXPR.match(condition["is"])
@@ -110,7 +118,7 @@ class PushRuleEvaluatorForEvent:
event: EventBase,
room_member_count: int,
sender_power_level: int,
power_levels: dict,
power_levels: Dict[str, Union[int, Dict[str, int]]],
):
self._event = event
self._room_member_count = room_member_count
@@ -120,7 +128,9 @@ class PushRuleEvaluatorForEvent:
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event)
def matches(self, condition: dict, user_id: str, display_name: str) -> bool:
def matches(
self, condition: Dict[str, Any], user_id: str, display_name: str
) -> bool:
if condition["kind"] == "event_match":
return self._event_match(condition, user_id)
elif condition["kind"] == "contains_display_name":
@@ -261,7 +271,13 @@ def _re_word_boundary(r: str) -> str:
return r"(^|\W)%s(\W|$)" % (r,)
def _flatten_dict(d, prefix=[], result=None):
def _flatten_dict(
d: Union[EventBase, dict],
prefix: Optional[List[str]] = None,
result: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
if prefix is None:
prefix = []
if result is None:
result = {}
for key, value in d.items():
+32 -15
View File
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -25,13 +26,17 @@ from synapse.http.servlet import (
parse_list_from_args,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
)
from synapse.storage.databases.main.room import RoomSortOrder
from synapse.types import RoomAlias, RoomID, UserID, create_requester
from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -45,12 +50,14 @@ class ShutdownRoomRestServlet(RestServlet):
PATTERNS = admin_patterns("/shutdown_room/(?P<room_id>[^/]+)")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.room_shutdown_handler = hs.get_room_shutdown_handler()
async def on_POST(self, request, room_id):
async def on_POST(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -86,13 +93,15 @@ class DeleteRoomRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.room_shutdown_handler = hs.get_room_shutdown_handler()
self.pagination_handler = hs.get_pagination_handler()
async def on_POST(self, request, room_id):
async def on_POST(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -146,12 +155,12 @@ class ListRoomRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -236,19 +245,24 @@ class RoomRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request, room_id):
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room_with_stats(room_id)
if not ret:
raise NotFoundError("Room not found")
return 200, ret
members = await self.store.get_users_in_room(room_id)
ret["joined_local_devices"] = await self.store.count_devices_by_users(members)
return (200, ret)
class RoomMembersRestServlet(RestServlet):
@@ -258,12 +272,14 @@ class RoomMembersRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request, room_id):
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room(room_id)
@@ -280,14 +296,16 @@ class JoinRoomAliasServlet(RestServlet):
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.room_member_handler = hs.get_room_member_handler()
self.admin_handler = hs.get_admin_handler()
self.state_handler = hs.get_state_handler()
async def on_POST(self, request, room_identifier):
async def on_POST(
self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -312,7 +330,6 @@ class JoinRoomAliasServlet(RestServlet):
handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
room_id = room_id.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
+19 -6
View File
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
from typing import Awaitable, Callable, Dict, Optional
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
@@ -30,6 +30,9 @@ from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -42,7 +45,7 @@ class LoginRestServlet(RestServlet):
JWT_TYPE_DEPRECATED = "m.login.jwt"
APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@@ -105,22 +108,27 @@ class LoginRestServlet(RestServlet):
return 200, {"flows": flows}
async def on_POST(self, request: SynapseRequest):
self._address_ratelimiter.ratelimit(request.getClientIP())
login_submission = parse_json_object_from_request(request)
try:
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
appservice = self.auth.get_appservice_by_req(request)
if appservice.is_rate_limited():
self._address_ratelimiter.ratelimit(request.getClientIP())
result = await self._do_appservice_login(login_submission, appservice)
elif self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
):
self._address_ratelimiter.ratelimit(request.getClientIP())
result = await self._do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
self._address_ratelimiter.ratelimit(request.getClientIP())
result = await self._do_token_login(login_submission)
else:
self._address_ratelimiter.ratelimit(request.getClientIP())
result = await self._do_other_login(login_submission)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@@ -159,7 +167,9 @@ class LoginRestServlet(RestServlet):
if not appservice.is_interested_in_user(qualified_user_id):
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
return await self._complete_login(qualified_user_id, login_submission)
return await self._complete_login(
qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited()
)
async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
"""Handle non-token/saml/jwt logins
@@ -194,6 +204,7 @@ class LoginRestServlet(RestServlet):
login_submission: JsonDict,
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
create_non_existent_users: bool = False,
ratelimit: bool = True,
) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
@@ -208,6 +219,7 @@ class LoginRestServlet(RestServlet):
callback: Callback function to run after login.
create_non_existent_users: Whether to create the user if they don't
exist. Defaults to False.
ratelimit: Whether to ratelimit the login request.
Returns:
result: Dictionary of account information after successful login.
@@ -216,7 +228,8 @@ class LoginRestServlet(RestServlet):
# Before we actually log them in we check if they've already logged in
# too often. This happens here rather than before as we don't
# necessarily know the user before now.
self._account_ratelimiter.ratelimit(user_id.lower())
if ratelimit:
self._account_ratelimiter.ratelimit(user_id.lower())
if create_non_existent_users:
canonical_uid = await self.auth_handler.check_user_exists(user_id)
+6 -10
View File
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
import os
import shutil
@@ -21,6 +20,7 @@ from typing import Optional
from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background
from synapse.util.async_helpers import maybe_awaitable
from ._base import FileInfo, Responder
from .media_storage import FileResponder
@@ -91,16 +91,14 @@ class StorageProviderWrapper(StorageProvider):
if self.store_synchronous:
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
result = self.backend.store_file(path, file_info)
if inspect.isawaitable(result):
return await result
return await maybe_awaitable(self.backend.store_file(path, file_info))
else:
# TODO: Handle errors.
async def store():
try:
result = self.backend.store_file(path, file_info)
if inspect.isawaitable(result):
return await result
return await maybe_awaitable(
self.backend.store_file(path, file_info)
)
except Exception:
logger.exception("Error storing file")
@@ -110,9 +108,7 @@ class StorageProviderWrapper(StorageProvider):
async def fetch(self, path, file_info):
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
result = self.backend.fetch(path, file_info)
if inspect.isawaitable(result):
return await result
return await maybe_awaitable(self.backend.fetch(path, file_info))
class FileStorageProviderBackend(StorageProvider):
+1 -1
View File
@@ -632,7 +632,7 @@ class HomeServer(metaclass=abc.ABCMeta):
return StatsHandler(self)
@cache_in_self
def get_spam_checker(self):
def get_spam_checker(self) -> SpamChecker:
return SpamChecker(self)
@cache_in_self
+32
View File
@@ -57,6 +57,38 @@ class DeviceWorkerStore(SQLBaseStore):
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
)
async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
"""Retrieve number of all devices of given users.
Only returns number of devices that are not marked as hidden.
Args:
user_ids: The IDs of the users which owns devices
Returns:
Number of devices of this users.
"""
def count_devices_by_users_txn(txn, user_ids):
sql = """
SELECT count(*)
FROM devices
WHERE
hidden = '0' AND
"""
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", user_ids
)
txn.execute(sql + clause, args)
return txn.fetchone()[0]
if not user_ids:
return 0
return await self.db_pool.runInteraction(
"count_devices_by_users", count_devices_by_users_txn, user_ids
)
async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
+5 -3
View File
@@ -15,10 +15,12 @@
# limitations under the License.
import collections
import inspect
import logging
from contextlib import contextmanager
from typing import (
Any,
Awaitable,
Callable,
Dict,
Hashable,
@@ -542,11 +544,11 @@ class DoneAwaitable:
raise StopIteration(self.value)
def maybe_awaitable(value):
def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
"""Convert a value to an awaitable if not already an awaitable.
"""
if hasattr(value, "__await__"):
if inspect.isawaitable(value):
assert isinstance(value, Awaitable)
return value
return DoneAwaitable(value)
+2 -5
View File
@@ -12,13 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@@ -105,10 +105,7 @@ class Signal:
async def do(observer):
try:
result = observer(*args, **kwargs)
if inspect.isawaitable(result):
result = await result
return result
return await maybe_awaitable(observer(*args, **kwargs))
except Exception as e:
logger.warning(
"%s signal observer %s failed: %r", self.name, observer, e,
+2 -2
View File
@@ -273,7 +273,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
spam_checker = self.hs.get_spam_checker()
class AllowAll:
def check_username_for_spam(self, user_profile):
async def check_username_for_spam(self, user_profile):
# Allow all users.
return False
@@ -286,7 +286,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Configure a spam checker that filters all users.
class BlockAll:
def check_username_for_spam(self, user_profile):
async def check_username_for_spam(self, user_profile):
# All users are spammy.
return True
+46 -27
View File
@@ -18,30 +18,35 @@ import logging
from io import StringIO
from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
from synapse.logging.context import LoggingContext, LoggingContextFilter
from tests.logging import LoggerCleanupMixin
from tests.unittest import TestCase
class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
def setUp(self):
self.output = StringIO()
def get_log_line(self):
# One log message, with a single trailing newline.
data = self.output.getvalue()
logs = data.splitlines()
self.assertEqual(len(logs), 1)
self.assertEqual(data.count("\n"), 1)
return json.loads(logs[0])
def test_terse_json_output(self):
"""
The Terse JSON formatter converts log messages to JSON.
"""
output = StringIO()
handler = logging.StreamHandler(output)
handler = logging.StreamHandler(self.output)
handler.setFormatter(TerseJsonFormatter())
logger = self.get_logger(handler)
logger.info("Hello there, %s!", "wally")
# One log message, with a single trailing newline.
data = output.getvalue()
logs = data.splitlines()
self.assertEqual(len(logs), 1)
self.assertEqual(data.count("\n"), 1)
log = json.loads(logs[0])
log = self.get_log_line()
# The terse logger should give us these keys.
expected_log_keys = [
@@ -57,9 +62,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
"""
Additional information can be included in the structured logging.
"""
output = StringIO()
handler = logging.StreamHandler(output)
handler = logging.StreamHandler(self.output)
handler.setFormatter(TerseJsonFormatter())
logger = self.get_logger(handler)
@@ -67,12 +70,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
"Hello there, %s!", "wally", extra={"foo": "bar", "int": 3, "bool": True}
)
# One log message, with a single trailing newline.
data = output.getvalue()
logs = data.splitlines()
self.assertEqual(len(logs), 1)
self.assertEqual(data.count("\n"), 1)
log = json.loads(logs[0])
log = self.get_log_line()
# The terse logger should give us these keys.
expected_log_keys = [
@@ -96,20 +94,13 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
"""
The Terse JSON formatter converts log messages to JSON.
"""
output = StringIO()
handler = logging.StreamHandler(output)
handler = logging.StreamHandler(self.output)
handler.setFormatter(JsonFormatter())
logger = self.get_logger(handler)
logger.info("Hello there, %s!", "wally")
# One log message, with a single trailing newline.
data = output.getvalue()
logs = data.splitlines()
self.assertEqual(len(logs), 1)
self.assertEqual(data.count("\n"), 1)
log = json.loads(logs[0])
log = self.get_log_line()
# The terse logger should give us these keys.
expected_log_keys = [
@@ -119,3 +110,31 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
]
self.assertCountEqual(log.keys(), expected_log_keys)
self.assertEqual(log["log"], "Hello there, wally!")
def test_with_context(self):
"""
The logging context should be added to the JSON response.
"""
handler = logging.StreamHandler(self.output)
handler.setFormatter(JsonFormatter())
handler.addFilter(LoggingContextFilter(request=""))
logger = self.get_logger(handler)
with LoggingContext() as context_one:
context_one.request = "test"
logger.info("Hello there, %s!", "wally")
log = self.get_log_line()
# The terse logger should give us these keys.
expected_log_keys = [
"log",
"level",
"namespace",
"request",
"scope",
]
self.assertCountEqual(log.keys(), expected_log_keys)
self.assertEqual(log["log"], "Hello there, wally!")
self.assertEqual(log["request"], "test")
self.assertIsNone(log["scope"])
+34
View File
@@ -1084,6 +1084,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertIn("canonical_alias", channel.json_body)
self.assertIn("joined_members", channel.json_body)
self.assertIn("joined_local_members", channel.json_body)
self.assertIn("joined_local_devices", channel.json_body)
self.assertIn("version", channel.json_body)
self.assertIn("creator", channel.json_body)
self.assertIn("encryption", channel.json_body)
@@ -1096,6 +1097,39 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(room_id_1, channel.json_body["room_id"])
def test_single_room_devices(self):
"""Test that `joined_local_devices` can be requested correctly"""
room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
request, channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["joined_local_devices"])
# Have another user join the room
user_1 = self.register_user("foo", "pass")
user_tok_1 = self.login("foo", "pass")
self.helper.join(room_id_1, user_1, tok=user_tok_1)
url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
request, channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(2, channel.json_body["joined_local_devices"])
# leave room
self.helper.leave(room_id_1, self.admin_user, tok=self.admin_user_tok)
self.helper.leave(room_id_1, user_1, tok=user_tok_1)
url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
request, channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["joined_local_devices"])
def test_room_members(self):
"""Test that room members can be requested correctly"""
# Create two test rooms
+16
View File
@@ -26,6 +26,7 @@ from mock import Mock
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
from synapse.rest import admin
from synapse.rest.client.v1 import directory, login, profile, room
from synapse.rest.client.v2_alpha import account
from synapse.types import JsonDict, RoomAlias, UserID
@@ -625,6 +626,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
user_id = "@sid1:red"
servlets = [
admin.register_servlets,
profile.register_servlets,
room.register_servlets,
]
@@ -703,6 +705,20 @@ class RoomJoinRatelimitTestCase(RoomBase):
request, channel = self.make_request("POST", path % room_id, {})
self.assertEquals(channel.code, 200)
@unittest.override_config(
{
"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}},
"auto_join_rooms": ["#room:red", "#room2:red", "#room3:red", "#room4:red"],
"autocreate_auto_join_rooms": True,
},
)
def test_autojoin_rooms(self):
user_id = self.register_user("testuser", "password")
# Check that the new user successfully joined the four rooms
rooms = self.get_success(self.hs.get_datastore().get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 4)
class RoomMessagesTestCase(RoomBase):
""" Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """
+26
View File
@@ -79,6 +79,32 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
res["device2"],
)
@defer.inlineCallbacks
def test_count_devices_by_users(self):
yield defer.ensureDeferred(
self.store.store_device("user_id", "device1", "display_name 1")
)
yield defer.ensureDeferred(
self.store.store_device("user_id", "device2", "display_name 2")
)
yield defer.ensureDeferred(
self.store.store_device("user_id2", "device3", "display_name 3")
)
res = yield defer.ensureDeferred(self.store.count_devices_by_users())
self.assertEqual(0, res)
res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"]))
self.assertEqual(0, res)
res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"]))
self.assertEqual(2, res)
res = yield defer.ensureDeferred(
self.store.count_devices_by_users(["user_id", "user_id2"])
)
self.assertEqual(3, res)
@defer.inlineCallbacks
def test_get_device_updates_by_remote(self):
device_ids = ["device_id1", "device_id2"]