Compare commits

...

8 Commits

Author SHA1 Message Date
Will Hunt
f807c7291f Add basic support for device list updates 2020-09-30 12:55:57 +01:00
Will Hunt
6201ea56ee Merge branch 'develop' into hs/super-wip-edus-down-sync 2020-09-28 11:11:57 +01:00
Will Hunt
316ad09a64 Add support for device messages, start support for device lists 2020-09-22 11:31:59 +01:00
Will Hunt
4392526bf0 Last little bits 2020-09-21 16:22:11 +01:00
Will Hunt
3bf1b79d3c Add is_interested_in_presence func 2020-09-21 16:21:22 +01:00
Will Hunt
42090bcc7c Call appservice handler when seeing new events in the notifier 2020-09-21 15:10:37 +01:00
Will Hunt
ae724db899 Changes to handlers to support fetching events for appservices 2020-09-21 15:10:06 +01:00
Will Hunt
78911ca46a Appservice API changes 2020-09-21 15:09:31 +01:00
13 changed files with 416 additions and 5 deletions

View File

@@ -91,6 +91,7 @@ class ApplicationService:
protocols=None,
rate_limited=True,
ip_range_whitelist=None,
supports_ephemeral=False,
):
self.token = token
self.url = (
@@ -102,6 +103,7 @@ class ApplicationService:
self.namespaces = self._check_namespaces(namespaces)
self.id = id
self.ip_range_whitelist = ip_range_whitelist
self.supports_ephemeral = supports_ephemeral
if "|" in self.id:
raise Exception("application service ID cannot contain '|' character")
@@ -188,11 +190,11 @@ class ApplicationService:
if not store:
return False
does_match = await self._matches_user_in_member_list(event.room_id, store)
does_match = await self.matches_user_in_member_list(event.room_id, store)
return does_match
@cached(num_args=1, cache_context=True)
async def _matches_user_in_member_list(self, room_id, store, cache_context):
async def matches_user_in_member_list(self, room_id, store, cache_context):
member_list = await store.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate
)
@@ -239,6 +241,19 @@ class ApplicationService:
return False
@cached(num_args=1, cache_context=True)
async def is_interested_in_presence(self, user_id, store, cache_context):
# Find all the rooms the sender is in
if self.is_interested_in_user(user_id.to_string()):
return True
room_ids = await store.get_rooms_for_user(user_id.to_string())
# Then find out if the appservice is interested in any of those rooms
for room_id in room_ids:
if await self.matches_user_in_member_list(room_id, store, cache_context):
return True
return False
def is_interested_in_user(self, user_id):
return (
self._matches_regex(user_id, ApplicationService.NS_USERS)

View File

@@ -201,6 +201,32 @@ class ApplicationServiceApi(SimpleHttpClient):
key = (service.id, protocol)
return await self.protocol_meta_cache.wrap(key, _get)
async def push_ephemeral(self, service, events, to_device=None, device_lists=None):
if service.url is None:
return True
if service.supports_ephemeral is False:
return True
uri = service.url + (
"%s/uk.half-shot.appservice/ephemeral" % APP_SERVICE_PREFIX
)
try:
await self.put_json(
uri=uri,
json_body={
"events": events,
"device_messages": to_device,
"device_lists": device_lists,
},
args={"access_token": service.hs_token},
)
return True
except CodeMessageException as e:
logger.warning("push_ephemeral to %s received %s", uri, e.code)
except Exception as ex:
logger.warning("push_ephemeral to %s threw exception %s", uri, ex)
return False
async def push_bulk(self, service, events, txn_id=None):
if service.url is None:
return True

View File

@@ -85,6 +85,10 @@ class ApplicationServiceScheduler:
def submit_event_for_as(self, service, event):
self.queuer.enqueue(service, event)
async def submit_ephemeral_events_for_as(self, service, events):
if self.txn_ctrl.is_service_up(service):
await self.as_api.push_ephemeral(service, events)
class _ServiceQueuer:
"""Queue of events waiting to be sent to appservices.
@@ -161,7 +165,7 @@ class _TransactionController:
async def send(self, service, events):
try:
txn = await self.store.create_appservice_txn(service=service, events=events)
service_is_up = await self._is_service_up(service)
service_is_up = await self.is_service_up(service)
if service_is_up:
sent = await txn.send(self.as_api)
if sent:
@@ -204,7 +208,7 @@ class _TransactionController:
recoverer.recover()
logger.info("Now %i active recoverers", len(self.recoverers))
async def _is_service_up(self, service):
async def is_service_up(self, service):
state = await self.store.get_appservice_state(service)
return state == ApplicationServiceState.UP or state is None

View File

@@ -160,6 +160,8 @@ def _load_appservice(hostname, as_info, config_filename):
if as_info.get("ip_range_whitelist"):
ip_range_whitelist = IPSet(as_info.get("ip_range_whitelist"))
supports_ephemeral = as_info.get("uk.half-shot.appservice.push_ephemeral", False)
return ApplicationService(
token=as_info["as_token"],
hostname=hostname,
@@ -168,6 +170,7 @@ def _load_appservice(hostname, as_info, config_filename):
hs_token=as_info["hs_token"],
sender=user_id,
id=as_info["id"],
supports_ephemeral=supports_ephemeral,
protocols=protocols,
rate_limited=rate_limited,
ip_range_whitelist=ip_range_whitelist,

View File

@@ -14,6 +14,7 @@
# limitations under the License.
import logging
from typing import Collection, List, Union
from prometheus_client import Counter
@@ -21,12 +22,15 @@ from twisted.internet import defer
import synapse
from synapse.api.constants import EventTypes
from synapse.appservice import ApplicationService
from synapse.handlers.presence import format_user_presence_state
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import (
event_processing_loop_counter,
event_processing_loop_room_count,
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import RoomStreamToken, UserID
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -43,6 +47,7 @@ class ApplicationServicesHandler:
self.started_scheduler = False
self.clock = hs.get_clock()
self.notify_appservices = hs.config.notify_appservices
self.event_sources = hs.get_event_sources()
self.current_max = 0
self.is_processing = False
@@ -158,6 +163,131 @@ class ApplicationServicesHandler:
finally:
self.is_processing = False
async def notify_interested_services_ephemeral(
self,
stream_key: str,
new_token: Union[int, RoomStreamToken],
users: Collection[UserID] = [],
):
services = [
service
for service in self.store.get_app_services()
if service.supports_ephemeral
]
if not services or not self.notify_appservices:
return
logger.info("Checking interested services for %s" % (stream_key))
with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services:
events = []
if stream_key == "typing_key":
events = await self._handle_typing(service, new_token)
elif stream_key == "receipt_key":
events = await self._handle_receipts(service)
elif stream_key == "presence_key":
events = await self._handle_as_presence(service, users)
elif stream_key == "device_list_key":
# Check if the device lists have changed for any of the users we are interested in
events = await self._handle_device_list(service, users, new_token)
elif stream_key == "to_device_key":
# Check the inbox for any users the bridge owns
events = await self._handle_to_device(service, users, new_token)
if events:
# TODO: Do in background?
await self.scheduler.submit_ephemeral_events_for_as(
service, events, new_token
)
# We don't persist the token for typing_key
if stream_key == "presence_key":
await self.store.set_type_stream_id_for_appservice(
service, "presence", new_token
)
elif stream_key == "receipt_key":
await self.store.set_type_stream_id_for_appservice(
service, "read_receipt", new_token
)
elif stream_key == "to_device_key":
await self.store.set_type_stream_id_for_appservice(
service, "to_device", new_token
)
async def _handle_typing(self, service, new_token):
typing_source = self.event_sources.sources["typing"]
# Get the typing events from just before current
typing, _key = await typing_source.get_new_events_as(
service=service,
# For performance reasons, we don't persist the previous
# token in the DB and instead fetch the latest typing information
# for appservices.
from_key=new_token - 1,
)
return typing
async def _handle_receipts(self, service, token: int):
from_key = await self.store.get_type_stream_id_for_appservice(
service, "read_receipt"
)
receipts_source = self.event_sources.sources["receipt"]
receipts, _ = await receipts_source.get_new_events_as(
service=service, from_key=from_key
)
return receipts
async def _handle_device_list(
self, service: ApplicationService, users: List[str], new_token: int
):
# TODO: Determine if any user have left and report those
from_token = await self.store.get_type_stream_id_for_appservice(
service, "device_list"
)
changed_user_ids = await self.store.get_device_changes_for_as(
service, from_token, new_token
)
# Return the
return {
"type": "m.device_list_update",
"content": {"changed": changed_user_ids,},
}
async def _handle_to_device(self, service, users, token):
if not any([True for u in users if service.is_interested_in_user(u)]):
return False
since_token = await self.store.get_type_stream_id_for_appservice(
service, "to_device"
)
messages, _ = await self.store.get_new_messages_for_as(
service, since_token, token
)
# This returns user_id -> device_id -> message
return messages
async def _handle_as_presence(self, service, users):
events = []
presence_source = self.event_sources.sources["presence"]
from_key = await self.store.get_type_stream_id_for_appservice(
service, "presence"
)
for user in users:
interested = await service.is_interested_in_presence(user, self.store)
if not interested:
continue
presence_events, _key = await presence_source.get_new_events(
user=user, service=service, from_key=from_key,
)
time_now = self.clock.time_msec()
presence_events = [
{
"type": "m.presence",
"sender": event.user_id,
"content": format_user_presence_state(
event, time_now, include_user_id=False
),
}
for event in presence_events
]
events = events + presence_events
async def query_user_exists(self, user_id):
"""Check if any application service knows this user_id exists.

View File

@@ -140,5 +140,27 @@ class ReceiptEventSource:
return (events, to_key)
async def get_new_events_as(self, from_key, service, **kwargs):
from_key = int(from_key)
to_key = self.get_current_key()
if from_key == to_key:
return [], to_key
# We first need to fetch all new receipts
rooms_to_events = await self.store.get_linearized_receipts_for_all_rooms(
from_key=from_key, to_key=to_key
)
# Then filter down to rooms that the AS can read
events = []
for room_id, event in rooms_to_events.items():
if not await service.matches_user_in_member_list(room_id, self.store):
continue
events.append(event)
return (events, to_key)
def get_current_key(self, direction="f"):
return self.store.get_max_receipt_stream_id()

View File

@@ -19,6 +19,7 @@ from collections import namedtuple
from typing import TYPE_CHECKING, List, Set, Tuple
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
from synapse.appservice import ApplicationService
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.streams import TypingStream
from synapse.types import UserID, get_domain_from_id
@@ -430,6 +431,27 @@ class TypingNotificationEventSource:
"content": {"user_ids": list(typing)},
}
async def get_new_events_as(
self, from_key: int, service: ApplicationService, **kwargs
):
with Measure(self.clock, "typing.get_new_events_as"):
from_key = int(from_key)
handler = self.get_typing_handler()
events = []
for room_id in handler._room_serials.keys():
if handler._room_serials[room_id] <= from_key:
print("Key too old")
continue
if not await service.matches_user_in_member_list(
room_id, handler.store
):
continue
events.append(self._make_event_for(room_id))
return (events, handler._latest_room_serial)
async def get_new_events(self, from_key, room_ids, **kwargs):
with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key)

