1
0

Merge remote-tracking branch 'origin/rei/msc3202_otks_fbks' into anoa/e2e_as_internal_testing

This commit is contained in:
Travis Ralston
2021-12-28 20:31:10 -07:00
15 changed files with 536 additions and 56 deletions

View File

@@ -0,0 +1 @@
Add support for MSC3202: sending one-time key counts and fallback key usage states to Application Services.

View File

@@ -14,7 +14,7 @@
import logging
import re
from enum import Enum
from typing import TYPE_CHECKING, Iterable, List, Match, Optional
from typing import TYPE_CHECKING, Dict, Iterable, List, Match, Optional
from synapse.api.constants import EventTypes
from synapse.events import EventBase
@@ -27,6 +27,14 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Type for the `device_one_time_key_counts` field in an appservice transaction
# user ID -> {device ID -> {algorithm -> count}}
TransactionOneTimeKeyCounts = Dict[str, Dict[str, Dict[str, int]]]
# Type for the `device_unused_fallback_keys` field in an appservice transaction
# user ID -> {device ID -> [algorithm]}
TransactionUnusedFallbackKeys = Dict[str, Dict[str, List[str]]]
class ApplicationServiceState(Enum):
DOWN = "down"
@@ -61,6 +69,7 @@ class ApplicationService:
rate_limited=True,
ip_range_whitelist=None,
supports_ephemeral=False,
msc3202_transaction_extensions: bool = False,
):
self.token = token
self.url = (
@@ -73,6 +82,7 @@ class ApplicationService:
self.id = id
self.ip_range_whitelist = ip_range_whitelist
self.supports_ephemeral = supports_ephemeral
self.msc3202_transaction_extensions = msc3202_transaction_extensions
if "|" in self.id:
raise Exception("application service ID cannot contain '|' character")
@@ -371,6 +381,8 @@ class AppServiceTransaction:
ephemeral: List[JsonDict],
to_device_messages: List[JsonDict],
device_list_summary: DeviceLists,
one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys,
):
self.service = service
self.id = id
@@ -378,6 +390,8 @@ class AppServiceTransaction:
self.ephemeral = ephemeral
self.to_device_messages = to_device_messages
self.device_list_summary = device_list_summary
self.one_time_key_counts = one_time_key_counts
self.unused_fallback_keys = unused_fallback_keys
async def send(self, as_api: "ApplicationServiceApi") -> bool:
"""Sends this transaction using the provided AS API interface.
@@ -393,6 +407,8 @@ class AppServiceTransaction:
ephemeral=self.ephemeral,
to_device_messages=self.to_device_messages,
device_list_summary=self.device_list_summary,
one_time_key_counts=self.one_time_key_counts,
unused_fallback_keys=self.unused_fallback_keys,
txn_id=self.id,
)

View File

@@ -20,6 +20,11 @@ from prometheus_client import Counter
from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException
from synapse.appservice import (
ApplicationService,
TransactionOneTimeKeyCounts,
TransactionUnusedFallbackKeys,
)
from synapse.events import EventBase
from synapse.events.utils import serialize_event
from synapse.http.client import SimpleHttpClient
@@ -27,7 +32,6 @@ from synapse.types import DeviceLists, JsonDict, ThirdPartyInstanceID
from synapse.util.caches.response_cache import ResponseCache
if TYPE_CHECKING:
from synapse.appservice import ApplicationService
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -207,6 +211,8 @@ class ApplicationServiceApi(SimpleHttpClient):
ephemeral: List[JsonDict],
to_device_messages: List[JsonDict],
device_list_summary: DeviceLists,
one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys,
txn_id: Optional[int] = None,
) -> bool:
"""
@@ -258,6 +264,16 @@ class ApplicationServiceApi(SimpleHttpClient):
}
)
if service.msc3202_transaction_extensions:
if one_time_key_counts:
body[
"org.matrix.msc3202.device_one_time_key_counts"
] = one_time_key_counts
if unused_fallback_keys:
body[
"org.matrix.msc3202.device_unused_fallback_keys"
] = unused_fallback_keys
try:
await self.put_json(
uri=uri,

View File

@@ -48,14 +48,22 @@ This is all tied together by the AppServiceScheduler which DIs the required
components.
"""
import logging
from typing import Dict, Iterable, List, Optional
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.appservice import (
ApplicationService,
ApplicationServiceState,
TransactionOneTimeKeyCounts,
TransactionUnusedFallbackKeys,
)
from synapse.events import EventBase
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import DeviceLists, JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -81,7 +89,7 @@ class ApplicationServiceScheduler:
self.as_api = hs.get_application_service_api()
self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api)
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock, hs)
async def start(self):
logger.info("Starting appservice scheduler")
@@ -151,7 +159,7 @@ class _ServiceQueuer:
appservice at a given time.
"""
def __init__(self, txn_ctrl, clock):
def __init__(self, txn_ctrl, clock, hs: "HomeServer"):
# dict of {service_id: [events]}
self.queued_events: Dict[str, List[EventBase]] = {}
# dict of {service_id: [event_json]}
@@ -162,9 +170,13 @@ class _ServiceQueuer:
self.queued_device_list_summaries: Dict[str, List[DeviceLists]] = {}
# the appservices which currently have a transaction in flight
self.requests_in_flight = set()
self.requests_in_flight: Set[str] = set()
self.txn_ctrl = txn_ctrl
self.clock = clock
self._msc3202_transaction_extensions_enabled: bool = (
hs.config.experimental.msc3202_transaction_extensions
)
self._store = hs.get_datastore()
def start_background_request(self, service):
# start a sender for this appservice if we don't already have one
@@ -235,6 +247,26 @@ class _ServiceQueuer:
):
return
one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None
unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None
if (
self._msc3202_transaction_extensions_enabled
and service.msc3202_transaction_extensions
):
# Lazily compute the one-time key counts and fallback key
# usage states for the users which are mentioned in this
# transaction, as well as the appservice's sender.
interesting_users = await self._determine_interesting_users_for_msc3202_otk_counts_and_fallback_keys(
service, events, ephemeral, to_device_messages_to_send
)
(
one_time_key_counts,
unused_fallback_keys,
) = await self._compute_msc3202_otk_counts_and_fallback_keys(
interesting_users
)
try:
await self.txn_ctrl.send(
service,
@@ -242,12 +274,66 @@ class _ServiceQueuer:
ephemeral,
to_device_messages_to_send,
device_list_summary,
one_time_key_counts,
unused_fallback_keys,
)
except Exception:
logger.exception("AS request failed")
finally:
self.requests_in_flight.discard(service.id)
async def _determine_interesting_users_for_msc3202_otk_counts_and_fallback_keys(
self,
service: ApplicationService,
events: Iterable[EventBase],
ephemerals: Iterable[JsonDict],
to_device_messages: Iterable[JsonDict],
) -> Set[str]:
"""
Given a list of the events, ephemeral messages and to-device messaeges,
compute a list of application services users that may have interesting
updates to the one-time key counts or fallback key usage.
"""
interesting_users: Set[str] = set()
# The sender is always included
interesting_users.add(service.sender)
# All AS users that would receive the PDUs or EDUs sent to these rooms
# are classed as 'interesting'.
rooms_of_interesting_users: Set[str] = set()
# PDUs
rooms_of_interesting_users.update(event.room_id for event in events)
# EDUs
rooms_of_interesting_users.update(
ephemeral["room_id"] for ephemeral in ephemerals
)
# Look up the AS users in those rooms
for room_id in rooms_of_interesting_users:
interesting_users.update(
await self._store.get_app_service_users_in_room(room_id, service)
)
# Add recipients of to-device messages.
# device_message["user_id"] is the ID of the recipient.
interesting_users.update(
device_message["user_id"] for device_message in to_device_messages
)
return interesting_users
async def _compute_msc3202_otk_counts_and_fallback_keys(
self, users: Set[str]
) -> Tuple[TransactionOneTimeKeyCounts, TransactionUnusedFallbackKeys]:
"""
Given a list of application service users that are interesting,
compute one-time key counts and fallback key usages for the users.
"""
otk_counts = await self._store.count_bulk_e2e_one_time_keys_for_as(users)
unused_fbks = await self._store.get_e2e_bulk_unused_fallback_key_types(users)
return otk_counts, unused_fbks
class _TransactionController:
"""Transaction manager.
@@ -281,6 +367,8 @@ class _TransactionController:
ephemeral: Optional[List[JsonDict]] = None,
to_device_messages: Optional[List[JsonDict]] = None,
device_list_summary: Optional[DeviceLists] = None,
one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None,
unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None,
) -> None:
"""
Create a transaction with the given data and send to the provided
@@ -292,6 +380,10 @@ class _TransactionController:
ephemeral: The ephemeral events to include in the transaction.
to_device_messages: The to-device messages to include in the transaction.
device_list_summary: The device list summary to include in the transaction.
one_time_key_counts: Counts of remaining one-time keys for relevant
appservice devices in the transaction.
unused_fallback_keys: Lists of unused fallback keys for relevant
appservice devices in the transaction.
"""
try:
txn = await self.store.create_appservice_txn(
@@ -300,6 +392,8 @@ class _TransactionController:
ephemeral=ephemeral or [],
to_device_messages=to_device_messages or [],
device_list_summary=device_list_summary or DeviceLists(),
one_time_key_counts=one_time_key_counts or {},
unused_fallback_keys=unused_fallback_keys or {},
)
service_is_up = await self._is_service_up(service)
if service_is_up:

View File

@@ -167,6 +167,16 @@ def _load_appservice(
supports_ephemeral = as_info.get("de.sorunome.msc2409.push_ephemeral", False)
# Opt-in flag for the MSC3202-specific transactional behaviour.
# When enabled, appservice transactions contain the following information:
# - device One-Time Key counts
# - device unused fallback key usage states
msc3202_transaction_extensions = as_info.get("org.matrix.msc3202", False)
if not isinstance(msc3202_transaction_extensions, bool):
raise ValueError(
"The `org.matrix.msc3202` option should be true or false if specified."
)
return ApplicationService(
token=as_info["as_token"],
hostname=hostname,
@@ -175,8 +185,9 @@ def _load_appservice(
hs_token=as_info["hs_token"],
sender=user_id,
id=as_info["id"],
supports_ephemeral=supports_ephemeral,
protocols=protocols,
rate_limited=rate_limited,
ip_range_whitelist=ip_range_whitelist,
supports_ephemeral=supports_ephemeral,
msc3202_transaction_extensions=msc3202_transaction_extensions,
)

View File

@@ -66,3 +66,9 @@ class ExperimentalConfig(Config):
self.msc3202_device_masquerading_enabled: bool = experimental.get(
"msc3202_device_masquerading", False
)
# Portion of MSC3202 related to transaction extensions:
# sending one-time key counts and fallback key usage to application services.
self.msc3202_transaction_extensions: bool = experimental.get(
"msc3202_transaction_extensions", False
)

View File

@@ -20,15 +20,19 @@ from synapse.appservice import (
ApplicationService,
ApplicationServiceState,
AppServiceTransaction,
TransactionOneTimeKeyCounts,
TransactionUnusedFallbackKeys,
)
from synapse.config.appservice import load_appservices
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage._base import db_to_json
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.types import Connection
from synapse.types import DeviceLists, JsonDict
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -57,8 +61,13 @@ def _make_exclusive_regex(
return exclusive_user_pattern
class ApplicationServiceWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
self.services_cache = load_appservices(
hs.hostname, hs.config.appservice.app_service_config_files
)
@@ -120,6 +129,14 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
return service
return None
# OSTD cache invalidation
@cached(iterable=True, prune_unread_entries=False)
async def get_app_service_users_in_room(
self, room_id: str, app_service: "ApplicationService"
) -> List[str]:
users_in_room = await self.get_users_in_room(room_id)
return list(filter(app_service.is_interested_in_user, users_in_room))
class ApplicationServiceStore(ApplicationServiceWorkerStore):
# This is currently empty due to there not being any AS storage functions
@@ -196,6 +213,8 @@ class ApplicationServiceTransactionWorkerStore(
ephemeral: List[JsonDict],
to_device_messages: List[JsonDict],
device_list_summary: DeviceLists,
one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys,
) -> AppServiceTransaction:
"""Atomically creates a new transaction for this application service
with the given list of events. Ephemeral events are NOT persisted to the
@@ -207,6 +226,10 @@ class ApplicationServiceTransactionWorkerStore(
ephemeral: A list of ephemeral events to put in the transaction.
to_device_messages: A list of to-device messages to put in the transaction.
device_list_summary: The device list summary to include in the transaction.
one_time_key_counts: Counts of remaining one-time keys for relevant
appservice devices in the transaction.
unused_fallback_keys: Lists of unused fallback keys for relevant
appservice devices in the transaction.
Returns:
A new transaction.
@@ -243,6 +266,8 @@ class ApplicationServiceTransactionWorkerStore(
ephemeral=ephemeral,
to_device_messages=to_device_messages,
device_list_summary=device_list_summary,
one_time_key_counts=one_time_key_counts,
unused_fallback_keys=unused_fallback_keys,
)
return await self.db_pool.runInteraction(
@@ -334,6 +359,8 @@ class ApplicationServiceTransactionWorkerStore(
events = await self.get_events_as_list(event_ids)
# TODO: should we recalculate one-time key counts and unused fallback
# key counts here?
return AppServiceTransaction(
service=service,
id=entry["txn_id"],
@@ -341,6 +368,8 @@ class ApplicationServiceTransactionWorkerStore(
ephemeral=[],
to_device_messages=[],
device_list_summary=DeviceLists(),
one_time_key_counts={},
unused_fallback_keys={},
)
def _get_last_txn(self, txn, service_id: Optional[str]) -> int:

View File

@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json
@@ -22,9 +22,17 @@ from canonicaljson import encode_canonical_json
from twisted.enterprise.adbapi import Connection
from synapse.api.constants import DeviceKeyAlgorithms
from synapse.appservice import (
TransactionOneTimeKeyCounts,
TransactionUnusedFallbackKeys,
)
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor
from synapse.types import JsonDict
@@ -397,6 +405,104 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
async def count_bulk_e2e_one_time_keys_for_as(
self, user_ids: Collection[str]
) -> TransactionOneTimeKeyCounts:
"""
Counts, in bulk, the one-time keys for all the users specified.
Intended to be used by application services for populating OTK counts in
transactions.
Return structure is of the shape:
user_id -> device_id -> algorithm -> count
"""
def _count_bulk_e2e_one_time_keys_txn(
txn: LoggingTransaction,
) -> TransactionOneTimeKeyCounts:
user_in_where_clause, user_parameters = make_in_list_sql_clause(
self.database_engine, "user_id", user_ids
)
sql = f"""
SELECT user_id, device_id, algorithm, COUNT(key_id)
FROM devices
LEFT JOIN e2e_one_time_keys_json USING (user_id, device_id)
WHERE {user_in_where_clause}
GROUP BY user_id, device_id, algorithm
"""
txn.execute(sql, user_parameters)
result = {}
for user_id, device_id, algorithm, count in txn:
device_count_by_algo = result.setdefault(user_id, {}).setdefault(
device_id, {}
)
if algorithm is not None:
# algorithm will be None if this device has no keys.
device_count_by_algo[algorithm] = count
return result
return await self.db_pool.runInteraction(
"count_bulk_e2e_one_time_keys", _count_bulk_e2e_one_time_keys_txn
)
async def get_e2e_bulk_unused_fallback_key_types(
self, user_ids: Collection[str]
) -> TransactionUnusedFallbackKeys:
"""
Finds, in bulk, the types of unused fallback keys for all the users specified.
Intended to be used by application services for populating unused fallback
keys in transactions.
Return structure is of the shape:
user_id -> device_id -> algorithms
"""
if len(user_ids) == 0:
return {}
def _get_bulk_e2e_unused_fallback_keys_txn(
txn: LoggingTransaction,
) -> TransactionUnusedFallbackKeys:
user_in_where_clause, user_parameters = make_in_list_sql_clause(
self.database_engine, "devices.user_id", user_ids
)
# We can't use USING here because we require the `.used` condition
# to be part of the JOIN condition so that we generate empty lists
# when all keys are used (as opposed to just when there are no keys at all).
sql = f"""
SELECT devices.user_id, devices.device_id, algorithm
FROM devices
LEFT JOIN e2e_fallback_keys_json AS fallback_keys
/* We can't use USING here because we require the `.used`
condition to be part of the JOIN condition so that we
generate empty lists when all keys are used (as opposed
to just when there are no keys at all). */
ON devices.user_id = fallback_keys.user_id
AND devices.device_id = fallback_keys.device_id
AND NOT fallback_keys.used
WHERE
{user_in_where_clause}
"""
txn.execute(sql, user_parameters)
result = {}
for user_id, device_id, algorithm in txn:
device_unused_keys = result.setdefault(user_id, {}).setdefault(
device_id, []
)
if algorithm is not None:
# algorithm will be None if this device has no keys.
device_unused_keys.append(algorithm)
return result
return await self.db_pool.runInteraction(
"_get_bulk_e2e_unused_fallback_keys", _get_bulk_e2e_unused_fallback_keys_txn
)
async def set_e2e_fallback_keys(
self, user_id: str, device_id: str, fallback_keys: JsonDict
) -> None:

View File

@@ -68,6 +68,8 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
ephemeral=[],
to_device_messages=[],
device_list_summary=DeviceLists(),
one_time_key_counts={},
unused_fallback_keys={},
)
self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made
txn.complete.assert_called_once_with(self.store) # txn completed
@@ -94,6 +96,8 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
ephemeral=[],
to_device_messages=[],
device_list_summary=DeviceLists(),
one_time_key_counts={},
unused_fallback_keys={},
)
self.assertEquals(0, txn.send.call_count) # txn not sent though
self.assertEquals(0, txn.complete.call_count) # or completed
@@ -121,6 +125,8 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
ephemeral=[],
to_device_messages=[],
device_list_summary=DeviceLists(),
one_time_key_counts={},
unused_fallback_keys={},
)
self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made
self.assertEquals(1, self.recoverer.recover.call_count) # and invoked
@@ -222,9 +228,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
service = Mock(id=4)
event = Mock()
self.scheduler.enqueue_for_appservice(service, events=[event])
self.txn_ctrl.send.assert_called_once_with(
service, [event], [], [], DeviceLists()
)
self.txn_ctrl.send.assert_called_once_with(service, [event], [], [], DeviceLists(), None, None)
def test_send_single_event_with_queue(self):
d = defer.Deferred()
@@ -239,12 +243,12 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# (call enqueue_for_appservice multiple times deliberately)
self.scheduler.enqueue_for_appservice(service, events=[event2])
self.scheduler.enqueue_for_appservice(service, events=[event3])
self.txn_ctrl.send.assert_called_with(service, [event], [], [], DeviceLists())
self.txn_ctrl.send.assert_called_with(service, [event], [], [], DeviceLists(), None, None)
self.assertEquals(1, self.txn_ctrl.send.call_count)
# Resolve the send event: expect the queued events to be sent
d.callback(service)
self.txn_ctrl.send.assert_called_with(
service, [event2, event3], [], [], DeviceLists()
service, [event2, event3], [], [], DeviceLists(), None, None
)
self.assertEquals(2, self.txn_ctrl.send.call_count)
@@ -271,21 +275,15 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# send events for different ASes and make sure they are sent
self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event])
self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2])
self.txn_ctrl.send.assert_called_with(
srv1, [srv_1_event], [], [], DeviceLists()
)
self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [], DeviceLists(), None, None)
self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event])
self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2])
self.txn_ctrl.send.assert_called_with(
srv2, [srv_2_event], [], [], DeviceLists()
)
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [], DeviceLists(), None, None)
# make sure callbacks for a service only send queued events for THAT
# service
srv_2_defer.callback(srv2)
self.txn_ctrl.send.assert_called_with(
srv2, [srv_2_event2], [], [], DeviceLists()
)
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [], DeviceLists(), None, None)
self.assertEquals(3, self.txn_ctrl.send.call_count)
def test_send_large_txns(self):
@@ -305,17 +303,17 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# Expect the first event to be sent immediately.
self.txn_ctrl.send.assert_called_with(
service, [event_list[0]], [], [], DeviceLists()
service, [event_list[0]], [], [], DeviceLists(), None, None
)
srv_1_defer.callback(service)
# Then send the next 100 events
self.txn_ctrl.send.assert_called_with(
service, event_list[1:101], [], [], DeviceLists()
service, event_list[1:101], [], [], DeviceLists(), None, None
)
srv_2_defer.callback(service)
# Then the final 99 events
self.txn_ctrl.send.assert_called_with(
service, event_list[101:], [], [], DeviceLists()
service, event_list[101:], [], [], DeviceLists(), None, None
)
self.assertEquals(3, self.txn_ctrl.send.call_count)
@@ -325,7 +323,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
event_list = [Mock(name="event")]
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with(
service, [], event_list, [], DeviceLists()
service, [], event_list, [], DeviceLists(), None, None
)
def test_send_multiple_ephemeral_no_queue(self):
@@ -334,7 +332,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")]
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with(
service, [], event_list, [], DeviceLists()
service, [], event_list, [], DeviceLists(), None, None
)
def test_send_single_ephemeral_with_queue(self):
@@ -350,15 +348,13 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# Send more events: expect send() to NOT be called multiple times.
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2)
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3)
self.txn_ctrl.send.assert_called_with(
service, [], event_list_1, [], DeviceLists()
)
self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [], DeviceLists(), None, None)
self.assertEquals(1, self.txn_ctrl.send.call_count)
# Resolve txn_ctrl.send
d.callback(service)
# Expect the queued events to be sent
self.txn_ctrl.send.assert_called_with(
service, [], event_list_2 + event_list_3, [], DeviceLists()
service, [], event_list_2 + event_list_3, [], DeviceLists(), None, None
)
self.assertEquals(2, self.txn_ctrl.send.call_count)
@@ -372,10 +368,8 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
event_list = first_chunk + second_chunk
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with(
service, [], first_chunk, [], DeviceLists()
service, [], first_chunk, [], DeviceLists(), None, None
)
d.callback(service)
self.txn_ctrl.send.assert_called_with(
service, [], second_chunk, [], DeviceLists()
)
self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [], DeviceLists(), None, None)
self.assertEquals(2, self.txn_ctrl.send.call_count)

