1
0

Merge commit 'c97da1e45' into anoa/dinsic_release_1_23_1

This commit is contained in:
Andrew Morgan
2020-12-31 13:40:57 +00:00
38 changed files with 521 additions and 173 deletions

View File

@@ -46,7 +46,7 @@ locally. You'll need python 3.6 or later, and to install a number of tools:
```
# Install the dependencies
pip install -e ".[lint]"
pip install -e ".[lint,mypy]"
# Run the linter script
./scripts-dev/lint.sh
@@ -63,7 +63,7 @@ run-time:
./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder
```
You can also provided the `-d` option, which will lint the files that have been
You can also provide the `-d` option, which will lint the files that have been
changed since the last git commit. This will often be significantly faster than
linting the whole codebase.

View File

@@ -256,9 +256,9 @@ directory of your choice::
Synapse has a number of external dependencies, that are easiest
to install using pip and a virtualenv::
virtualenv -p python3 env
source env/bin/activate
python -m pip install --no-use-pep517 -e ".[all]"
python3 -m venv ./env
source ./env/bin/activate
pip install -e ".[all,test]"
This will run a process of downloading and installing all the needed
dependencies into a virtual env.
@@ -270,9 +270,9 @@ check that everything is installed as it should be::
This should end with a 'PASSED' result::
Ran 143 tests in 0.601s
Ran 1266 tests in 643.930s
PASSED (successes=143)
PASSED (skips=15, successes=1251)
Running the Integration Tests
=============================

1
changelog.d/8610.feature Normal file
View File

@@ -0,0 +1 @@
Add an admin APIs to allow server admins to list users' pushers. Contributed by @dklimpel.

1
changelog.d/8633.misc Normal file
View File

@@ -0,0 +1 @@
Run `mypy` as part of the lint.sh script.

1
changelog.d/8655.misc Normal file
View File

@@ -0,0 +1 @@
Add more type hints to the application services code.

1
changelog.d/8664.misc Normal file
View File

@@ -0,0 +1 @@
Tell Black to format code for Python 3.5.

1
changelog.d/8666.doc Normal file
View File

@@ -0,0 +1 @@
Minor updates to docs on running tests.

1
changelog.d/8669.misc Normal file
View File

@@ -0,0 +1 @@
Don't pull event from DB when handling replication traffic.

1
changelog.d/8678.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix `Object of type frozendict is not JSON serializable` exceptions when using third-party event rules.

View File

@@ -611,3 +611,82 @@ The following parameters should be set in the URL:
- ``user_id`` - fully qualified: for example, ``@user:server.com``.
- ``device_id`` - The device to delete.
List all pushers
================
Gets information about all pushers for a specific ``user_id``.
The API is::
GET /_synapse/admin/v1/users/<user_id>/pushers
To use it, you will need to authenticate by providing an ``access_token`` for a
server admin: see `README.rst <README.rst>`_.
A response body like the following is returned:
.. code:: json
{
"pushers": [
{
"app_display_name":"HTTP Push Notifications",
"app_id":"m.http",
"data": {
"url":"example.com"
},
"device_display_name":"pushy push",
"kind":"http",
"lang":"None",
"profile_tag":"",
"pushkey":"a@example.com"
}
],
"total": 1
}
**Parameters**
The following parameters should be set in the URL:
- ``user_id`` - fully qualified: for example, ``@user:server.com``.
**Response**
The following fields are returned in the JSON response body:
- ``pushers`` - An array containing the current pushers for the user
- ``app_display_name`` - string - A string that will allow the user to identify
what application owns this pusher.
- ``app_id`` - string - This is a reverse-DNS style identifier for the application.
Max length, 64 chars.
- ``data`` - A dictionary of information for the pusher implementation itself.
- ``url`` - string - Required if ``kind`` is ``http``. The URL to use to send
notifications to.
- ``format`` - string - The format to use when sending notifications to the
Push Gateway.
- ``device_display_name`` - string - A string that will allow the user to identify
what device owns this pusher.
- ``profile_tag`` - string - This string determines which set of device specific rules
this pusher executes.
- ``kind`` - string - The kind of pusher. "http" is a pusher that sends HTTP pokes.
- ``lang`` - string - The preferred language for receiving notifications
(e.g. 'en' or 'en-US')
- ``profile_tag`` - string - This string determines which set of device specific rules
this pusher executes.
- ``pushkey`` - string - This is a unique identifier for this pusher.
Max length, 512 bytes.
- ``total`` - integer - Number of pushers.
See also `Client-Server API Spec <https://matrix.org/docs/spec/client_server/latest#get-matrix-client-r0-pushers>`_