View File

@@ -329,6 +329,19 @@ class Notifier:
except Exception:
logger.exception("Error notifying application services of event")
async def _notify_app_services_ephemeral(
self,
stream_key: str,
new_token: Union[int, RoomStreamToken],
users: Collection[UserID] = [],
):
try:
await self.appservice_handler.notify_interested_services_ephemeral(
stream_key, new_token, users
)
except Exception:
logger.exception("Error notifying application services of event")
async def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken):
try:
await self._pusher_pool.on_new_notifications(max_room_stream_token.stream)
@@ -367,6 +380,15 @@ class Notifier:
self.notify_replication()
# Notify appservices
run_as_background_process(
"_notify_app_services_ephemeral",
self._notify_app_services_ephemeral,
stream_key,
new_token,
users,
)
def on_new_replication_data(self) -> None:
"""Used to inform replication listeners that something has happend
without waking up any of the normal user event streams"""

View File

@@ -320,7 +320,7 @@ class ApplicationServiceTransactionWorkerStore(
)
async def get_new_events_for_appservice(self, current_id, limit):
"""Get all new evnets"""
"""Get all new events for an appservice"""
def get_new_events_for_appservice_txn(txn):
sql = (
@@ -351,6 +351,37 @@ class ApplicationServiceTransactionWorkerStore(
return upper_bound, events
async def get_type_stream_id_for_appservice(self, service, type: str) -> int:
def get_type_stream_id_for_appservice_txn(txn):
stream_id_type = "%s_stream_id" % type
txn.execute(
"SELECT ? FROM application_services_state WHERE as_id=?",
(stream_id_type, service.id,),
)
last_txn_id = txn.fetchone()
if last_txn_id is None or last_txn_id[0] is None: # no row exists
return 0
else:
return int(last_txn_id[0])
return await self.db_pool.runInteraction(
"get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn
)
async def set_type_stream_id_for_appservice(
self, service, type: str, pos: int
) -> None:
def set_type_stream_id_for_appservice_txn(txn):
stream_id_type = "%s_stream_id" % type
txn.execute(
"UPDATE ? SET device_list_stream_id = ? WHERE as_id=?",
(stream_id_type, pos, service.id),
)
await self.db_pool.runInteraction(
"set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn
)
class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore):
# This is currently empty due to there not being any AS storage functions

View File

@@ -16,6 +16,7 @@
import logging
from typing import List, Tuple
from synapse.appservice import ApplicationService
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
@@ -29,6 +30,40 @@ class DeviceInboxWorkerStore(SQLBaseStore):
def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token()
async def get_new_messages_for_as(
self,
service: ApplicationService,
last_stream_id: int,
current_stream_id: int,
limit: int = 100,
) -> Tuple[List[dict], int]:
def get_new_messages_for_device_txn(txn):
sql = (
"SELECT stream_id, message_json, device_id, user_id FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(sql, (last_stream_id, current_stream_id, limit))
messages = []
for row in txn:
stream_pos = row[0]
if service.is_interested_in_user(row.user_id):
msg = db_to_json(row[1])
msg.recipient = {
"device_id": row.device_id,
"user_id": row.user_id,
}
messages.append(msg)
if len(messages) < limit:
stream_pos = current_stream_id
return messages, stream_pos
return await self.db_pool.runInteraction(
"get_new_messages_for_device", get_new_messages_for_device_txn
)
async def get_new_messages_for_device(
self,
user_id: str,

View File

@@ -19,6 +19,7 @@ import logging
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import Codes, StoreError
from synapse.appservice import ApplicationService
from synapse.logging.opentracing import (
get_active_span_text_map,
set_tag,
@@ -525,6 +526,31 @@ class DeviceWorkerStore(SQLBaseStore):
"get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
)
async def get_device_changes_for_as(
self,
service: ApplicationService,
last_stream_id: int,
current_stream_id: int,
limit: int = 100,
) -> Tuple[List[dict], int]:
def get_device_changes_for_as_txn(txn):
sql = (
"SELECT DISTINCT user_ids FROM device_lists_stream"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(sql, (last_stream_id, current_stream_id, limit))
rows = txn.fetchall()
users = []
for user in db_to_json(rows[0]):
if await service.is_interested_in_presence(user):
users.append(user)
return await self.db_pool.runInteraction(
"get_device_changes_for_as", get_device_changes_for_as_txn
)
async def get_users_whose_signatures_changed(
self, user_id: str, from_key: int
) -> Set[str]:

View File

@@ -123,6 +123,15 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
for row in rows
}
async def get_linearized_receipts_for_all_rooms(
self, to_key: int, from_key: Optional[int] = None
) -> List[dict]:
results = await self._get_linearized_receipts_for_all_rooms(
to_key, from_key=from_key
)
return results
async def get_linearized_receipts_for_rooms(
self, room_ids: List[str], to_key: int, from_key: Optional[int] = None
) -> List[dict]:
@@ -274,6 +283,47 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
}
return results
@cached(num_args=2,)
async def _get_linearized_receipts_for_all_rooms(self, to_key, from_key=None):
def f(txn):
if from_key:
sql = """
SELECT * FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ?
"""
txn.execute(sql, [from_key, to_key])
else:
sql = """
SELECT * FROM receipts_linearized WHERE
stream_id <= ?
"""
txn.execute(sql, [to_key])
return self.db_pool.cursor_to_dict(txn)
txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_all_rooms", f
)
results = {}
for row in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(
row["room_id"],
{"type": "m.receipt", "room_id": row["room_id"], "content": {}},
)
# The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
event_entry = room_event["content"].setdefault(row["event_id"], {})
receipt_type = event_entry.setdefault(row["receipt_type"], {})
receipt_type[row["user_id"]] = db_to_json(row["data"])
return results
async def get_users_sent_receipts_between(
self, last_id: int, current_id: int
) -> List[str]:

View File

@@ -0,0 +1,25 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* for some reason, we have accumulated duplicate entries in
* device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less
* efficient.
*/
ALTER TABLE application_services_state
ADD COLUMN device_list_stream_id INT;
ALTER TABLE application_services_state
ADD COLUMN device_message_stream_id INT;