View File

@@ -16,17 +16,25 @@ from typing import Dict, Iterable, List, Optional
from unittest.mock import Mock
from twisted.internet import defer
from twisted.internet.testing import MemoryReactor
import synapse.rest.admin
import synapse.storage
from synapse.appservice import ApplicationService
from synapse.appservice import (
ApplicationService,
TransactionOneTimeKeyCounts,
TransactionUnusedFallbackKeys,
)
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.rest.client import login, receipts, room, sendtodevice
from synapse.rest.client import login, receipts, register, room, sendtodevice
from synapse.server import HomeServer
from synapse.types import RoomStreamToken
from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
from tests.test_utils import make_awaitable, simple_async_mock
from tests.unittest import override_config
from tests.utils import MockClock
@@ -437,6 +445,8 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
_ephemeral,
to_device_messages,
_device_list_summary,
_otks,
_fbks,
) = self.send_mock.call_args[0]
# Assert that this was the same to-device message that local_user sent
@@ -554,6 +564,8 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
_ephemeral,
to_device_messages,
_device_list_summary,
_otks,
_fbks,
) = call[0]
# Check that this was made to an interested service
@@ -596,3 +608,176 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
self._services.append(appservice)
return appservice
class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
# Argument indices for pulling out arguments from a `send_mock`.
ARG_OTK_COUNTS = 4
ARG_FALLBACK_KEYS = 5
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
register.register_servlets,
room.register_servlets,
sendtodevice.register_servlets,
receipts.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
# we can track what's going out
self.send_mock = simple_async_mock()
hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # We assign to a method.
# Define an application service for the tests
self._service_token = "VERYSECRET"
self._service = ApplicationService(
self._service_token,
"as1.invalid",
"as1",
"@as.sender:test",
namespaces={
"users": [
{"regex": "@_as_.*:test", "exclusive": True},
{"regex": "@as.sender:test", "exclusive": True},
]
},
msc3202_transaction_extensions=True,
)
self.hs.get_datastore().services_cache = [self._service]
# Register some appservice users
self._sender_user, self._sender_device = self.register_appservice_user(
"as.sender", self._service_token
)
self._namespaced_user, self._namespaced_device = self.register_appservice_user(
"_as_user1", self._service_token
)
# Register a real user as well.
self._real_user = self.register_user("real.user", "meow")
self._real_user_token = self.login("real.user", "meow")
async def _add_otks_for_device(
self, user_id: str, device_id: str, otk_count: int
) -> None:
"""
Add some dummy keys. It doesn't matter if they're not a real algorithm;
that should be opaque to the server anyway.
"""
await self.hs.get_datastore().add_e2e_one_time_keys(
user_id,
device_id,
self.clock.time_msec(),
[("algo", f"k{i}", "{}") for i in range(otk_count)],
)
async def _add_fallback_key_for_device(
self, user_id: str, device_id: str, used: bool
) -> None:
"""
Adds a fake fallback key to a device, optionally marking it as used
right away.
"""
store = self.hs.get_datastore()
await store.set_e2e_fallback_keys(user_id, device_id, {"algo:fk": "fall back!"})
if used is True:
# Mark the key as used
await store.db_pool.simple_update_one(
table="e2e_fallback_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": "algo",
"key_id": "fk",
},
updatevalues={"used": True},
desc="_get_fallback_key_set_used",
)
def _set_up_devices_and_a_room(self) -> str:
"""
Helper to set up devices for all the users
and a room for the users to talk in.
"""
async def preparation():
await self._add_otks_for_device(self._sender_user, self._sender_device, 42)
await self._add_fallback_key_for_device(
self._sender_user, self._sender_device, used=True
)
await self._add_otks_for_device(
self._namespaced_user, self._namespaced_device, 36
)
await self._add_fallback_key_for_device(
self._namespaced_user, self._namespaced_device, used=False
)
# Register a device for the real user, too, so that we can later ensure
# that we don't leak information to the AS about the non-AS user.
await self.hs.get_datastore().store_device(
self._real_user, "REALDEV", "UltraMatrix 3000"
)
await self._add_otks_for_device(self._real_user, "REALDEV", 50)
self.get_success(preparation())
room_id = self.helper.create_room_as(
self._real_user, is_public=True, tok=self._real_user_token
)
self.helper.join(
room_id,
self._namespaced_user,
tok=self._service_token,
appservice_user_id=self._namespaced_user,
)
# Check it was called for sanity. (This was to send the join event to the AS.)
self.send_mock.assert_called()
self.send_mock.reset_mock()
return room_id
@override_config(
{"experimental_features": {"msc3202_transaction_extensions": True}}
)
def test_application_services_receive_otk_counts_and_fallback_key_usages_with_pdus(
self,
) -> None:
"""
Tests that:
- the AS receives one-time key counts and unused fallback keys for:
- the specified sender; and
- any user who is in receipt of the PDUs
"""
room_id = self._set_up_devices_and_a_room()
# Send a message into the AS's room
self.helper.send(room_id, "woof woof", tok=self._real_user_token)
# Capture what was sent as an AS transaction.
self.send_mock.assert_called()
last_args = self.send_mock.call_args
otks: Optional[TransactionOneTimeKeyCounts] = last_args.args[
self.ARG_OTK_COUNTS
]
unused_fallbacks: Optional[TransactionUnusedFallbackKeys] = last_args.args[
self.ARG_FALLBACK_KEYS
]
self.assertEqual(
otks,
{
"@as.sender:test": {self._sender_device: {"algo": 42}},
"@_as_user1:test": {self._namespaced_device: {"algo": 36}},
},
)
self.assertEqual(
unused_fallbacks,
{
"@as.sender:test": {self._sender_device: []},
"@_as_user1:test": {self._namespaced_device: ["algo"]},
},
)

