1
0

Merge commit 'a7a379006' into anoa/dinsic_release_1_31_0

This commit is contained in:
Andrew Morgan
2021-04-23 17:28:34 +01:00
15 changed files with 307 additions and 64 deletions

1
changelog.d/9573.feature Normal file
View File

@@ -0,0 +1 @@
Add prometheus metrics for number of users successfully registering and logging in.

1
changelog.d/9576.misc Normal file
View File

@@ -0,0 +1 @@
Improve efficiency of calculating the auth chain in large rooms.

1
changelog.d/9580.doc Normal file
View File

@@ -0,0 +1 @@
Clarify the spam checker modules documentation example to mention that `parse_config` is a required method.

1
changelog.d/9586.misc Normal file
View File

@@ -0,0 +1 @@
Convert `synapse.types.Requester` to an `attrs` class.

View File

@@ -14,6 +14,7 @@ The Python class is instantiated with two objects:
* An instance of `synapse.module_api.ModuleApi`.
It then implements methods which return a boolean to alter behavior in Synapse.
All the methods must be defined.
There's a generic method for checking every event (`check_event_for_spam`), as
well as some specific methods:
@@ -24,6 +25,7 @@ well as some specific methods:
* `user_may_publish_room`
* `check_username_for_spam`
* `check_registration_for_spam`
* `check_media_file_for_spam`
The details of each of these methods (as well as their inputs and outputs)
are documented in the `synapse.events.spamcheck.SpamChecker` class.
@@ -31,6 +33,10 @@ are documented in the `synapse.events.spamcheck.SpamChecker` class.
The `ModuleApi` class provides a way for the custom spam checker class to
call back into the homeserver internals.
Additionally, a `parse_config` method is mandatory and receives the plugin config
dictionary. After parsing, It must return an object which will be
passed to `__init__` later.
### Example
```python
@@ -41,6 +47,10 @@ class ExampleSpamChecker:
self.config = config
self.api = api
@staticmethod
def parse_config(config):
return config
async def check_event_for_spam(self, foo):
return False # allow all events

View File

@@ -448,7 +448,7 @@ class FederationServer(FederationBase):
async def _on_state_ids_request_compute(self, room_id, event_id):
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
auth_chain_ids = await self.store.get_auth_chain_ids(state_ids)
auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
async def _on_context_state_request_compute(
@@ -461,7 +461,9 @@ class FederationServer(FederationBase):
else:
pdus = (await self.state.get_current_state(room_id)).values()
auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus])
auth_chain = await self.store.get_auth_chain(
room_id, [pdu.event_id for pdu in pdus]
)
return {
"pdus": [pdu.get_pdu_json() for pdu in pdus],

View File

@@ -337,7 +337,8 @@ class AuthHandler(BaseHandler):
user is too high to proceed
"""
if not requester.access_token_id:
raise ValueError("Cannot validate a user without an access token")
if self._ui_auth_session_timeout:
last_validated = await self.store.get_access_token_last_validated(
requester.access_token_id
@@ -1213,7 +1214,7 @@ class AuthHandler(BaseHandler):
async def delete_access_tokens_for_user(
self,
user_id: str,
except_token_id: Optional[str] = None,
except_token_id: Optional[int] = None,
device_id: Optional[str] = None,
):
"""Invalidate access tokens belonging to a user

View File

@@ -1319,7 +1319,7 @@ class FederationHandler(BaseHandler):
async def on_event_auth(self, event_id: str) -> List[EventBase]:
event = await self.store.get_event(event_id)
auth = await self.store.get_auth_chain(
list(event.auth_event_ids()), include_given=True
event.room_id, list(event.auth_event_ids()), include_given=True
)
return list(auth)
@@ -1653,7 +1653,7 @@ class FederationHandler(BaseHandler):
prev_state_ids = await context.get_prev_state_ids()
state_ids = list(prev_state_ids.values())
auth_chain = await self.store.get_auth_chain(state_ids)
auth_chain = await self.store.get_auth_chain(event.room_id, state_ids)
state = await self.store.get_events(list(prev_state_ids.values()))
@@ -2413,7 +2413,7 @@ class FederationHandler(BaseHandler):
# Now get the current auth_chain for the event.
local_auth_chain = await self.store.get_auth_chain(
list(event.auth_event_ids()), include_given=True
room_id, list(event.auth_event_ids()), include_given=True
)
# TODO: Check if we would now reject event_id. If so we need to tell

View File

@@ -16,7 +16,7 @@
"""Contains functions for registering clients."""
import logging
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
from prometheus_client import Counter
@@ -85,6 +85,7 @@ class RegistrationHandler(BaseHandler):
)
else:
self.device_handler = hs.get_device_handler()
self._register_device_client = self.register_device_inner
self.pusher_pool = hs.get_pusherpool()
self.session_lifetime = hs.config.session_lifetime
@@ -758,17 +759,35 @@ class RegistrationHandler(BaseHandler):
Returns:
Tuple of device ID and access token
"""
res = await self._register_device_client(
user_id=user_id,
device_id=device_id,
initial_display_name=initial_display_name,
is_guest=is_guest,
is_appservice_ghost=is_appservice_ghost,
)
if self.hs.config.worker_app:
r = await self._register_device_client(
user_id=user_id,
device_id=device_id,
initial_display_name=initial_display_name,
is_guest=is_guest,
is_appservice_ghost=is_appservice_ghost,
)
return r["device_id"], r["access_token"]
login_counter.labels(
guest=is_guest,
auth_provider=(auth_provider_id or ""),
).inc()
return res["device_id"], res["access_token"]
async def register_device_inner(
self,
user_id: str,
device_id: Optional[str],
initial_display_name: Optional[str],
is_guest: bool = False,
is_appservice_ghost: bool = False,
) -> Dict[str, str]:
"""Helper for register_device
Does the bits that need doing on the main process. Not for use outside this
class and RegisterDeviceReplicationServlet.
"""
assert not self.hs.config.worker_app
valid_until_ms = None
if self.session_lifetime is not None:
if is_guest:
@@ -793,12 +812,7 @@ class RegistrationHandler(BaseHandler):
is_appservice_ghost=is_appservice_ghost,
)
login_counter.labels(
guest=is_guest,
auth_provider=(auth_provider_id or ""),
).inc()
return (registered_device_id, access_token)
return {"device_id": registered_device_id, "access_token": access_token}
async def post_registration_actions(
self, user_id: str, auth_result: dict, access_token: Optional[str]

View File

@@ -61,7 +61,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
is_guest = content["is_guest"]
is_appservice_ghost = content["is_appservice_ghost"]
device_id, access_token = await self.registration_handler.register_device(
res = await self.registration_handler.register_device_inner(
user_id,
device_id,
initial_display_name,
@@ -69,7 +69,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
is_appservice_ghost=is_appservice_ghost,
)
return 200, {"device_id": device_id, "access_token": access_token}
return 200, res
def register_servlets(hs, http_server):

View File

@@ -35,6 +35,7 @@ from synapse.api.errors import (
from synapse.config._base import ConfigError
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import random_string
@@ -145,7 +146,7 @@ class MediaRepository:
upload_name: Optional[str],
content: IO,
content_length: int,
auth_user: str,
auth_user: UserID,
) -> str:
"""Store uploaded content for a local user and return the mxc URL

View File

@@ -54,11 +54,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
) # type: LruCache[str, List[Tuple[str, int]]]
async def get_auth_chain(
self, event_ids: Collection[str], include_given: bool = False
self, room_id: str, event_ids: Collection[str], include_given: bool = False
) -> List[EventBase]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
room_id: The room the event is in.
event_ids: state events
include_given: include the given events in result
@@ -66,24 +67,44 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
list of events
"""
event_ids = await self.get_auth_chain_ids(
event_ids, include_given=include_given
room_id, event_ids, include_given=include_given
)
return await self.get_events_as_list(event_ids)
async def get_auth_chain_ids(
self,
room_id: str,
event_ids: Collection[str],
include_given: bool = False,
) -> List[str]:
"""Get auth events for given event_ids. The events *must* be state events.
Args:
room_id: The room the event is in.
event_ids: state events
include_given: include the given events in result
Returns:
An awaitable which resolve to a list of event_ids
list of event_ids
"""
# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id)
if room["has_auth_chain_index"]:
try:
return await self.db_pool.runInteraction(
"get_auth_chain_ids_chains",
self._get_auth_chain_ids_using_cover_index_txn,
room_id,
event_ids,
include_given,
)
except _NoChainCoverIndex:
# For whatever reason we don't actually have a chain cover index
# for the events in question, so we fall back to the old method.
pass
return await self.db_pool.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
@@ -91,9 +112,130 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
include_given,
)
def _get_auth_chain_ids_using_cover_index_txn(
self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
) -> List[str]:
"""Calculates the auth chain IDs using the chain index."""
# First we look up the chain ID/sequence numbers for the given events.
initial_events = set(event_ids)
# All the events that we've found that are reachable from the events.
seen_events = set() # type: Set[str]
# A map from chain ID to max sequence number of the given events.
event_chains = {} # type: Dict[int, int]
sql = """
SELECT event_id, chain_id, sequence_number
FROM event_auth_chains
WHERE %s
"""
for batch in batch_iter(initial_events, 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "event_id", batch
)
txn.execute(sql % (clause,), args)
for event_id, chain_id, sequence_number in txn:
seen_events.add(event_id)
event_chains[chain_id] = max(
sequence_number, event_chains.get(chain_id, 0)
)
# Check that we actually have a chain ID for all the events.
events_missing_chain_info = initial_events.difference(seen_events)
if events_missing_chain_info:
# This can happen due to e.g. downgrade/upgrade of the server. We
# raise an exception and fall back to the previous algorithm.
logger.info(
"Unexpectedly found that events don't have chain IDs in room %s: %s",
room_id,
events_missing_chain_info,
)
raise _NoChainCoverIndex(room_id)
# Now we look up all links for the chains we have, adding chains that
# are reachable from any event.
sql = """
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM event_auth_chain_links
WHERE %s
"""
# A map from chain ID to max sequence number *reachable* from any event ID.
chains = {} # type: Dict[int, int]
# Add all linked chains reachable from initial set of chains.
for batch in batch_iter(event_chains, 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch
)
txn.execute(sql % (clause,), args)
for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
# chains are only reachable if the origin sequence number of
# the link is less than the max sequence number in the
# origin chain.
if origin_sequence_number <= event_chains.get(origin_chain_id, 0):
chains[target_chain_id] = max(
target_sequence_number,
chains.get(target_chain_id, 0),
)
# Add the initial set of chains, excluding the sequence corresponding to
# initial event.
for chain_id, seq_no in event_chains.items():
chains[chain_id] = max(seq_no - 1, chains.get(chain_id, 0))
# Now for each chain we figure out the maximum sequence number reachable
# from *any* event ID. Events with a sequence less than that are in the
# auth chain.
if include_given:
results = initial_events
else:
results = set()
if isinstance(self.database_engine, PostgresEngine):
# We can use `execute_values` to efficiently fetch the gaps when
# using postgres.
sql = """
SELECT event_id
FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq)
WHERE
c.chain_id = l.chain_id
AND sequence_number <= max_seq
"""
rows = txn.execute_values(sql, chains.items())
results.update(r for r, in rows)
else:
# For SQLite we just fall back to doing a noddy for loop.
sql = """
SELECT event_id FROM event_auth_chains
WHERE chain_id = ? AND sequence_number <= ?
"""
for chain_id, max_no in chains.items():
txn.execute(sql, (chain_id, max_no))
results.update(r for r, in txn)
return list(results)
def _get_auth_chain_ids_txn(
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
) -> List[str]:
"""Calculates the auth chain IDs.
This is used when we don't have a cover index for the room.
"""
if include_given:
results = set(event_ids)
else:

View File

@@ -16,7 +16,7 @@
# limitations under the License.
import logging
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import attr
@@ -1614,7 +1614,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
async def user_delete_access_tokens(
self,
user_id: str,
except_token_id: Optional[str] = None,
except_token_id: Optional[int] = None,
device_id: Optional[str] = None,
) -> List[Tuple[str, int, Optional[str]]]:
"""
@@ -1637,7 +1637,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
items = keyvalues.items()
where_clause = " AND ".join(k + " = ?" for k, _ in items)
values = [v for _, v in items]
values = [v for _, v in items] # type: List[Union[str, int]]
if except_token_id:
where_clause += " AND id != ?"
values.append(except_token_id)

View File

@@ -84,33 +84,32 @@ class ISynapseReactor(
"""The interfaces necessary for Synapse to function."""
class Requester(
namedtuple(
"Requester",
[
"user",
"access_token_id",
"is_guest",
"shadow_banned",
"device_id",
"app_service",
"authenticated_entity",
],
)
):
@attr.s(frozen=True, slots=True)
class Requester:
"""
Represents the user making a request
Attributes:
user (UserID): id of the user making the request
access_token_id (int|None): *ID* of the access token used for this
user: id of the user making the request
access_token_id: *ID* of the access token used for this
request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user
shadow_banned (bool): True if the user making this request has been shadow-banned.
device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user
is_guest: True if the user making this request is a guest user
shadow_banned: True if the user making this request has been shadow-banned.
device_id: device_id which was set at authentication time
app_service: the AS requesting on behalf of the user
authenticated_entity: The entity that authenticated when making the request.
This is different to the user_id when an admin user or the server is
"puppeting" the user.
"""
user = attr.ib(type="UserID")
access_token_id = attr.ib(type=Optional[int])
is_guest = attr.ib(type=bool)
shadow_banned = attr.ib(type=bool)
device_id = attr.ib(type=Optional[str])
app_service = attr.ib(type=Optional["ApplicationService"])
authenticated_entity = attr.ib(type=str)
def serialize(self):
"""Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize`
@@ -158,23 +157,23 @@ class Requester(
def create_requester(
user_id: Union[str, "UserID"],
access_token_id: Optional[int] = None,
is_guest: Optional[bool] = False,
shadow_banned: Optional[bool] = False,
is_guest: bool = False,
shadow_banned: bool = False,
device_id: Optional[str] = None,
app_service: Optional["ApplicationService"] = None,
authenticated_entity: Optional[str] = None,
):
) -> Requester:
"""
Create a new ``Requester`` object
Args:
user_id (str|UserID): id of the user making the request
access_token_id (int|None): *ID* of the access token used for this
user_id: id of the user making the request
access_token_id: *ID* of the access token used for this
request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user
shadow_banned (bool): True if the user making this request is shadow-banned.
device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user
is_guest: True if the user making this request is a guest user
shadow_banned: True if the user making this request is shadow-banned.
device_id: device_id which was set at authentication time
app_service: the AS requesting on behalf of the user
authenticated_entity: The entity that authenticated when making the request.
This is different to the user_id when an admin user or the server is
"puppeting" the user.

View File

@@ -118,8 +118,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
self.assertTrue(r == [room2] or r == [room3])
@parameterized.expand([(True,), (False,)])
def test_auth_difference(self, use_chain_cover_index: bool):
def _setup_auth_chain(self, use_chain_cover_index: bool) -> str:
room_id = "@ROOM:local"
# The silly auth graph we use to test the auth difference algorithm,
@@ -165,7 +164,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
"j": 1,
}
# Mark the room as not having a cover index
# Mark the room as maybe having a cover index.
def store_room(txn):
self.store.db_pool.simple_insert_txn(
@@ -222,6 +221,77 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
)
return room_id
@parameterized.expand([(True,), (False,)])
def test_auth_chain_ids(self, use_chain_cover_index: bool):
room_id = self._setup_auth_chain(use_chain_cover_index)
# a and b have the same auth chain.
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["a"]))
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["b"]))
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
auth_chain_ids = self.get_success(
self.store.get_auth_chain_ids(room_id, ["a", "b"])
)
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["c"]))
self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
# d and e have the same auth chain.
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["d"]))
self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["e"]))
self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"])
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["f"]))
self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"])
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["g"]))
self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"])
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"]))
self.assertEqual(auth_chain_ids, ["k"])
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"]))
self.assertEqual(auth_chain_ids, ["j"])
# j and k have no parents.
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"]))
self.assertEqual(auth_chain_ids, [])
auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"]))
self.assertEqual(auth_chain_ids, [])
# More complex input sequences.
auth_chain_ids = self.get_success(
self.store.get_auth_chain_ids(room_id, ["b", "c", "d"])
)
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
auth_chain_ids = self.get_success(
self.store.get_auth_chain_ids(room_id, ["h", "i"])
)
self.assertCountEqual(auth_chain_ids, ["k", "j"])
# e gets returned even though include_given is false, but it is in the
# auth chain of b.
auth_chain_ids = self.get_success(
self.store.get_auth_chain_ids(room_id, ["b", "e"])
)
self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"])
# Test include_given.
auth_chain_ids = self.get_success(
self.store.get_auth_chain_ids(room_id, ["i"], include_given=True)
)
self.assertCountEqual(auth_chain_ids, ["i", "j"])
@parameterized.expand([(True,), (False,)])
def test_auth_difference(self, use_chain_cover_index: bool):
room_id = self._setup_auth_chain(use_chain_cover_index)
# Now actually test that various combinations give the right result:
difference = self.get_success(