Merge commit 'f14428b25' into anoa/dinsic_release_1_31_0
This commit is contained in:
@@ -7,6 +7,8 @@ jobs:
|
||||
- checkout
|
||||
- docker_prepare
|
||||
- run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD
|
||||
# for release builds, we want to get the amd64 image out asap, so first
|
||||
# we do an amd64-only build, before following up with a multiarch build.
|
||||
- docker_build:
|
||||
tag: -t matrixdotorg/synapse:${CIRCLE_TAG}
|
||||
platforms: linux/amd64
|
||||
@@ -21,9 +23,8 @@ jobs:
|
||||
- checkout
|
||||
- docker_prepare
|
||||
- run: docker login --username $DOCKER_HUB_USERNAME --password $DOCKER_HUB_PASSWORD
|
||||
- docker_build:
|
||||
tag: -t matrixdotorg/synapse:latest
|
||||
platforms: linux/amd64
|
||||
# for `latest`, we don't want the arm images to disappear, so don't update the tag
|
||||
# until all of the platforms are built.
|
||||
- docker_build:
|
||||
tag: -t matrixdotorg/synapse:latest
|
||||
platforms: linux/amd64,linux/arm/v7,linux/arm64
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
Add number of local devices to Room Details Admin API. Contributed by @dklimpel.
|
||||
@@ -0,0 +1 @@
|
||||
Spam-checkers may now define their methods as `async`.
|
||||
@@ -0,0 +1 @@
|
||||
Add type hints to push module.
|
||||
@@ -0,0 +1 @@
|
||||
Don't publish `latest` docker image until all archs are built.
|
||||
@@ -0,0 +1 @@
|
||||
Improve structured logging tests.
|
||||
@@ -0,0 +1 @@
|
||||
Fix occasional deadlock when handling SIGHUP.
|
||||
@@ -0,0 +1 @@
|
||||
Fix login API to not ratelimit application services that have ratelimiting disabled.
|
||||
@@ -0,0 +1 @@
|
||||
Fix bug where we ratelimited auto joining of rooms on registration (using `auto_join_rooms` config).
|
||||
+13
-11
@@ -87,7 +87,7 @@ GET /_synapse/admin/v1/rooms
|
||||
|
||||
Response:
|
||||
|
||||
```
|
||||
```jsonc
|
||||
{
|
||||
"rooms": [
|
||||
{
|
||||
@@ -139,7 +139,7 @@ GET /_synapse/admin/v1/rooms?search_term=TWIM
|
||||
|
||||
Response:
|
||||
|
||||
```
|
||||
```json
|
||||
{
|
||||
"rooms": [
|
||||
{
|
||||
@@ -174,7 +174,7 @@ GET /_synapse/admin/v1/rooms?order_by=size
|
||||
|
||||
Response:
|
||||
|
||||
```
|
||||
```jsonc
|
||||
{
|
||||
"rooms": [
|
||||
{
|
||||
@@ -230,14 +230,14 @@ GET /_synapse/admin/v1/rooms?order_by=size&from=100
|
||||
|
||||
Response:
|
||||
|
||||
```
|
||||
```jsonc
|
||||
{
|
||||
"rooms": [
|
||||
{
|
||||
"room_id": "!mscvqgqpHYjBGDxNym:matrix.org",
|
||||
"name": "Music Theory",
|
||||
"canonical_alias": "#musictheory:matrix.org",
|
||||
"joined_members": 127
|
||||
"joined_members": 127,
|
||||
"joined_local_members": 2,
|
||||
"version": "1",
|
||||
"creator": "@foo:matrix.org",
|
||||
@@ -254,7 +254,7 @@ Response:
|
||||
"room_id": "!twcBhHVdZlQWuuxBhN:termina.org.uk",
|
||||
"name": "weechat-matrix",
|
||||
"canonical_alias": "#weechat-matrix:termina.org.uk",
|
||||
"joined_members": 137
|
||||
"joined_members": 137,
|
||||
"joined_local_members": 20,
|
||||
"version": "4",
|
||||
"creator": "@foo:termina.org.uk",
|
||||
@@ -289,6 +289,7 @@ The following fields are possible in the JSON response body:
|
||||
* `canonical_alias` - The canonical (main) alias address of the room.
|
||||
* `joined_members` - How many users are currently in the room.
|
||||
* `joined_local_members` - How many local users are currently in the room.
|
||||
* `joined_local_devices` - How many local devices are currently in the room.
|
||||
* `version` - The version of the room as a string.
|
||||
* `creator` - The `user_id` of the room creator.
|
||||
* `encryption` - Algorithm of end-to-end encryption of messages. Is `null` if encryption is not active.
|
||||
@@ -311,15 +312,16 @@ GET /_synapse/admin/v1/rooms/<room_id>
|
||||
|
||||
Response:
|
||||
|
||||
```
|
||||
```json
|
||||
{
|
||||
"room_id": "!mscvqgqpHYjBGDxNym:matrix.org",
|
||||
"name": "Music Theory",
|
||||
"avatar": "mxc://matrix.org/AQDaVFlbkQoErdOgqWRgiGSV",
|
||||
"topic": "Theory, Composition, Notation, Analysis",
|
||||
"canonical_alias": "#musictheory:matrix.org",
|
||||
"joined_members": 127
|
||||
"joined_members": 127,
|
||||
"joined_local_members": 2,
|
||||
"joined_local_devices": 2,
|
||||
"version": "1",
|
||||
"creator": "@foo:matrix.org",
|
||||
"encryption": null,
|
||||
@@ -353,13 +355,13 @@ GET /_synapse/admin/v1/rooms/<room_id>/members
|
||||
|
||||
Response:
|
||||
|
||||
```
|
||||
```json
|
||||
{
|
||||
"members": [
|
||||
"@foo:matrix.org",
|
||||
"@bar:matrix.org",
|
||||
"@foobar:matrix.org
|
||||
],
|
||||
"@foobar:matrix.org"
|
||||
],
|
||||
"total": 3
|
||||
}
|
||||
```
|
||||
|
||||
+13
-6
@@ -22,6 +22,8 @@ well as some specific methods:
|
||||
* `user_may_create_room`
|
||||
* `user_may_create_room_alias`
|
||||
* `user_may_publish_room`
|
||||
* `check_username_for_spam`
|
||||
* `check_registration_for_spam`
|
||||
|
||||
The details of the each of these methods (as well as their inputs and outputs)
|
||||
are documented in the `synapse.events.spamcheck.SpamChecker` class.
|
||||
@@ -32,28 +34,33 @@ call back into the homeserver internals.
|
||||
### Example
|
||||
|
||||
```python
|
||||
from synapse.spam_checker_api import RegistrationBehaviour
|
||||
|
||||
class ExampleSpamChecker:
|
||||
def __init__(self, config, api):
|
||||
self.config = config
|
||||
self.api = api
|
||||
|
||||
def check_event_for_spam(self, foo):
|
||||
async def check_event_for_spam(self, foo):
|
||||
return False # allow all events
|
||||
|
||||
def user_may_invite(self, inviter_userid, invitee_userid, room_id):
|
||||
async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
|
||||
return True # allow all invites
|
||||
|
||||
def user_may_create_room(self, userid):
|
||||
async def user_may_create_room(self, userid):
|
||||
return True # allow all room creations
|
||||
|
||||
def user_may_create_room_alias(self, userid, room_alias):
|
||||
async def user_may_create_room_alias(self, userid, room_alias):
|
||||
return True # allow all room aliases
|
||||
|
||||
def user_may_publish_room(self, userid, room_id):
|
||||
async def user_may_publish_room(self, userid, room_id):
|
||||
return True # allow publishing of all rooms
|
||||
|
||||
def check_username_for_spam(self, user_profile):
|
||||
async def check_username_for_spam(self, user_profile):
|
||||
return False # allow all usernames
|
||||
|
||||
async def check_registration_for_spam(self, email_threepid, username, request_info):
|
||||
return RegistrationBehaviour.ALLOW # allow all registrations
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
@@ -56,12 +56,7 @@ files =
|
||||
synapse/metrics,
|
||||
synapse/module_api,
|
||||
synapse/notifier.py,
|
||||
synapse/push/emailpusher.py,
|
||||
synapse/push/httppusher.py,
|
||||
synapse/push/mailer.py,
|
||||
synapse/push/pusher.py,
|
||||
synapse/push/pusherpool.py,
|
||||
synapse/push/push_rule_evaluator.py,
|
||||
synapse/push,
|
||||
synapse/replication,
|
||||
synapse/rest,
|
||||
synapse/server.py,
|
||||
|
||||
@@ -31,6 +31,8 @@ class SynapsePlugin(Plugin):
|
||||
) -> Optional[Callable[[MethodSigContext], CallableType]]:
|
||||
if fullname.startswith(
|
||||
"synapse.util.caches.descriptors._CachedFunction.__call__"
|
||||
) or fullname.startswith(
|
||||
"synapse.util.caches.descriptors._LruCachedFunction.__call__"
|
||||
):
|
||||
return cached_function_method_signature
|
||||
return None
|
||||
|
||||
+3
-1
@@ -31,7 +31,9 @@ from synapse.api.errors import (
|
||||
MissingClientTokenError,
|
||||
)
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.events import EventBase
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging import opentracing as opentracing
|
||||
from synapse.storage.databases.main.registration import TokenLookupResult
|
||||
from synapse.types import StateMap, UserID
|
||||
@@ -479,7 +481,7 @@ class Auth:
|
||||
now = self.hs.get_clock().time_msec()
|
||||
return now < expiry
|
||||
|
||||
def get_appservice_by_req(self, request):
|
||||
def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
|
||||
token = self.get_access_token_from_request(request)
|
||||
service = self.store.get_app_service_by_token(token)
|
||||
if not service:
|
||||
|
||||
@@ -245,6 +245,8 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
|
||||
# Set up the SIGHUP machinery.
|
||||
if hasattr(signal, "SIGHUP"):
|
||||
|
||||
reactor = hs.get_reactor()
|
||||
|
||||
@wrap_as_background_process("sighup")
|
||||
def handle_sighup(*args, **kwargs):
|
||||
# Tell systemd our state, if we're using it. This will silently fail if
|
||||
@@ -260,7 +262,9 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
|
||||
# is so that we're in a sane state, e.g. flushing the logs may fail
|
||||
# if the sighup happens in the middle of writing a log entry.
|
||||
def run_sighup(*args, **kwargs):
|
||||
hs.get_clock().call_later(0, handle_sighup, *args, **kwargs)
|
||||
# `callFromThread` should be "signal safe" as well as thread
|
||||
# safe.
|
||||
reactor.callFromThread(handle_sighup, *args, **kwargs)
|
||||
|
||||
signal.signal(signal.SIGHUP, run_sighup)
|
||||
|
||||
|
||||
+45
-27
@@ -15,10 +15,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from synapse.spam_checker_api import RegistrationBehaviour
|
||||
from synapse.types import Collection
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import synapse.events
|
||||
@@ -39,7 +40,9 @@ class SpamChecker:
|
||||
else:
|
||||
self.spam_checkers.append(module(config=config))
|
||||
|
||||
def check_event_for_spam(self, event: "synapse.events.EventBase") -> bool:
|
||||
async def check_event_for_spam(
|
||||
self, event: "synapse.events.EventBase"
|
||||
) -> Union[bool, str]:
|
||||
"""Checks if a given event is considered "spammy" by this server.
|
||||
|
||||
If the server considers an event spammy, then it will be rejected if
|
||||
@@ -50,15 +53,16 @@ class SpamChecker:
|
||||
event: the event to be checked
|
||||
|
||||
Returns:
|
||||
True if the event is spammy.
|
||||
True or a string if the event is spammy. If a string is returned it
|
||||
will be used as the error message returned to the user.
|
||||
"""
|
||||
for spam_checker in self.spam_checkers:
|
||||
if spam_checker.check_event_for_spam(event):
|
||||
if await maybe_awaitable(spam_checker.check_event_for_spam(event)):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def user_may_invite(
|
||||
async def user_may_invite(
|
||||
self,
|
||||
inviter_userid: str,
|
||||
invitee_userid: str,
|
||||
@@ -91,21 +95,22 @@ class SpamChecker:
|
||||
"""
|
||||
for spam_checker in self.spam_checkers:
|
||||
if (
|
||||
spam_checker.user_may_invite(
|
||||
inviter_userid,
|
||||
invitee_userid,
|
||||
third_party_invite,
|
||||
room_id,
|
||||
new_room,
|
||||
published_room,
|
||||
)
|
||||
is False
|
||||
await maybe_awaitable(
|
||||
spam_checker.user_may_invite(
|
||||
inviter_userid,
|
||||
invitee_userid,
|
||||
third_party_invite,
|
||||
room_id,
|
||||
new_room,
|
||||
published_room,
|
||||
)
|
||||
) is False
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def user_may_create_room(
|
||||
async def user_may_create_room(
|
||||
self,
|
||||
userid: str,
|
||||
invite_list: List[str],
|
||||
@@ -130,16 +135,17 @@ class SpamChecker:
|
||||
"""
|
||||
for spam_checker in self.spam_checkers:
|
||||
if (
|
||||
spam_checker.user_may_create_room(
|
||||
userid, invite_list, third_party_invite_list, cloning
|
||||
)
|
||||
is False
|
||||
await maybe_awaitable(
|
||||
spam_checker.user_may_create_room(
|
||||
userid, invite_list, third_party_invite_list, cloning
|
||||
)
|
||||
) is False
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
|
||||
async def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
|
||||
"""Checks if a given user may create a room alias
|
||||
|
||||
If this method returns false, the association request will be rejected.
|
||||
@@ -152,12 +158,17 @@ class SpamChecker:
|
||||
True if the user may create a room alias, otherwise False
|
||||
"""
|
||||
for spam_checker in self.spam_checkers:
|
||||
if spam_checker.user_may_create_room_alias(userid, room_alias) is False:
|
||||
if (
|
||||
await maybe_awaitable(
|
||||
spam_checker.user_may_create_room_alias(userid, room_alias)
|
||||
)
|
||||
is False
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def user_may_publish_room(self, userid: str, room_id: str) -> bool:
|
||||
async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
|
||||
"""Checks if a given user may publish a room to the directory
|
||||
|
||||
If this method returns false, the publish request will be rejected.
|
||||
@@ -170,7 +181,12 @@ class SpamChecker:
|
||||
True if the user may publish the room, otherwise False
|
||||
"""
|
||||
for spam_checker in self.spam_checkers:
|
||||
if spam_checker.user_may_publish_room(userid, room_id) is False:
|
||||
if (
|
||||
await maybe_awaitable(
|
||||
spam_checker.user_may_publish_room(userid, room_id)
|
||||
)
|
||||
is False
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -194,7 +210,7 @@ class SpamChecker:
|
||||
|
||||
return True
|
||||
|
||||
def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
|
||||
async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
|
||||
"""Checks if a user ID or display name are considered "spammy" by this server.
|
||||
|
||||
If the server considers a username spammy, then it will not be included in
|
||||
@@ -216,12 +232,12 @@ class SpamChecker:
|
||||
if checker:
|
||||
# Make a copy of the user profile object to ensure the spam checker
|
||||
# cannot modify it.
|
||||
if checker(user_profile.copy()):
|
||||
if await maybe_awaitable(checker(user_profile.copy())):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def check_registration_for_spam(
|
||||
async def check_registration_for_spam(
|
||||
self,
|
||||
email_threepid: Optional[dict],
|
||||
username: Optional[str],
|
||||
@@ -244,7 +260,9 @@ class SpamChecker:
|
||||
# spam checker
|
||||
checker = getattr(spam_checker, "check_registration_for_spam", None)
|
||||
if checker:
|
||||
behaviour = checker(email_threepid, username, request_info)
|
||||
behaviour = await maybe_awaitable(
|
||||
checker(email_threepid, username, request_info)
|
||||
)
|
||||
assert isinstance(behaviour, RegistrationBehaviour)
|
||||
if behaviour != RegistrationBehaviour.ALLOW:
|
||||
return behaviour
|
||||
|
||||
@@ -78,6 +78,7 @@ class FederationBase:
|
||||
|
||||
ctx = current_context()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def callback(_, pdu: EventBase):
|
||||
with PreserveLoggingContext(ctx):
|
||||
if not check_event_content_hash(pdu):
|
||||
@@ -105,7 +106,11 @@ class FederationBase:
|
||||
)
|
||||
return redacted_event
|
||||
|
||||
if self.spam_checker.check_event_for_spam(pdu):
|
||||
result = yield defer.ensureDeferred(
|
||||
self.spam_checker.check_event_for_spam(pdu)
|
||||
)
|
||||
|
||||
if result:
|
||||
logger.warning(
|
||||
"Event contains spam, redacting %s: %s",
|
||||
pdu.event_id,
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import logging
|
||||
import time
|
||||
import unicodedata
|
||||
@@ -22,6 +21,7 @@ import urllib.parse
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
@@ -58,6 +58,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.types import JsonDict, Requester, UserID
|
||||
from synapse.util import stringutils as stringutils
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
from synapse.util.threepids import canonicalise_email
|
||||
|
||||
@@ -851,7 +852,7 @@ class AuthHandler(BaseHandler):
|
||||
|
||||
async def validate_login(
|
||||
self, login_submission: Dict[str, Any], ratelimit: bool = False,
|
||||
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
|
||||
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
|
||||
"""Authenticates the user for the /login API
|
||||
|
||||
Also used by the user-interactive auth flow to validate auth types which don't
|
||||
@@ -994,7 +995,7 @@ class AuthHandler(BaseHandler):
|
||||
|
||||
async def _validate_userid_login(
|
||||
self, username: str, login_submission: Dict[str, Any],
|
||||
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
|
||||
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
|
||||
"""Helper for validate_login
|
||||
|
||||
Handles login, once we've mapped 3pids onto userids
|
||||
@@ -1072,7 +1073,7 @@ class AuthHandler(BaseHandler):
|
||||
|
||||
async def check_password_provider_3pid(
|
||||
self, medium: str, address: str, password: str
|
||||
) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]:
|
||||
) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
|
||||
"""Check if a password provider is able to validate a thirdparty login
|
||||
|
||||
Args:
|
||||
@@ -1628,6 +1629,6 @@ class PasswordProvider:
|
||||
|
||||
# This might return an awaitable, if it does block the log out
|
||||
# until it completes.
|
||||
result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
await maybe_awaitable(
|
||||
g(user_id=user_id, device_id=device_id, access_token=access_token,)
|
||||
)
|
||||
|
||||
@@ -133,7 +133,9 @@ class DirectoryHandler(BaseHandler):
|
||||
403, "You must be in the room to create an alias for it"
|
||||
)
|
||||
|
||||
if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
|
||||
if not await self.spam_checker.user_may_create_room_alias(
|
||||
user_id, room_alias
|
||||
):
|
||||
raise AuthError(403, "This user is not permitted to create this alias")
|
||||
|
||||
if not self.config.is_alias_creation_allowed(
|
||||
@@ -409,7 +411,7 @@ class DirectoryHandler(BaseHandler):
|
||||
"""
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
if not self.spam_checker.user_may_publish_room(user_id, room_id):
|
||||
if not await self.spam_checker.user_may_publish_room(user_id, room_id):
|
||||
raise AuthError(
|
||||
403, "This user is not permitted to publish rooms to the room list"
|
||||
)
|
||||
|
||||
@@ -1666,7 +1666,7 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
is_published = await self.store.is_room_published(event.room_id)
|
||||
|
||||
if not self.spam_checker.user_may_invite(
|
||||
if not await self.spam_checker.user_may_invite(
|
||||
event.sender,
|
||||
event.state_key,
|
||||
None,
|
||||
|
||||
@@ -746,7 +746,7 @@ class EventCreationHandler:
|
||||
event.sender,
|
||||
)
|
||||
|
||||
spam_error = self.spam_checker.check_event_for_spam(event)
|
||||
spam_error = await self.spam_checker.check_event_for_spam(event)
|
||||
if spam_error:
|
||||
if not isinstance(spam_error, str):
|
||||
spam_error = "Spam is not permitted here"
|
||||
|
||||
@@ -18,7 +18,6 @@ from typing import List, Tuple
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.handlers._base import BaseHandler
|
||||
from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -98,10 +97,8 @@ class ReceiptsHandler(BaseHandler):
|
||||
|
||||
self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
|
||||
# Note that the min here shouldn't be relied upon to be accurate.
|
||||
await maybe_awaitable(
|
||||
self.hs.get_pusherpool().on_new_receipts(
|
||||
min_batch_id, max_batch_id, affected_room_ids
|
||||
)
|
||||
await self.hs.get_pusherpool().on_new_receipts(
|
||||
min_batch_id, max_batch_id, affected_room_ids
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@@ -202,7 +202,7 @@ class RegistrationHandler(BaseHandler):
|
||||
"""
|
||||
self.check_registration_ratelimit(address)
|
||||
|
||||
result = self.spam_checker.check_registration_for_spam(
|
||||
result = await self.spam_checker.check_registration_for_spam(
|
||||
threepid, localpart, user_agent_ips or [],
|
||||
)
|
||||
|
||||
|
||||
@@ -368,7 +368,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
else:
|
||||
is_requester_admin = await self.auth.is_server_admin(requester.user)
|
||||
|
||||
if not is_requester_admin and not self.spam_checker.user_may_create_room(
|
||||
if not is_requester_admin and not await self.spam_checker.user_may_create_room(
|
||||
user_id, invite_list=[], third_party_invite_list=[], cloning=True
|
||||
):
|
||||
raise SynapseError(403, "You are not permitted to create rooms")
|
||||
@@ -452,6 +452,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
invite_list=[],
|
||||
initial_state=initial_state,
|
||||
creation_content=creation_content,
|
||||
ratelimit=False,
|
||||
)
|
||||
|
||||
# Transfer membership events
|
||||
@@ -623,7 +624,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
invite_list = config.get("invite", [])
|
||||
invite_3pid_list = config.get("invite_3pid", [])
|
||||
|
||||
if not is_requester_admin and not self.spam_checker.user_may_create_room(
|
||||
if not is_requester_admin and not await self.spam_checker.user_may_create_room(
|
||||
user_id,
|
||||
invite_list=invite_list,
|
||||
third_party_invite_list=invite_3pid_list,
|
||||
@@ -753,6 +754,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
room_alias=room_alias,
|
||||
power_level_content_override=power_level_content_override,
|
||||
creator_join_profile=creator_join_profile,
|
||||
ratelimit=ratelimit,
|
||||
)
|
||||
|
||||
if "name" in config:
|
||||
@@ -858,6 +860,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
room_alias: Optional[RoomAlias] = None,
|
||||
power_level_content_override: Optional[JsonDict] = None,
|
||||
creator_join_profile: Optional[JsonDict] = None,
|
||||
ratelimit: bool = True,
|
||||
) -> int:
|
||||
"""Sends the initial events into a new room.
|
||||
|
||||
@@ -904,7 +907,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
creator.user,
|
||||
room_id,
|
||||
"join",
|
||||
ratelimit=False,
|
||||
ratelimit=ratelimit,
|
||||
content=creator_join_profile,
|
||||
new_room=True,
|
||||
)
|
||||
|
||||
@@ -246,7 +246,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||
|
||||
# Only rate-limit if the user actually joined the room, otherwise we'll end
|
||||
# up blocking profile updates.
|
||||
if newly_joined:
|
||||
if newly_joined and ratelimit:
|
||||
time_now_s = self.clock.time()
|
||||
(
|
||||
allowed,
|
||||
@@ -456,7 +456,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||
|
||||
is_published = await self.store.is_room_published(room_id)
|
||||
|
||||
if not self.spam_checker.user_may_invite(
|
||||
if not await self.spam_checker.user_may_invite(
|
||||
requester.user.to_string(),
|
||||
target.to_string(),
|
||||
third_party_invite=None,
|
||||
@@ -560,17 +560,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||
raise SynapseError(403, "Not allowed to join this room")
|
||||
|
||||
if not is_host_in_room:
|
||||
time_now_s = self.clock.time()
|
||||
(
|
||||
allowed,
|
||||
time_allowed,
|
||||
) = self._join_rate_limiter_remote.can_requester_do_action(requester,)
|
||||
|
||||
if not allowed:
|
||||
raise LimitExceededError(
|
||||
retry_after_ms=int(1000 * (time_allowed - time_now_s))
|
||||
if ratelimit:
|
||||
time_now_s = self.clock.time()
|
||||
(
|
||||
allowed,
|
||||
time_allowed,
|
||||
) = self._join_rate_limiter_remote.can_requester_do_action(
|
||||
requester,
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
raise LimitExceededError(
|
||||
retry_after_ms=int(1000 * (time_allowed - time_now_s))
|
||||
)
|
||||
|
||||
inviter = await self._get_inviter(target.to_string(), room_id)
|
||||
if inviter and not self.hs.is_mine(inviter):
|
||||
remote_room_hosts.append(inviter.domain)
|
||||
@@ -947,7 +950,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||
|
||||
is_published = await self.store.is_room_published(room_id)
|
||||
|
||||
if not self.spam_checker.user_may_invite(
|
||||
if not await self.spam_checker.user_may_invite(
|
||||
requester.user.to_string(),
|
||||
invitee,
|
||||
third_party_invite={"medium": medium, "address": address},
|
||||
|
||||
@@ -81,11 +81,11 @@ class UserDirectoryHandler(StateDeltasHandler):
|
||||
results = await self.store.search_user_dir(user_id, search_term, limit)
|
||||
|
||||
# Remove any spammy users from the results.
|
||||
results["results"] = [
|
||||
user
|
||||
for user in results["results"]
|
||||
if not self.spam_checker.check_username_for_spam(user)
|
||||
]
|
||||
non_spammy_users = []
|
||||
for user in results["results"]:
|
||||
if not await self.spam_checker.check_username_for_spam(user):
|
||||
non_spammy_users.append(user)
|
||||
results["results"] = non_spammy_users
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
from functools import wraps
|
||||
@@ -25,6 +24,7 @@ from twisted.internet import defer
|
||||
|
||||
from synapse.logging.context import LoggingContext, PreserveLoggingContext
|
||||
from synapse.logging.opentracing import noop_context_manager, start_active_span
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import resource
|
||||
@@ -206,12 +206,7 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
|
||||
if bg_start_span:
|
||||
ctx = start_active_span(desc, tags={"request_id": context.request})
|
||||
with ctx:
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
|
||||
return result
|
||||
return await maybe_awaitable(func(*args, **kwargs))
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Background process '%s' threw an exception", desc,
|
||||
|
||||
@@ -14,19 +14,22 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ActionGenerator:
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.bulk_evaluator = BulkPushRuleEvaluator(hs)
|
||||
# really we want to get all user ids and all profile tags too,
|
||||
# since we want the actions for each profile tag for every user and
|
||||
@@ -35,6 +38,8 @@ class ActionGenerator:
|
||||
# event stream, so we just run the rules for a client with no profile
|
||||
# tag (ie. we just need all the users).
|
||||
|
||||
async def handle_push_actions_for_event(self, event, context):
|
||||
async def handle_push_actions_for_event(
|
||||
self, event: EventBase, context: EventContext
|
||||
) -> None:
|
||||
with Measure(self.clock, "action_for_event_by_user"):
|
||||
await self.bulk_evaluator.action_for_event_by_user(event, context)
|
||||
|
||||
@@ -15,16 +15,19 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
|
||||
|
||||
|
||||
def list_with_base_rules(rawrules, use_new_defaults=False):
|
||||
def list_with_base_rules(
|
||||
rawrules: List[Dict[str, Any]], use_new_defaults: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Combine the list of rules set by the user with the default push rules
|
||||
|
||||
Args:
|
||||
rawrules(list): The rules the user has modified or set.
|
||||
use_new_defaults(bool): Whether to use the new experimental default rules when
|
||||
rawrules: The rules the user has modified or set.
|
||||
use_new_defaults: Whether to use the new experimental default rules when
|
||||
appending or prepending default rules.
|
||||
|
||||
Returns:
|
||||
@@ -94,7 +97,11 @@ def list_with_base_rules(rawrules, use_new_defaults=False):
|
||||
return ruleslist
|
||||
|
||||
|
||||
def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
|
||||
def make_base_append_rules(
|
||||
kind: str,
|
||||
modified_base_rules: Dict[str, Dict[str, Any]],
|
||||
use_new_defaults: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
rules = []
|
||||
|
||||
if kind == "override":
|
||||
@@ -116,6 +123,7 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
|
||||
rules = copy.deepcopy(rules)
|
||||
for r in rules:
|
||||
# Only modify the actions, keep the conditions the same.
|
||||
assert isinstance(r["rule_id"], str)
|
||||
modified = modified_base_rules.get(r["rule_id"])
|
||||
if modified:
|
||||
r["actions"] = modified["actions"]
|
||||
@@ -123,7 +131,11 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
|
||||
return rules
|
||||
|
||||
|
||||
def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False):
|
||||
def make_base_prepend_rules(
|
||||
kind: str,
|
||||
modified_base_rules: Dict[str, Dict[str, Any]],
|
||||
use_new_defaults: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
rules = []
|
||||
|
||||
if kind == "override":
|
||||
@@ -133,6 +145,7 @@ def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False):
|
||||
rules = copy.deepcopy(rules)
|
||||
for r in rules:
|
||||
# Only modify the actions, keep the conditions the same.
|
||||
assert isinstance(r["rule_id"], str)
|
||||
modified = modified_base_rules.get(r["rule_id"])
|
||||
if modified:
|
||||
r["actions"] = modified["actions"]
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import attr
|
||||
from prometheus_client import Counter
|
||||
@@ -25,18 +26,18 @@ from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.state import POWER_KEY
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches import register_cache
|
||||
from synapse.util.caches import CacheMetric, register_cache
|
||||
from synapse.util.caches.descriptors import lru_cache
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
|
||||
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
rules_by_room = {}
|
||||
|
||||
|
||||
push_rules_invalidation_counter = Counter(
|
||||
"synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", ""
|
||||
)
|
||||
@@ -101,7 +102,7 @@ class BulkPushRuleEvaluator:
|
||||
room at once.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
@@ -113,7 +114,9 @@ class BulkPushRuleEvaluator:
|
||||
resizable=False,
|
||||
)
|
||||
|
||||
async def _get_rules_for_event(self, event, context):
|
||||
async def _get_rules_for_event(
|
||||
self, event: EventBase, context: EventContext
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""This gets the rules for all users in the room at the time of the event,
|
||||
as well as the push rules for the invitee if the event is an invite.
|
||||
|
||||
@@ -140,11 +143,8 @@ class BulkPushRuleEvaluator:
|
||||
return rules_by_user
|
||||
|
||||
@lru_cache()
|
||||
def _get_rules_for_room(self, room_id):
|
||||
def _get_rules_for_room(self, room_id: str) -> "RulesForRoom":
|
||||
"""Get the current RulesForRoom object for the given room id
|
||||
|
||||
Returns:
|
||||
RulesForRoom
|
||||
"""
|
||||
# It's important that RulesForRoom gets added to self._get_rules_for_room.cache
|
||||
# before any lookup methods get called on it as otherwise there may be
|
||||
@@ -156,20 +156,21 @@ class BulkPushRuleEvaluator:
|
||||
self.room_push_rule_cache_metrics,
|
||||
)
|
||||
|
||||
async def _get_power_levels_and_sender_level(self, event, context):
|
||||
async def _get_power_levels_and_sender_level(
|
||||
self, event: EventBase, context: EventContext
|
||||
) -> Tuple[dict, int]:
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
pl_event_id = prev_state_ids.get(POWER_KEY)
|
||||
if pl_event_id:
|
||||
# fastpath: if there's a power level event, that's all we need, and
|
||||
# not having a power level event is an extreme edge case
|
||||
pl_event = await self.store.get_event(pl_event_id)
|
||||
auth_events = {POWER_KEY: pl_event}
|
||||
auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
|
||||
else:
|
||||
auth_events_ids = self.auth.compute_auth_events(
|
||||
event, prev_state_ids, for_verification=False
|
||||
)
|
||||
auth_events = await self.store.get_events(auth_events_ids)
|
||||
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
|
||||
auth_events_dict = await self.store.get_events(auth_events_ids)
|
||||
auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()}
|
||||
|
||||
sender_level = get_user_power_level(event.sender, auth_events)
|
||||
|
||||
@@ -177,7 +178,9 @@ class BulkPushRuleEvaluator:
|
||||
|
||||
return pl_event.content if pl_event else {}, sender_level
|
||||
|
||||
async def action_for_event_by_user(self, event, context) -> None:
|
||||
async def action_for_event_by_user(
|
||||
self, event: EventBase, context: EventContext
|
||||
) -> None:
|
||||
"""Given an event and context, evaluate the push rules, check if the message
|
||||
should increment the unread count, and insert the results into the
|
||||
event_push_actions_staging table.
|
||||
@@ -185,7 +188,7 @@ class BulkPushRuleEvaluator:
|
||||
count_as_unread = _should_count_as_unread(event, context)
|
||||
|
||||
rules_by_user = await self._get_rules_for_event(event, context)
|
||||
actions_by_user = {}
|
||||
actions_by_user = {} # type: Dict[str, List[Union[dict, str]]]
|
||||
|
||||
room_members = await self.store.get_joined_users_from_context(event, context)
|
||||
|
||||
@@ -198,7 +201,7 @@ class BulkPushRuleEvaluator:
|
||||
event, len(room_members), sender_power_level, power_levels
|
||||
)
|
||||
|
||||
condition_cache = {}
|
||||
condition_cache = {} # type: Dict[str, bool]
|
||||
|
||||
for uid, rules in rules_by_user.items():
|
||||
if event.sender == uid:
|
||||
@@ -249,7 +252,13 @@ class BulkPushRuleEvaluator:
|
||||
)
|
||||
|
||||
|
||||
def _condition_checker(evaluator, conditions, uid, display_name, cache):
|
||||
def _condition_checker(
|
||||
evaluator: PushRuleEvaluatorForEvent,
|
||||
conditions: List[dict],
|
||||
uid: str,
|
||||
display_name: str,
|
||||
cache: Dict[str, bool],
|
||||
) -> bool:
|
||||
for cond in conditions:
|
||||
_id = cond.get("_id", None)
|
||||
if _id:
|
||||
@@ -277,15 +286,19 @@ class RulesForRoom:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics
|
||||
self,
|
||||
hs: "HomeServer",
|
||||
room_id: str,
|
||||
rules_for_room_cache: LruCache,
|
||||
room_push_rule_cache_metrics: CacheMetric,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hs (HomeServer)
|
||||
room_id (str)
|
||||
hs: The HomeServer object.
|
||||
room_id: The room ID.
|
||||
rules_for_room_cache: The cache object that caches these
|
||||
RoomsForUser objects.
|
||||
room_push_rule_cache_metrics (CacheMetric)
|
||||
room_push_rule_cache_metrics: The metrics object
|
||||
"""
|
||||
self.room_id = room_id
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
@@ -294,8 +307,10 @@ class RulesForRoom:
|
||||
|
||||
self.linearizer = Linearizer(name="rules_for_room")
|
||||
|
||||
self.member_map = {} # event_id -> (user_id, state)
|
||||
self.rules_by_user = {} # user_id -> rules
|
||||
# event_id -> (user_id, state)
|
||||
self.member_map = {} # type: Dict[str, Tuple[str, str]]
|
||||
# user_id -> rules
|
||||
self.rules_by_user = {} # type: Dict[str, List[Dict[str, dict]]]
|
||||
|
||||
# The last state group we updated the caches for. If the state_group of
|
||||
# a new event comes along, we know that we can just return the cached
|
||||
@@ -315,7 +330,7 @@ class RulesForRoom:
|
||||
# calculate push for)
|
||||
# These never need to be invalidated as we will never set up push for
|
||||
# them.
|
||||
self.uninteresting_user_set = set()
|
||||
self.uninteresting_user_set = set() # type: Set[str]
|
||||
|
||||
# We need to be clever on the invalidating caches callbacks, as
|
||||
# otherwise the invalidation callback holds a reference to the object,
|
||||
@@ -325,7 +340,9 @@ class RulesForRoom:
|
||||
# to self around in the callback.
|
||||
self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
|
||||
|
||||
async def get_rules(self, event, context):
|
||||
async def get_rules(
|
||||
self, event: EventBase, context: EventContext
|
||||
) -> Dict[str, List[Dict[str, dict]]]:
|
||||
"""Given an event context return the rules for all users who are
|
||||
currently in the room.
|
||||
"""
|
||||
@@ -356,6 +373,8 @@ class RulesForRoom:
|
||||
else:
|
||||
current_state_ids = await context.get_current_state_ids()
|
||||
push_rules_delta_state_cache_metric.inc_misses()
|
||||
# Ensure the state IDs exist.
|
||||
assert current_state_ids is not None
|
||||
|
||||
push_rules_state_size_counter.inc(len(current_state_ids))
|
||||
|
||||
@@ -420,18 +439,23 @@ class RulesForRoom:
|
||||
return ret_rules_by_user
|
||||
|
||||
async def _update_rules_with_member_event_ids(
|
||||
self, ret_rules_by_user, member_event_ids, state_group, event
|
||||
):
|
||||
self,
|
||||
ret_rules_by_user: Dict[str, list],
|
||||
member_event_ids: Dict[str, str],
|
||||
state_group: Optional[int],
|
||||
event: EventBase,
|
||||
) -> None:
|
||||
"""Update the partially filled rules_by_user dict by fetching rules for
|
||||
any newly joined users in the `member_event_ids` list.
|
||||
|
||||
Args:
|
||||
ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
|
||||
ret_rules_by_user: Partially filled dict of push rules. Gets
|
||||
updated with any new rules.
|
||||
member_event_ids (dict): Dict of user id to event id for membership events
|
||||
member_event_ids: Dict of user id to event id for membership events
|
||||
that have happened since the last time we filled rules_by_user
|
||||
state_group: The state group we are currently computing push rules
|
||||
for. Used when updating the cache.
|
||||
event: The event we are currently computing push rules for.
|
||||
"""
|
||||
sequence = self.sequence
|
||||
|
||||
@@ -449,19 +473,19 @@ class RulesForRoom:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug("Found members %r: %r", self.room_id, members.values())
|
||||
|
||||
user_ids = {
|
||||
joined_user_ids = {
|
||||
user_id
|
||||
for user_id, membership in members.values()
|
||||
if membership == Membership.JOIN
|
||||
}
|
||||
|
||||
logger.debug("Joined: %r", user_ids)
|
||||
logger.debug("Joined: %r", joined_user_ids)
|
||||
|
||||
# Previously we only considered users with pushers or read receipts in that
|
||||
# room. We can't do this anymore because we use push actions to calculate unread
|
||||
# counts, which don't rely on the user having pushers or sent a read receipt into
|
||||
# the room. Therefore we just need to filter for local users here.
|
||||
user_ids = list(filter(self.is_mine_id, user_ids))
|
||||
user_ids = list(filter(self.is_mine_id, joined_user_ids))
|
||||
|
||||
rules_by_user = await self.store.bulk_get_push_rules(
|
||||
user_ids, on_invalidate=self.invalidate_all_cb
|
||||
@@ -473,7 +497,7 @@ class RulesForRoom:
|
||||
|
||||
self.update_cache(sequence, members, ret_rules_by_user, state_group)
|
||||
|
||||
def invalidate_all(self):
|
||||
def invalidate_all(self) -> None:
|
||||
# Note: Don't hand this function directly to an invalidation callback
|
||||
# as it keeps a reference to self and will stop this instance from being
|
||||
# GC'd if it gets dropped from the rules_to_user cache. Instead use
|
||||
@@ -485,7 +509,7 @@ class RulesForRoom:
|
||||
self.rules_by_user = {}
|
||||
push_rules_invalidation_counter.inc()
|
||||
|
||||
def update_cache(self, sequence, members, rules_by_user, state_group):
|
||||
def update_cache(self, sequence, members, rules_by_user, state_group) -> None:
|
||||
if sequence == self.sequence:
|
||||
self.member_map.update(members)
|
||||
self.rules_by_user = rules_by_user
|
||||
@@ -506,7 +530,7 @@ class _Invalidation:
|
||||
cache = attr.ib(type=LruCache)
|
||||
room_id = attr.ib(type=str)
|
||||
|
||||
def __call__(self):
|
||||
def __call__(self) -> None:
|
||||
rules = self.cache.get(self.room_id, None, update_metrics=False)
|
||||
if rules:
|
||||
rules.invalidate_all()
|
||||
|
||||
@@ -14,24 +14,27 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
|
||||
from synapse.types import UserID
|
||||
|
||||
|
||||
def format_push_rules_for_user(user, ruleslist):
|
||||
def format_push_rules_for_user(user: UserID, ruleslist) -> Dict[str, Dict[str, list]]:
|
||||
"""Converts a list of rawrules and a enabled map into nested dictionaries
|
||||
to match the Matrix client-server format for push rules"""
|
||||
|
||||
# We're going to be mutating this a lot, so do a deep copy
|
||||
ruleslist = copy.deepcopy(ruleslist)
|
||||
|
||||
rules = {"global": {}, "device": {}}
|
||||
rules = {
|
||||
"global": {},
|
||||
"device": {},
|
||||
} # type: Dict[str, Dict[str, List[Dict[str, Any]]]]
|
||||
|
||||
rules["global"] = _add_empty_priority_class_arrays(rules["global"])
|
||||
|
||||
for r in ruleslist:
|
||||
rulearray = None
|
||||
|
||||
template_name = _priority_class_to_template_name(r["priority_class"])
|
||||
|
||||
# Remove internal stuff.
|
||||
@@ -57,13 +60,13 @@ def format_push_rules_for_user(user, ruleslist):
|
||||
return rules
|
||||
|
||||
|
||||
def _add_empty_priority_class_arrays(d):
|
||||
def _add_empty_priority_class_arrays(d: Dict[str, list]) -> Dict[str, list]:
|
||||
for pc in PRIORITY_CLASS_MAP.keys():
|
||||
d[pc] = []
|
||||
return d
|
||||
|
||||
|
||||
def _rule_to_template(rule):
|
||||
def _rule_to_template(rule: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
unscoped_rule_id = None
|
||||
if "rule_id" in rule:
|
||||
unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"])
|
||||
@@ -82,6 +85,10 @@ def _rule_to_template(rule):
|
||||
return None
|
||||
templaterule = {"actions": rule["actions"]}
|
||||
templaterule["pattern"] = thecond["pattern"]
|
||||
else:
|
||||
# This should not be reached unless this function is not kept in sync
|
||||
# with PRIORITY_CLASS_INVERSE_MAP.
|
||||
raise ValueError("Unexpected template_name: %s" % (template_name,))
|
||||
|
||||
if unscoped_rule_id:
|
||||
templaterule["rule_id"] = unscoped_rule_id
|
||||
@@ -90,9 +97,9 @@ def _rule_to_template(rule):
|
||||
return templaterule
|
||||
|
||||
|
||||
def _rule_id_from_namespaced(in_rule_id):
|
||||
def _rule_id_from_namespaced(in_rule_id: str) -> str:
|
||||
return in_rule_id.split("/")[-1]
|
||||
|
||||
|
||||
def _priority_class_to_template_name(pc):
|
||||
def _priority_class_to_template_name(pc: int) -> str:
|
||||
return PRIORITY_CLASS_INVERSE_MAP[pc]
|
||||
|
||||
@@ -15,8 +15,14 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, Optional
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import StateMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,25 +34,29 @@ ALL_ALONE = "Empty Room"
|
||||
|
||||
|
||||
async def calculate_room_name(
|
||||
store,
|
||||
room_state_ids,
|
||||
user_id,
|
||||
fallback_to_members=True,
|
||||
fallback_to_single_member=True,
|
||||
):
|
||||
store: "DataStore",
|
||||
room_state_ids: StateMap[str],
|
||||
user_id: str,
|
||||
fallback_to_members: bool = True,
|
||||
fallback_to_single_member: bool = True,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Works out a user-facing name for the given room as per Matrix
|
||||
spec recommendations.
|
||||
Does not yet support internationalisation.
|
||||
Args:
|
||||
room_state: Dictionary of the room's state
|
||||
store: The data store to query.
|
||||
room_state_ids: Dictionary of the room's state IDs.
|
||||
user_id: The ID of the user to whom the room name is being presented
|
||||
fallback_to_members: If False, return None instead of generating a name
|
||||
based on the room's members if the room has no
|
||||
title or aliases.
|
||||
fallback_to_single_member: If False, return None instead of generating a
|
||||
name based on the user who invited this user to the room if the room
|
||||
has no title or aliases.
|
||||
|
||||
Returns:
|
||||
(string or None) A human readable name for the room.
|
||||
A human readable name for the room, if possible.
|
||||
"""
|
||||
# does it have a name?
|
||||
if (EventTypes.Name, "") in room_state_ids:
|
||||
@@ -97,7 +107,7 @@ async def calculate_room_name(
|
||||
name_from_member_event(inviter_member_event),
|
||||
)
|
||||
else:
|
||||
return
|
||||
return None
|
||||
else:
|
||||
return "Room Invite"
|
||||
|
||||
@@ -150,19 +160,19 @@ async def calculate_room_name(
|
||||
else:
|
||||
return ALL_ALONE
|
||||
elif len(other_members) == 1 and not fallback_to_single_member:
|
||||
return
|
||||
else:
|
||||
return descriptor_from_member_events(other_members)
|
||||
return None
|
||||
|
||||
return descriptor_from_member_events(other_members)
|
||||
|
||||
|
||||
def descriptor_from_member_events(member_events):
|
||||
def descriptor_from_member_events(member_events: Iterable[EventBase]) -> str:
|
||||
"""Get a description of the room based on the member events.
|
||||
|
||||
Args:
|
||||
member_events (Iterable[FrozenEvent])
|
||||
member_events: The events of a room.
|
||||
|
||||
Returns:
|
||||
str
|
||||
The room description
|
||||
"""
|
||||
|
||||
member_events = list(member_events)
|
||||
@@ -183,7 +193,7 @@ def descriptor_from_member_events(member_events):
|
||||
)
|
||||
|
||||
|
||||
def name_from_member_event(member_event):
|
||||
def name_from_member_event(member_event: EventBase) -> str:
|
||||
if (
|
||||
member_event.content
|
||||
and "displayname" in member_event.content
|
||||
@@ -193,12 +203,12 @@ def name_from_member_event(member_event):
|
||||
return member_event.state_key
|
||||
|
||||
|
||||
def _state_as_two_level_dict(state):
|
||||
ret = {}
|
||||
def _state_as_two_level_dict(state: StateMap[str]) -> Dict[str, Dict[str, str]]:
|
||||
ret = {} # type: Dict[str, Dict[str, str]]
|
||||
for k, v in state.items():
|
||||
ret.setdefault(k[0], {})[k[1]] = v
|
||||
return ret
|
||||
|
||||
|
||||
def _looks_like_an_alias(string):
|
||||
def _looks_like_an_alias(string: str) -> bool:
|
||||
return ALIAS_RE.match(string) is not None
|
||||
|
||||
@@ -30,22 +30,30 @@ IS_GLOB = re.compile(r"[\?\*\[\]]")
|
||||
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
|
||||
|
||||
|
||||
def _room_member_count(ev, condition, room_member_count):
|
||||
def _room_member_count(
|
||||
ev: EventBase, condition: Dict[str, Any], room_member_count: int
|
||||
) -> bool:
|
||||
return _test_ineq_condition(condition, room_member_count)
|
||||
|
||||
|
||||
def _sender_notification_permission(ev, condition, sender_power_level, power_levels):
|
||||
def _sender_notification_permission(
|
||||
ev: EventBase,
|
||||
condition: Dict[str, Any],
|
||||
sender_power_level: int,
|
||||
power_levels: Dict[str, Union[int, Dict[str, int]]],
|
||||
) -> bool:
|
||||
notif_level_key = condition.get("key")
|
||||
if notif_level_key is None:
|
||||
return False
|
||||
|
||||
notif_levels = power_levels.get("notifications", {})
|
||||
assert isinstance(notif_levels, dict)
|
||||
room_notif_level = notif_levels.get(notif_level_key, 50)
|
||||
|
||||
return sender_power_level >= room_notif_level
|
||||
|
||||
|
||||
def _test_ineq_condition(condition, number):
|
||||
def _test_ineq_condition(condition: Dict[str, Any], number: int) -> bool:
|
||||
if "is" not in condition:
|
||||
return False
|
||||
m = INEQUALITY_EXPR.match(condition["is"])
|
||||
@@ -110,7 +118,7 @@ class PushRuleEvaluatorForEvent:
|
||||
event: EventBase,
|
||||
room_member_count: int,
|
||||
sender_power_level: int,
|
||||
power_levels: dict,
|
||||
power_levels: Dict[str, Union[int, Dict[str, int]]],
|
||||
):
|
||||
self._event = event
|
||||
self._room_member_count = room_member_count
|
||||
@@ -120,7 +128,9 @@ class PushRuleEvaluatorForEvent:
|
||||
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
|
||||
self._value_cache = _flatten_dict(event)
|
||||
|
||||
def matches(self, condition: dict, user_id: str, display_name: str) -> bool:
|
||||
def matches(
|
||||
self, condition: Dict[str, Any], user_id: str, display_name: str
|
||||
) -> bool:
|
||||
if condition["kind"] == "event_match":
|
||||
return self._event_match(condition, user_id)
|
||||
elif condition["kind"] == "contains_display_name":
|
||||
@@ -261,7 +271,13 @@ def _re_word_boundary(r: str) -> str:
|
||||
return r"(^|\W)%s(\W|$)" % (r,)
|
||||
|
||||
|
||||
def _flatten_dict(d, prefix=[], result=None):
|
||||
def _flatten_dict(
|
||||
d: Union[EventBase, dict],
|
||||
prefix: Optional[List[str]] = None,
|
||||
result: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, str]:
|
||||
if prefix is None:
|
||||
prefix = []
|
||||
if result is None:
|
||||
result = {}
|
||||
for key, value in d.items():
|
||||
|
||||
+32
-15
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.constants import EventTypes, JoinRules
|
||||
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||
@@ -25,13 +26,17 @@ from synapse.http.servlet import (
|
||||
parse_list_from_args,
|
||||
parse_string,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.admin._base import (
|
||||
admin_patterns,
|
||||
assert_requester_is_admin,
|
||||
assert_user_is_admin,
|
||||
)
|
||||
from synapse.storage.databases.main.room import RoomSortOrder
|
||||
from synapse.types import RoomAlias, RoomID, UserID, create_requester
|
||||
from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -45,12 +50,14 @@ class ShutdownRoomRestServlet(RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/shutdown_room/(?P<room_id>[^/]+)")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.room_shutdown_handler = hs.get_room_shutdown_handler()
|
||||
|
||||
async def on_POST(self, request, room_id):
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
@@ -86,13 +93,15 @@ class DeleteRoomRestServlet(RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.room_shutdown_handler = hs.get_room_shutdown_handler()
|
||||
self.pagination_handler = hs.get_pagination_handler()
|
||||
|
||||
async def on_POST(self, request, room_id):
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
@@ -146,12 +155,12 @@ class ListRoomRestServlet(RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/rooms$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
|
||||
async def on_GET(self, request):
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
@@ -236,19 +245,24 @@ class RoomRestServlet(RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_GET(self, request, room_id):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
ret = await self.store.get_room_with_stats(room_id)
|
||||
if not ret:
|
||||
raise NotFoundError("Room not found")
|
||||
|
||||
return 200, ret
|
||||
members = await self.store.get_users_in_room(room_id)
|
||||
ret["joined_local_devices"] = await self.store.count_devices_by_users(members)
|
||||
|
||||
return (200, ret)
|
||||
|
||||
|
||||
class RoomMembersRestServlet(RestServlet):
|
||||
@@ -258,12 +272,14 @@ class RoomMembersRestServlet(RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_GET(self, request, room_id):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
ret = await self.store.get_room(room_id)
|
||||
@@ -280,14 +296,16 @@ class JoinRoomAliasServlet(RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.room_member_handler = hs.get_room_member_handler()
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
|
||||
async def on_POST(self, request, room_identifier):
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, room_identifier: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
@@ -312,7 +330,6 @@ class JoinRoomAliasServlet(RestServlet):
|
||||
handler = self.room_member_handler
|
||||
room_alias = RoomAlias.from_string(room_identifier)
|
||||
room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
|
||||
room_id = room_id.to_string()
|
||||
else:
|
||||
raise SynapseError(
|
||||
400, "%s was not legal room ID or room alias" % (room_identifier,)
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Awaitable, Callable, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
|
||||
|
||||
from synapse.api.errors import Codes, LoginError, SynapseError
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
@@ -30,6 +30,9 @@ from synapse.rest.client.v2_alpha._base import client_patterns
|
||||
from synapse.rest.well_known import WellKnownBuilder
|
||||
from synapse.types import JsonDict, UserID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -42,7 +45,7 @@ class LoginRestServlet(RestServlet):
|
||||
JWT_TYPE_DEPRECATED = "m.login.jwt"
|
||||
APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
|
||||
@@ -105,22 +108,27 @@ class LoginRestServlet(RestServlet):
|
||||
return 200, {"flows": flows}
|
||||
|
||||
async def on_POST(self, request: SynapseRequest):
|
||||
self._address_ratelimiter.ratelimit(request.getClientIP())
|
||||
|
||||
login_submission = parse_json_object_from_request(request)
|
||||
|
||||
try:
|
||||
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
|
||||
appservice = self.auth.get_appservice_by_req(request)
|
||||
|
||||
if appservice.is_rate_limited():
|
||||
self._address_ratelimiter.ratelimit(request.getClientIP())
|
||||
|
||||
result = await self._do_appservice_login(login_submission, appservice)
|
||||
elif self.jwt_enabled and (
|
||||
login_submission["type"] == LoginRestServlet.JWT_TYPE
|
||||
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
|
||||
):
|
||||
self._address_ratelimiter.ratelimit(request.getClientIP())
|
||||
result = await self._do_jwt_login(login_submission)
|
||||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
||||
self._address_ratelimiter.ratelimit(request.getClientIP())
|
||||
result = await self._do_token_login(login_submission)
|
||||
else:
|
||||
self._address_ratelimiter.ratelimit(request.getClientIP())
|
||||
result = await self._do_other_login(login_submission)
|
||||
except KeyError:
|
||||
raise SynapseError(400, "Missing JSON keys.")
|
||||
@@ -159,7 +167,9 @@ class LoginRestServlet(RestServlet):
|
||||
if not appservice.is_interested_in_user(qualified_user_id):
|
||||
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
|
||||
|
||||
return await self._complete_login(qualified_user_id, login_submission)
|
||||
return await self._complete_login(
|
||||
qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited()
|
||||
)
|
||||
|
||||
async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
|
||||
"""Handle non-token/saml/jwt logins
|
||||
@@ -194,6 +204,7 @@ class LoginRestServlet(RestServlet):
|
||||
login_submission: JsonDict,
|
||||
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
|
||||
create_non_existent_users: bool = False,
|
||||
ratelimit: bool = True,
|
||||
) -> Dict[str, str]:
|
||||
"""Called when we've successfully authed the user and now need to
|
||||
actually login them in (e.g. create devices). This gets called on
|
||||
@@ -208,6 +219,7 @@ class LoginRestServlet(RestServlet):
|
||||
callback: Callback function to run after login.
|
||||
create_non_existent_users: Whether to create the user if they don't
|
||||
exist. Defaults to False.
|
||||
ratelimit: Whether to ratelimit the login request.
|
||||
|
||||
Returns:
|
||||
result: Dictionary of account information after successful login.
|
||||
@@ -216,7 +228,8 @@ class LoginRestServlet(RestServlet):
|
||||
# Before we actually log them in we check if they've already logged in
|
||||
# too often. This happens here rather than before as we don't
|
||||
# necessarily know the user before now.
|
||||
self._account_ratelimiter.ratelimit(user_id.lower())
|
||||
if ratelimit:
|
||||
self._account_ratelimiter.ratelimit(user_id.lower())
|
||||
|
||||
if create_non_existent_users:
|
||||
canonical_uid = await self.auth_handler.check_user_exists(user_id)
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
@@ -21,6 +20,7 @@ from typing import Optional
|
||||
|
||||
from synapse.config._base import Config
|
||||
from synapse.logging.context import defer_to_thread, run_in_background
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
|
||||
from ._base import FileInfo, Responder
|
||||
from .media_storage import FileResponder
|
||||
@@ -91,16 +91,14 @@ class StorageProviderWrapper(StorageProvider):
|
||||
if self.store_synchronous:
|
||||
# store_file is supposed to return an Awaitable, but guard
|
||||
# against improper implementations.
|
||||
result = self.backend.store_file(path, file_info)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
return await maybe_awaitable(self.backend.store_file(path, file_info))
|
||||
else:
|
||||
# TODO: Handle errors.
|
||||
async def store():
|
||||
try:
|
||||
result = self.backend.store_file(path, file_info)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
return await maybe_awaitable(
|
||||
self.backend.store_file(path, file_info)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error storing file")
|
||||
|
||||
@@ -110,9 +108,7 @@ class StorageProviderWrapper(StorageProvider):
|
||||
async def fetch(self, path, file_info):
|
||||
# store_file is supposed to return an Awaitable, but guard
|
||||
# against improper implementations.
|
||||
result = self.backend.fetch(path, file_info)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
return await maybe_awaitable(self.backend.fetch(path, file_info))
|
||||
|
||||
|
||||
class FileStorageProviderBackend(StorageProvider):
|
||||
|
||||
+1
-1
@@ -632,7 +632,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||
return StatsHandler(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_spam_checker(self):
|
||||
def get_spam_checker(self) -> SpamChecker:
|
||||
return SpamChecker(self)
|
||||
|
||||
@cache_in_self
|
||||
|
||||
@@ -57,6 +57,38 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
|
||||
)
|
||||
|
||||
async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
|
||||
"""Retrieve number of all devices of given users.
|
||||
Only returns number of devices that are not marked as hidden.
|
||||
|
||||
Args:
|
||||
user_ids: The IDs of the users which owns devices
|
||||
Returns:
|
||||
Number of devices of this users.
|
||||
"""
|
||||
|
||||
def count_devices_by_users_txn(txn, user_ids):
|
||||
sql = """
|
||||
SELECT count(*)
|
||||
FROM devices
|
||||
WHERE
|
||||
hidden = '0' AND
|
||||
"""
|
||||
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "user_id", user_ids
|
||||
)
|
||||
|
||||
txn.execute(sql + clause, args)
|
||||
return txn.fetchone()[0]
|
||||
|
||||
if not user_ids:
|
||||
return 0
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"count_devices_by_users", count_devices_by_users_txn, user_ids
|
||||
)
|
||||
|
||||
async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
|
||||
"""Retrieve a device. Only returns devices that are not marked as
|
||||
hidden.
|
||||
|
||||
@@ -15,10 +15,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import inspect
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Hashable,
|
||||
@@ -542,11 +544,11 @@ class DoneAwaitable:
|
||||
raise StopIteration(self.value)
|
||||
|
||||
|
||||
def maybe_awaitable(value):
|
||||
def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
|
||||
"""Convert a value to an awaitable if not already an awaitable.
|
||||
"""
|
||||
|
||||
if hasattr(value, "__await__"):
|
||||
if inspect.isawaitable(value):
|
||||
assert isinstance(value, Awaitable)
|
||||
return value
|
||||
|
||||
return DoneAwaitable(value)
|
||||
|
||||
@@ -12,13 +12,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -105,10 +105,7 @@ class Signal:
|
||||
|
||||
async def do(observer):
|
||||
try:
|
||||
result = observer(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
return result
|
||||
return await maybe_awaitable(observer(*args, **kwargs))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"%s signal observer %s failed: %r", self.name, observer, e,
|
||||
|
||||
@@ -273,7 +273,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||
spam_checker = self.hs.get_spam_checker()
|
||||
|
||||
class AllowAll:
|
||||
def check_username_for_spam(self, user_profile):
|
||||
async def check_username_for_spam(self, user_profile):
|
||||
# Allow all users.
|
||||
return False
|
||||
|
||||
@@ -286,7 +286,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# Configure a spam checker that filters all users.
|
||||
class BlockAll:
|
||||
def check_username_for_spam(self, user_profile):
|
||||
async def check_username_for_spam(self, user_profile):
|
||||
# All users are spammy.
|
||||
return True
|
||||
|
||||
|
||||
@@ -18,30 +18,35 @@ import logging
|
||||
from io import StringIO
|
||||
|
||||
from synapse.logging._terse_json import JsonFormatter, TerseJsonFormatter
|
||||
from synapse.logging.context import LoggingContext, LoggingContextFilter
|
||||
|
||||
from tests.logging import LoggerCleanupMixin
|
||||
from tests.unittest import TestCase
|
||||
|
||||
|
||||
class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
|
||||
def setUp(self):
|
||||
self.output = StringIO()
|
||||
|
||||
def get_log_line(self):
|
||||
# One log message, with a single trailing newline.
|
||||
data = self.output.getvalue()
|
||||
logs = data.splitlines()
|
||||
self.assertEqual(len(logs), 1)
|
||||
self.assertEqual(data.count("\n"), 1)
|
||||
return json.loads(logs[0])
|
||||
|
||||
def test_terse_json_output(self):
|
||||
"""
|
||||
The Terse JSON formatter converts log messages to JSON.
|
||||
"""
|
||||
output = StringIO()
|
||||
|
||||
handler = logging.StreamHandler(output)
|
||||
handler = logging.StreamHandler(self.output)
|
||||
handler.setFormatter(TerseJsonFormatter())
|
||||
logger = self.get_logger(handler)
|
||||
|
||||
logger.info("Hello there, %s!", "wally")
|
||||
|
||||
# One log message, with a single trailing newline.
|
||||
data = output.getvalue()
|
||||
logs = data.splitlines()
|
||||
self.assertEqual(len(logs), 1)
|
||||
self.assertEqual(data.count("\n"), 1)
|
||||
log = json.loads(logs[0])
|
||||
log = self.get_log_line()
|
||||
|
||||
# The terse logger should give us these keys.
|
||||
expected_log_keys = [
|
||||
@@ -57,9 +62,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
|
||||
"""
|
||||
Additional information can be included in the structured logging.
|
||||
"""
|
||||
output = StringIO()
|
||||
|
||||
handler = logging.StreamHandler(output)
|
||||
handler = logging.StreamHandler(self.output)
|
||||
handler.setFormatter(TerseJsonFormatter())
|
||||
logger = self.get_logger(handler)
|
||||
|
||||
@@ -67,12 +70,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
|
||||
"Hello there, %s!", "wally", extra={"foo": "bar", "int": 3, "bool": True}
|
||||
)
|
||||
|
||||
# One log message, with a single trailing newline.
|
||||
data = output.getvalue()
|
||||
logs = data.splitlines()
|
||||
self.assertEqual(len(logs), 1)
|
||||
self.assertEqual(data.count("\n"), 1)
|
||||
log = json.loads(logs[0])
|
||||
log = self.get_log_line()
|
||||
|
||||
# The terse logger should give us these keys.
|
||||
expected_log_keys = [
|
||||
@@ -96,20 +94,13 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
|
||||
"""
|
||||
The Terse JSON formatter converts log messages to JSON.
|
||||
"""
|
||||
output = StringIO()
|
||||
|
||||
handler = logging.StreamHandler(output)
|
||||
handler = logging.StreamHandler(self.output)
|
||||
handler.setFormatter(JsonFormatter())
|
||||
logger = self.get_logger(handler)
|
||||
|
||||
logger.info("Hello there, %s!", "wally")
|
||||
|
||||
# One log message, with a single trailing newline.
|
||||
data = output.getvalue()
|
||||
logs = data.splitlines()
|
||||
self.assertEqual(len(logs), 1)
|
||||
self.assertEqual(data.count("\n"), 1)
|
||||
log = json.loads(logs[0])
|
||||
log = self.get_log_line()
|
||||
|
||||
# The terse logger should give us these keys.
|
||||
expected_log_keys = [
|
||||
@@ -119,3 +110,31 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
|
||||
]
|
||||
self.assertCountEqual(log.keys(), expected_log_keys)
|
||||
self.assertEqual(log["log"], "Hello there, wally!")
|
||||
|
||||
def test_with_context(self):
|
||||
"""
|
||||
The logging context should be added to the JSON response.
|
||||
"""
|
||||
handler = logging.StreamHandler(self.output)
|
||||
handler.setFormatter(JsonFormatter())
|
||||
handler.addFilter(LoggingContextFilter(request=""))
|
||||
logger = self.get_logger(handler)
|
||||
|
||||
with LoggingContext() as context_one:
|
||||
context_one.request = "test"
|
||||
logger.info("Hello there, %s!", "wally")
|
||||
|
||||
log = self.get_log_line()
|
||||
|
||||
# The terse logger should give us these keys.
|
||||
expected_log_keys = [
|
||||
"log",
|
||||
"level",
|
||||
"namespace",
|
||||
"request",
|
||||
"scope",
|
||||
]
|
||||
self.assertCountEqual(log.keys(), expected_log_keys)
|
||||
self.assertEqual(log["log"], "Hello there, wally!")
|
||||
self.assertEqual(log["request"], "test")
|
||||
self.assertIsNone(log["scope"])
|
||||
|
||||
@@ -1084,6 +1084,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
|
||||
self.assertIn("canonical_alias", channel.json_body)
|
||||
self.assertIn("joined_members", channel.json_body)
|
||||
self.assertIn("joined_local_members", channel.json_body)
|
||||
self.assertIn("joined_local_devices", channel.json_body)
|
||||
self.assertIn("version", channel.json_body)
|
||||
self.assertIn("creator", channel.json_body)
|
||||
self.assertIn("encryption", channel.json_body)
|
||||
@@ -1096,6 +1097,39 @@ class RoomTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertEqual(room_id_1, channel.json_body["room_id"])
|
||||
|
||||
def test_single_room_devices(self):
|
||||
"""Test that `joined_local_devices` can be requested correctly"""
|
||||
room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
|
||||
|
||||
url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
|
||||
request, channel = self.make_request(
|
||||
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(1, channel.json_body["joined_local_devices"])
|
||||
|
||||
# Have another user join the room
|
||||
user_1 = self.register_user("foo", "pass")
|
||||
user_tok_1 = self.login("foo", "pass")
|
||||
self.helper.join(room_id_1, user_1, tok=user_tok_1)
|
||||
|
||||
url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
|
||||
request, channel = self.make_request(
|
||||
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(2, channel.json_body["joined_local_devices"])
|
||||
|
||||
# leave room
|
||||
self.helper.leave(room_id_1, self.admin_user, tok=self.admin_user_tok)
|
||||
self.helper.leave(room_id_1, user_1, tok=user_tok_1)
|
||||
url = "/_synapse/admin/v1/rooms/%s" % (room_id_1,)
|
||||
request, channel = self.make_request(
|
||||
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(0, channel.json_body["joined_local_devices"])
|
||||
|
||||
def test_room_members(self):
|
||||
"""Test that room members can be requested correctly"""
|
||||
# Create two test rooms
|
||||
|
||||
@@ -26,6 +26,7 @@ from mock import Mock
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import EventContentFields, EventTypes, Membership
|
||||
from synapse.handlers.pagination import PurgeStatus
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client.v1 import directory, login, profile, room
|
||||
from synapse.rest.client.v2_alpha import account
|
||||
from synapse.types import JsonDict, RoomAlias, UserID
|
||||
@@ -625,6 +626,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
|
||||
user_id = "@sid1:red"
|
||||
|
||||
servlets = [
|
||||
admin.register_servlets,
|
||||
profile.register_servlets,
|
||||
room.register_servlets,
|
||||
]
|
||||
@@ -703,6 +705,20 @@ class RoomJoinRatelimitTestCase(RoomBase):
|
||||
request, channel = self.make_request("POST", path % room_id, {})
|
||||
self.assertEquals(channel.code, 200)
|
||||
|
||||
@unittest.override_config(
|
||||
{
|
||||
"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}},
|
||||
"auto_join_rooms": ["#room:red", "#room2:red", "#room3:red", "#room4:red"],
|
||||
"autocreate_auto_join_rooms": True,
|
||||
},
|
||||
)
|
||||
def test_autojoin_rooms(self):
|
||||
user_id = self.register_user("testuser", "password")
|
||||
|
||||
# Check that the new user successfully joined the four rooms
|
||||
rooms = self.get_success(self.hs.get_datastore().get_rooms_for_user(user_id))
|
||||
self.assertEqual(len(rooms), 4)
|
||||
|
||||
|
||||
class RoomMessagesTestCase(RoomBase):
|
||||
""" Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """
|
||||
|
||||
@@ -79,6 +79,32 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||
res["device2"],
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_count_devices_by_users(self):
|
||||
yield defer.ensureDeferred(
|
||||
self.store.store_device("user_id", "device1", "display_name 1")
|
||||
)
|
||||
yield defer.ensureDeferred(
|
||||
self.store.store_device("user_id", "device2", "display_name 2")
|
||||
)
|
||||
yield defer.ensureDeferred(
|
||||
self.store.store_device("user_id2", "device3", "display_name 3")
|
||||
)
|
||||
|
||||
res = yield defer.ensureDeferred(self.store.count_devices_by_users())
|
||||
self.assertEqual(0, res)
|
||||
|
||||
res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"]))
|
||||
self.assertEqual(0, res)
|
||||
|
||||
res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"]))
|
||||
self.assertEqual(2, res)
|
||||
|
||||
res = yield defer.ensureDeferred(
|
||||
self.store.count_devices_by_users(["user_id", "user_id2"])
|
||||
)
|
||||
self.assertEqual(3, res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_device_updates_by_remote(self):
|
||||
device_ids = ["device_id1", "device_id2"]
|
||||
|
||||
Reference in New Issue
Block a user