View File

@@ -169,7 +169,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Register an AS user.
user = self.register_user("user", "pass")
token = self.login(user, "pass")
as_user = self.register_appservice_user("as_user_potato", self.appservice.token)
as_user, _ = self.register_appservice_user(
"as_user_potato", self.appservice.token
)
# Join the AS user to rooms owned by the normal user.
public, private = self._create_rooms_and_inject_memberships(
@@ -388,7 +390,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def test_handle_local_profile_change_with_appservice_user(self) -> None:
# create user
as_user_id = self.register_appservice_user(
as_user_id, _ = self.register_appservice_user(
"as_user_alice", self.appservice.token
)

View File

@@ -147,12 +147,20 @@ class RestHelper:
expect_code=expect_code,
)
def join(self, room=None, user=None, expect_code=200, tok=None):
def join(
self,
room=None,
user=None,
expect_code=200,
tok=None,
appservice_user_id: Optional[str] = None,
):
self.change_membership(
room=room,
src=user,
targ=user,
tok=tok,
appservice_user_id=appservice_user_id,
membership=Membership.JOIN,
expect_code=expect_code,
)
@@ -204,6 +212,7 @@ class RestHelper:
membership: str,
extra_data: Optional[dict] = None,
tok: Optional[str] = None,
appservice_user_id: Optional[str] = None,
expect_code: int = 200,
expect_errcode: Optional[str] = None,
) -> None:
@@ -217,6 +226,9 @@ class RestHelper:
membership: The type of membership event
extra_data: Extra information to include in the content of the event
tok: The user access token to use
appservice_user_id: The `user_id` URL parameter to pass.
This allows driving an application service user
using an application service access token in `tok`.
expect_code: The expected HTTP response code
expect_errcode: The expected Matrix error code
"""
@@ -224,8 +236,14 @@ class RestHelper:
self.auth_user_id = src
path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room, targ)
next_arg_char = "?"
if tok:
path = path + "?access_token=%s" % tok
path += "?access_token=%s" % tok
next_arg_char = "&"
if appservice_user_id:
path += f"{next_arg_char}user_id={appservice_user_id}"
data = {"membership": membership}
data.update(extra_data or {})

View File

@@ -268,7 +268,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
txn = self.get_success(
defer.ensureDeferred(
self.store.create_appservice_txn(service, events, [], [], DeviceLists())
self.store.create_appservice_txn(service, events, [], [], DeviceLists(), {}, {})
)
)
self.assertEquals(txn.id, 1)
@@ -284,7 +284,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self._insert_txn(service.id, 9644, events))
self.get_success(self._insert_txn(service.id, 9645, events))
txn = self.get_success(
self.store.create_appservice_txn(service, events, [], [], DeviceLists())
self.store.create_appservice_txn(service, events, [], [], DeviceLists(), {}, {})
)
self.assertEquals(txn.id, 9646)
self.assertEquals(txn.events, events)
@@ -297,7 +297,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
self.get_success(self._set_last_txn(service.id, 9643))
txn = self.get_success(
self.store.create_appservice_txn(service, events, [], [], DeviceLists())
self.store.create_appservice_txn(service, events, [], [], DeviceLists(), {}, {})
)
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)
@@ -321,7 +321,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events))
txn = self.get_success(
self.store.create_appservice_txn(service, events, [], [], DeviceLists())
self.store.create_appservice_txn(service, events, [], [], DeviceLists(), {}, {})
)
self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events)

View File

@@ -341,7 +341,9 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
# Register an AS user.
user = self.register_user("user", "pass")
token = self.login(user, "pass")
as_user = self.register_appservice_user("as_user_potato", self.appservice.token)
as_user, _ = self.register_appservice_user(
"as_user_potato", self.appservice.token
)
# Join the AS user to rooms owned by the normal user.
public, private = self._create_rooms_and_inject_memberships(

View File

@@ -620,18 +620,18 @@ class HomeserverTestCase(TestCase):
self,
username: str,
appservice_token: str,
) -> str:
) -> Tuple[str, str]:
"""Register an appservice user as an application service.
Requires the client-facing registration API be registered.
Args:
username: the user to be registered by an application service.
Should be a full username, i.e. ""@localpart:hostname" as opposed to just "localpart"
Should NOT be a full username, i.e. just "localpart" as opposed to "@localpart:hostname"
appservice_token: the acccess token for that application service.
Raises: if the request to '/register' does not return 200 OK.
Returns: the MXID of the new user.
Returns: the MXID of the new user, the device ID of the new user's first device.
"""
channel = self.make_request(
"POST",
@@ -643,7 +643,7 @@ class HomeserverTestCase(TestCase):
access_token=appservice_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
return channel.json_body["user_id"]
return channel.json_body["user_id"], channel.json_body["device_id"]
def login(
self,