View File

@@ -57,6 +57,7 @@ files =
synapse/server_notices,
synapse/spam_checker_api,
synapse/state,
synapse/storage/databases/main/appservice.py,
synapse/storage/databases/main/events.py,
synapse/storage/databases/main/registration.py,
synapse/storage/databases/main/stream.py,
@@ -82,6 +83,9 @@ ignore_missing_imports = True
[mypy-zope]
ignore_missing_imports = True
[mypy-bcrypt]
ignore_missing_imports = True
[mypy-constantly]
ignore_missing_imports = True

View File

@@ -35,7 +35,7 @@
showcontent = true
[tool.black]
target-version = ['py34']
target-version = ['py35']
exclude = '''
(

View File

@@ -94,3 +94,4 @@ isort "${files[@]}"
python3 -m black "${files[@]}"
./scripts-dev/config-lint.sh
flake8 "${files[@]}"
mypy

View File

@@ -12,9 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from prometheus_client import Counter
@@ -34,16 +33,20 @@ from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
from synapse.types import Collection, JsonDict, RoomStreamToken, UserID
from synapse.storage.databases.main.directory import RoomAliasMapping
from synapse.types import Collection, JsonDict, RoomAlias, RoomStreamToken, UserID
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "")
class ApplicationServicesHandler:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id
self.appservice_api = hs.get_application_service_api()
@@ -247,7 +250,9 @@ class ApplicationServicesHandler:
service, "presence", new_token
)
async def _handle_typing(self, service: ApplicationService, new_token: int):
async def _handle_typing(
self, service: ApplicationService, new_token: int
) -> List[JsonDict]:
typing_source = self.event_sources.sources["typing"]
# Get the typing events from just before current
typing, _ = await typing_source.get_new_events_as(
@@ -259,7 +264,7 @@ class ApplicationServicesHandler:
)
return typing
async def _handle_receipts(self, service: ApplicationService):
async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]:
from_key = await self.store.get_type_stream_id_for_appservice(
service, "read_receipt"
)
@@ -301,11 +306,11 @@ class ApplicationServicesHandler:
return events
async def query_user_exists(self, user_id):
async def query_user_exists(self, user_id: str) -> bool:
"""Check if any application service knows this user_id exists.
Args:
user_id(str): The user to query if they exist on any AS.
user_id: The user to query if they exist on any AS.
Returns:
True if this user exists on at least one application service.
"""
@@ -316,11 +321,13 @@ class ApplicationServicesHandler:
return True
return False
async def query_room_alias_exists(self, room_alias):
async def query_room_alias_exists(
self, room_alias: RoomAlias
) -> Optional[RoomAliasMapping]:
"""Check if an application service knows this room alias exists.
Args:
room_alias(RoomAlias): The room alias to query.
room_alias: The room alias to query.
Returns:
namedtuple: with keys "room_id" and "servers" or None if no
association can be found.
@@ -336,10 +343,13 @@ class ApplicationServicesHandler:
)
if is_known_alias:
# the alias exists now so don't query more ASes.
result = await self.store.get_association_from_room_alias(room_alias)
return result
return await self.store.get_association_from_room_alias(room_alias)
async def query_3pe(self, kind, protocol, fields):
return None
async def query_3pe(
self, kind: str, protocol: str, fields: Dict[bytes, List[bytes]]
) -> List[JsonDict]:
services = self._get_services_for_3pn(protocol)
results = await make_deferred_yieldable(
@@ -361,7 +371,9 @@ class ApplicationServicesHandler:
return ret
async def get_3pe_protocols(self, only_protocol=None):
async def get_3pe_protocols(
self, only_protocol: Optional[str] = None
) -> Dict[str, JsonDict]:
services = self.store.get_app_services()
protocols = {} # type: Dict[str, List[JsonDict]]
@@ -379,7 +391,7 @@ class ApplicationServicesHandler:
if info is not None:
protocols[p].append(info)
def _merge_instances(infos):
def _merge_instances(infos: List[JsonDict]) -> JsonDict:
if not infos:
return {}
@@ -394,19 +406,17 @@ class ApplicationServicesHandler:
return combined
for p in protocols.keys():
protocols[p] = _merge_instances(protocols[p])
return {p: _merge_instances(protocols[p]) for p in protocols.keys()}
return protocols
async def _get_services_for_event(self, event):
async def _get_services_for_event(
self, event: EventBase
) -> List[ApplicationService]:
"""Retrieve a list of application services interested in this event.
Args:
event(Event): The event to check. Can be None if alias_list is not.
event: The event to check. Can be None if alias_list is not.
Returns:
list<ApplicationService>: A list of services interested in this
event based on the service regex.
A list of services interested in this event based on the service regex.
"""
services = self.store.get_app_services()
@@ -420,17 +430,15 @@ class ApplicationServicesHandler:
return interested_list
def _get_services_for_user(self, user_id):
def _get_services_for_user(self, user_id: str) -> List[ApplicationService]:
services = self.store.get_app_services()
interested_list = [s for s in services if (s.is_interested_in_user(user_id))]
return interested_list
return [s for s in services if (s.is_interested_in_user(user_id))]
def _get_services_for_3pn(self, protocol):
def _get_services_for_3pn(self, protocol: str) -> List[ApplicationService]:
services = self.store.get_app_services()
interested_list = [s for s in services if s.is_interested_in_protocol(protocol)]
return interested_list
return [s for s in services if s.is_interested_in_protocol(protocol)]
async def _is_unknown_user(self, user_id):
async def _is_unknown_user(self, user_id: str) -> bool:
if not self.is_mine_id(user_id):
# we don't know if they are unknown or not since it isn't one of our
# users. We can't poke ASes.
@@ -445,9 +453,8 @@ class ApplicationServicesHandler:
service_list = [s for s in services if s.sender == user_id]
return len(service_list) == 0
async def _check_user_exists(self, user_id):
async def _check_user_exists(self, user_id: str) -> bool:
unknown_user = await self._is_unknown_user(user_id)
if unknown_user:
exists = await self.query_user_exists(user_id)
return exists
return await self.query_user_exists(user_id)
return True

View File

@@ -18,10 +18,20 @@ import logging
import time
import unicodedata
import urllib.parse
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
)
import attr
import bcrypt # type: ignore[import]
import bcrypt
import pymacaroons
from synapse.api.constants import LoginType
@@ -49,6 +59,9 @@ from synapse.util.threepids import canonicalise_email
from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@@ -149,11 +162,7 @@ class SsoLoginExtraAttributes:
class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]

View File

@@ -50,9 +50,8 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
from synapse.util import json_decoder
from synapse.util import json_decoder, json_encoder
from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client
@@ -928,7 +927,7 @@ class EventCreationHandler:
# Ensure that we can round trip before trying to persist in db
try:
dump = frozendict_json_encoder.encode(event.content)
dump = json_encoder.encode(event.content)
json_decoder.decode(dump)
except Exception:
logger.exception("Failed to encode content: %r", event.content)

View File

@@ -359,7 +359,7 @@ class SimpleHttpClient:
agent=self.agent,
data=body_producer,
headers=headers,
**self._extra_treq_args
**self._extra_treq_args,
) # type: defer.Deferred
# we use our own timeout mechanism rather than treq's as a workaround

View File

@@ -35,8 +35,6 @@ from twisted.web.server import NOT_DONE_YET, Request
from twisted.web.static import File, NoRangeStaticProducer
from twisted.web.util import redirectTo
import synapse.events
import synapse.metrics
from synapse.api.errors import (
CodeMessageException,
Codes,
@@ -620,7 +618,7 @@ def respond_with_json(
if pretty_print:
encoder = iterencode_pretty_printed_json
else:
if canonical_json or synapse.events.USE_FROZEN_DICTS:
if canonical_json:
encoder = iterencode_canonical_json
else:
encoder = _encode_json_bytes

View File

@@ -28,6 +28,7 @@ from typing import (
Union,
)
import attr
from prometheus_client import Counter
from twisted.internet import defer
@@ -173,6 +174,17 @@ class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
return bool(self.events)
@attr.s(slots=True, frozen=True)
class _PendingRoomEventEntry:
event_pos = attr.ib(type=PersistedEventPosition)
extra_users = attr.ib(type=Collection[UserID])
room_id = attr.ib(type=str)
type = attr.ib(type=str)
state_key = attr.ib(type=Optional[str])
membership = attr.ib(type=Optional[str])
class Notifier:
""" This class is responsible for notifying any listeners when there are
new events available for it.
@@ -190,9 +202,7 @@ class Notifier:
self.storage = hs.get_storage()
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore()
self.pending_new_room_events = (
[]
) # type: List[Tuple[PersistedEventPosition, EventBase, Collection[UserID]]]
self.pending_new_room_events = [] # type: List[_PendingRoomEventEntry]
# Called when there are new things to stream over replication
self.replication_callbacks = [] # type: List[Callable[[], None]]
@@ -255,7 +265,29 @@ class Notifier:
max_room_stream_token: RoomStreamToken,
extra_users: Collection[UserID] = [],
):
""" Used by handlers to inform the notifier something has happened
"""Unwraps event and calls `on_new_room_event_args`.
"""
self.on_new_room_event_args(
event_pos=event_pos,
room_id=event.room_id,
event_type=event.type,
state_key=event.get("state_key"),
membership=event.content.get("membership"),
max_room_stream_token=max_room_stream_token,
extra_users=extra_users,
)
def on_new_room_event_args(
self,
room_id: str,
event_type: str,
state_key: Optional[str],
membership: Optional[str],
event_pos: PersistedEventPosition,
max_room_stream_token: RoomStreamToken,
extra_users: Collection[UserID] = [],
):
"""Used by handlers to inform the notifier something has happened
in the room, room event wise.
This triggers the notifier to wake up any listeners that are
@@ -266,7 +298,16 @@ class Notifier:
until all previous events have been persisted before notifying
the client streams.
"""
self.pending_new_room_events.append((event_pos, event, extra_users))
self.pending_new_room_events.append(
_PendingRoomEventEntry(
event_pos=event_pos,
extra_users=extra_users,
room_id=room_id,
type=event_type,
state_key=state_key,
membership=membership,
)
)
self._notify_pending_new_room_events(max_room_stream_token)
self.notify_replication()
@@ -284,18 +325,19 @@ class Notifier:
users = set() # type: Set[UserID]
rooms = set() # type: Set[str]
for event_pos, event, extra_users in pending:
if event_pos.persisted_after(max_room_stream_token):
self.pending_new_room_events.append((event_pos, event, extra_users))
for entry in pending:
if entry.event_pos.persisted_after(max_room_stream_token):
self.pending_new_room_events.append(entry)
else:
if (
event.type == EventTypes.Member
and event.membership == Membership.JOIN
entry.type == EventTypes.Member
and entry.membership == Membership.JOIN
and entry.state_key
):
self._user_joined_room(event.state_key, event.room_id)
self._user_joined_room(entry.state_key, entry.room_id)
users.update(extra_users)
rooms.add(event.room_id)
users.update(entry.extra_users)
rooms.add(entry.room_id)
if users or rooms:
self.on_new_event(

View File

@@ -141,21 +141,25 @@ class ReplicationDataHandler:
if row.type != EventsStreamEventRow.TypeId:
continue
assert isinstance(row, EventsStreamRow)
assert isinstance(row.data, EventsStreamEventRow)
event = await self.store.get_event(
row.data.event_id, allow_rejected=True
)
if event.rejected_reason:
if row.data.rejected:
continue
extra_users = () # type: Tuple[UserID, ...]
if event.type == EventTypes.Member:
extra_users = (UserID.from_string(event.state_key),)
if row.data.type == EventTypes.Member and row.data.state_key:
extra_users = (UserID.from_string(row.data.state_key),)
max_token = self.store.get_room_max_token()
event_pos = PersistedEventPosition(instance_name, token)
self.notifier.on_new_room_event(
event, event_pos, max_token, extra_users
self.notifier.on_new_room_event_args(
event_pos=event_pos,
max_room_stream_token=max_token,
extra_users=extra_users,
room_id=row.data.room_id,
event_type=row.data.type,
state_key=row.data.state_key,
membership=row.data.membership,
)
# Notify any waiting deferreds. The list is ordered by position so we

View File

@@ -15,12 +15,15 @@
# limitations under the License.
import heapq
from collections.abc import Iterable
from typing import List, Tuple, Type
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
import attr
from ._base import Stream, StreamUpdateResult, Token
if TYPE_CHECKING:
from synapse.server import HomeServer
"""Handling of the 'events' replication stream
This stream contains rows of various types. Each row therefore contains a 'type'
@@ -81,12 +84,14 @@ class BaseEventsStreamRow:
class EventsStreamEventRow(BaseEventsStreamRow):
TypeId = "ev"
event_id = attr.ib() # str
room_id = attr.ib() # str
type = attr.ib() # str
state_key = attr.ib() # str, optional
redacts = attr.ib() # str, optional
relates_to = attr.ib() # str, optional
event_id = attr.ib(type=str)
room_id = attr.ib(type=str)
type = attr.ib(type=str)
state_key = attr.ib(type=Optional[str])
redacts = attr.ib(type=Optional[str])
relates_to = attr.ib(type=Optional[str])
membership = attr.ib(type=Optional[str])
rejected = attr.ib(type=bool)
@attr.s(slots=True, frozen=True)
@@ -113,7 +118,7 @@ class EventsStream(Stream):
NAME = "events"
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),

View File

@@ -50,6 +50,7 @@ from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
from synapse.rest.admin.users import (
AccountValidityRenewServlet,
DeactivateAccountRestServlet,
PushersRestServlet,
ResetPasswordRestServlet,
SearchUsersRestServlet,
UserAdminServlet,
@@ -226,8 +227,9 @@ def register_servlets(hs, http_server):
DeviceRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeleteDevicesRestServlet(hs).register(http_server)
EventReportsRestServlet(hs).register(http_server)
EventReportDetailRestServlet(hs).register(http_server)
EventReportsRestServlet(hs).register(http_server)
PushersRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):

View File

@@ -39,6 +39,17 @@ from synapse.types import JsonDict, UserID
logger = logging.getLogger(__name__)
_GET_PUSHERS_ALLOWED_KEYS = {
"app_display_name",
"app_id",
"data",
"device_display_name",
"kind",
"lang",
"profile_tag",
"pushkey",
}
class UsersRestServlet(RestServlet):
PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$")
@@ -713,6 +724,47 @@ class UserMembershipRestServlet(RestServlet):
return 200, ret
class PushersRestServlet(RestServlet):
"""
Gets information about all pushers for a specific `user_id`.
Example:
http://localhost:8008/_synapse/admin/v1/users/
@user:server/pushers
Returns:
pushers: Dictionary containing pushers information.
total: Number of pushers in dictonary `pushers`.
"""
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$")
def __init__(self, hs):
self.is_mine = hs.is_mine
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
if not self.is_mine(UserID.from_string(user_id)):
raise SynapseError(400, "Can only lookup local users")
if not await self.store.get_user_by_id(user_id):
raise NotFoundError("User not found")
pushers = await self.store.get_pushers_by_user_id(user_id)
filtered_pushers = [
{k: v for k, v in p.items() if k in _GET_PUSHERS_ALLOWED_KEYS}
for p in pushers
]
return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)}
class UserMediaRestServlet(RestServlet):
"""
Gets information about all uploaded local media for a specific `user_id`.

View File

@@ -94,7 +94,7 @@ def make_pool(
cp_openfun=lambda conn: engine.on_new_connection(
LoggingDatabaseConnection(conn, engine, "on_new_connection")
),
**db_config.config.get("args", {})
**db_config.config.get("args", {}),
)
@@ -632,7 +632,7 @@ class DatabasePool:
func,
*args,
db_autocommit=db_autocommit,
**kwargs
**kwargs,
)
for after_callback, after_args, after_kwargs in after_callbacks:

View File

@@ -15,21 +15,31 @@
# limitations under the License.
import logging
import re
from typing import List
from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple
from synapse.appservice import ApplicationService, AppServiceTransaction
from synapse.appservice import (
ApplicationService,
ApplicationServiceState,
AppServiceTransaction,
)
from synapse.config.appservice import load_appservices
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.types import Connection
from synapse.types import JsonDict
from synapse.util import json_encoder
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
def _make_exclusive_regex(services_cache):
def _make_exclusive_regex(
services_cache: List[ApplicationService],
) -> Optional[Pattern]:
# We precompile a regex constructed from all the regexes that the AS's
# have registered for exclusive users.
exclusive_user_regexes = [
@@ -39,17 +49,19 @@ def _make_exclusive_regex(services_cache):
]
if exclusive_user_regexes:
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
exclusive_user_regex = re.compile(exclusive_user_regex)
exclusive_user_pattern = re.compile(
exclusive_user_regex
) # type: Optional[Pattern]
else:
# We handle this case specially otherwise the constructed regex
# will always match
exclusive_user_regex = None
exclusive_user_pattern = None
return exclusive_user_regex
return exclusive_user_pattern
class ApplicationServiceWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
self.services_cache = load_appservices(
hs.hostname, hs.config.app_service_config_files
)
@@ -60,7 +72,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
def get_app_services(self):
return self.services_cache
def get_if_app_services_interested_in_user(self, user_id):
def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
"""Check if the user is one associated with an app service (exclusively)
"""
if self.exclusive_user_regex:
@@ -68,7 +80,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
else:
return False
def get_app_service_by_user_id(self, user_id):
def get_app_service_by_user_id(self, user_id: str) -> Optional[ApplicationService]:
"""Retrieve an application service from their user ID.
All application services have associated with them a particular user ID.
@@ -77,35 +89,35 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
a user ID to an application service.
Args:
user_id(str): The user ID to see if it is an application service.
user_id: The user ID to see if it is an application service.
Returns:
synapse.appservice.ApplicationService or None.
The application service or None.
"""
for service in self.services_cache:
if service.sender == user_id:
return service
return None
def get_app_service_by_token(self, token):
def get_app_service_by_token(self, token: str) -> Optional[ApplicationService]:
"""Get the application service with the given appservice token.
Args:
token (str): The application service token.
token: The application service token.
Returns:
synapse.appservice.ApplicationService or None.
The application service or None.
"""
for service in self.services_cache:
if service.token == token:
return service
return None
def get_app_service_by_id(self, as_id):
def get_app_service_by_id(self, as_id: str) -> Optional[ApplicationService]:
"""Get the application service with the given appservice ID.
Args:
as_id (str): The application service ID.
as_id: The application service ID.
Returns:
synapse.appservice.ApplicationService or None.
The application service or None.
"""
for service in self.services_cache:
if service.id == as_id:
@@ -124,11 +136,13 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
class ApplicationServiceTransactionWorkerStore(
ApplicationServiceWorkerStore, EventsWorkerStore
):
async def get_appservices_by_state(self, state):
async def get_appservices_by_state(
self, state: ApplicationServiceState
) -> List[ApplicationService]:
"""Get a list of application services based on their state.
Args:
state(ApplicationServiceState): The state to filter on.
state: The state to filter on.
Returns:
A list of ApplicationServices, which may be empty.
"""
@@ -145,13 +159,15 @@ class ApplicationServiceTransactionWorkerStore(
services.append(service)
return services
async def get_appservice_state(self, service):
async def get_appservice_state(
self, service: ApplicationService
) -> Optional[ApplicationServiceState]:
"""Get the application service state.
Args:
service(ApplicationService): The service whose state to set.
service: The service whose state to set.
Returns:
An ApplicationServiceState.
An ApplicationServiceState or none.
"""
result = await self.db_pool.simple_select_one(
"application_services_state",
@@ -164,12 +180,14 @@ class ApplicationServiceTransactionWorkerStore(
return result.get("state")
return None
async def set_appservice_state(self, service, state) -> None:
async def set_appservice_state(
self, service: ApplicationService, state: ApplicationServiceState
) -> None:
"""Set the application service state.
Args:
service(ApplicationService): The service whose state to set.
state(ApplicationServiceState): The connectivity state to apply.
service: The service whose state to set.
state: The connectivity state to apply.
"""
await self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state}
@@ -226,13 +244,14 @@ class ApplicationServiceTransactionWorkerStore(
"create_appservice_txn", _create_appservice_txn
)
async def complete_appservice_txn(self, txn_id, service) -> None:
async def complete_appservice_txn(
self, txn_id: int, service: ApplicationService
) -> None:
"""Completes an application service transaction.
Args:
txn_id(str): The transaction ID being completed.
service(ApplicationService): The application service which was sent
this transaction.
txn_id: The transaction ID being completed.
service: The application service which was sent this transaction.
"""
txn_id = int(txn_id)
@@ -242,7 +261,7 @@ class ApplicationServiceTransactionWorkerStore(
# has probably missed some events), so whine loudly but still continue,
# since it shouldn't fail completion of the transaction.
last_txn_id = self._get_last_txn(txn, service.id)
if (last_txn_id + 1) != txn_id:
if (txn_id + 1) != txn_id:
logger.error(
"appservice: Completing a transaction which has an ID > 1 from "
"the last ID sent to this AS. We've either dropped events or "
@@ -272,12 +291,13 @@ class ApplicationServiceTransactionWorkerStore(
"complete_appservice_txn", _complete_appservice_txn
)
async def get_oldest_unsent_txn(self, service):
"""Get the oldest transaction which has not been sent for this
service.
async def get_oldest_unsent_txn(
self, service: ApplicationService
) -> Optional[AppServiceTransaction]:
"""Get the oldest transaction which has not been sent for this service.
Args:
service(ApplicationService): The app service to get the oldest txn.
service: The app service to get the oldest txn.
Returns:
An AppServiceTransaction or None.
"""
@@ -313,7 +333,7 @@ class ApplicationServiceTransactionWorkerStore(
service=service, id=entry["txn_id"], events=events, ephemeral=[]
)
def _get_last_txn(self, txn, service_id):
def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
txn.execute(
"SELECT last_txn FROM application_services_state WHERE as_id=?",
(service_id,),
@@ -324,7 +344,7 @@ class ApplicationServiceTransactionWorkerStore(
else:
return int(last_txn_id[0]) # select 'last_txn' col
async def set_appservice_last_pos(self, pos) -> None:
async def set_appservice_last_pos(self, pos: int) -> None:
def set_appservice_last_pos_txn(txn):
txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
@@ -334,7 +354,9 @@ class ApplicationServiceTransactionWorkerStore(
"set_appservice_last_pos", set_appservice_last_pos_txn
)
async def get_new_events_for_appservice(self, current_id, limit):
async def get_new_events_for_appservice(
self, current_id: int, limit: int
) -> Tuple[int, List[EventBase]]:
"""Get all new events for an appservice"""
def get_new_events_for_appservice_txn(txn):
@@ -394,7 +416,7 @@ class ApplicationServiceTransactionWorkerStore(
)
async def set_type_stream_id_for_appservice(
self, service: ApplicationService, type: str, pos: int
self, service: ApplicationService, type: str, pos: Optional[int]
) -> None:
if type not in ("read_receipt", "presence"):
raise ValueError(

View File

@@ -22,7 +22,7 @@ from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -104,7 +104,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
and original_event.internal_metadata.is_redacted()
):
# Redaction was allowed
pruned_json = frozendict_json_encoder.encode(
pruned_json = json_encoder.encode(
prune_event_dict(
original_event.room_version, original_event.get_dict()
)
@@ -170,7 +170,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
return
# Prune the event's dict then convert it to JSON.
pruned_json = frozendict_json_encoder.encode(
pruned_json = json_encoder.encode(
prune_event_dict(event.room_version, event.get_dict())
)

View File

@@ -34,7 +34,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -769,9 +769,7 @@ class PersistEventsStore:
logger.exception("")
raise
metadata_json = frozendict_json_encoder.encode(
event.internal_metadata.get_dict()
)
metadata_json = json_encoder.encode(event.internal_metadata.get_dict())
sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
txn.execute(sql, (metadata_json, event.event_id))
@@ -826,10 +824,10 @@ class PersistEventsStore:
{
"event_id": event.event_id,
"room_id": event.room_id,
"internal_metadata": frozendict_json_encoder.encode(
"internal_metadata": json_encoder.encode(
event.internal_metadata.get_dict()
),
"json": frozendict_json_encoder.encode(event_dict(event)),
"json": json_encoder.encode(event_dict(event)),
"format_version": event.format_version,
}
for event, _ in events_and_contexts

View File

@@ -1117,11 +1117,13 @@ class EventsWorkerStore(SQLBaseStore):
def get_all_new_forward_event_rows(txn):
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" LEFT JOIN room_memberships USING (event_id)"
" LEFT JOIN rejections USING (event_id)"
" WHERE ? < stream_ordering AND stream_ordering <= ?"
" AND instance_name = ?"
" ORDER BY stream_ordering ASC"
@@ -1152,12 +1154,14 @@ class EventsWorkerStore(SQLBaseStore):
def get_ex_outlier_stream_rows_txn(txn):
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" LEFT JOIN room_memberships USING (event_id)"
" LEFT JOIN rejections USING (event_id)"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" AND out.instance_name = ?"

View File

@@ -18,6 +18,7 @@ import logging
import re
import attr
from frozendict import frozendict
from twisted.internet import defer, task
@@ -31,9 +32,26 @@ def _reject_invalid_json(val):
raise ValueError("Invalid JSON value: '%s'" % val)
# Create a custom encoder to reduce the whitespace produced by JSON encoding and
# ensure that valid JSON is produced.
json_encoder = json.JSONEncoder(allow_nan=False, separators=(",", ":"))
def _handle_frozendict(obj):
"""Helper for json_encoder. Makes frozendicts serializable by returning
the underlying dict
"""
if type(obj) is frozendict:
# fishing the protected dict out of the object is a bit nasty,
# but we don't really want the overhead of copying the dict.
return obj._dict
raise TypeError(
"Object of type %s is not JSON serializable" % obj.__class__.__name__
)
# A custom JSON encoder which:
# * handles frozendicts
# * produces valid JSON (no NaNs etc)
# * reduces redundant whitespace
json_encoder = json.JSONEncoder(
allow_nan=False, separators=(",", ":"), default=_handle_frozendict
)
# Create a custom decoder to reject Python extensions to JSON.
json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)

View File

@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from frozendict import frozendict
@@ -49,23 +47,3 @@ def unfreeze(o):
pass
return o
def _handle_frozendict(obj):
"""Helper for EventEncoder. Makes frozendicts serializable by returning
the underlying dict
"""
if type(obj) is frozendict:
# fishing the protected dict out of the object is a bit nasty,
# but we don't really want the overhead of copying the dict.
return obj._dict
raise TypeError(
"Object of type %s is not JSON serializable" % obj.__class__.__name__
)
# A JSONEncoder which is capable of encoding frozendicts without barfing.
# Additionally reduce the whitespace produced by JSON encoding.
frozendict_json_encoder = json.JSONEncoder(
allow_nan=False, separators=(",", ":"), default=_handle_frozendict,
)

View File

@@ -110,7 +110,7 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k
failure_ts,
retry_interval,
backoff_on_failure=backoff_on_failure,
**kwargs
**kwargs,
)

View File

@@ -269,7 +269,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
homeserver_to_use=GenericWorkerServer,
config=config,
reactor=self.reactor,
**kwargs
**kwargs,
)
# If the instance is in the `instance_map` config then workers may try

View File

@@ -449,7 +449,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
sender=sender,
type="test_event",
content={"body": body},
**kwargs
**kwargs,
)
)

View File

@@ -1118,6 +1118,130 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
class PushersRestTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
self.other_user = self.register_user("user", "pass")
self.url = "/_synapse/admin/v1/users/%s/pushers" % urllib.parse.quote(
self.other_user
)
def test_no_auth(self):
"""
Try to list pushers of an user without authentication.
"""
request, channel = self.make_request("GET", self.url, b"{}")
self.render(request)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self):
"""
If the user is not a server admin, an error is returned.
"""
other_user_token = self.login("user", "pass")
request, channel = self.make_request(
"GET", self.url, access_token=other_user_token,
)
self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self):
"""
Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/pushers"
request, channel = self.make_request(
"GET", url, access_token=self.admin_user_tok,
)
self.render(request)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self):
"""
Tests that a lookup for a user that is not a local returns a 400
"""
url = "/_synapse/admin/v1/users/@unknown_person:unknown_domain/pushers"
request, channel = self.make_request(
"GET", url, access_token=self.admin_user_tok,
)
self.render(request)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"])
def test_get_pushers(self):
"""
Tests that a normal lookup for pushers is successfully
"""
# Get pushers
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
)
self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["total"])
# Register the pusher
other_user_token = self.login("user", "pass")
user_tuple = self.get_success(
self.store.get_user_by_access_token(other_user_token)
)
token_id = user_tuple["token_id"]
self.get_success(
self.hs.get_pusherpool().add_pusher(
user_id=self.other_user,
access_token=token_id,
kind="http",
app_id="m.http",
app_display_name="HTTP Push Notifications",
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
data={"url": "example.com"},
)
)
# Get pushers
request, channel = self.make_request(
"GET", self.url, access_token=self.admin_user_tok,
)
self.render(request)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(1, channel.json_body["total"])
for p in channel.json_body["pushers"]:
self.assertIn("pushkey", p)
self.assertIn("kind", p)
self.assertIn("app_id", p)
self.assertIn("app_display_name", p)
self.assertIn("device_display_name", p)
self.assertIn("profile_tag", p)
self.assertIn("lang", p)
self.assertIn("url", p["data"])
class UserMediaRestTestCase(unittest.HomeserverTestCase):
servlets = [

View File

@@ -380,7 +380,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
pool._runWithConnection,
func,
*args,
**kwargs
**kwargs,
)
def runInteraction(interaction, *args, **kwargs):
@@ -390,7 +390,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
pool._runInteraction,
interaction,
*args,
**kwargs
**kwargs,
)
pool.runWithConnection = runWithConnection

View File

@@ -412,7 +412,7 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
"GET",
"/_matrix/client/r0/admin/users/" + self.user_id,
access_token=access_token,
**make_request_args
**make_request_args,
)
request.requestHeaders.addRawHeader(b"User-Agent", b"Mozzila pizza")

View File

@@ -50,7 +50,7 @@ async def inject_member_event(
sender=sender,
state_key=target,
content=content,
**kwargs
**kwargs,
)

View File

@@ -24,11 +24,6 @@ deps =
pip>=10
setenv =
# we have a pyproject.toml, but don't want pip to use it for building.
# (otherwise we get an error about 'editable mode is not supported for
# pyproject.toml-style projects').
PIP_USE_PEP517 = false
PYTHONDONTWRITEBYTECODE = no_byte_code
COVERAGE_PROCESS_START = {toxinidir}/.coveragerc