1
0

Compare commits

..

9 Commits

Author SHA1 Message Date
Mark Haines
3a676b8ee3 More merging 2016-04-21 16:28:05 +01:00
Mark Haines
0d5622b088 Merge branch 'markjh/slave_event_push_actions' into markjh/split_pusher 2016-04-21 16:24:38 +01:00
Mark Haines
712030aeef Merge branch 'develop' into markjh/split_pusher 2016-04-21 16:21:49 +01:00
Mark Haines
0b282d33af Add an HTTP API for removing rejected pushers.
When a push is rejected by the push gateway then synapse needs to
remove the pusher from the database. However we probably don't want
to do that directly from the slave, so we add an HTTP API to synapse
to remove the pusher from the database.
2016-04-19 14:43:47 +01:00
Mark Haines
03c8df54f0 Invalidate the receipt cache correctly 2016-04-14 17:25:27 +01:00
Mark Haines
c214d3e36e Merge branch 'develop' into markjh/split_pusher 2016-04-14 17:00:40 +01:00
Mark Haines
1c1b2de975 Poke the slaved pushers on new receipts 2016-04-14 16:59:56 +01:00
Mark Haines
f41b1a8723 Make push sort of work 2016-04-14 13:30:57 +01:00
Mark Haines
1209d3174e Optionally split out the pusher into a separate process 2016-04-14 11:20:48 +01:00
137 changed files with 1973 additions and 6503 deletions

View File

@@ -11,7 +11,6 @@ recursive-include synapse/storage/schema *.sql
recursive-include synapse/storage/schema *.py
recursive-include docs *
recursive-include res *
recursive-include scripts *
recursive-include scripts-dev *
recursive-include tests *.py

View File

@@ -32,4 +32,5 @@ The format of the AS configuration file is as follows:
See the spec_ for further details on how application services work.
.. _spec: https://matrix.org/docs/spec/application_service/unstable.html
.. _spec: https://github.com/matrix-org/matrix-doc/blob/master/specification/25_application_service_api.rst#application-service-api

View File

@@ -1,10 +0,0 @@
What do I do about "Unexpected logging context" debug log-lines everywhere?
<Mjark> The logging context lives in thread local storage
<Mjark> Sometimes it gets out of sync with what it should actually be, usually because something scheduled something to run on the reactor without preserving the logging context.
<Matthew> what is the impact of it getting out of sync? and how and when should we preserve log context?
<Mjark> The impact is that some of the CPU and database metrics will be under-reported, and some log lines will be mis-attributed.
<Mjark> It should happen auto-magically in all the APIs that do IO or otherwise defer to the reactor.
<Erik> Mjark: the other place is if we branch, e.g. using defer.gatherResults
Unanswered: how and when should we preserve log context?

View File

@@ -1,7 +0,0 @@
.header {
border-bottom: 4px solid #e4f7ed ! important;
}
.notif_link a, .footer a {
color: #76CFA6 ! important;
}

View File

@@ -1,156 +0,0 @@
body {
margin: 0px;
}
pre, code {
word-break: break-word;
white-space: pre-wrap;
}
#page {
font-family: 'Open Sans', Helvetica, Arial, Sans-Serif;
font-color: #454545;
font-size: 12pt;
width: 100%;
padding: 20px;
}
#inner {
width: 640px;
}
.header {
width: 100%;
height: 87px;
color: #454545;
border-bottom: 4px solid #e5e5e5;
}
.logo {
text-align: right;
margin-left: 20px;
}
.salutation {
padding-top: 10px;
font-weight: bold;
}
.summarytext {
}
.room {
width: 100%;
color: #454545;
border-bottom: 1px solid #e5e5e5;
}
.room_header td {
padding-top: 38px;
padding-bottom: 10px;
border-bottom: 1px solid #e5e5e5;
}
.room_name {
vertical-align: middle;
font-size: 18px;
font-weight: bold;
}
.room_header h2 {
margin-top: 0px;
margin-left: 75px;
font-size: 20px;
}
.room_avatar {
width: 56px;
line-height: 0px;
text-align: center;
vertical-align: middle;
}
.room_avatar img {
width: 48px;
height: 48px;
object-fit: cover;
border-radius: 24px;
}
.notif {
border-bottom: 1px solid #e5e5e5;
margin-top: 16px;
padding-bottom: 16px;
}
.historical_message .sender_avatar {
opacity: 0.3;
}
/* spell out opacity and historical_message class names for Outlook aka Word */
.historical_message .sender_name {
color: #e3e3e3;
}
.historical_message .message_time {
color: #e3e3e3;
}
.historical_message .message_body {
color: #c7c7c7;
}
.historical_message td,
.message td {
padding-top: 10px;
}
.sender_avatar {
width: 56px;
text-align: center;
vertical-align: top;
}
.sender_avatar img {
margin-top: -2px;
width: 32px;
height: 32px;
border-radius: 16px;
}
.sender_name {
display: inline;
font-size: 13px;
color: #a2a2a2;
}
.message_time {
text-align: right;
width: 100px;
font-size: 11px;
color: #a2a2a2;
}
.message_body {
}
.notif_link td {
padding-top: 10px;
padding-bottom: 10px;
font-weight: bold;
}
.notif_link a, .footer a {
color: #454545;
text-decoration: none;
}
.debug {
font-size: 10px;
color: #888;
}
.footer {
margin-top: 20px;
text-align: center;
}

View File

@@ -1,45 +0,0 @@
{% for message in notif.messages %}
<tr class="{{ "historical_message" if message.is_historical else "message" }}">
<td class="sender_avatar">
{% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
{% if message.sender_avatar_url %}
<img alt="" class="sender_avatar" src="{{ message.sender_avatar_url|mxc_to_http(32,32) }}" />
{% else %}
{% if message.sender_hash % 3 == 0 %}
<img class="sender_avatar" src="https://vector.im/beta/img/76cfa6.png" />
{% elif message.sender_hash % 3 == 1 %}
<img class="sender_avatar" src="https://vector.im/beta/img/50e2c2.png" />
{% else %}
<img class="sender_avatar" src="https://vector.im/beta/img/f4c371.png" />
{% endif %}
{% endif %}
{% endif %}
</td>
<td class="message_contents">
{% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
<div class="sender_name">{% if message.msgtype == "m.emote" %}*{% endif %} {{ message.sender_name }}</div>
{% endif %}
<div class="message_body">
{% if message.msgtype == "m.text" %}
{{ message.body_text_html }}
{% elif message.msgtype == "m.emote" %}
{{ message.body_text_html }}
{% elif message.msgtype == "m.notice" %}
{{ message.body_text_html }}
{% elif message.msgtype == "m.image" %}
<img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
{% elif message.msgtype == "m.file" %}
<span class="filename">{{ message.body_text_plain }}</span>
{% endif %}
</div>
</td>
<td class="message_time">{{ message.ts|format_ts("%H:%M") }}</td>
</tr>
{% endfor %}
<tr class="notif_link">
<td></td>
<td>
<a href="{{ notif.link }}">View {{ room.title }}</a>
</td>
<td></td>
</tr>

View File

@@ -1,16 +0,0 @@
{% for message in notif.messages %}
{% if message.msgtype == "m.emote" %}* {% endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }})
{% if message.msgtype == "m.text" %}
{{ message.body_text_plain }}
{% elif message.msgtype == "m.emote" %}
{{ message.body_text_plain }}
{% elif message.msgtype == "m.notice" %}
{{ message.body_text_plain }}
{% elif message.msgtype == "m.image" %}
{{ message.body_text_plain }}
{% elif message.msgtype == "m.file" %}
{{ message.body_text_plain }}
{% endif %}
{% endfor %}
View {{ room.title }} at {{ notif.link }}

View File

@@ -1,53 +0,0 @@
<!doctype html>
<html lang="en">
<head>
<style type="text/css">
{% include 'mail.css' without context %}
{% include "mail-%s.css" % app_name ignore missing without context %}
</style>
</head>
<body>
<table id="page">
<tr>
<td> </td>
<td id="inner">
<table class="header">
<tr>
<td>
<div class="salutation">Hi {{ user_display_name }},</div>
<div class="summarytext">{{ summary_text }}</div>
</td>
<td class="logo">
{% if app_name == "Vector" %}
<img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
{% else %}
<img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>
{% endif %}
</td>
</tr>
</table>
{% for room in rooms %}
{% include 'room.html' with context %}
{% endfor %}
<div class="footer">
<a href="{{ unsubscribe_link }}">Unsubscribe</a>
<br/>
<br/>
<div class="debug">
Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
an event was received at {{ reason.received_at|format_ts("%c") }}
which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} (delay_before_mail_ms) mins ago,
{% if reason.last_sent_ts %}
and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
{% else %}
and we don't have a last time we sent a mail for this room.
{% endif %}
</div>
</div>
</td>
<td> </td>
</tr>
</table>
</body>
</html>

View File

@@ -1,10 +0,0 @@
Hi {{ user_display_name }},
{{ summary_text }}
{% for room in rooms %}
{% include 'room.txt' with context %}
{% endfor %}
You can disable these notifications at {{ unsubscribe_link }}

View File

@@ -1,33 +0,0 @@
<table class="room">
<tr class="room_header">
<td class="room_avatar">
{% if room.avatar_url %}
<img alt="" src="{{ room.avatar_url|mxc_to_http(48,48) }}" />
{% else %}
{% if room.hash % 3 == 0 %}
<img alt="" src="https://vector.im/beta/img/76cfa6.png" />
{% elif room.hash % 3 == 1 %}
<img alt="" src="https://vector.im/beta/img/50e2c2.png" />
{% else %}
<img alt="" src="https://vector.im/beta/img/f4c371.png" />
{% endif %}
{% endif %}
</td>
<td class="room_name" colspan="2">
{{ room.title }}
</td>
</tr>
{% if room.invite %}
<tr>
<td></td>
<td>
<a href="{{ room.link }}">Join the conversation.</a>
</td>
<td></td>
</tr>
{% else %}
{% for notif in room.notifs %}
{% include 'notif.html' with context %}
{% endfor %}
{% endif %}
</table>

View File

@@ -1,9 +0,0 @@
{{ room.title }}
{% if room.invite %}
You've been invited, join at {{ room.link }}
{% else %}
{% for notif in room.notifs %}
{% include 'notif.txt' with context %}
{% endfor %}
{% endif %}

View File

@@ -214,10 +214,6 @@ class Porter(object):
self.progress.add_table(table, postgres_size, table_size)
if table == "event_search":
yield self.handle_search_table(postgres_size, table_size, next_chunk)
return
select = (
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
% (table,)
@@ -236,19 +232,51 @@ class Porter(object):
if rows:
next_chunk = rows[-1][0] + 1
self._convert_rows(table, headers, rows)
if table == "event_search":
# We have to treat event_search differently since it has a
# different structure in the two different databases.
def insert(txn):
sql = (
"INSERT INTO event_search (event_id, room_id, key, sender, vector)"
" VALUES (?,?,?,?,to_tsvector('english', ?))"
)
def insert(txn):
self.postgres_store.insert_many_txn(
txn, table, headers[1:], rows
)
rows_dict = [
dict(zip(headers, row))
for row in rows
]
self.postgres_store._simple_update_one_txn(
txn,
table="port_from_sqlite3",
keyvalues={"table_name": table},
updatevalues={"rowid": next_chunk},
)
txn.executemany(sql, [
(
row["event_id"],
row["room_id"],
row["key"],
row["sender"],
row["value"],
)
for row in rows_dict
])
self.postgres_store._simple_update_one_txn(
txn,
table="port_from_sqlite3",
keyvalues={"table_name": table},
updatevalues={"rowid": next_chunk},
)
else:
self._convert_rows(table, headers, rows)
def insert(txn):
self.postgres_store.insert_many_txn(
txn, table, headers[1:], rows
)
self.postgres_store._simple_update_one_txn(
txn,
table="port_from_sqlite3",
keyvalues={"table_name": table},
updatevalues={"rowid": next_chunk},
)
yield self.postgres_store.execute(insert)
@@ -258,73 +286,6 @@ class Porter(object):
else:
return
@defer.inlineCallbacks
def handle_search_table(self, postgres_size, table_size, next_chunk):
select = (
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
" FROM event_search as es"
" INNER JOIN events AS e USING (event_id, room_id)"
" WHERE es.rowid >= ?"
" ORDER BY es.rowid LIMIT ?"
)
while True:
def r(txn):
txn.execute(select, (next_chunk, self.batch_size,))
rows = txn.fetchall()
headers = [column[0] for column in txn.description]
return headers, rows
headers, rows = yield self.sqlite_store.runInteraction("select", r)
if rows:
next_chunk = rows[-1][0] + 1
# We have to treat event_search differently since it has a
# different structure in the two different databases.
def insert(txn):
sql = (
"INSERT INTO event_search (event_id, room_id, key,"
" sender, vector, origin_server_ts, stream_ordering)"
" VALUES (?,?,?,?,to_tsvector('english', ?),?,?)"
)
rows_dict = [
dict(zip(headers, row))
for row in rows
]
txn.executemany(sql, [
(
row["event_id"],
row["room_id"],
row["key"],
row["sender"],
row["value"],
row["origin_server_ts"],
row["stream_ordering"],
)
for row in rows_dict
])
self.postgres_store._simple_update_one_txn(
txn,
table="port_from_sqlite3",
keyvalues={"table_name": "event_search"},
updatevalues={"rowid": next_chunk},
)
yield self.postgres_store.execute(insert)
postgres_size += len(rows)
self.progress.update("event_search", postgres_size)
else:
return
def setup_db(self, db_config, database_engine):
db_conn = database_engine.module.connect(
**{

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains classes for authenticating the user."""
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException
@@ -21,7 +22,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import Requester, UserID, get_domain_from_id
from synapse.types import Requester, RoomID, UserID, EventID
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.metrics import Measure
@@ -41,20 +42,13 @@ AuthEventTypes = (
class Auth(object):
"""
FIXME: This class contains a mix of functions for authenticating users
of our client-server API and authenticating events added to room graphs.
"""
def __init__(self, hs):
self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
# Docs for these currently lives at
# https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
# In addition, we have type == delete_pusher which grants access only to
# delete pushers.
self._KNOWN_CAVEAT_PREFIXES = set([
"gen = ",
"guest = ",
@@ -97,8 +91,8 @@ class Auth(object):
"Room %r does not exist" % (event.room_id,)
)
creating_domain = get_domain_from_id(event.room_id)
originating_domain = get_domain_from_id(event.sender)
creating_domain = RoomID.from_string(event.room_id).domain
originating_domain = UserID.from_string(event.sender).domain
if creating_domain != originating_domain:
if not self.can_federate(event, auth_events):
raise AuthError(
@@ -126,24 +120,6 @@ class Auth(object):
return allowed
self.check_event_sender_in_room(event, auth_events)
# Special case to allow m.room.third_party_invite events wherever
# a user is allowed to issue invites. Fixes
# https://github.com/vector-im/vector-web/issues/1208 hopefully
if event.type == EventTypes.ThirdPartyInvite:
user_level = self._get_user_power_level(event.user_id, auth_events)
invite_level = self._get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, (
"You cannot issue a third party invite for %s." %
(event.content.display_name,)
)
)
else:
return True
self._can_send_event(event, auth_events)
if event.type == EventTypes.PowerLevels:
@@ -243,7 +219,7 @@ class Auth(object):
for event in curr_state.values():
if event.type == EventTypes.Member:
try:
if get_domain_from_id(event.state_key) != host:
if UserID.from_string(event.state_key).domain != host:
continue
except:
logger.warn("state_key not user_id: %s", event.state_key)
@@ -290,8 +266,8 @@ class Auth(object):
target_user_id = event.state_key
creating_domain = get_domain_from_id(event.room_id)
target_domain = get_domain_from_id(target_user_id)
creating_domain = RoomID.from_string(event.room_id).domain
target_domain = UserID.from_string(target_user_id).domain
if creating_domain != target_domain:
if not self.can_federate(event, auth_events):
raise AuthError(
@@ -531,7 +507,7 @@ class Auth(object):
return default
@defer.inlineCallbacks
def get_user_by_req(self, request, allow_guest=False, rights="access"):
def get_user_by_req(self, request, allow_guest=False):
""" Get a registered user's ID.
Args:
@@ -553,7 +529,7 @@ class Auth(object):
)
access_token = request.args["access_token"][0]
user_info = yield self.get_user_by_access_token(access_token, rights)
user_info = yield self.get_user_by_access_token(access_token)
user = user_info["user"]
token_id = user_info["token_id"]
is_guest = user_info["is_guest"]
@@ -614,7 +590,7 @@ class Auth(object):
defer.returnValue(user_id)
@defer.inlineCallbacks
def get_user_by_access_token(self, token, rights="access"):
def get_user_by_access_token(self, token):
""" Get a registered user's ID.
Args:
@@ -625,7 +601,7 @@ class Auth(object):
AuthError if no user by that token exists or the token is invalid.
"""
try:
ret = yield self.get_user_from_macaroon(token, rights)
ret = yield self.get_user_from_macaroon(token)
except AuthError:
# TODO(daniel): Remove this fallback when all existing access tokens
# have been re-issued as macaroons.
@@ -633,11 +609,10 @@ class Auth(object):
defer.returnValue(ret)
@defer.inlineCallbacks
def get_user_from_macaroon(self, macaroon_str, rights="access"):
def get_user_from_macaroon(self, macaroon_str):
try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
self.validate_macaroon(macaroon, rights, self.hs.config.expire_access_token)
self.validate_macaroon(macaroon, "access", False)
user_prefix = "user_id = "
user = None
@@ -660,13 +635,6 @@ class Auth(object):
"is_guest": True,
"token_id": None,
}
elif rights == "delete_pusher":
# We don't store these tokens in the database
ret = {
"user": user,
"is_guest": False,
"token_id": None,
}
else:
# This codepath exists so that we can actually return a
# token ID, because we use token IDs in place of device
@@ -698,8 +666,7 @@ class Auth(object):
Args:
macaroon(pymacaroons.Macaroon): The macaroon to validate
type_string(str): The kind of token required (e.g. "access", "refresh",
"delete_pusher")
type_string(str): The kind of token this is (e.g. "access", "refresh")
verify_expiry(bool): Whether to verify whether the macaroon has expired.
This should really always be True, but no clients currently implement
token refresh, so we can't enforce expiry yet.
@@ -922,8 +889,8 @@ class Auth(object):
if user_level >= redact_level:
return False
redacter_domain = get_domain_from_id(event.event_id)
redactee_domain = get_domain_from_id(event.redacts)
redacter_domain = EventID.from_string(event.event_id).domain
redactee_domain = EventID.from_string(event.redacts).domain
if redacter_domain == redactee_domain:
return True

View File

@@ -15,8 +15,6 @@
from synapse.api.errors import SynapseError
from synapse.types import UserID, RoomID
from twisted.internet import defer
import ujson as json
@@ -26,10 +24,10 @@ class Filtering(object):
super(Filtering, self).__init__()
self.store = hs.get_datastore()
@defer.inlineCallbacks
def get_user_filter(self, user_localpart, filter_id):
result = yield self.store.get_user_filter(user_localpart, filter_id)
defer.returnValue(FilterCollection(result))
result = self.store.get_user_filter(user_localpart, filter_id)
result.addCallback(FilterCollection)
return result
def add_user_filter(self, user_localpart, user_filter):
self.check_valid_filter(user_filter)

View File

@@ -16,9 +16,12 @@
import synapse
import contextlib
import logging
import os
import re
import sys
import time
from synapse.config._base import ConfigError
from synapse.python_dependencies import (
@@ -32,11 +35,18 @@ from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_d
from synapse.server import HomeServer
from twisted.conch.manhole import ColoredManhole
from twisted.conch.insults import insults
from twisted.conch import manhole_ssh
from twisted.cred import checkers, portal
from twisted.internet import reactor, task, defer
from twisted.application import service
from twisted.web.resource import Resource, EncodingResourceWrapper
from twisted.web.static import File
from twisted.web.server import GzipEncoderFactory
from twisted.web.server import Site, GzipEncoderFactory, Request
from synapse.http.server import RootRedirect
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
@@ -56,10 +66,6 @@ from synapse.federation.transport.server import TransportLayerServer
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.http.site import SynapseSite
from synapse import events
@@ -68,6 +74,9 @@ from daemonize import Daemonize
logger = logging.getLogger("synapse.app.homeserver")
ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
def gz_wrap(r):
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
@@ -165,12 +174,7 @@ class SynapseHomeServer(HomeServer):
if name == "replication":
resources[REPLICATION_PREFIX] = ReplicationResource(self)
if WEB_CLIENT_PREFIX in resources:
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
else:
root_resource = Resource()
root_resource = create_resource_tree(resources, root_resource)
root_resource = create_resource_tree(resources)
if tls:
reactor.listenSSL(
port,
@@ -203,13 +207,24 @@ class SynapseHomeServer(HomeServer):
if listener["type"] == "http":
self._listener_http(config, listener)
elif listener["type"] == "manhole":
checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(
matrix="rabbithole"
)
rlm = manhole_ssh.TerminalRealm()
rlm.chainedProtocolFactory = lambda: insults.ServerProtocol(
ColoredManhole,
{
"__name__": "__console__",
"hs": self,
}
)
f = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
f,
interface=listener.get("bind_address", '127.0.0.1')
)
else:
@@ -356,6 +371,210 @@ class SynapseService(service.Service):
return self._port.stopListening()
class SynapseRequest(Request):
def __init__(self, site, *args, **kw):
Request.__init__(self, *args, **kw)
self.site = site
self.authenticated_entity = None
self.start_time = 0
def __repr__(self):
# We overwrite this so that we don't log ``access_token``
return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
self.__class__.__name__,
id(self),
self.method,
self.get_redacted_uri(),
self.clientproto,
self.site.site_tag,
)
def get_redacted_uri(self):
return ACCESS_TOKEN_RE.sub(
r'\1<redacted>\3',
self.uri
)
def get_user_agent(self):
return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1]
def started_processing(self):
self.site.access_logger.info(
"%s - %s - Received request: %s %s",
self.getClientIP(),
self.site.site_tag,
self.method,
self.get_redacted_uri()
)
self.start_time = int(time.time() * 1000)
def finished_processing(self):
try:
context = LoggingContext.current_context()
ru_utime, ru_stime = context.get_resource_usage()
db_txn_count = context.db_txn_count
db_txn_duration = context.db_txn_duration
except:
ru_utime, ru_stime = (0, 0)
db_txn_count, db_txn_duration = (0, 0)
self.site.access_logger.info(
"%s - %s - {%s}"
" Processed request: %dms (%dms, %dms) (%dms/%d)"
" %sB %s \"%s %s %s\" \"%s\"",
self.getClientIP(),
self.site.site_tag,
self.authenticated_entity,
int(time.time() * 1000) - self.start_time,
int(ru_utime * 1000),
int(ru_stime * 1000),
int(db_txn_duration * 1000),
int(db_txn_count),
self.sentLength,
self.code,
self.method,
self.get_redacted_uri(),
self.clientproto,
self.get_user_agent(),
)
@contextlib.contextmanager
def processing(self):
self.started_processing()
yield
self.finished_processing()
class XForwardedForRequest(SynapseRequest):
def __init__(self, *args, **kw):
SynapseRequest.__init__(self, *args, **kw)
"""
Add a layer on top of another request that only uses the value of an
X-Forwarded-For header as the result of C{getClientIP}.
"""
def getClientIP(self):
"""
@return: The client address (the first address) in the value of the
I{X-Forwarded-For header}. If the header is not present, return
C{b"-"}.
"""
return self.requestHeaders.getRawHeaders(
b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip()
class SynapseRequestFactory(object):
def __init__(self, site, x_forwarded_for):
self.site = site
self.x_forwarded_for = x_forwarded_for
def __call__(self, *args, **kwargs):
if self.x_forwarded_for:
return XForwardedForRequest(self.site, *args, **kwargs)
else:
return SynapseRequest(self.site, *args, **kwargs)
class SynapseSite(Site):
"""
Subclass of a twisted http Site that does access logging with python's
standard logging
"""
def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs):
Site.__init__(self, resource, *args, **kwargs)
self.site_tag = site_tag
proxied = config.get("x_forwarded", False)
self.requestFactory = SynapseRequestFactory(self, proxied)
self.access_logger = logging.getLogger(logger_name)
def log(self, request):
pass
def create_resource_tree(desired_tree, redirect_root_to_web_client=True):
"""Create the resource tree for this Home Server.
This in unduly complicated because Twisted does not support putting
child resources more than 1 level deep at a time.
Args:
web_client (bool): True to enable the web client.
redirect_root_to_web_client (bool): True to redirect '/' to the
location of the web client. This does nothing if web_client is not
True.
"""
if redirect_root_to_web_client and WEB_CLIENT_PREFIX in desired_tree:
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
else:
root_resource = Resource()
# ideally we'd just use getChild and putChild but getChild doesn't work
# unless you give it a Request object IN ADDITION to the name :/ So
# instead, we'll store a copy of this mapping so we can actually add
# extra resources to existing nodes. See self._resource_id for the key.
resource_mappings = {}
for full_path, res in desired_tree.items():
logger.info("Attaching %s to path %s", res, full_path)
last_resource = root_resource
for path_seg in full_path.split('/')[1:-1]:
if path_seg not in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource"
child_resource = Resource()
last_resource.putChild(path_seg, child_resource)
res_id = _resource_id(last_resource, path_seg)
resource_mappings[res_id] = child_resource
last_resource = child_resource
else:
# we have an existing Resource, use that instead.
res_id = _resource_id(last_resource, path_seg)
last_resource = resource_mappings[res_id]
# ===========================
# now attach the actual desired resource
last_path_seg = full_path.split('/')[-1]
# if there is already a resource here, thieve its children and
# replace it
res_id = _resource_id(last_resource, last_path_seg)
if res_id in resource_mappings:
# there is a dummy resource at this path already, which needs
# to be replaced with the desired resource.
existing_dummy_resource = resource_mappings[res_id]
for child_name in existing_dummy_resource.listNames():
child_res_id = _resource_id(
existing_dummy_resource, child_name
)
child_resource = resource_mappings[child_res_id]
# steal the children
res.putChild(child_name, child_resource)
# finally, insert the desired resource in the right place
last_resource.putChild(last_path_seg, res)
res_id = _resource_id(last_resource, last_path_seg)
resource_mappings[res_id] = res
return root_resource
def _resource_id(resource, path_seg):
"""Construct an arbitrary resource ID so you can retrieve the mapping
later.
If you want to represent resource A putChild resource B with path C,
the mapping should looks like _resource_id(A,C) = B.
Args:
resource (Resource): The *parent* Resourceb
path_seg (str): The name of the child Resource to be attached.
Returns:
str: A unique string which can be a key to the child Resource.
"""
return "%s-%s" % (resource, path_seg)
def run(hs):
PROFILE_SYNAPSE = False
if PROFILE_SYNAPSE:

View File

@@ -17,31 +17,19 @@
import synapse
from synapse.server import HomeServer
from synapse.util.versionstring import get_version_string
from synapse.config._base import ConfigError
from synapse.config.database import DatabaseConfig
from synapse.config.logger import LoggingConfig
from synapse.config.emailconfig import EmailConfig
from synapse.config.key import KeyConfig
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.storage.roommember import RoomMemberStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.pushers import SlavedPusherStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.storage.engines import create_engine
from synapse.storage import DataStore
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, preserve_fn
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse.util.logcontext import (LoggingContext, preserve_fn)
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
@@ -53,114 +41,30 @@ class SlaveConfig(DatabaseConfig):
def read_config(self, config):
self.replication_url = config["replication_url"]
self.server_name = config["server_name"]
self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get(
"use_insecure_ssl_client_just_for_testing_do_not_use", False
)
self.use_insecure_ssl_client_just_for_testing_do_not_use = True
self.user_agent_suffix = None
self.start_pushers = True
self.listeners = config["listeners"]
self.soft_file_limit = config.get("soft_file_limit")
self.daemonize = config.get("daemonize")
self.pid_file = self.abspath(config.get("pid_file"))
self.public_baseurl = config["public_baseurl"]
# some things used by the auth handler but not actually used in the
# pusher codebase
self.bcrypt_rounds = None
self.ldap_enabled = None
self.ldap_server = None
self.ldap_port = None
self.ldap_tls = None
self.ldap_search_base = None
self.ldap_search_property = None
self.ldap_email_property = None
self.ldap_full_name_property = None
# We would otherwise try to use the registration shared secret as the
# macaroon shared secret if there was no macaroon_shared_secret, but
# that means pulling in RegistrationConfig too. We don't need to be
# backwards compaitible in the pusher codebase so just make people set
# macaroon_shared_secret. We set this to None to prevent it referencing
# an undefined key.
self.registration_shared_secret = None
def default_config(self, server_name, **kwargs):
pid_file = self.abspath("pusher.pid")
def default_config(self, **kwargs):
return """\
# Slave configuration
# The replication listener on the synapse to talk to.
## Slave ##
#replication_url: https://localhost:{replication_port}/_synapse/replication
server_name: "%(server_name)s"
listeners: []
# Enable a ssh manhole listener on the pusher.
# - type: manhole
# port: {manhole_port}
# bind_address: 127.0.0.1
# Enable a metric listener on the pusher.
# - type: http
# port: {metrics_port}
# bind_address: 127.0.0.1
# resources:
# - names: ["metrics"]
# compress: False
report_stats: False
daemonize: False
pid_file: %(pid_file)s
""" % locals()
"""
class PusherSlaveConfig(SlaveConfig, LoggingConfig, EmailConfig, KeyConfig):
class PusherSlaveConfig(SlaveConfig, LoggingConfig):
pass
class PusherSlaveStore(
SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore,
SlavedAccountDataStore
SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore
):
update_pusher_last_stream_ordering_and_success = (
DataStore.update_pusher_last_stream_ordering_and_success.__func__
)
update_pusher_failing_since = (
DataStore.update_pusher_failing_since.__func__
)
update_pusher_last_stream_ordering = (
DataStore.update_pusher_last_stream_ordering.__func__
)
get_throttle_params_by_room = (
DataStore.get_throttle_params_by_room.__func__
)
set_throttle_params = (
DataStore.set_throttle_params.__func__
)
get_time_of_last_push_action_before = (
DataStore.get_time_of_last_push_action_before.__func__
)
get_profile_displayname = (
DataStore.get_profile_displayname.__func__
)
# XXX: This is a bit broken because we don't persist forgotten rooms
# in a way that they can be streamed. This means that we don't have a
# way to invalidate the forgotten rooms cache correctly.
# For now we expire the cache every 10 minutes.
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"]
)
class PusherServer(HomeServer):
@@ -194,53 +98,12 @@ class PusherServer(HomeServer):
}]
})
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse pusher now listening on port %d", port)
def start_listening(self):
for listener in self.config.listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.replication_url
pusher_pool = self.get_pusherpool()
clock = self.get_clock()
def stop_pusher(user_id, app_id, pushkey):
key = "%s:%s" % (app_id, pushkey)
@@ -292,21 +155,11 @@ class PusherServer(HomeServer):
min_stream_id, max_stream_id, affected_room_ids
)
def expire_broken_caches():
store.who_forgot_in_room.invalidate_all()
next_expire_broken_caches_ms = 0
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
now_ms = clock.time_msec()
if now_ms > next_expire_broken_caches_ms:
expire_broken_caches()
next_expire_broken_caches_ms = (
now_ms + store.BROKEN_CACHE_EXPIRY_MS
)
yield store.process_replication(result)
poke_pushers(result)
except:
@@ -323,9 +176,6 @@ def setup(config_options):
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
if not config:
sys.exit(0)
config.setup_logging()
database_engine = create_engine(config.database_config)
@@ -339,9 +189,6 @@ def setup(config_options):
)
ps.setup()
ps.start_listening()
change_resource_limit(ps.config.soft_file_limit)
def start():
ps.replicate()
@@ -356,22 +203,4 @@ def setup(config_options):
if __name__ == '__main__':
with LoggingContext("main"):
ps = setup(sys.argv[1:])
if ps.config.daemonize:
def run():
with LoggingContext("run"):
change_resource_limit(ps.config.soft_file_limit)
reactor.run()
daemon = Daemonize(
app="synapse-pusher",
pid=ps.config.pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
reactor.run()
reactor.run()

View File

@@ -1,468 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.api.constants import EventTypes
from synapse.config._base import ConfigError
from synapse.config.database import DatabaseConfig
from synapse.config.logger import LoggingConfig
from synapse.config.appservice import AppServiceConfig
from synapse.events import FrozenEvent
from synapse.handlers.presence import PresenceHandler
from synapse.http.site import SynapseSite
from synapse.http.server import JsonResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.rest.client.v2_alpha import sync
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.storage.presence import UserPresenceState
from synapse.storage.roommember import RoomMemberStore
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.stringutils import random_string
from synapse.util.versionstring import get_version_string
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import contextlib
import ujson as json
logger = logging.getLogger("synapse.app.synchrotron")
class SynchrotronConfig(DatabaseConfig, LoggingConfig, AppServiceConfig):
def read_config(self, config):
self.replication_url = config["replication_url"]
self.server_name = config["server_name"]
self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get(
"use_insecure_ssl_client_just_for_testing_do_not_use", False
)
self.user_agent_suffix = None
self.listeners = config["listeners"]
self.soft_file_limit = config.get("soft_file_limit")
self.daemonize = config.get("daemonize")
self.pid_file = self.abspath(config.get("pid_file"))
self.macaroon_secret_key = config["macaroon_secret_key"]
self.expire_access_token = config.get("expire_access_token", False)
def default_config(self, server_name, **kwargs):
pid_file = self.abspath("synchroton.pid")
return """\
# Slave configuration
# The replication listener on the synapse to talk to.
#replication_url: https://localhost:{replication_port}/_synapse/replication
server_name: "%(server_name)s"
listeners:
# Enable a /sync listener on the synchrontron
#- type: http
# port: {http_port}
# bind_address: ""
# Enable a ssh manhole listener on the synchrotron
# - type: manhole
# port: {manhole_port}
# bind_address: 127.0.0.1
# Enable a metric listener on the synchrotron
# - type: http
# port: {metrics_port}
# bind_address: 127.0.0.1
# resources:
# - names: ["metrics"]
# compress: False
report_stats: False
daemonize: False
pid_file: %(pid_file)s
""" % locals()
class SynchrotronSlavedStore(
SlavedPushRuleStore,
SlavedEventStore,
SlavedReceiptsStore,
SlavedAccountDataStore,
SlavedApplicationServiceStore,
SlavedRegistrationStore,
SlavedFilteringStore,
SlavedPresenceStore,
):
def get_presence_list_accepted(self, user_localpart):
return ()
def insert_client_ip(self, user, access_token, ip, user_agent):
pass
# XXX: This is a bit broken because we don't persist forgotten rooms
# in a way that they can be streamed. This means that we don't have a
# way to invalidate the forgotten rooms cache correctly.
# For now we expire the cache every 10 minutes.
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"]
)
class SynchrotronPresence(object):
def __init__(self, hs):
self.http_client = hs.get_simple_http_client()
self.store = hs.get_datastore()
self.user_to_num_current_syncs = {}
self.syncing_users_url = hs.config.replication_url + "/syncing_users"
self.clock = hs.get_clock()
active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = {
state.user_id: state
for state in active_presence
}
self.process_id = random_string(16)
logger.info("Presence process_id is %r", self.process_id)
def set_state(self, user, state):
# TODO Hows this supposed to work?
pass
get_states = PresenceHandler.get_states.__func__
current_state_for_users = PresenceHandler.current_state_for_users.__func__
@defer.inlineCallbacks
def user_syncing(self, user_id, affect_presence):
if affect_presence:
curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
self.user_to_num_current_syncs[user_id] = curr_sync + 1
# TODO: Send this less frequently.
# TODO: Make sure this doesn't race. Currently we can lose updates
# if two users come online in quick sucession and the second http
# to the master completes before the first.
# TODO: Don't block the sync request on this HTTP hit.
yield self._send_syncing_users()
def _end():
if affect_presence:
self.user_to_num_current_syncs[user_id] -= 1
@contextlib.contextmanager
def _user_syncing():
try:
yield
finally:
_end()
defer.returnValue(_user_syncing())
def _send_syncing_users(self):
return self.http_client.post_json_get_json(self.syncing_users_url, {
"process_id": self.process_id,
"syncing_users": [
user_id for user_id, count in self.user_to_num_current_syncs.items()
if count > 0
],
})
def process_replication(self, result):
stream = result.get("presence", {"rows": []})
for row in stream["rows"]:
(
position, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active
) = row
self.user_to_current_state[user_id] = UserPresenceState(
user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active
)
class SynchrotronTyping(object):
def __init__(self, hs):
self._latest_room_serial = 0
self._room_serials = {}
self._room_typing = {}
def stream_positions(self):
return {"typing": self._latest_room_serial}
def process_replication(self, result):
stream = result.get("typing")
if stream:
self._latest_room_serial = int(stream["position"])
for row in stream["rows"]:
position, room_id, typing_json = row
typing = json.loads(typing_json)
self._room_serials[room_id] = position
self._room_typing[room_id] = typing
class SynchrotronApplicationService(object):
def notify_interested_services(self, event):
pass
class SynchrotronServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "client":
resource = JsonResource(self, canonical_json=False)
sync.register_servlets(self, resource)
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource,
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse synchrotron now listening on port %d", port)
def start_listening(self):
for listener in self.config.listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.replication_url
clock = self.get_clock()
notifier = self.get_notifier()
presence_handler = self.get_presence_handler()
typing_handler = self.get_typing_handler()
def expire_broken_caches():
store.who_forgot_in_room.invalidate_all()
def notify_from_stream(
result, stream_name, stream_key, room=None, user=None
):
stream = result.get(stream_name)
if stream:
position_index = stream["field_names"].index("position")
if room:
room_index = stream["field_names"].index(room)
if user:
user_index = stream["field_names"].index(user)
users = ()
rooms = ()
for row in stream["rows"]:
position = row[position_index]
if user:
users = (row[user_index],)
if room:
rooms = (row[room_index],)
notifier.on_new_event(
stream_key, position, users=users, rooms=rooms
)
def notify(result):
stream = result.get("events")
if stream:
max_position = stream["position"]
for row in stream["rows"]:
position = row[0]
internal = json.loads(row[1])
event_json = json.loads(row[2])
event = FrozenEvent(event_json, internal_metadata_dict=internal)
extra_users = ()
if event.type == EventTypes.Member:
extra_users = (event.state_key,)
notifier.on_new_room_event(
event, position, max_position, extra_users
)
notify_from_stream(
result, "push_rules", "push_rules_key", user="user_id"
)
notify_from_stream(
result, "user_account_data", "account_data_key", user="user_id"
)
notify_from_stream(
result, "room_account_data", "account_data_key", user="user_id"
)
notify_from_stream(
result, "tag_account_data", "account_data_key", user="user_id"
)
notify_from_stream(
result, "receipts", "receipt_key", room="room_id"
)
notify_from_stream(
result, "typing", "typing_key", room="room_id"
)
next_expire_broken_caches_ms = 0
while True:
try:
args = store.stream_positions()
args.update(typing_handler.stream_positions())
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
logger.error("FENRIS %r", result)
now_ms = clock.time_msec()
if now_ms > next_expire_broken_caches_ms:
expire_broken_caches()
next_expire_broken_caches_ms = (
now_ms + store.BROKEN_CACHE_EXPIRY_MS
)
yield store.process_replication(result)
typing_handler.process_replication(result)
presence_handler.process_replication(result)
notify(result)
except:
logger.exception("Error replicating from %r", replication_url)
sleep(5)
def build_presence_handler(self):
return SynchrotronPresence(self)
def build_typing_handler(self):
return SynchrotronTyping(self)
def setup(config_options):
try:
config = SynchrotronConfig.load_config(
"Synapse synchrotron", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
if not config:
sys.exit(0)
config.setup_logging()
database_engine = create_engine(config.database_config)
ss = SynchrotronServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string=get_version_string("Synapse", synapse),
database_engine=database_engine,
application_service_handler=SynchrotronApplicationService(),
)
ss.setup()
ss.start_listening()
change_resource_limit(ss.config.soft_file_limit)
def start():
ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start)
return ss
if __name__ == '__main__':
with LoggingContext("main"):
ps = setup(sys.argv[1:])
if ps.config.daemonize:
def run():
with LoggingContext("run"):
change_resource_limit(ps.config.soft_file_limit)
reactor.run()
daemon = Daemonize(
app="synapse-pusher",
pid=ps.config.pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
reactor.run()

View File

@@ -66,10 +66,6 @@ def main():
config = yaml.load(open(configfile))
pidfile = config["pid_file"]
cache_factor = config.get("synctl_cache_factor", None)
if cache_factor:
os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
action = sys.argv[1] if sys.argv[1:] else "usage"
if action == "start":

View File

@@ -56,22 +56,22 @@ import logging
logger = logging.getLogger(__name__)
class ApplicationServiceScheduler(object):
class AppServiceScheduler(object):
""" Public facing API for this module. Does the required DI to tie the
components together. This also serves as the "event_pool", which in this
case is a simple array.
"""
def __init__(self, hs):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.as_api = hs.get_application_service_api()
def __init__(self, clock, store, as_api):
self.clock = clock
self.store = store
self.as_api = as_api
def create_recoverer(service, callback):
return _Recoverer(self.clock, self.store, self.as_api, service, callback)
return _Recoverer(clock, store, as_api, service, callback)
self.txn_ctrl = _TransactionController(
self.clock, self.store, self.as_api, create_recoverer
clock, store, as_api, create_recoverer
)
self.queuer = _ServiceQueuer(self.txn_ctrl)

View File

@@ -12,16 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config, ConfigError
from synapse.appservice import ApplicationService
from synapse.types import UserID
import urllib
import yaml
import logging
logger = logging.getLogger(__name__)
from ._base import Config
class AppServiceConfig(Config):
@@ -34,99 +25,3 @@ class AppServiceConfig(Config):
# A list of application service config file to use
app_service_config_files: []
"""
def load_appservices(hostname, config_files):
"""Returns a list of Application Services from the config files."""
if not isinstance(config_files, list):
logger.warning(
"Expected %s to be a list of AS config files.", config_files
)
return []
# Dicts of value -> filename
seen_as_tokens = {}
seen_ids = {}
appservices = []
for config_file in config_files:
try:
with open(config_file, 'r') as f:
appservice = _load_appservice(
hostname, yaml.load(f), config_file
)
if appservice.id in seen_ids:
raise ConfigError(
"Cannot reuse ID across application services: "
"%s (files: %s, %s)" % (
appservice.id, config_file, seen_ids[appservice.id],
)
)
seen_ids[appservice.id] = config_file
if appservice.token in seen_as_tokens:
raise ConfigError(
"Cannot reuse as_token across application services: "
"%s (files: %s, %s)" % (
appservice.token,
config_file,
seen_as_tokens[appservice.token],
)
)
seen_as_tokens[appservice.token] = config_file
logger.info("Loaded application service: %s", appservice)
appservices.append(appservice)
except Exception as e:
logger.error("Failed to load appservice from '%s'", config_file)
logger.exception(e)
raise
return appservices
def _load_appservice(hostname, as_info, config_filename):
required_string_fields = [
"id", "url", "as_token", "hs_token", "sender_localpart"
]
for field in required_string_fields:
if not isinstance(as_info.get(field), basestring):
raise KeyError("Required string field: '%s' (%s)" % (
field, config_filename,
))
localpart = as_info["sender_localpart"]
if urllib.quote(localpart) != localpart:
raise ValueError(
"sender_localpart needs characters which are not URL encoded."
)
user = UserID(localpart, hostname)
user_id = user.to_string()
# namespace checks
if not isinstance(as_info.get("namespaces"), dict):
raise KeyError("Requires 'namespaces' object.")
for ns in ApplicationService.NS_LIST:
# specific namespaces are optional
if ns in as_info["namespaces"]:
# expect a list of dicts with exclusive and regex keys
for regex_obj in as_info["namespaces"][ns]:
if not isinstance(regex_obj, dict):
raise ValueError(
"Expected namespace entry in %s to be an object,"
" but got %s", ns, regex_obj
)
if not isinstance(regex_obj.get("regex"), basestring):
raise ValueError(
"Missing/bad type 'regex' key in %s", regex_obj
)
if not isinstance(regex_obj.get("exclusive"), bool):
raise ValueError(
"Missing/bad type 'exclusive' key in %s", regex_obj
)
return ApplicationService(
token=as_info["as_token"],
url=as_info["url"],
namespaces=as_info["namespaces"],
hs_token=as_info["hs_token"],
sender=user_id,
id=as_info["id"],
)

View File

@@ -1,98 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file can't be called email.py because if it is, we cannot:
import email.utils
from ._base import Config
class EmailConfig(Config):
def read_config(self, config):
self.email_enable_notifs = False
email_config = config.get("email", {})
self.email_enable_notifs = email_config.get("enable_notifs", False)
if self.email_enable_notifs:
# make sure we can import the required deps
import jinja2
import bleach
# prevent unused warnings
jinja2
bleach
required = [
"smtp_host",
"smtp_port",
"notif_from",
"template_dir",
"notif_template_html",
"notif_template_text",
]
missing = []
for k in required:
if k not in email_config:
missing.append(k)
if (len(missing) > 0):
raise RuntimeError(
"email.enable_notifs is True but required keys are missing: %s" %
(", ".join(["email." + k for k in missing]),)
)
if config.get("public_baseurl") is None:
raise RuntimeError(
"email.enable_notifs is True but no public_baseurl is set"
)
self.email_smtp_host = email_config["smtp_host"]
self.email_smtp_port = email_config["smtp_port"]
self.email_notif_from = email_config["notif_from"]
self.email_template_dir = email_config["template_dir"]
self.email_notif_template_html = email_config["notif_template_html"]
self.email_notif_template_text = email_config["notif_template_text"]
self.email_notif_for_new_users = email_config.get(
"notif_for_new_users", True
)
if "app_name" in email_config:
self.email_app_name = email_config["app_name"]
else:
self.email_app_name = "Matrix"
# make sure it's valid
parsed = email.utils.parseaddr(self.email_notif_from)
if parsed[1] == '':
raise RuntimeError("Invalid notif_from address")
else:
self.email_enable_notifs = False
# Not much point setting defaults for the rest: it would be an
# error for them to be used.
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Enable sending emails for notification events
#email:
# enable_notifs: false
# smtp_host: "localhost"
# smtp_port: 25
# notif_from: Your Friendly Matrix Home Server <noreply@example.com>
# app_name: Matrix
# template_dir: res/templates
# notif_template_html: notif_mail.html
# notif_template_text: notif_mail.txt
# notif_for_new_users: True
"""

View File

@@ -31,14 +31,13 @@ from .cas import CasConfig
from .password import PasswordConfig
from .jwt import JWTConfig
from .ldap import LDAPConfig
from .emailconfig import EmailConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
JWTConfig, LDAPConfig, PasswordConfig, EmailConfig,):
JWTConfig, LDAPConfig, PasswordConfig,):
pass

View File

@@ -13,16 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config, ConfigError
MISSING_JWT = (
"""Missing jwt library. This is required for jwt login.
Install by running:
pip install pyjwt
"""
)
from ._base import Config
class JWTConfig(Config):
@@ -32,12 +23,6 @@ class JWTConfig(Config):
self.jwt_enabled = jwt_config.get("enabled", False)
self.jwt_secret = jwt_config["secret"]
self.jwt_algorithm = jwt_config["algorithm"]
try:
import jwt
jwt # To stop unused lint.
except ImportError:
raise ConfigError(MISSING_JWT)
else:
self.jwt_enabled = False
self.jwt_secret = None
@@ -45,8 +30,6 @@ class JWTConfig(Config):
def default_config(self, **kwargs):
return """\
# The JWT needs to contain a globally unique "sub" (subject) claim.
#
# jwt_config:
# enabled: true
# secret: "a secret"

View File

@@ -57,8 +57,6 @@ class KeyConfig(Config):
seed = self.signing_key[0].seed
self.macaroon_secret_key = hashlib.sha256(seed)
self.expire_access_token = config.get("expire_access_token", False)
def default_config(self, config_dir_path, server_name, is_generating_file=False,
**kwargs):
base_key_name = os.path.join(config_dir_path, server_name)
@@ -71,9 +69,6 @@ class KeyConfig(Config):
return """\
macaroon_secret_key: "%(macaroon_secret_key)s"
# Used to enable access token expiration.
expire_access_token: False
## Signing Keys ##
# Path to the signing key to sign messages with

View File

@@ -32,7 +32,6 @@ class RegistrationConfig(Config):
)
self.registration_shared_secret = config.get("registration_shared_secret")
self.user_creation_max_duration = int(config["user_creation_max_duration"])
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
@@ -55,11 +54,6 @@ class RegistrationConfig(Config):
# secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s"
# Sets the expiry for the short term user creation in
# milliseconds. For instance the bellow duration is two weeks
# in milliseconds.
user_creation_max_duration: 1209600000
# Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash.
# The default number of rounds is 12.

View File

@@ -100,13 +100,8 @@ class ContentRepositoryConfig(Config):
"to work"
)
self.url_preview_ip_range_whitelist = IPSet(
config.get("url_preview_ip_range_whitelist", ())
)
self.url_preview_url_blacklist = config.get(
"url_preview_url_blacklist", ()
)
if "url_preview_url_blacklist" in config:
self.url_preview_url_blacklist = config["url_preview_url_blacklist"]
def default_config(self, **kwargs):
media_store = self.default_path("media_store")
@@ -167,15 +162,6 @@ class ContentRepositoryConfig(Config):
# - '10.0.0.0/8'
# - '172.16.0.0/12'
# - '192.168.0.0/16'
#
# List of IP address CIDR ranges that the URL preview spider is allowed
# to access even if they are specified in url_preview_ip_range_blacklist.
# This is useful for specifying exceptions to wide-ranging blacklisted
# target IP ranges - e.g. for enabling URL previews for a specific private
# website only visible in your network.
#
# url_preview_ip_range_whitelist:
# - '192.168.1.1'
# Optional list of URL matches that the URL preview spider is
# denied from accessing. You should use url_preview_ip_range_blacklist

View File

@@ -28,12 +28,6 @@ class ServerConfig(Config):
self.print_pidfile = config.get("print_pidfile")
self.user_agent_suffix = config.get("user_agent_suffix")
self.use_frozen_dicts = config.get("use_frozen_dicts", True)
self.public_baseurl = config.get("public_baseurl")
self.secondary_directory_servers = config.get("secondary_directory_servers", [])
if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/':
self.public_baseurl += '/'
self.start_pushers = config.get("start_pushers", True)
self.listeners = config.get("listeners", [])
@@ -149,23 +143,11 @@ class ServerConfig(Config):
# Whether to serve a web client from the HTTP/HTTPS root resource.
web_client: True
# The public-facing base URL for the client API (not including _matrix/...)
# public_baseurl: https://example.com:8448/
# Set the soft limit on the number of file descriptors synapse can use
# Zero is used to indicate synapse should set the soft limit to the
# hard limit.
soft_file_limit: 0
# A list of other Home Servers to fetch the public room directory from
# and include in the public room directory of this home server
# This is a temporary stopgap solution to populate new server with a
# list of rooms until there exists a good solution of a decentralized
# room directory.
# secondary_directory_servers:
# - matrix.org
# - vector.im
# List of ports that Synapse should listen on, their purpose and their
# configuration.
listeners:

View File

@@ -24,7 +24,6 @@ from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError,
)
from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.events import FrozenEvent
@@ -551,25 +550,6 @@ class FederationClient(FederationBase):
raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks
def get_public_rooms(self, destinations):
results_by_server = {}
@defer.inlineCallbacks
def _get_result(s):
if s == self.server_name:
defer.returnValue()
try:
result = yield self.transport_layer.get_public_rooms(s)
results_by_server[s] = result
except:
logger.exception("Error getting room list from server %r", s)
yield concurrently_execute(_get_result, destinations, 3)
defer.returnValue(results_by_server)
@defer.inlineCallbacks
def query_auth(self, destination, room_id, event_id, local_auth):
"""

View File

@@ -387,11 +387,6 @@ class FederationServer(FederationBase):
"events": [ev.get_pdu_json(time_now) for ev in missing_events],
})
@log_function
def on_openid_userinfo(self, token):
ts_now_ms = self._clock.time_msec()
return self.store.get_user_id_for_open_id_token(token, ts_now_ms)
@log_function
def _get_persisted_pdu(self, origin, event_id, do_auth=True):
""" Get a PDU from the database with given origin and id.

View File

@@ -20,7 +20,6 @@ from .persistence import TransactionActions
from .units import Transaction
from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.retryutils import (
@@ -200,8 +199,6 @@ class TransactionQueue(object):
@defer.inlineCallbacks
@log_function
def _attempt_new_transaction(self, destination):
yield run_on_reactor()
# list of (pending_pdu, deferred, order)
if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending

View File

@@ -224,18 +224,6 @@ class TransportLayerClient(object):
defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def get_public_rooms(self, remote_server):
path = PREFIX + "/publicRooms"
response = yield self.client.get_json(
destination=remote_server,
path=path,
)
defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(self, destination, room_id, event_dict):

View File

@@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource
from synapse.http.servlet import parse_json_object_from_request, parse_string
from synapse.http.servlet import parse_json_object_from_request
from synapse.util.ratelimitutils import FederationRateLimiter
import functools
@@ -134,12 +134,10 @@ class Authenticator(object):
class BaseFederationServlet(object):
def __init__(self, handler, authenticator, ratelimiter, server_name,
room_list_handler):
def __init__(self, handler, authenticator, ratelimiter, server_name):
self.handler = handler
self.authenticator = authenticator
self.ratelimiter = ratelimiter
self.room_list_handler = room_list_handler
def _wrap(self, code):
authenticator = self.authenticator
@@ -325,7 +323,7 @@ class FederationSendLeaveServlet(BaseFederationServlet):
class FederationEventAuthServlet(BaseFederationServlet):
PATH = "/event_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
PATH = "/event_auth(?P<context>[^/]*)/(?P<event_id>[^/]*)"
def on_GET(self, origin, content, query, context, event_id):
return self.handler.on_event_auth(origin, context, event_id)
@@ -450,94 +448,6 @@ class On3pidBindServlet(BaseFederationServlet):
return code
class OpenIdUserInfo(BaseFederationServlet):
"""
Exchange a bearer token for information about a user.
The response format should be compatible with:
http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse
GET /openid/userinfo?access_token=ABDEFGH HTTP/1.1
HTTP/1.1 200 OK
Content-Type: application/json
{
"sub": "@userpart:example.org",
}
"""
PATH = "/openid/userinfo"
@defer.inlineCallbacks
def on_GET(self, request):
token = parse_string(request, "access_token")
if token is None:
defer.returnValue((401, {
"errcode": "M_MISSING_TOKEN", "error": "Access Token required"
}))
return
user_id = yield self.handler.on_openid_userinfo(token)
if user_id is None:
defer.returnValue((401, {
"errcode": "M_UNKNOWN_TOKEN",
"error": "Access Token unknown or expired"
}))
defer.returnValue((200, {"sub": user_id}))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
class PublicRoomList(BaseFederationServlet):
"""
Fetch the public room list for this server.
This API returns information in the same format as /publicRooms on the
client API, but will only ever include local public rooms and hence is
intended for consumption by other home servers.
GET /publicRooms HTTP/1.1
HTTP/1.1 200 OK
Content-Type: application/json
{
"chunk": [
{
"aliases": [
"#test:localhost"
],
"guest_can_join": false,
"name": "test room",
"num_joined_members": 3,
"room_id": "!whkydVegtvatLfXmPN:localhost",
"world_readable": false
}
],
"end": "END",
"start": "START"
}
"""
PATH = "/publicRooms"
@defer.inlineCallbacks
def on_GET(self, request):
data = yield self.room_list_handler.get_local_public_room_list()
defer.returnValue((200, data))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
SERVLET_CLASSES = (
FederationSendServlet,
FederationPullServlet,
@@ -558,8 +468,6 @@ SERVLET_CLASSES = (
FederationClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,
OpenIdUserInfo,
PublicRoomList,
)
@@ -570,5 +478,4 @@ def register_servlets(hs, resource, authenticator, ratelimiter):
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
room_list_handler=hs.get_room_list_handler(),
).register(resource)

View File

@@ -13,17 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.appservice.scheduler import AppServiceScheduler
from synapse.appservice.api import ApplicationServiceApi
from .register import RegistrationHandler
from .room import (
RoomCreationHandler, RoomContextHandler,
RoomCreationHandler, RoomListHandler, RoomContextHandler,
)
from .room_member import RoomMemberHandler
from .message import MessageHandler
from .events import EventStreamHandler, EventHandler
from .federation import FederationHandler
from .profile import ProfileHandler
from .presence import PresenceHandler
from .directory import DirectoryHandler
from .typing import TypingNotificationHandler
from .admin import AdminHandler
from .appservice import ApplicationServicesHandler
from .sync import SyncHandler
from .auth import AuthHandler
from .identity import IdentityHandler
from .receipts import ReceiptsHandler
from .search import SearchHandler
@@ -46,9 +53,22 @@ class Handlers(object):
self.event_handler = EventHandler(hs)
self.federation_handler = FederationHandler(hs)
self.profile_handler = ProfileHandler(hs)
self.presence_handler = PresenceHandler(hs)
self.room_list_handler = RoomListHandler(hs)
self.directory_handler = DirectoryHandler(hs)
self.typing_notification_handler = TypingNotificationHandler(hs)
self.admin_handler = AdminHandler(hs)
self.receipts_handler = ReceiptsHandler(hs)
asapi = ApplicationServiceApi(hs)
self.appservice_handler = ApplicationServicesHandler(
hs, asapi, AppServiceScheduler(
clock=hs.get_clock(),
store=hs.get_datastore(),
as_api=asapi
)
)
self.sync_handler = SyncHandler(hs)
self.auth_handler = AuthHandler(hs)
self.identity_handler = IdentityHandler(hs)
self.search_handler = SearchHandler(hs)
self.room_context_handler = RoomContextHandler(hs)

View File

@@ -15,10 +15,13 @@
from twisted.internet import defer
from synapse.api.errors import LimitExceededError
from synapse.api.errors import LimitExceededError, SynapseError, AuthError
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID, Requester
from synapse.types import UserID, RoomAlias, Requester
from synapse.push.action_generator import ActionGenerator
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
import logging
@@ -26,6 +29,23 @@ import logging
logger = logging.getLogger(__name__)
VISIBILITY_PRIORITY = (
"world_readable",
"shared",
"invited",
"joined",
)
MEMBERSHIP_PRIORITY = (
Membership.JOIN,
Membership.INVITE,
Membership.KNOCK,
Membership.LEAVE,
Membership.BAN,
)
class BaseHandler(object):
"""
Common base class for the event handlers.
@@ -45,10 +65,161 @@ class BaseHandler(object):
self.clock = hs.get_clock()
self.hs = hs
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks
def filter_events_for_clients(self, user_tuples, events, event_id_to_state):
""" Returns dict of user_id -> list of events that user is allowed to
see.
Args:
user_tuples (str, bool): (user id, is_peeking) for each user to be
checked. is_peeking should be true if:
* the user is not currently a member of the room, and:
* the user has not been a member of the room since the
given events
events ([synapse.events.EventBase]): list of events to filter
"""
forgotten = yield defer.gatherResults([
self.store.who_forgot_in_room(
room_id,
)
for room_id in frozenset(e.room_id for e in events)
], consumeErrors=True)
# Set of membership event_ids that have been forgotten
event_id_forgotten = frozenset(
row["event_id"] for rows in forgotten for row in rows
)
def allowed(event, user_id, is_peeking):
"""
Args:
event (synapse.events.EventBase): event to check
user_id (str)
is_peeking (bool)
"""
state = event_id_to_state[event.event_id]
# get the room_visibility at the time of the event.
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if visibility_event:
visibility = visibility_event.content.get("history_visibility", "shared")
else:
visibility = "shared"
if visibility not in VISIBILITY_PRIORITY:
visibility = "shared"
# if it was world_readable, it's easy: everyone can read it
if visibility == "world_readable":
return True
# Always allow history visibility events on boundaries. This is done
# by setting the effective visibility to the least restrictive
# of the old vs new.
if event.type == EventTypes.RoomHistoryVisibility:
prev_content = event.unsigned.get("prev_content", {})
prev_visibility = prev_content.get("history_visibility", None)
if prev_visibility not in VISIBILITY_PRIORITY:
prev_visibility = "shared"
new_priority = VISIBILITY_PRIORITY.index(visibility)
old_priority = VISIBILITY_PRIORITY.index(prev_visibility)
if old_priority < new_priority:
visibility = prev_visibility
# likewise, if the event is the user's own membership event, use
# the 'most joined' membership
membership = None
if event.type == EventTypes.Member and event.state_key == user_id:
membership = event.content.get("membership", None)
if membership not in MEMBERSHIP_PRIORITY:
membership = "leave"
prev_content = event.unsigned.get("prev_content", {})
prev_membership = prev_content.get("membership", None)
if prev_membership not in MEMBERSHIP_PRIORITY:
prev_membership = "leave"
new_priority = MEMBERSHIP_PRIORITY.index(membership)
old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
if old_priority < new_priority:
membership = prev_membership
# otherwise, get the user's membership at the time of the event.
if membership is None:
membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event:
if membership_event.event_id not in event_id_forgotten:
membership = membership_event.membership
# if the user was a member of the room at the time of the event,
# they can see it.
if membership == Membership.JOIN:
return True
if visibility == "joined":
# we weren't a member at the time of the event, so we can't
# see this event.
return False
elif visibility == "invited":
# user can also see the event if they were *invited* at the time
# of the event.
return membership == Membership.INVITE
else:
# visibility is shared: user can also see the event if they have
# become a member since the event
#
# XXX: if the user has subsequently joined and then left again,
# ideally we would share history up to the point they left. But
# we don't know when they left.
return not is_peeking
defer.returnValue({
user_id: [
event
for event in events
if allowed(event, user_id, is_peeking)
]
for user_id, is_peeking in user_tuples
})
@defer.inlineCallbacks
def _filter_events_for_client(self, user_id, events, is_peeking=False):
"""
Check which events a user is allowed to see
Args:
user_id(str): user id to be checked
events([synapse.events.EventBase]): list of events to be checked
is_peeking(bool): should be True if:
* the user is not currently a member of the room, and:
* the user has not been a member of the room since the given
events
Returns:
[synapse.events.EventBase]
"""
types = (
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id),
)
event_id_to_state = yield self.store.get_state_for_events(
frozenset(e.event_id for e in events),
types=types
)
res = yield self.filter_events_for_clients(
[(user_id, is_peeking)], events, event_id_to_state
)
defer.returnValue(res.get(user_id, []))
def ratelimit(self, requester):
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.send_message(
@@ -61,6 +232,56 @@ class BaseHandler(object):
retry_after_ms=int(1000 * (time_allowed - time_now)),
)
@defer.inlineCallbacks
def _create_new_client_event(self, builder, prev_event_ids=None):
if prev_event_ids:
prev_events = yield self.store.add_event_hashes(prev_event_ids)
prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
depth = prev_max_depth + 1
else:
latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
builder.room_id,
)
if latest_ret:
depth = max([d for _, _, d in latest_ret]) + 1
else:
depth = 1
prev_events = [
(event_id, prev_hashes)
for event_id, prev_hashes, _ in latest_ret
]
builder.prev_events = prev_events
builder.depth = depth
state_handler = self.state_handler
context = yield state_handler.compute_event_context(builder)
if builder.is_state():
builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events
)
yield self.auth.add_auth_events(builder, context)
add_hashes_and_signatures(
builder, self.server_name, self.signing_key
)
event = builder.build()
logger.debug(
"Created event %s with current state: %s",
event.event_id, context.current_state,
)
defer.returnValue(
(event, context,)
)
def is_host_in_room(self, current_state):
room_members = [
(state_key, event.membership)
@@ -75,12 +296,153 @@ class BaseHandler(object):
return True
for (state_key, membership) in room_members:
if (
self.hs.is_mine_id(state_key)
UserID.from_string(state_key).domain == self.hs.hostname
and membership == Membership.JOIN
):
return True
return False
@defer.inlineCallbacks
def handle_new_client_event(
self,
requester,
event,
context,
ratelimit=True,
extra_users=[]
):
# We now need to go and hit out to wherever we need to hit out to.
if ratelimit:
self.ratelimit(requester)
try:
self.auth.check(event, auth_events=context.current_state)
except AuthError as err:
logger.warn("Denying new event %r because %s", event, err)
raise err
yield self.maybe_kick_guest_users(event, context.current_state.values())
if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least)
room_alias_str = event.content.get("alias", None)
if room_alias_str:
room_alias = RoomAlias.from_string(room_alias_str)
directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias)
if mapping["room_id"] != event.room_id:
raise SynapseError(
400,
"Room alias %s does not point to the room" % (
room_alias_str,
)
)
federation_handler = self.hs.get_handlers().federation_handler
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
def is_inviter_member_event(e):
return (
e.type == EventTypes.Member and
e.sender == event.sender
)
event.unsigned["invite_room_state"] = [
{
"type": e.type,
"state_key": e.state_key,
"content": e.content,
"sender": e.sender,
}
for k, e in context.current_state.items()
if e.type in self.hs.config.room_invite_state_types
or is_inviter_member_event(e)
]
invitee = UserID.from_string(event.state_key)
if not self.hs.is_mine(invitee):
# TODO: Can we add signature from remote server in a nicer
# way? If we have been invited by a remote server, we need
# to get them to sign the event.
returned_invite = yield federation_handler.send_invite(
invitee.domain,
event,
)
event.unsigned.pop("room_state", None)
# TODO: Make sure the signatures actually are correct.
event.signatures.update(
returned_invite.signatures
)
if event.type == EventTypes.Redaction:
if self.auth.check_redaction(event, auth_events=context.current_state):
original_event = yield self.store.get_event(
event.redacts,
check_redacted=False,
get_prev_content=False,
allow_rejected=False,
allow_none=False
)
if event.user_id != original_event.user_id:
raise AuthError(
403,
"You don't have permission to redact events"
)
if event.type == EventTypes.Create and context.current_state:
raise AuthError(
403,
"Changing the room create event is forbidden",
)
action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event(
event, context, self
)
(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
)
# this intentionally does not yield: we don't care about the result
# and don't need to wait for it.
preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
event_stream_id, max_stream_id
)
destinations = set()
for k, s in context.current_state.items():
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN:
destinations.add(
UserID.from_string(s.state_key).domain
)
except SynapseError:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
with PreserveLoggingContext():
# Don't block waiting on waking up all the listeners.
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
# If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None)
federation_handler.handle_new_event(
event, destinations=destinations,
)
@defer.inlineCallbacks
def maybe_kick_guest_users(self, event, current_state):
# Technically this function invalidates current_state by changing it.

View File

@@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.appservice import ApplicationService
from synapse.types import UserID
import logging
@@ -34,13 +35,16 @@ def log_failure(failure):
)
# NB: Purposefully not inheriting BaseHandler since that contains way too much
# setup code which this handler does not need or use. This makes testing a lot
# easier.
class ApplicationServicesHandler(object):
def __init__(self, hs):
def __init__(self, hs, appservice_api, appservice_scheduler):
self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id
self.appservice_api = hs.get_application_service_api()
self.scheduler = hs.get_application_service_scheduler()
self.hs = hs
self.appservice_api = appservice_api
self.scheduler = appservice_scheduler
self.started_scheduler = False
@defer.inlineCallbacks
@@ -165,7 +169,8 @@ class ApplicationServicesHandler(object):
@defer.inlineCallbacks
def _is_unknown_user(self, user_id):
if not self.is_mine_id(user_id):
user = UserID.from_string(user_id)
if not self.hs.is_mine(user):
# we don't know if they are unknown or not since it isn't one of our
# users. We can't poke ASes.
defer.returnValue(False)

View File

@@ -18,7 +18,7 @@ from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.constants import LoginType
from synapse.types import UserID
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
from synapse.api.errors import AuthError, LoginError, Codes
from synapse.util.async import run_on_reactor
from twisted.web.client import PartialDownloadError
@@ -521,19 +521,14 @@ class AuthHandler(BaseHandler):
))
return m.serialize()
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
def generate_short_term_login_token(self, user_id):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
expiry = now + duration_in_ms
expiry = now + (2 * 60 * 1000)
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def generate_delete_pusher_token(self, user_id):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher")
return macaroon.serialize()
def validate_short_term_login_token_and_get_user_id(self, login_token):
try:
macaroon = pymacaroons.Macaroon.deserialize(login_token)
@@ -568,12 +563,7 @@ class AuthHandler(BaseHandler):
except_access_token_ids = [requester.access_token_id] if requester else []
try:
yield self.store.user_set_password_hash(user_id, password_hash)
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e
yield self.store.user_set_password_hash(user_id, password_hash)
yield self.store.user_delete_access_tokens(
user_id, except_access_token_ids
)
@@ -625,7 +615,4 @@ class AuthHandler(BaseHandler):
Returns:
Whether self.hash(password) == stored_hash (bool).
"""
if stored_hash:
return bcrypt.hashpw(password, stored_hash) == stored_hash
else:
return False
return bcrypt.hashpw(password, stored_hash) == stored_hash

View File

@@ -33,7 +33,6 @@ class DirectoryHandler(BaseHandler):
super(DirectoryHandler, self).__init__(hs)
self.state = hs.get_state_handler()
self.appservice_handler = hs.get_application_service_handler()
self.federation = hs.get_replication_layer()
self.federation.register_query_handler(
@@ -282,7 +281,7 @@ class DirectoryHandler(BaseHandler):
)
if not result:
# Query AS to see if it exists
as_handler = self.appservice_handler
as_handler = self.hs.get_handlers().appservice_handler
result = yield as_handler.query_room_alias_exists(room_alias)
defer.returnValue(result)

View File

@@ -58,7 +58,7 @@ class EventStreamHandler(BaseHandler):
If `only_keys` is not None, events from keys will be sent down.
"""
auth_user = UserID.from_string(auth_user_id)
presence_handler = self.hs.get_presence_handler()
presence_handler = self.hs.get_handlers().presence_handler
context = yield presence_handler.user_syncing(
auth_user_id, affect_presence=affect_presence,

View File

@@ -33,7 +33,7 @@ from synapse.util.frozenutils import unfreeze
from synapse.crypto.event_signing import (
compute_event_signature, add_hashes_and_signatures,
)
from synapse.types import UserID, get_domain_from_id
from synapse.types import UserID
from synapse.events.utils import prune_event
@@ -453,7 +453,7 @@ class FederationHandler(BaseHandler):
joined_domains = {}
for u, d in joined_users:
try:
dom = get_domain_from_id(u)
dom = UserID.from_string(u).domain
old_d = joined_domains.get(dom)
if old_d:
joined_domains[dom] = min(d, old_d)
@@ -682,8 +682,7 @@ class FederationHandler(BaseHandler):
})
try:
message_handler = self.hs.get_handlers().message_handler
event, context = yield message_handler._create_new_client_event(
event, context = yield self._create_new_client_event(
builder=builder,
)
except AuthError as e:
@@ -744,7 +743,9 @@ class FederationHandler(BaseHandler):
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN:
destinations.add(get_domain_from_id(s.state_key))
destinations.add(
UserID.from_string(s.state_key).domain
)
except:
logger.warn(
"Failed to get destination from event %s", s.event_id
@@ -914,8 +915,7 @@ class FederationHandler(BaseHandler):
"state_key": user_id,
})
message_handler = self.hs.get_handlers().message_handler
event, context = yield message_handler._create_new_client_event(
event, context = yield self._create_new_client_event(
builder=builder,
)
@@ -970,7 +970,9 @@ class FederationHandler(BaseHandler):
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.LEAVE:
destinations.add(get_domain_from_id(s.state_key))
destinations.add(
UserID.from_string(s.state_key).domain
)
except:
logger.warn(
"Failed to get destination from event %s", s.event_id
@@ -1113,7 +1115,7 @@ class FederationHandler(BaseHandler):
if not event.internal_metadata.is_outlier():
action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event(
event, context
event, context, self
)
event_stream_id, max_stream_id = yield self.store.persist_event(
@@ -1690,10 +1692,7 @@ class FederationHandler(BaseHandler):
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder)
message_handler = self.hs.get_handlers().message_handler
event, context = yield message_handler._create_new_client_event(
builder=builder
)
event, context = yield self._create_new_client_event(builder=builder)
event, context = yield self.add_display_name_to_third_party_invite(
event_dict, event, context
@@ -1721,8 +1720,7 @@ class FederationHandler(BaseHandler):
def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
builder = self.event_builder_factory.new(event_dict)
message_handler = self.hs.get_handlers().message_handler
event, context = yield message_handler._create_new_client_event(
event, context = yield self._create_new_client_event(
builder=builder,
)
@@ -1761,8 +1759,7 @@ class FederationHandler(BaseHandler):
event_dict["content"]["third_party_invite"]["display_name"] = display_name
builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder)
message_handler = self.hs.get_handlers().message_handler
event, context = yield message_handler._create_new_client_event(builder=builder)
event, context = yield self._create_new_client_event(builder=builder)
defer.returnValue((event, context))
@defer.inlineCallbacks

View File

@@ -17,19 +17,13 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.push.action_generator import ActionGenerator
from synapse.streams.config import PaginationConfig
from synapse.types import (
UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id
)
from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
from synapse.visibility import filter_events_for_client
from synapse.types import UserID, RoomStreamToken, StreamToken
from ._base import BaseHandler
@@ -129,8 +123,7 @@ class MessageHandler(BaseHandler):
"end": next_token.to_string(),
})
events = yield filter_events_for_client(
self.store,
events = yield self._filter_events_for_client(
user_id,
events,
is_peeking=(member_event_id is None),
@@ -236,7 +229,7 @@ class MessageHandler(BaseHandler):
)
if event.type == EventTypes.Message:
presence = self.hs.get_presence_handler()
presence = self.hs.get_handlers().presence_handler
yield presence.bump_presence_active_time(user)
def deduplicate_state_event(self, event, context):
@@ -490,8 +483,8 @@ class MessageHandler(BaseHandler):
]
).addErrback(unwrapFirstError)
messages = yield filter_events_for_client(
self.store, user_id, messages
messages = yield self._filter_events_for_client(
user_id, messages
)
start_token = now_token.copy_and_replace("room_key", token[0])
@@ -626,8 +619,8 @@ class MessageHandler(BaseHandler):
end_token=stream_token
)
messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking
messages = yield self._filter_events_for_client(
user_id, messages, is_peeking=is_peeking
)
start_token = StreamToken.START.copy_and_replace("room_key", token[0])
@@ -674,7 +667,7 @@ class MessageHandler(BaseHandler):
and m.content["membership"] == Membership.JOIN
]
presence_handler = self.hs.get_presence_handler()
presence_handler = self.hs.get_handlers().presence_handler
@defer.inlineCallbacks
def get_presence():
@@ -707,8 +700,8 @@ class MessageHandler(BaseHandler):
consumeErrors=True,
).addErrback(unwrapFirstError)
messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking,
messages = yield self._filter_events_for_client(
user_id, messages, is_peeking=is_peeking,
)
start_token = now_token.copy_and_replace("room_key", token[0])
@@ -731,193 +724,3 @@ class MessageHandler(BaseHandler):
ret["membership"] = membership
defer.returnValue(ret)
@defer.inlineCallbacks
def _create_new_client_event(self, builder, prev_event_ids=None):
if prev_event_ids:
prev_events = yield self.store.add_event_hashes(prev_event_ids)
prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
depth = prev_max_depth + 1
else:
latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
builder.room_id,
)
if latest_ret:
depth = max([d for _, _, d in latest_ret]) + 1
else:
depth = 1
prev_events = [
(event_id, prev_hashes)
for event_id, prev_hashes, _ in latest_ret
]
builder.prev_events = prev_events
builder.depth = depth
state_handler = self.state_handler
context = yield state_handler.compute_event_context(builder)
if builder.is_state():
builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events
)
yield self.auth.add_auth_events(builder, context)
signing_key = self.hs.config.signing_key[0]
add_hashes_and_signatures(
builder, self.server_name, signing_key
)
event = builder.build()
logger.debug(
"Created event %s with current state: %s",
event.event_id, context.current_state,
)
defer.returnValue(
(event, context,)
)
@defer.inlineCallbacks
def handle_new_client_event(
self,
requester,
event,
context,
ratelimit=True,
extra_users=[]
):
# We now need to go and hit out to wherever we need to hit out to.
if ratelimit:
self.ratelimit(requester)
try:
self.auth.check(event, auth_events=context.current_state)
except AuthError as err:
logger.warn("Denying new event %r because %s", event, err)
raise err
yield self.maybe_kick_guest_users(event, context.current_state.values())
if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least)
room_alias_str = event.content.get("alias", None)
if room_alias_str:
room_alias = RoomAlias.from_string(room_alias_str)
directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias)
if mapping["room_id"] != event.room_id:
raise SynapseError(
400,
"Room alias %s does not point to the room" % (
room_alias_str,
)
)
federation_handler = self.hs.get_handlers().federation_handler
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
def is_inviter_member_event(e):
return (
e.type == EventTypes.Member and
e.sender == event.sender
)
event.unsigned["invite_room_state"] = [
{
"type": e.type,
"state_key": e.state_key,
"content": e.content,
"sender": e.sender,
}
for k, e in context.current_state.items()
if e.type in self.hs.config.room_invite_state_types
or is_inviter_member_event(e)
]
invitee = UserID.from_string(event.state_key)
if not self.hs.is_mine(invitee):
# TODO: Can we add signature from remote server in a nicer
# way? If we have been invited by a remote server, we need
# to get them to sign the event.
returned_invite = yield federation_handler.send_invite(
invitee.domain,
event,
)
event.unsigned.pop("room_state", None)
# TODO: Make sure the signatures actually are correct.
event.signatures.update(
returned_invite.signatures
)
if event.type == EventTypes.Redaction:
if self.auth.check_redaction(event, auth_events=context.current_state):
original_event = yield self.store.get_event(
event.redacts,
check_redacted=False,
get_prev_content=False,
allow_rejected=False,
allow_none=False
)
if event.user_id != original_event.user_id:
raise AuthError(
403,
"You don't have permission to redact events"
)
if event.type == EventTypes.Create and context.current_state:
raise AuthError(
403,
"Changing the room create event is forbidden",
)
action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event(
event, context
)
(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
)
# this intentionally does not yield: we don't care about the result
# and don't need to wait for it.
preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
event_stream_id, max_stream_id
)
destinations = set()
for k, s in context.current_state.items():
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN:
destinations.add(get_domain_from_id(s.state_key))
except SynapseError:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
with PreserveLoggingContext():
# Don't block waiting on waking up all the listeners.
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
# If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None)
federation_handler.handle_new_event(
event, destinations=destinations,
)

View File

@@ -33,9 +33,11 @@ from synapse.util.logcontext import preserve_fn
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
from synapse.types import UserID, get_domain_from_id
from synapse.types import UserID
import synapse.metrics
from ._base import BaseHandler
import logging
@@ -68,18 +70,14 @@ FEDERATION_TIMEOUT = 30 * 60 * 1000
# How often to resend presence to remote servers
FEDERATION_PING_INTERVAL = 25 * 60 * 1000
# How long we will wait before assuming that the syncs from an external process
# are dead.
EXTERNAL_PROCESS_EXPIRY = 5 * 60 * 1000
assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
class PresenceHandler(object):
class PresenceHandler(BaseHandler):
def __init__(self, hs):
self.is_mine = hs.is_mine
self.is_mine_id = hs.is_mine_id
super(PresenceHandler, self).__init__(hs)
self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.wheel_timer = WheelTimer()
@@ -140,7 +138,7 @@ class PresenceHandler(object):
obj=state.user_id,
then=state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
)
if self.is_mine_id(state.user_id):
if self.hs.is_mine_id(state.user_id):
self.wheel_timer.insert(
now=now,
obj=state.user_id,
@@ -162,26 +160,15 @@ class PresenceHandler(object):
self.serial_to_user = {}
self._next_serial = 1
# Keeps track of the number of *ongoing* syncs on this process. While
# this is non zero a user will never go offline.
# Keeps track of the number of *ongoing* syncs. While this is non zero
# a user will never go offline.
self.user_to_num_current_syncs = {}
# Keeps track of the number of *ongoing* syncs on other processes.
# While any sync is ongoing on another process the user will never
# go offline.
# Each process has a unique identifier and an update frequency. If
# no update is received from that process within the update period then
# we assume that all the sync requests on that process have stopped.
# Stored as a dict from process_id to set of user_id, and a dict of
# process_id to millisecond timestamp last updated.
self.external_process_to_current_syncs = {}
self.external_process_last_updated_ms = {}
# Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline.
self.clock.call_later(
30 * 1000,
0 * 1000,
self.clock.looping_call,
self._handle_timeouts,
5000,
@@ -241,7 +228,7 @@ class PresenceHandler(object):
new_state, should_notify, should_ping = handle_update(
prev_state, new_state,
is_mine=self.is_mine_id(user_id),
is_mine=self.hs.is_mine_id(user_id),
wheel_timer=self.wheel_timer,
now=now
)
@@ -287,34 +274,21 @@ class PresenceHandler(object):
# Fetch the list of users that *may* have timed out. Things may have
# changed since the timeout was set, so we won't necessarily have to
# take any action.
users_to_check = set(self.wheel_timer.fetch(now))
# Check whether the lists of syncing processes from an external
# process have expired.
expired_process_ids = [
process_id for process_id, last_update
in self.external_process_last_update.items()
if now - last_update > EXTERNAL_PROCESS_EXPIRY
]
for process_id in expired_process_ids:
users_to_check.update(
self.external_process_to_current_syncs.pop(process_id, ())
)
self.external_process_last_update.pop(process_id)
users_to_check = self.wheel_timer.fetch(now)
states = [
self.user_to_current_state.get(
user_id, UserPresenceState.default(user_id)
)
for user_id in users_to_check
for user_id in set(users_to_check)
]
timers_fired_counter.inc_by(len(states))
changes = handle_timeouts(
states,
is_mine_fn=self.is_mine_id,
syncing_users=self.get_syncing_users(),
is_mine_fn=self.hs.is_mine_id,
user_to_num_current_syncs=self.user_to_num_current_syncs,
now=now,
)
@@ -391,73 +365,6 @@ class PresenceHandler(object):
defer.returnValue(_user_syncing())
def get_currently_syncing_users(self):
"""Get the set of user ids that are currently syncing on this HS.
Returns:
set(str): A set of user_id strings.
"""
syncing_user_ids = {
user_id for user_id, count in self.user_to_num_current_syncs.items()
if count
}
syncing_user_ids.update(self.external_process_to_current_syncs.values())
return syncing_user_ids
@defer.inlineCallbacks
def update_external_syncs(self, process_id, syncing_user_ids):
"""Update the syncing users for an external process
Args:
process_id(str): An identifier for the process the users are
syncing against. This allows synapse to process updates
as user start and stop syncing against a given process.
syncing_user_ids(set(str)): The set of user_ids that are
currently syncing on that server.
"""
# Grab the previous list of user_ids that were syncing on that process
prev_syncing_user_ids = (
self.external_process_to_current_syncs.get(process_id, set())
)
# Grab the current presence state for both the users that are syncing
# now and the users that were syncing before this update.
prev_states = yield self.current_state_for_users(
syncing_user_ids | prev_syncing_user_ids
)
updates = []
time_now_ms = self.clock.time_msec()
# For each new user that is syncing check if we need to mark them as
# being online.
for new_user_id in syncing_user_ids - prev_syncing_user_ids:
prev_state = prev_states[new_user_id]
if prev_state.state == PresenceState.OFFLINE:
updates.append(prev_state.copy_and_replace(
state=PresenceState.ONLINE,
last_active_ts=time_now_ms,
last_user_sync_ts=time_now_ms,
))
else:
updates.append(prev_state.copy_and_replace(
last_user_sync_ts=time_now_ms,
))
# For each user that is still syncing or stopped syncing update the
# last sync time so that we will correctly apply the grace period when
# they stop syncing.
for old_user_id in prev_syncing_user_ids:
prev_state = prev_states[old_user_id]
updates.append(prev_state.copy_and_replace(
last_user_sync_ts=time_now_ms,
))
yield self._update_states(updates)
# Update the last updated time for the process. We expire the entries
# if we don't receive an update in the given timeframe.
self.external_process_last_updated_ms[process_id] = self.clock.time_msec()
self.external_process_to_current_syncs[process_id] = syncing_user_ids
@defer.inlineCallbacks
def current_state_for_user(self, user_id):
"""Get the current presence state for a user.
@@ -520,7 +427,7 @@ class PresenceHandler(object):
hosts_to_states = {}
for room_id, states in room_ids_to_states.items():
local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
local_states = filter(lambda s: self.hs.is_mine_id(s.user_id), states)
if not local_states:
continue
@@ -529,11 +436,11 @@ class PresenceHandler(object):
hosts_to_states.setdefault(host, []).extend(local_states)
for user_id, states in users_to_states.items():
local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
local_states = filter(lambda s: self.hs.is_mine_id(s.user_id), states)
if not local_states:
continue
host = get_domain_from_id(user_id)
host = UserID.from_string(user_id).domain
hosts_to_states.setdefault(host, []).extend(local_states)
# TODO: de-dup hosts_to_states, as a single host might have multiple
@@ -704,14 +611,14 @@ class PresenceHandler(object):
# don't need to send to local clients here, as that is done as part
# of the event stream/sync.
# TODO: Only send to servers not already in the room.
if self.is_mine(user):
if self.hs.is_mine(user):
state = yield self.current_state_for_user(user.to_string())
hosts = yield self.store.get_joined_hosts_for_room(room_id)
self._push_to_remotes({host: (state,) for host in hosts})
else:
user_ids = yield self.store.get_users_in_room(room_id)
user_ids = filter(self.is_mine_id, user_ids)
user_ids = filter(self.hs.is_mine_id, user_ids)
states = yield self.current_state_for_users(user_ids)
@@ -721,7 +628,7 @@ class PresenceHandler(object):
def get_presence_list(self, observer_user, accepted=None):
"""Returns the presence for all users in their presence list.
"""
if not self.is_mine(observer_user):
if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server")
presence_list = yield self.store.get_presence_list(
@@ -752,7 +659,7 @@ class PresenceHandler(object):
observer_user.localpart, observed_user.to_string()
)
if self.is_mine(observed_user):
if self.hs.is_mine(observed_user):
yield self.invite_presence(observed_user, observer_user)
else:
yield self.federation.send_edu(
@@ -768,11 +675,11 @@ class PresenceHandler(object):
def invite_presence(self, observed_user, observer_user):
"""Handles new presence invites.
"""
if not self.is_mine(observed_user):
if not self.hs.is_mine(observed_user):
raise SynapseError(400, "User is not hosted on this Home Server")
# TODO: Don't auto accept
if self.is_mine(observer_user):
if self.hs.is_mine(observer_user):
yield self.accept_presence(observed_user, observer_user)
else:
self.federation.send_edu(
@@ -835,7 +742,7 @@ class PresenceHandler(object):
Returns:
A Deferred.
"""
if not self.is_mine(observer_user):
if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server")
yield self.store.del_presence_list(
@@ -927,11 +834,7 @@ def _format_user_presence_state(state, now):
class PresenceEventSource(object):
def __init__(self, hs):
# We can't call get_presence_handler here because there's a cycle:
#
# Presence -> Notifier -> PresenceEventSource -> Presence
#
self.get_presence_handler = hs.get_presence_handler
self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@@ -957,7 +860,7 @@ class PresenceEventSource(object):
from_key = int(from_key)
room_ids = room_ids or []
presence = self.get_presence_handler()
presence = self.hs.get_handlers().presence_handler
stream_change_cache = self.store.presence_stream_cache
if not room_ids:
@@ -1030,14 +933,15 @@ class PresenceEventSource(object):
return self.get_new_events(user, from_key=None, include_offline=False)
def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now):
"""Checks the presence of users that have timed out and updates as
appropriate.
Args:
user_states(list): List of UserPresenceState's to check.
is_mine_fn (fn): Function that returns if a user_id is ours
syncing_user_ids (set): Set of user_ids with active syncs.
user_to_num_current_syncs (dict): Mapping of user_id to number of currently
active syncs.
now (int): Current time in ms.
Returns:
@@ -1048,20 +952,21 @@ def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
for state in user_states:
is_mine = is_mine_fn(state.user_id)
new_state = handle_timeout(state, is_mine, syncing_user_ids, now)
new_state = handle_timeout(state, is_mine, user_to_num_current_syncs, now)
if new_state:
changes[state.user_id] = new_state
return changes.values()
def handle_timeout(state, is_mine, syncing_user_ids, now):
def handle_timeout(state, is_mine, user_to_num_current_syncs, now):
"""Checks the presence of the user to see if any of the timers have elapsed
Args:
state (UserPresenceState)
is_mine (bool): Whether the user is ours
syncing_user_ids (set): Set of user_ids with active syncs.
user_to_num_current_syncs (dict): Mapping of user_id to number of currently
active syncs.
now (int): Current time in ms.
Returns:
@@ -1095,7 +1000,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
# If there are have been no sync for a while (and none ongoing),
# set presence to offline
if user_id not in syncing_user_ids:
if not user_to_num_current_syncs.get(user_id, 0):
if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT:
state = state.copy_and_replace(
state=PresenceState.OFFLINE,

View File

@@ -29,8 +29,6 @@ class ReceiptsHandler(BaseHandler):
def __init__(self, hs):
super(ReceiptsHandler, self).__init__(hs)
self.server_name = hs.config.server_name
self.store = hs.get_datastore()
self.hs = hs
self.federation = hs.get_replication_layer()
self.federation.register_edu_handler(
@@ -133,9 +131,12 @@ class ReceiptsHandler(BaseHandler):
event_ids = receipt["event_ids"]
data = receipt["data"]
remotedomains = yield self.store.get_joined_hosts_for_room(room_id)
remotedomains = remotedomains.copy()
remotedomains.discard(self.server_name)
remotedomains = set()
rm_handler = self.hs.get_handlers().room_member_handler
yield rm_handler.fetch_room_distributions_into(
room_id, localusers=None, remotedomains=remotedomains
)
logger.debug("Sending receipt to: %r", remotedomains)

View File

@@ -16,7 +16,7 @@
"""Contains functions for registering clients."""
from twisted.internet import defer
from synapse.types import UserID, Requester
from synapse.types import UserID
from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
)
@@ -358,62 +358,8 @@ class RegistrationHandler(BaseHandler):
)
defer.returnValue(data)
@defer.inlineCallbacks
def get_or_create_user(self, localpart, displayname, duration_seconds):
"""Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one.
Args:
localpart : The local part of the user ID to register. If None,
one will be randomly generated.
Returns:
A tuple of (user_id, access_token).
Raises:
RegistrationError if there was a problem registering.
"""
yield run_on_reactor()
if localpart is None:
raise SynapseError(400, "Request must include user id")
need_register = True
try:
yield self.check_username(localpart)
except SynapseError as e:
if e.errcode == Codes.USER_IN_USE:
need_register = False
else:
raise
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
auth_handler = self.hs.get_handlers().auth_handler
token = auth_handler.generate_short_term_login_token(user_id, duration_seconds)
if need_register:
yield self.store.register(
user_id=user_id,
token=token,
password_hash=None
)
yield registered_user(self.distributor, user)
else:
yield self.store.user_delete_access_tokens(user_id=user_id)
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname)
profile_handler = self.hs.get_handlers().profile_handler
yield profile_handler.set_displayname(
user, Requester(user, token, False), displayname
)
defer.returnValue((user_id, token))
def auth_handler(self):
return self.hs.get_auth_handler()
return self.hs.get_handlers().auth_handler
@defer.inlineCallbacks
def guest_access_token_for(self, medium, address, inviter_user_id):

View File

@@ -26,7 +26,6 @@ from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.util import stringutils
from synapse.util.async import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client
from collections import OrderedDict
@@ -36,8 +35,6 @@ import string
logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
id_server_scheme = "https://"
@@ -346,14 +343,8 @@ class RoomListHandler(BaseHandler):
def __init__(self, hs):
super(RoomListHandler, self).__init__(hs)
self.response_cache = ResponseCache()
self.remote_list_request_cache = ResponseCache()
self.remote_list_cache = {}
self.fetch_looping_call = hs.get_clock().looping_call(
self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL
)
self.fetch_all_remote_lists()
def get_local_public_room_list(self):
def get_public_room_list(self):
result = self.response_cache.get(())
if not result:
result = self.response_cache.set((), self._get_public_room_list())
@@ -435,55 +426,6 @@ class RoomListHandler(BaseHandler):
# FIXME (erikj): START is no longer a valid value
defer.returnValue({"start": "START", "end": "END", "chunk": results})
@defer.inlineCallbacks
def fetch_all_remote_lists(self):
deferred = self.hs.get_replication_layer().get_public_rooms(
self.hs.config.secondary_directory_servers
)
self.remote_list_request_cache.set((), deferred)
self.remote_list_cache = yield deferred
@defer.inlineCallbacks
def get_aggregated_public_room_list(self):
"""
Get the public room list from this server and the servers
specified in the secondary_directory_servers config option.
XXX: Pagination...
"""
# We return the results from out cache which is updated by a looping call,
# unless we're missing a cache entry, in which case wait for the result
# of the fetch if there's one in progress. If not, omit that server.
wait = False
for s in self.hs.config.secondary_directory_servers:
if s not in self.remote_list_cache:
logger.warn("No cached room list from %s: waiting for fetch", s)
wait = True
break
if wait and self.remote_list_request_cache.get(()):
yield self.remote_list_request_cache.get(())
public_rooms = yield self.get_local_public_room_list()
# keep track of which room IDs we've seen so we can de-dup
room_ids = set()
# tag all the ones in our list with our server name.
# Also add the them to the de-deping set
for room in public_rooms['chunk']:
room["server_name"] = self.hs.hostname
room_ids.add(room["room_id"])
# Now add the results from federation
for server_name, server_result in self.remote_list_cache.items():
for room in server_result["chunk"]:
if room["room_id"] not in room_ids:
room["server_name"] = server_name
public_rooms["chunk"].append(room)
room_ids.add(room["room_id"])
defer.returnValue(public_rooms)
class RoomContextHandler(BaseHandler):
@defer.inlineCallbacks
@@ -507,12 +449,10 @@ class RoomContextHandler(BaseHandler):
now_token = yield self.hs.get_event_sources().get_current_token()
def filter_evts(events):
return filter_events_for_client(
self.store,
return self._filter_events_for_client(
user.to_string(),
events,
is_peeking=is_guest
)
is_peeking=is_guest)
event = yield self.store.get_event(event_id, get_prev_content=True,
allow_none=True)

View File

@@ -55,6 +55,35 @@ class RoomMemberHandler(BaseHandler):
self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room")
@defer.inlineCallbacks
def get_room_members(self, room_id):
users = yield self.store.get_users_in_room(room_id)
defer.returnValue([UserID.from_string(u) for u in users])
@defer.inlineCallbacks
def fetch_room_distributions_into(self, room_id, localusers=None,
remotedomains=None, ignore_user=None):
"""Fetch the distribution of a room, adding elements to either
'localusers' or 'remotedomains', which should be a set() if supplied.
If ignore_user is set, ignore that user.
This function returns nothing; its result is performed by the
side-effect on the two passed sets. This allows easy accumulation of
member lists of multiple rooms at once if required.
"""
members = yield self.get_room_members(room_id)
for member in members:
if ignore_user is not None and member == ignore_user:
continue
if self.hs.is_mine(member):
if localusers is not None:
localusers.add(member)
else:
if remotedomains is not None:
remotedomains.add(member.domain)
@defer.inlineCallbacks
def _local_membership_update(
self, requester, target, room_id, membership,
@@ -84,7 +113,7 @@ class RoomMemberHandler(BaseHandler):
prev_event_ids=prev_event_ids,
)
yield msg_handler.handle_new_client_event(
yield self.handle_new_client_event(
requester,
event,
context,
@@ -203,7 +232,7 @@ class RoomMemberHandler(BaseHandler):
if old_membership == "ban" and action != "unban":
raise SynapseError(
403,
"Cannot %s user who was banned" % (action,),
"Cannot %s user who was is banned" % (action,),
errcode=Codes.BAD_STATE
)
@@ -328,7 +357,7 @@ class RoomMemberHandler(BaseHandler):
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
yield message_handler.handle_new_client_event(
yield self.handle_new_client_event(
requester,
event,
context,
@@ -397,6 +426,21 @@ class RoomMemberHandler(BaseHandler):
if invite:
defer.returnValue(UserID.from_string(invite.sender))
@defer.inlineCallbacks
def get_joined_rooms_for_user(self, user):
"""Returns a list of roomids that the user has any of the given
membership states in."""
rooms = yield self.store.get_rooms_for_user(
user.to_string(),
)
# For some reason the list of events contains duplicates
# TODO(paul): work out why because I really don't think it should
room_ids = set(r.room_id for r in rooms)
defer.returnValue(room_ids)
@defer.inlineCallbacks
def do_3pid_invite(
self,
@@ -413,7 +457,8 @@ class RoomMemberHandler(BaseHandler):
)
if invitee:
yield self.update_membership(
handler = self.hs.get_handlers().room_member_handler
yield handler.update_membership(
requester,
UserID.from_string(invitee),
room_id,

View File

@@ -21,7 +21,6 @@ from synapse.api.constants import Membership, EventTypes
from synapse.api.filtering import Filter
from synapse.api.errors import SynapseError
from synapse.events.utils import serialize_event
from synapse.visibility import filter_events_for_client
from unpaddedbase64 import decode_base64, encode_base64
@@ -173,8 +172,8 @@ class SearchHandler(BaseHandler):
filtered_events = search_filter.filter([r["event"] for r in results])
events = yield filter_events_for_client(
self.store, user.to_string(), filtered_events
events = yield self._filter_events_for_client(
user.to_string(), filtered_events
)
events.sort(key=lambda e: -rank_map[e.event_id])
@@ -224,8 +223,8 @@ class SearchHandler(BaseHandler):
r["event"] for r in results
])
events = yield filter_events_for_client(
self.store, user.to_string(), filtered_events
events = yield self._filter_events_for_client(
user.to_string(), filtered_events
)
room_events.extend(events)
@@ -282,12 +281,12 @@ class SearchHandler(BaseHandler):
event.room_id, event.event_id, before_limit, after_limit
)
res["events_before"] = yield filter_events_for_client(
self.store, user.to_string(), res["events_before"]
res["events_before"] = yield self._filter_events_for_client(
user.to_string(), res["events_before"]
)
res["events_after"] = yield filter_events_for_client(
self.store, user.to_string(), res["events_after"]
res["events_after"] = yield self._filter_events_for_client(
user.to_string(), res["events_after"]
)
res["start"] = now_token.copy_and_replace(

File diff suppressed because it is too large Load Diff

View File

@@ -15,6 +15,8 @@
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.metrics import Measure
@@ -30,16 +32,14 @@ logger = logging.getLogger(__name__)
# A tiny object useful for storing a user's membership in a room, as a mapping
# key
RoomMember = namedtuple("RoomMember", ("room_id", "user_id"))
RoomMember = namedtuple("RoomMember", ("room_id", "user"))
class TypingHandler(object):
class TypingNotificationHandler(BaseHandler):
def __init__(self, hs):
self.store = hs.get_datastore()
self.server_name = hs.config.server_name
self.auth = hs.get_auth()
self.is_mine_id = hs.is_mine_id
self.notifier = hs.get_notifier()
super(TypingNotificationHandler, self).__init__(hs)
self.homeserver = hs
self.clock = hs.get_clock()
@@ -67,23 +67,20 @@ class TypingHandler(object):
@defer.inlineCallbacks
def started_typing(self, target_user, auth_user, room_id, timeout):
target_user_id = target_user.to_string()
auth_user_id = auth_user.to_string()
if not self.is_mine_id(target_user_id):
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server")
if target_user_id != auth_user_id:
if target_user != auth_user:
raise AuthError(400, "Cannot set another user's typing state")
yield self.auth.check_joined_room(room_id, target_user_id)
yield self.auth.check_joined_room(room_id, target_user.to_string())
logger.debug(
"%s has started typing in %s", target_user_id, room_id
"%s has started typing in %s", target_user.to_string(), room_id
)
until = self.clock.time_msec() + timeout
member = RoomMember(room_id=room_id, user_id=target_user_id)
member = RoomMember(room_id=room_id, user=target_user)
was_present = member in self._member_typing_until
@@ -107,28 +104,25 @@ class TypingHandler(object):
yield self._push_update(
room_id=room_id,
user_id=target_user_id,
user=target_user,
typing=True,
)
@defer.inlineCallbacks
def stopped_typing(self, target_user, auth_user, room_id):
target_user_id = target_user.to_string()
auth_user_id = auth_user.to_string()
if not self.is_mine_id(target_user_id):
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server")
if target_user_id != auth_user_id:
if target_user != auth_user:
raise AuthError(400, "Cannot set another user's typing state")
yield self.auth.check_joined_room(room_id, target_user_id)
yield self.auth.check_joined_room(room_id, target_user.to_string())
logger.debug(
"%s has stopped typing in %s", target_user_id, room_id
"%s has stopped typing in %s", target_user.to_string(), room_id
)
member = RoomMember(room_id=room_id, user_id=target_user_id)
member = RoomMember(room_id=room_id, user=target_user)
if member in self._member_typing_timer:
self.clock.cancel_call_later(self._member_typing_timer[member])
@@ -138,9 +132,8 @@ class TypingHandler(object):
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
user_id = user.to_string()
if self.is_mine_id(user_id):
member = RoomMember(room_id=room_id, user=user_id)
if self.hs.is_mine(user):
member = RoomMember(room_id=room_id, user=user)
yield self._stopped_typing(member)
@defer.inlineCallbacks
@@ -151,7 +144,7 @@ class TypingHandler(object):
yield self._push_update(
room_id=member.room_id,
user_id=member.user_id,
user=member.user,
typing=False,
)
@@ -163,53 +156,61 @@ class TypingHandler(object):
del self._member_typing_timer[member]
@defer.inlineCallbacks
def _push_update(self, room_id, user_id, typing):
domains = yield self.store.get_joined_hosts_for_room(room_id)
def _push_update(self, room_id, user, typing):
localusers = set()
remotedomains = set()
rm_handler = self.homeserver.get_handlers().room_member_handler
yield rm_handler.fetch_room_distributions_into(
room_id, localusers=localusers, remotedomains=remotedomains
)
if localusers:
self._push_update_local(
room_id=room_id,
user=user,
typing=typing
)
deferreds = []
for domain in domains:
if domain == self.server_name:
self._push_update_local(
room_id=room_id,
user_id=user_id,
typing=typing
)
else:
deferreds.append(self.federation.send_edu(
destination=domain,
edu_type="m.typing",
content={
"room_id": room_id,
"user_id": user_id,
"typing": typing,
},
))
for domain in remotedomains:
deferreds.append(self.federation.send_edu(
destination=domain,
edu_type="m.typing",
content={
"room_id": room_id,
"user_id": user.to_string(),
"typing": typing,
},
))
yield defer.DeferredList(deferreds, consumeErrors=True)
@defer.inlineCallbacks
def _recv_edu(self, origin, content):
room_id = content["room_id"]
user_id = content["user_id"]
user = UserID.from_string(content["user_id"])
# Check that the string is a valid user id
UserID.from_string(user_id)
localusers = set()
domains = yield self.store.get_joined_hosts_for_room(room_id)
rm_handler = self.homeserver.get_handlers().room_member_handler
yield rm_handler.fetch_room_distributions_into(
room_id, localusers=localusers
)
if self.server_name in domains:
if localusers:
self._push_update_local(
room_id=room_id,
user_id=user_id,
user=user,
typing=content["typing"]
)
def _push_update_local(self, room_id, user_id, typing):
def _push_update_local(self, room_id, user, typing):
room_set = self._room_typing.setdefault(room_id, set())
if typing:
room_set.add(user_id)
room_set.add(user)
else:
room_set.discard(user_id)
room_set.discard(user)
self._latest_room_serial += 1
self._room_serials[room_id] = self._latest_room_serial
@@ -225,7 +226,9 @@ class TypingHandler(object):
for room_id, serial in self._room_serials.items():
if last_id < serial and serial <= current_id:
typing = self._room_typing[room_id]
typing_bytes = json.dumps(list(typing), ensure_ascii=False)
typing_bytes = json.dumps([
u.to_string() for u in typing
], ensure_ascii=False)
rows.append((serial, room_id, typing_bytes))
rows.sort()
return rows
@@ -235,26 +238,34 @@ class TypingNotificationEventSource(object):
def __init__(self, hs):
self.hs = hs
self.clock = hs.get_clock()
# We can't call get_typing_handler here because there's a cycle:
#
# Typing -> Notifier -> TypingNotificationEventSource -> Typing
#
self.get_typing_handler = hs.get_typing_handler
self._handler = None
self._room_member_handler = None
def handler(self):
# Avoid cyclic dependency in handler setup
if not self._handler:
self._handler = self.hs.get_handlers().typing_notification_handler
return self._handler
def room_member_handler(self):
if not self._room_member_handler:
self._room_member_handler = self.hs.get_handlers().room_member_handler
return self._room_member_handler
def _make_event_for(self, room_id):
typing = self.get_typing_handler()._room_typing[room_id]
typing = self.handler()._room_typing[room_id]
return {
"type": "m.typing",
"room_id": room_id,
"content": {
"user_ids": list(typing),
"user_ids": [u.to_string() for u in typing],
},
}
def get_new_events(self, from_key, room_ids, **kwargs):
with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key)
handler = self.get_typing_handler()
handler = self.handler()
events = []
for room_id in room_ids:
@@ -268,7 +279,7 @@ class TypingNotificationEventSource(object):
return events, handler._latest_room_serial
def get_current_key(self):
return self.get_typing_handler()._latest_room_serial
return self.handler()._latest_room_serial
def get_pagination_rows(self, user, pagination_config, key):
return ([], pagination_config.from_key)

View File

@@ -380,14 +380,13 @@ class CaptchaServerHttpClient(SimpleHttpClient):
class SpiderEndpointFactory(object):
def __init__(self, hs):
self.blacklist = hs.config.url_preview_ip_range_blacklist
self.whitelist = hs.config.url_preview_ip_range_whitelist
self.policyForHTTPS = hs.get_http_client_context_factory()
def endpointForURI(self, uri):
logger.info("Getting endpoint for %s", uri.toBytes())
if uri.scheme == "http":
return SpiderEndpoint(
reactor, uri.host, uri.port, self.blacklist, self.whitelist,
reactor, uri.host, uri.port, self.blacklist,
endpoint=TCP4ClientEndpoint,
endpoint_kw_args={
'timeout': 15
@@ -396,7 +395,7 @@ class SpiderEndpointFactory(object):
elif uri.scheme == "https":
tlsPolicy = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port)
return SpiderEndpoint(
reactor, uri.host, uri.port, self.blacklist, self.whitelist,
reactor, uri.host, uri.port, self.blacklist,
endpoint=SSL4ClientEndpoint,
endpoint_kw_args={
'sslContextFactory': tlsPolicy,

View File

@@ -79,13 +79,12 @@ class SpiderEndpoint(object):
"""An endpoint which refuses to connect to blacklisted IP addresses
Implements twisted.internet.interfaces.IStreamClientEndpoint.
"""
def __init__(self, reactor, host, port, blacklist, whitelist,
def __init__(self, reactor, host, port, blacklist,
endpoint=TCP4ClientEndpoint, endpoint_kw_args={}):
self.reactor = reactor
self.host = host
self.port = port
self.blacklist = blacklist
self.whitelist = whitelist
self.endpoint = endpoint
self.endpoint_kw_args = endpoint_kw_args
@@ -94,13 +93,10 @@ class SpiderEndpoint(object):
address = yield self.reactor.resolve(self.host)
from netaddr import IPAddress
ip_address = IPAddress(address)
if ip_address in self.blacklist:
if self.whitelist is None or ip_address not in self.whitelist:
raise ConnectError(
"Refusing to spider blacklisted IP address %s" % address
)
if IPAddress(address) in self.blacklist:
raise ConnectError(
"Refusing to spider blacklisted IP address %s" % address
)
logger.info("Connecting to %s:%s", address, self.port)
endpoint = self.endpoint(

View File

@@ -74,12 +74,7 @@ response_db_txn_duration = metrics.register_distribution(
_next_request_id = 0
def request_handler(report_metrics=True):
"""Decorator for ``wrap_request_handler``"""
return lambda request_handler: wrap_request_handler(request_handler, report_metrics)
def wrap_request_handler(request_handler, report_metrics):
def request_handler(request_handler):
"""Wraps a method that acts as a request handler with the necessary logging
and exception handling.
@@ -101,12 +96,7 @@ def wrap_request_handler(request_handler, report_metrics):
global _next_request_id
request_id = "%s-%s" % (request.method, _next_request_id)
_next_request_id += 1
with LoggingContext(request_id) as request_context:
if report_metrics:
request_metrics = RequestMetrics()
request_metrics.start(self.clock)
request_context.request = request_id
with request.processing():
try:
@@ -143,14 +133,6 @@ def wrap_request_handler(request_handler, report_metrics):
},
send_cors=True
)
finally:
try:
if report_metrics:
request_metrics.stop(
self.clock, request, self.__class__.__name__
)
except:
pass
return wrapped_request_handler
@@ -215,23 +197,19 @@ class JsonResource(HttpServer, resource.Resource):
self._async_render(request)
return server.NOT_DONE_YET
# Disable metric reporting because _async_render does its own metrics.
# It does its own metric reporting because _async_render dispatches to
# a callback and it's the class name of that callback we want to report
# against rather than the JsonResource itself.
@request_handler(report_metrics=False)
@request_handler
@defer.inlineCallbacks
def _async_render(self, request):
""" This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and
path.
"""
start = self.clock.time_msec()
if request.method == "OPTIONS":
self._send_response(request, 200, {})
return
request_metrics = RequestMetrics()
request_metrics.start(self.clock)
start_context = LoggingContext.current_context()
# Loop through all the registered callbacks to check if the method
# and path regex match
@@ -263,7 +241,40 @@ class JsonResource(HttpServer, resource.Resource):
self._send_response(request, code, response)
try:
request_metrics.stop(self.clock, request, servlet_classname)
context = LoggingContext.current_context()
tag = ""
if context:
tag = context.tag
if context != start_context:
logger.warn(
"Context have unexpectedly changed %r, %r",
context, self.start_context
)
return
incoming_requests_counter.inc(request.method, servlet_classname, tag)
response_timer.inc_by(
self.clock.time_msec() - start, request.method,
servlet_classname, tag
)
ru_utime, ru_stime = context.get_resource_usage()
response_ru_utime.inc_by(
ru_utime, request.method, servlet_classname, tag
)
response_ru_stime.inc_by(
ru_stime, request.method, servlet_classname, tag
)
response_db_txn_count.inc_by(
context.db_txn_count, request.method, servlet_classname, tag
)
response_db_txn_duration.inc_by(
context.db_txn_duration, request.method, servlet_classname, tag
)
except:
pass
@@ -296,48 +307,6 @@ class JsonResource(HttpServer, resource.Resource):
)
class RequestMetrics(object):
def start(self, clock):
self.start = clock.time_msec()
self.start_context = LoggingContext.current_context()
def stop(self, clock, request, servlet_classname):
context = LoggingContext.current_context()
tag = ""
if context:
tag = context.tag
if context != self.start_context:
logger.warn(
"Context have unexpectedly changed %r, %r",
context, self.start_context
)
return
incoming_requests_counter.inc(request.method, servlet_classname, tag)
response_timer.inc_by(
clock.time_msec() - self.start, request.method,
servlet_classname, tag
)
ru_utime, ru_stime = context.get_resource_usage()
response_ru_utime.inc_by(
ru_utime, request.method, servlet_classname, tag
)
response_ru_stime.inc_by(
ru_stime, request.method, servlet_classname, tag
)
response_db_txn_count.inc_by(
context.db_txn_count, request.method, servlet_classname, tag
)
response_db_txn_duration.inc_by(
context.db_txn_duration, request.method, servlet_classname, tag
)
class RootRedirect(resource.Resource):
"""Redirects the root '/' path to another path."""

View File

@@ -1,146 +0,0 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.logcontext import LoggingContext
from twisted.web.server import Site, Request
import contextlib
import logging
import re
import time
ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
class SynapseRequest(Request):
def __init__(self, site, *args, **kw):
Request.__init__(self, *args, **kw)
self.site = site
self.authenticated_entity = None
self.start_time = 0
def __repr__(self):
# We overwrite this so that we don't log ``access_token``
return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
self.__class__.__name__,
id(self),
self.method,
self.get_redacted_uri(),
self.clientproto,
self.site.site_tag,
)
def get_redacted_uri(self):
return ACCESS_TOKEN_RE.sub(
r'\1<redacted>\3',
self.uri
)
def get_user_agent(self):
return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1]
def started_processing(self):
self.site.access_logger.info(
"%s - %s - Received request: %s %s",
self.getClientIP(),
self.site.site_tag,
self.method,
self.get_redacted_uri()
)
self.start_time = int(time.time() * 1000)
def finished_processing(self):
try:
context = LoggingContext.current_context()
ru_utime, ru_stime = context.get_resource_usage()
db_txn_count = context.db_txn_count
db_txn_duration = context.db_txn_duration
except:
ru_utime, ru_stime = (0, 0)
db_txn_count, db_txn_duration = (0, 0)
self.site.access_logger.info(
"%s - %s - {%s}"
" Processed request: %dms (%dms, %dms) (%dms/%d)"
" %sB %s \"%s %s %s\" \"%s\"",
self.getClientIP(),
self.site.site_tag,
self.authenticated_entity,
int(time.time() * 1000) - self.start_time,
int(ru_utime * 1000),
int(ru_stime * 1000),
int(db_txn_duration * 1000),
int(db_txn_count),
self.sentLength,
self.code,
self.method,
self.get_redacted_uri(),
self.clientproto,
self.get_user_agent(),
)
@contextlib.contextmanager
def processing(self):
self.started_processing()
yield
self.finished_processing()
class XForwardedForRequest(SynapseRequest):
def __init__(self, *args, **kw):
SynapseRequest.__init__(self, *args, **kw)
"""
Add a layer on top of another request that only uses the value of an
X-Forwarded-For header as the result of C{getClientIP}.
"""
def getClientIP(self):
"""
@return: The client address (the first address) in the value of the
I{X-Forwarded-For header}. If the header is not present, return
C{b"-"}.
"""
return self.requestHeaders.getRawHeaders(
b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip()
class SynapseRequestFactory(object):
def __init__(self, site, x_forwarded_for):
self.site = site
self.x_forwarded_for = x_forwarded_for
def __call__(self, *args, **kwargs):
if self.x_forwarded_for:
return XForwardedForRequest(self.site, *args, **kwargs)
else:
return SynapseRequest(self.site, *args, **kwargs)
class SynapseSite(Site):
"""
Subclass of a twisted http Site that does access logging with python's
standard logging
"""
def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs):
Site.__init__(self, resource, *args, **kwargs)
self.site_tag = site_tag
proxied = config.get("x_forwarded", False)
self.requestFactory = SynapseRequestFactory(self, proxied)
self.access_logger = logging.getLogger(logger_name)
def log(self, request):
pass

View File

@@ -21,7 +21,6 @@ from synapse.util.logutils import log_function
from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import StreamToken
from synapse.visibility import filter_events_for_client
import synapse.metrics
from collections import namedtuple
@@ -140,6 +139,8 @@ class Notifier(object):
UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
def __init__(self, hs):
self.hs = hs
self.user_to_user_stream = {}
self.room_to_user_streams = {}
self.appservice_to_user_streams = {}
@@ -149,8 +150,6 @@ class Notifier(object):
self.pending_new_room_events = []
self.clock = hs.get_clock()
self.appservice_handler = hs.get_application_service_handler()
self.state_handler = hs.get_state_handler()
hs.get_distributor().observe(
"user_joined_room", self._user_joined_room
@@ -232,7 +231,9 @@ class Notifier(object):
def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"""Notify any user streams that are interested in this room event"""
# poke any interested application service.
self.appservice_handler.notify_interested_services(event)
self.hs.get_handlers().appservice_handler.notify_interested_services(
event
)
app_streams = set()
@@ -397,8 +398,8 @@ class Notifier(object):
)
if name == "room":
new_events = yield filter_events_for_client(
self.store,
room_member_handler = self.hs.get_handlers().room_member_handler
new_events = yield room_member_handler._filter_events_for_client(
user.to_string(),
new_events,
is_peeking=is_peeking,
@@ -447,7 +448,7 @@ class Notifier(object):
@defer.inlineCallbacks
def _is_world_readable(self, room_id):
state = yield self.state_handler.get_current_state(
state = yield self.hs.get_state_handler().get_current_state(
room_id,
EventTypes.RoomHistoryVisibility
)

View File

@@ -37,14 +37,14 @@ class ActionGenerator:
# tag (ie. we just need all the users).
@defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context):
def handle_push_actions_for_event(self, event, context, handler):
with Measure(self.clock, "handle_push_actions_for_event"):
bulk_evaluator = yield evaluator_for_event(
event, self.hs, self.store
)
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
event, context.current_state
event, handler, context.current_state
)
context.push_actions = [

View File

@@ -22,14 +22,12 @@ from .baserules import list_with_base_rules
from .push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.api.constants import EventTypes
from synapse.visibility import filter_events_for_clients
logger = logging.getLogger(__name__)
def decode_rule_json(rule):
rule = dict(rule)
rule['conditions'] = json.loads(rule['conditions'])
rule['actions'] = json.loads(rule['actions'])
return rule
@@ -40,8 +38,6 @@ def _get_rules(room_id, user_ids, store):
rules_by_user = yield store.bulk_get_push_rules(user_ids)
rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids)
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
rules_by_user = {
uid: list_with_base_rules([
decode_rule_json(rule_list)
@@ -54,10 +50,11 @@ def _get_rules(room_id, user_ids, store):
# fetch disabled rules, but this won't account for any server default
# rules the user has disabled, so we need to do this too.
for uid in user_ids:
user_enabled_map = rules_enabled_by_user.get(uid)
if not user_enabled_map:
if uid not in rules_enabled_by_user:
continue
user_enabled_map = rules_enabled_by_user[uid]
for i, rule in enumerate(rules_by_user[uid]):
rule_id = rule['rule_id']
@@ -129,7 +126,7 @@ class BulkPushRuleEvaluator:
self.store = store
@defer.inlineCallbacks
def action_for_event_by_user(self, event, current_state):
def action_for_event_by_user(self, event, handler, current_state):
actions_by_user = {}
# None of these users can be peeking since this list of users comes
@@ -139,8 +136,8 @@ class BulkPushRuleEvaluator:
(u, False) for u in self.rules_by_user.keys()
]
filtered_by_user = yield filter_events_for_clients(
self.store, user_tuples, [event], {event.event_id: current_state}
filtered_by_user = yield handler.filter_events_for_clients(
user_tuples, [event], {event.event_id: current_state}
)
room_members = yield self.store.get_users_in_room(self.room_id)

View File

@@ -1,283 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer, reactor
import logging
from synapse.util.metrics import Measure
from synapse.util.logcontext import LoggingContext
from mailer import Mailer
logger = logging.getLogger(__name__)
# The amount of time we always wait before ever emailing about a notification
# (to give the user a chance to respond to other push or notice the window)
DELAY_BEFORE_MAIL_MS = 10 * 60 * 1000
# THROTTLE is the minimum time between mail notifications sent for a given room.
# Each room maintains its own throttle counter, but each new mail notification
# sends the pending notifications for all rooms.
THROTTLE_START_MS = 10 * 60 * 1000
THROTTLE_MAX_MS = 24 * 60 * 60 * 1000 # 24h
# THROTTLE_MULTIPLIER = 6 # 10 mins, 1 hour, 6 hours, 24 hours
THROTTLE_MULTIPLIER = 144 # 10 mins, 24 hours - i.e. jump straight to 1 day
# If no event triggers a notification for this long after the previous,
# the throttle is released.
# 12 hours - a gap of 12 hours in conversation is surely enough to merit a new
# notification when things get going again...
THROTTLE_RESET_AFTER_MS = (12 * 60 * 60 * 1000)
# does each email include all unread notifs, or just the ones which have happened
# since the last mail?
# XXX: this is currently broken as it includes ones from parted rooms(!)
INCLUDE_ALL_UNREAD_NOTIFS = False
class EmailPusher(object):
"""
A pusher that sends email notifications about events (approximately)
when they happen.
This shares quite a bit of code with httpusher: it would be good to
factor out the common parts
"""
def __init__(self, hs, pusherdict):
self.hs = hs
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.pusher_id = pusherdict['id']
self.user_id = pusherdict['user_name']
self.app_id = pusherdict['app_id']
self.email = pusherdict['pushkey']
self.last_stream_ordering = pusherdict['last_stream_ordering']
self.timed_call = None
self.throttle_params = None
# See httppusher
self.max_stream_ordering = None
self.processing = False
if self.hs.config.email_enable_notifs:
if 'data' in pusherdict and 'brand' in pusherdict['data']:
app_name = pusherdict['data']['brand']
else:
app_name = self.hs.config.email_app_name
self.mailer = Mailer(self.hs, app_name)
else:
self.mailer = None
@defer.inlineCallbacks
def on_started(self):
if self.mailer is not None:
self.throttle_params = yield self.store.get_throttle_params_by_room(
self.pusher_id
)
yield self._process()
def on_stop(self):
if self.timed_call:
self.timed_call.cancel()
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering)
yield self._process()
def on_new_receipts(self, min_stream_id, max_stream_id):
# We could wake up and cancel the timer but there tend to be quite a
# lot of read receipts so it's probably less work to just let the
# timer fire
return defer.succeed(None)
@defer.inlineCallbacks
def on_timer(self):
self.timed_call = None
yield self._process()
@defer.inlineCallbacks
def _process(self):
if self.processing:
return
with LoggingContext("emailpush._process"):
with Measure(self.clock, "emailpush._process"):
try:
self.processing = True
# if the max ordering changes while we're running _unsafe_process,
# call it again, and so on until we've caught up.
while True:
starting_max_ordering = self.max_stream_ordering
try:
yield self._unsafe_process()
except:
logger.exception("Exception processing notifs")
if self.max_stream_ordering == starting_max_ordering:
break
finally:
self.processing = False
@defer.inlineCallbacks
def _unsafe_process(self):
"""
Main logic of the push loop without the wrapper function that sets
up logging, measures and guards against multiple instances of it
being run.
"""
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
self.user_id, start, self.max_stream_ordering
)
soonest_due_at = None
for push_action in unprocessed:
received_at = push_action['received_ts']
if received_at is None:
received_at = 0
notif_ready_at = received_at + DELAY_BEFORE_MAIL_MS
room_ready_at = self.room_ready_to_notify_at(
push_action['room_id']
)
should_notify_at = max(notif_ready_at, room_ready_at)
if should_notify_at < self.clock.time_msec():
# one of our notifications is ready for sending, so we send
# *one* email updating the user on their notifications,
# we then consider all previously outstanding notifications
# to be delivered.
reason = {
'room_id': push_action['room_id'],
'now': self.clock.time_msec(),
'received_at': received_at,
'delay_before_mail_ms': DELAY_BEFORE_MAIL_MS,
'last_sent_ts': self.get_room_last_sent_ts(push_action['room_id']),
'throttle_ms': self.get_room_throttle_ms(push_action['room_id']),
}
yield self.send_notification(unprocessed, reason)
yield self.save_last_stream_ordering_and_success(max([
ea['stream_ordering'] for ea in unprocessed
]))
# we update the throttle on all the possible unprocessed push actions
for ea in unprocessed:
yield self.sent_notif_update_throttle(
ea['room_id'], ea
)
break
else:
if soonest_due_at is None or should_notify_at < soonest_due_at:
soonest_due_at = should_notify_at
if self.timed_call is not None:
self.timed_call.cancel()
self.timed_call = None
if soonest_due_at is not None:
self.timed_call = reactor.callLater(
self.seconds_until(soonest_due_at), self.on_timer
)
@defer.inlineCallbacks
def save_last_stream_ordering_and_success(self, last_stream_ordering):
self.last_stream_ordering = last_stream_ordering
yield self.store.update_pusher_last_stream_ordering_and_success(
self.app_id, self.email, self.user_id,
last_stream_ordering, self.clock.time_msec()
)
def seconds_until(self, ts_msec):
return (ts_msec - self.clock.time_msec()) / 1000
def get_room_throttle_ms(self, room_id):
if room_id in self.throttle_params:
return self.throttle_params[room_id]["throttle_ms"]
else:
return 0
def get_room_last_sent_ts(self, room_id):
if room_id in self.throttle_params:
return self.throttle_params[room_id]["last_sent_ts"]
else:
return 0
def room_ready_to_notify_at(self, room_id):
"""
Determines whether throttling should prevent us from sending an email
for the given room
Returns: The timestamp when we are next allowed to send an email notif
for this room
"""
last_sent_ts = self.get_room_last_sent_ts(room_id)
throttle_ms = self.get_room_throttle_ms(room_id)
may_send_at = last_sent_ts + throttle_ms
return may_send_at
@defer.inlineCallbacks
def sent_notif_update_throttle(self, room_id, notified_push_action):
# We have sent a notification, so update the throttle accordingly.
# If the event that triggered the notif happened more than
# THROTTLE_RESET_AFTER_MS after the previous one that triggered a
# notif, we release the throttle. Otherwise, the throttle is increased.
time_of_previous_notifs = yield self.store.get_time_of_last_push_action_before(
notified_push_action['stream_ordering']
)
time_of_this_notifs = notified_push_action['received_ts']
if time_of_previous_notifs is not None and time_of_this_notifs is not None:
gap = time_of_this_notifs - time_of_previous_notifs
else:
# if we don't know the arrival time of one of the notifs (it was not
# stored prior to email notification code) then assume a gap of
# zero which will just not reset the throttle
gap = 0
current_throttle_ms = self.get_room_throttle_ms(room_id)
if gap > THROTTLE_RESET_AFTER_MS:
new_throttle_ms = THROTTLE_START_MS
else:
if current_throttle_ms == 0:
new_throttle_ms = THROTTLE_START_MS
else:
new_throttle_ms = min(
current_throttle_ms * THROTTLE_MULTIPLIER,
THROTTLE_MAX_MS
)
self.throttle_params[room_id] = {
"last_sent_ts": self.clock.time_msec(),
"throttle_ms": new_throttle_ms
}
yield self.store.set_throttle_params(
self.pusher_id, room_id, self.throttle_params[room_id]
)
@defer.inlineCallbacks
def send_notification(self, push_actions, reason):
logger.info("Sending notif email for user %r", self.user_id)
yield self.mailer.send_notification_mail(
self.app_id, self.user_id, self.email, push_actions, reason
)

View File

@@ -1,507 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from twisted.mail.smtp import sendmail
import email.utils
import email.mime.multipart
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from synapse.util.async import concurrently_execute
from synapse.util.presentable_names import (
calculate_room_name, name_from_member_event, descriptor_from_member_events
)
from synapse.types import UserID
from synapse.api.errors import StoreError
from synapse.api.constants import EventTypes
from synapse.visibility import filter_events_for_client
import jinja2
import bleach
import time
import urllib
import logging
logger = logging.getLogger(__name__)
MESSAGE_FROM_PERSON_IN_ROOM = "You have a message on %(app)s from %(person)s " \
"in the %(room)s room..."
MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..."
MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..."
MESSAGES_IN_ROOM = "You have messages on %(app)s in the %(room)s room..."
MESSAGES_IN_ROOM_AND_OTHERS = \
"You have messages on %(app)s in the %(room)s room and others..."
MESSAGES_FROM_PERSON_AND_OTHERS = \
"You have messages on %(app)s from %(person)s and others..."
INVITE_FROM_PERSON_TO_ROOM = "%(person)s has invited you to join the " \
"%(room)s room on %(app)s..."
INVITE_FROM_PERSON = "%(person)s has invited you to chat on %(app)s..."
CONTEXT_BEFORE = 1
CONTEXT_AFTER = 1
# From https://github.com/matrix-org/matrix-react-sdk/blob/master/src/HtmlUtils.js
ALLOWED_TAGS = [
'font', # custom to matrix for IRC-style font coloring
'del', # for markdown
# deliberately no h1/h2 to stop people shouting.
'h3', 'h4', 'h5', 'h6', 'blockquote', 'p', 'a', 'ul', 'ol',
'nl', 'li', 'b', 'i', 'u', 'strong', 'em', 'strike', 'code', 'hr', 'br', 'div',
'table', 'thead', 'caption', 'tbody', 'tr', 'th', 'td', 'pre'
]
ALLOWED_ATTRS = {
# custom ones first:
"font": ["color"], # custom to matrix
"a": ["href", "name", "target"], # remote target: custom to matrix
# We don't currently allow img itself by default, but this
# would make sense if we did
"img": ["src"],
}
# When bleach release a version with this option, we can specify schemes
# ALLOWED_SCHEMES = ["http", "https", "ftp", "mailto"]
class Mailer(object):
def __init__(self, hs, app_name):
self.hs = hs
self.store = self.hs.get_datastore()
self.auth_handler = self.hs.get_auth_handler()
self.state_handler = self.hs.get_state_handler()
loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir)
self.app_name = app_name
logger.info("Created Mailer for app_name %s" % app_name)
env = jinja2.Environment(loader=loader)
env.filters["format_ts"] = format_ts_filter
env.filters["mxc_to_http"] = self.mxc_to_http_filter
self.notif_template_html = env.get_template(
self.hs.config.email_notif_template_html
)
self.notif_template_text = env.get_template(
self.hs.config.email_notif_template_text
)
@defer.inlineCallbacks
def send_notification_mail(self, app_id, user_id, email_address,
push_actions, reason):
raw_from = email.utils.parseaddr(self.hs.config.email_notif_from)[1]
raw_to = email.utils.parseaddr(email_address)[1]
if raw_to == '':
raise RuntimeError("Invalid 'to' address")
rooms_in_order = deduped_ordered_list(
[pa['room_id'] for pa in push_actions]
)
notif_events = yield self.store.get_events(
[pa['event_id'] for pa in push_actions]
)
notifs_by_room = {}
for pa in push_actions:
notifs_by_room.setdefault(pa["room_id"], []).append(pa)
# collect the current state for all the rooms in which we have
# notifications
state_by_room = {}
try:
user_display_name = yield self.store.get_profile_displayname(
UserID.from_string(user_id).localpart
)
if user_display_name is None:
user_display_name = user_id
except StoreError:
user_display_name = user_id
@defer.inlineCallbacks
def _fetch_room_state(room_id):
room_state = yield self.state_handler.get_current_state(room_id)
state_by_room[room_id] = room_state
# Run at most 3 of these at once: sync does 10 at a time but email
# notifs are much less realtime than sync so we can afford to wait a bit.
yield concurrently_execute(_fetch_room_state, rooms_in_order, 3)
# actually sort our so-called rooms_in_order list, most recent room first
rooms_in_order.sort(
key=lambda r: -(notifs_by_room[r][-1]['received_ts'] or 0)
)
rooms = []
for r in rooms_in_order:
roomvars = yield self.get_room_vars(
r, user_id, notifs_by_room[r], notif_events, state_by_room[r]
)
rooms.append(roomvars)
reason['room_name'] = calculate_room_name(
state_by_room[reason['room_id']], user_id, fallback_to_members=True
)
summary_text = self.make_summary_text(
notifs_by_room, state_by_room, notif_events, user_id, reason
)
template_vars = {
"user_display_name": user_display_name,
"unsubscribe_link": self.make_unsubscribe_link(
user_id, app_id, email_address
),
"summary_text": summary_text,
"app_name": self.app_name,
"rooms": rooms,
"reason": reason,
}
html_text = self.notif_template_html.render(**template_vars)
html_part = MIMEText(html_text, "html", "utf8")
plain_text = self.notif_template_text.render(**template_vars)
text_part = MIMEText(plain_text, "plain", "utf8")
multipart_msg = MIMEMultipart('alternative')
multipart_msg['Subject'] = "[%s] %s" % (self.app_name, summary_text)
multipart_msg['From'] = self.hs.config.email_notif_from
multipart_msg['To'] = email_address
multipart_msg['Date'] = email.utils.formatdate()
multipart_msg['Message-ID'] = email.utils.make_msgid()
multipart_msg.attach(text_part)
multipart_msg.attach(html_part)
logger.info("Sending email push notification to %s" % email_address)
# logger.debug(html_text)
yield sendmail(
self.hs.config.email_smtp_host,
raw_from, raw_to, multipart_msg.as_string(),
port=self.hs.config.email_smtp_port
)
@defer.inlineCallbacks
def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state):
my_member_event = room_state[("m.room.member", user_id)]
is_invite = my_member_event.content["membership"] == "invite"
room_vars = {
"title": calculate_room_name(room_state, user_id),
"hash": string_ordinal_total(room_id), # See sender avatar hash
"notifs": [],
"invite": is_invite,
"link": self.make_room_link(room_id),
}
if not is_invite:
for n in notifs:
notifvars = yield self.get_notif_vars(
n, user_id, notif_events[n['event_id']], room_state
)
# merge overlapping notifs together.
# relies on the notifs being in chronological order.
merge = False
if room_vars['notifs'] and 'messages' in room_vars['notifs'][-1]:
prev_messages = room_vars['notifs'][-1]['messages']
for message in notifvars['messages']:
pm = filter(lambda pm: pm['id'] == message['id'], prev_messages)
if pm:
if not message["is_historical"]:
pm[0]["is_historical"] = False
merge = True
elif merge:
# we're merging, so append any remaining messages
# in this notif to the previous one
prev_messages.append(message)
if not merge:
room_vars['notifs'].append(notifvars)
defer.returnValue(room_vars)
@defer.inlineCallbacks
def get_notif_vars(self, notif, user_id, notif_event, room_state):
results = yield self.store.get_events_around(
notif['room_id'], notif['event_id'],
before_limit=CONTEXT_BEFORE, after_limit=CONTEXT_AFTER
)
ret = {
"link": self.make_notif_link(notif),
"ts": notif['received_ts'],
"messages": [],
}
the_events = yield filter_events_for_client(
self.store, user_id, results["events_before"]
)
the_events.append(notif_event)
for event in the_events:
messagevars = self.get_message_vars(notif, event, room_state)
if messagevars is not None:
ret['messages'].append(messagevars)
defer.returnValue(ret)
def get_message_vars(self, notif, event, room_state):
if event.type != EventTypes.Message:
return None
sender_state_event = room_state[("m.room.member", event.sender)]
sender_name = name_from_member_event(sender_state_event)
sender_avatar_url = None
if "avatar_url" in sender_state_event.content:
sender_avatar_url = sender_state_event.content["avatar_url"]
# 'hash' for deterministically picking default images: use
# sender_hash % the number of default images to choose from
sender_hash = string_ordinal_total(event.sender)
ret = {
"msgtype": event.content["msgtype"],
"is_historical": event.event_id != notif['event_id'],
"id": event.event_id,
"ts": event.origin_server_ts,
"sender_name": sender_name,
"sender_avatar_url": sender_avatar_url,
"sender_hash": sender_hash,
}
if event.content["msgtype"] == "m.text":
self.add_text_message_vars(ret, event)
elif event.content["msgtype"] == "m.image":
self.add_image_message_vars(ret, event)
if "body" in event.content:
ret["body_text_plain"] = event.content["body"]
return ret
def add_text_message_vars(self, messagevars, event):
if "format" in event.content:
msgformat = event.content["format"]
else:
msgformat = None
messagevars["format"] = msgformat
if msgformat == "org.matrix.custom.html":
messagevars["body_text_html"] = safe_markup(event.content["formatted_body"])
else:
messagevars["body_text_html"] = safe_text(event.content["body"])
return messagevars
def add_image_message_vars(self, messagevars, event):
messagevars["image_url"] = event.content["url"]
return messagevars
def make_summary_text(self, notifs_by_room, state_by_room,
notif_events, user_id, reason):
if len(notifs_by_room) == 1:
# Only one room has new stuff
room_id = notifs_by_room.keys()[0]
# If the room has some kind of name, use it, but we don't
# want the generated-from-names one here otherwise we'll
# end up with, "new message from Bob in the Bob room"
room_name = calculate_room_name(
state_by_room[room_id], user_id, fallback_to_members=False
)
my_member_event = state_by_room[room_id][("m.room.member", user_id)]
if my_member_event.content["membership"] == "invite":
inviter_member_event = state_by_room[room_id][
("m.room.member", my_member_event.sender)
]
inviter_name = name_from_member_event(inviter_member_event)
if room_name is None:
return INVITE_FROM_PERSON % {
"person": inviter_name,
"app": self.app_name
}
else:
return INVITE_FROM_PERSON_TO_ROOM % {
"person": inviter_name,
"room": room_name,
"app": self.app_name,
}
sender_name = None
if len(notifs_by_room[room_id]) == 1:
# There is just the one notification, so give some detail
event = notif_events[notifs_by_room[room_id][0]["event_id"]]
if ("m.room.member", event.sender) in state_by_room[room_id]:
state_event = state_by_room[room_id][("m.room.member", event.sender)]
sender_name = name_from_member_event(state_event)
if sender_name is not None and room_name is not None:
return MESSAGE_FROM_PERSON_IN_ROOM % {
"person": sender_name,
"room": room_name,
"app": self.app_name,
}
elif sender_name is not None:
return MESSAGE_FROM_PERSON % {
"person": sender_name,
"app": self.app_name,
}
else:
# There's more than one notification for this room, so just
# say there are several
if room_name is not None:
return MESSAGES_IN_ROOM % {
"room": room_name,
"app": self.app_name,
}
else:
# If the room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room"
sender_ids = list(set([
notif_events[n['event_id']].sender
for n in notifs_by_room[room_id]
]))
return MESSAGES_FROM_PERSON % {
"person": descriptor_from_member_events([
state_by_room[room_id][("m.room.member", s)]
for s in sender_ids
]),
"app": self.app_name,
}
else:
# Stuff's happened in multiple different rooms
# ...but we still refer to the 'reason' room which triggered the mail
if reason['room_name'] is not None:
return MESSAGES_IN_ROOM_AND_OTHERS % {
"room": reason['room_name'],
"app": self.app_name,
}
else:
# If the reason room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room"
sender_ids = list(set([
notif_events[n['event_id']].sender
for n in notifs_by_room[reason['room_id']]
]))
return MESSAGES_FROM_PERSON_AND_OTHERS % {
"person": descriptor_from_member_events([
state_by_room[reason['room_id']][("m.room.member", s)]
for s in sender_ids
]),
"app": self.app_name,
}
def make_room_link(self, room_id):
# need /beta for Universal Links to work on iOS
if self.app_name == "Vector":
return "https://vector.im/beta/#/room/%s" % (room_id,)
else:
return "https://matrix.to/#/%s" % (room_id,)
def make_notif_link(self, notif):
# need /beta for Universal Links to work on iOS
if self.app_name == "Vector":
return "https://vector.im/beta/#/room/%s/%s" % (
notif['room_id'], notif['event_id']
)
else:
return "https://matrix.to/#/%s/%s" % (
notif['room_id'], notif['event_id']
)
def make_unsubscribe_link(self, user_id, app_id, email_address):
params = {
"access_token": self.auth_handler.generate_delete_pusher_token(user_id),
"app_id": app_id,
"pushkey": email_address,
}
# XXX: make r0 once API is stable
return "%s_matrix/client/unstable/pushers/remove?%s" % (
self.hs.config.public_baseurl,
urllib.urlencode(params),
)
def mxc_to_http_filter(self, value, width, height, resize_method="crop"):
if value[0:6] != "mxc://":
return ""
serverAndMediaId = value[6:]
fragment = None
if '#' in serverAndMediaId:
(serverAndMediaId, fragment) = serverAndMediaId.split('#', 1)
fragment = "#" + fragment
params = {
"width": width,
"height": height,
"method": resize_method,
}
return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
self.hs.config.public_baseurl,
serverAndMediaId,
urllib.urlencode(params),
fragment or "",
)
def safe_markup(raw_html):
return jinja2.Markup(bleach.linkify(bleach.clean(
raw_html, tags=ALLOWED_TAGS, attributes=ALLOWED_ATTRS,
# bleach master has this, but it isn't released yet
# protocols=ALLOWED_SCHEMES,
strip=True
)))
def safe_text(raw_text):
"""
Process text: treat it as HTML but escape any tags (ie. just escape the
HTML) then linkify it.
"""
return jinja2.Markup(bleach.linkify(bleach.clean(
raw_text, tags=[], attributes={},
strip=False
)))
def deduped_ordered_list(l):
seen = set()
ret = []
for item in l:
if item not in seen:
seen.add(item)
ret.append(item)
return ret
def string_ordinal_total(s):
tot = 0
for c in s:
tot += ord(c)
return tot
def format_ts_filter(value, format):
return time.strftime(format, time.localtime(value / 1000))

View File

@@ -38,9 +38,7 @@ def get_badge_count(store, user_id):
r.room_id, user_id, last_unread_event_id
)
)
# return one badge count per conversation, as count per
# message is so noisy as to be almost useless
badge += 1 if notifs["notify_count"] else 0
badge += notifs["notify_count"]
defer.returnValue(badge)

View File

@@ -1,47 +1,10 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from httppusher import HttpPusher
import logging
logger = logging.getLogger(__name__)
# We try importing this if we can (it will fail if we don't
# have the optional email dependencies installed). We don't
# yet have the config to know if we need the email pusher,
# but importing this after daemonizing seems to fail
# (even though a simple test of importing from a daemonized
# process works fine)
try:
from synapse.push.emailpusher import EmailPusher
except:
pass
PUSHER_TYPES = {
'http': HttpPusher
}
def create_pusher(hs, pusherdict):
logger.info("trying to create_pusher for %r", pusherdict)
PUSHER_TYPES = {
"http": HttpPusher,
}
logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
if hs.config.email_enable_notifs:
PUSHER_TYPES["email"] = EmailPusher
logger.info("defined email pusher type")
if pusherdict['kind'] in PUSHER_TYPES:
logger.info("found pusher")
return PUSHER_TYPES[pusherdict['kind']](hs, pusherdict)

View File

@@ -17,6 +17,7 @@
from twisted.internet import defer
import pusher
from synapse.push import PusherConfigException
from synapse.util.logcontext import preserve_fn
from synapse.util.async import run_on_reactor
@@ -49,7 +50,6 @@ class PusherPool:
# recreated, added and started: this means we have only one
# code path adding pushers.
pusher.create_pusher(self.hs, {
"id": None,
"user_name": user_id,
"kind": kind,
"app_id": app_id,
@@ -185,8 +185,8 @@ class PusherPool:
for pusherdict in pushers:
try:
p = pusher.create_pusher(self.hs, pusherdict)
except:
logger.exception("Couldn't start a pusher: caught Exception")
except PusherConfigException:
logger.exception("Couldn't start a pusher: caught PusherConfigException")
continue
if p:
appid_pushkey = "%s:%s" % (

View File

@@ -36,6 +36,7 @@ REQUIREMENTS = {
"blist": ["blist"],
"pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
"pymacaroons-pynacl": ["pymacaroons"],
"pyjwt": ["jwt"],
}
CONDITIONAL_REQUIREMENTS = {
"web_client": {
@@ -44,10 +45,6 @@ CONDITIONAL_REQUIREMENTS = {
"preview_url": {
"netaddr>=0.7.18": ["netaddr"],
},
"email.enable_notifs": {
"Jinja2>=2.8": ["Jinja2>=2.8"],
"bleach>=1.4.2": ["bleach>=1.4.2"],
},
}

View File

@@ -1,59 +0,0 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.server import respond_with_json_bytes, request_handler
from synapse.http.servlet import parse_json_object_from_request
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
class PresenceResource(Resource):
"""
HTTP endpoint for marking users as syncing.
POST /_synapse/replication/presence HTTP/1.1
Content-Type: application/json
{
"process_id": "<process_id>",
"syncing_users": ["<user_id>"]
}
"""
def __init__(self, hs):
Resource.__init__(self) # Resource is old-style, so no super()
self.version_string = hs.version_string
self.clock = hs.get_clock()
self.presence_handler = hs.get_presence_handler()
def render_POST(self, request):
self._async_render_POST(request)
return NOT_DONE_YET
@request_handler()
@defer.inlineCallbacks
def _async_render_POST(self, request):
content = parse_json_object_from_request(request)
process_id = content["process_id"]
syncing_user_ids = content["syncing_users"]
yield self.presence_handler.update_external_syncs(
process_id, set(syncing_user_ids)
)
respond_with_json_bytes(request, 200, "{}")

View File

@@ -31,13 +31,12 @@ class PusherResource(Resource):
self.version_string = hs.version_string
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.clock = hs.get_clock()
def render_POST(self, request):
self._async_render_POST(request)
return NOT_DONE_YET
@request_handler()
@request_handler
@defer.inlineCallbacks
def _async_render_POST(self, request):
content = parse_json_object_from_request(request)

View File

@@ -16,7 +16,6 @@
from synapse.http.servlet import parse_integer, parse_string
from synapse.http.server import request_handler, finish_request
from synapse.replication.pusher_resource import PusherResource
from synapse.replication.presence_resource import PresenceResource
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
@@ -110,13 +109,11 @@ class ReplicationResource(Resource):
self.version_string = hs.version_string
self.store = hs.get_datastore()
self.sources = hs.get_event_sources()
self.presence_handler = hs.get_presence_handler()
self.typing_handler = hs.get_typing_handler()
self.presence_handler = hs.get_handlers().presence_handler
self.typing_handler = hs.get_handlers().typing_notification_handler
self.notifier = hs.notifier
self.clock = hs.get_clock()
self.putChild("remove_pushers", PusherResource(hs))
self.putChild("syncing_users", PresenceResource(hs))
def render_GET(self, request):
self._async_render_GET(request)
@@ -142,7 +139,7 @@ class ReplicationResource(Resource):
state_token,
))
@request_handler()
@request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
limit = parse_integer(request, "limit", 100)
@@ -161,15 +158,6 @@ class ReplicationResource(Resource):
result = yield self.notifier.wait_for_replication(replicate, timeout)
for stream_name, stream_content in result.items():
logger.info(
"Replicating %d rows of %s from %s -> %s",
len(stream_content["rows"]),
stream_name,
request_streams.get(stream_name),
stream_content["position"],
)
request.write(json.dumps(result, ensure_ascii=False))
finish_request(request)
@@ -394,7 +382,7 @@ class _Writer(object):
position = rows[-1][0]
self.streams[name] = {
"position": position if type(position) is int else str(position),
"position": str(position),
"field_names": fields,
"rows": rows,
}

View File

@@ -1,100 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.storage.account_data import AccountDataStore
from synapse.storage.tags import TagsStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedAccountDataStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedAccountDataStore, self).__init__(db_conn, hs)
self._account_data_id_gen = SlavedIdTracker(
db_conn, "account_data_max_stream_id", "stream_id",
)
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache",
self._account_data_id_gen.get_current_token(),
)
get_account_data_for_user = (
AccountDataStore.__dict__["get_account_data_for_user"]
)
get_global_account_data_by_type_for_users = (
AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
)
get_global_account_data_by_type_for_user = (
AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
)
get_tags_for_user = TagsStore.__dict__["get_tags_for_user"]
get_updated_tags = DataStore.get_updated_tags.__func__
get_updated_account_data_for_user = (
DataStore.get_updated_account_data_for_user.__func__
)
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedAccountDataStore, self).stream_positions()
position = self._account_data_id_gen.get_current_token()
result["user_account_data"] = position
result["room_account_data"] = position
result["tag_account_data"] = position
return result
def process_replication(self, result):
stream = result.get("user_account_data")
if stream:
self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id, data_type = row[:3]
self.get_global_account_data_by_type_for_user.invalidate(
(data_type, user_id,)
)
self.get_account_data_for_user.invalidate((user_id,))
self._account_data_stream_cache.entity_has_changed(
user_id, position
)
stream = result.get("room_account_data")
if stream:
self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id = row[:2]
self.get_account_data_for_user.invalidate((user_id,))
self._account_data_stream_cache.entity_has_changed(
user_id, position
)
stream = result.get("tag_account_data")
if stream:
self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id = row[:2]
self.get_tags_for_user.invalidate((user_id,))
self._account_data_stream_cache.entity_has_changed(
user_id, position
)
return super(SlavedAccountDataStore, self).process_replication(result)

View File

@@ -1,30 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.config.appservice import load_appservices
class SlavedApplicationServiceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedApplicationServiceStore, self).__init__(db_conn, hs)
self.services_cache = load_appservices(
hs.config.server_name,
hs.config.app_service_config_files
)
get_app_service_by_token = DataStore.get_app_service_by_token.__func__
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__

View File

@@ -23,7 +23,6 @@ from synapse.storage.roommember import RoomMemberStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.event_push_actions import EventPushActionsStore
from synapse.storage.state import StateStore
from synapse.storage.stream import StreamStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
import ujson as json
@@ -58,9 +57,6 @@ class SlavedEventStore(BaseSlavedStore):
"EventsRoomStreamChangeCache", min_event_val,
prefilled_cache=event_cache_prefill,
)
self._membership_stream_cache = StreamChangeCache(
"MembershipStreamChangeCache", events_max,
)
# Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them.
@@ -79,21 +75,6 @@ class SlavedEventStore(BaseSlavedStore):
get_unread_event_push_actions_by_room_for_user = (
EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"]
)
_get_state_group_for_events = (
StateStore.__dict__["_get_state_group_for_events"]
)
_get_state_group_for_event = (
StateStore.__dict__["_get_state_group_for_event"]
)
_get_state_groups_from_groups = (
StateStore.__dict__["_get_state_groups_from_groups"]
)
_get_state_group_from_group = (
StateStore.__dict__["_get_state_group_from_group"]
)
get_recent_event_ids_for_room = (
StreamStore.__dict__["get_recent_event_ids_for_room"]
)
get_unread_push_actions_for_user_in_range = (
DataStore.get_unread_push_actions_for_user_in_range.__func__
@@ -102,7 +83,6 @@ class SlavedEventStore(BaseSlavedStore):
DataStore.get_push_action_users_in_range.__func__
)
get_event = DataStore.get_event.__func__
get_events = DataStore.get_events.__func__
get_current_state = DataStore.get_current_state.__func__
get_current_state_for_key = DataStore.get_current_state_for_key.__func__
get_rooms_for_user_where_membership_is = (
@@ -115,17 +95,8 @@ class SlavedEventStore(BaseSlavedStore):
get_room_events_stream_for_room = (
DataStore.get_room_events_stream_for_room.__func__
)
get_events_around = DataStore.get_events_around.__func__
get_state_for_event = DataStore.get_state_for_event.__func__
get_state_for_events = DataStore.get_state_for_events.__func__
get_state_groups = DataStore.get_state_groups.__func__
get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
get_room_events_stream_for_rooms = (
DataStore.get_room_events_stream_for_rooms.__func__
)
get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
_set_before_and_after = staticmethod(DataStore._set_before_and_after)
_set_before_and_after = DataStore._set_before_and_after
_get_events = DataStore._get_events.__func__
_get_events_from_cache = DataStore._get_events_from_cache.__func__
@@ -133,7 +104,6 @@ class SlavedEventStore(BaseSlavedStore):
_invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__
_parse_events_txn = DataStore._parse_events_txn.__func__
_get_events_txn = DataStore._get_events_txn.__func__
_get_event_txn = DataStore._get_event_txn.__func__
_enqueue_events = DataStore._enqueue_events.__func__
_do_fetch = DataStore._do_fetch.__func__
_fetch_events_txn = DataStore._fetch_events_txn.__func__
@@ -144,15 +114,11 @@ class SlavedEventStore(BaseSlavedStore):
DataStore._get_rooms_for_user_where_membership_is_txn.__func__
)
_get_members_rows_txn = DataStore._get_members_rows_txn.__func__
_get_state_for_groups = DataStore._get_state_for_groups.__func__
_get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__
_get_events_around_txn = DataStore._get_events_around_txn.__func__
_get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__
def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token()
result["backfill"] = -self._backfill_id_gen.get_current_token()
result["backfill"] = self._backfill_id_gen.get_current_token()
return result
def process_replication(self, result):
@@ -162,7 +128,7 @@ class SlavedEventStore(BaseSlavedStore):
stream = result.get("events")
if stream:
self._stream_id_gen.advance(int(stream["position"]))
self._stream_id_gen.advance(stream["position"])
for row in stream["rows"]:
self._process_replication_row(
row, backfilled=False, state_resets=state_resets
@@ -170,7 +136,7 @@ class SlavedEventStore(BaseSlavedStore):
stream = result.get("backfill")
if stream:
self._backfill_id_gen.advance(-int(stream["position"]))
self._backfill_id_gen.advance(stream["position"])
for row in stream["rows"]:
self._process_replication_row(
row, backfilled=True, state_resets=state_resets
@@ -178,14 +144,12 @@ class SlavedEventStore(BaseSlavedStore):
stream = result.get("forward_ex_outliers")
if stream:
self._stream_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
event_id = row[1]
self._invalidate_get_event_cache(event_id)
stream = result.get("backward_ex_outliers")
if stream:
self._backfill_id_gen.advance(-int(stream["position"]))
for row in stream["rows"]:
event_id = row[1]
self._invalidate_get_event_cache(event_id)
@@ -233,9 +197,9 @@ class SlavedEventStore(BaseSlavedStore):
self.get_rooms_for_user.invalidate((event.state_key,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
self.get_users_in_room.invalidate((event.room_id,))
self._membership_stream_cache.entity_has_changed(
event.state_key, event.internal_metadata.stream_ordering
)
# self._membership_stream_cache.entity_has_changed(
# event.state_key, event.internal_metadata.stream_ordering
# )
self.get_invited_rooms_for_user.invalidate((event.state_key,))
if not event.is_state():

View File

@@ -1,24 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage.filtering import FilteringStore
class SlavedFilteringStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedFilteringStore, self).__init__(db_conn, hs)
get_user_filter = FilteringStore.__dict__["get_user_filter"]

View File

@@ -1,59 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.storage import DataStore
class SlavedPresenceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedPresenceStore, self).__init__(db_conn, hs)
self._presence_id_gen = SlavedIdTracker(
db_conn, "presence_stream", "stream_id",
)
self._presence_on_startup = self._get_active_presence(db_conn)
self.presence_stream_cache = self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
)
_get_active_presence = DataStore._get_active_presence.__func__
take_presence_startup_info = DataStore.take_presence_startup_info.__func__
get_presence_for_users = DataStore.get_presence_for_users.__func__
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedPresenceStore, self).stream_positions()
position = self._presence_id_gen.get_current_token()
result["presence"] = position
return result
def process_replication(self, result):
stream = result.get("presence")
if stream:
self._presence_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id = row[:2]
self.presence_stream_cache.entity_has_changed(
user_id, position
)
return super(SlavedPresenceStore, self).process_replication(result)

View File

@@ -1,67 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .events import SlavedEventStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.storage.push_rule import PushRuleStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedPushRuleStore(SlavedEventStore):
def __init__(self, db_conn, hs):
super(SlavedPushRuleStore, self).__init__(db_conn, hs)
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id",
)
self.push_rules_stream_cache = StreamChangeCache(
"PushRulesStreamChangeCache",
self._push_rules_stream_id_gen.get_current_token(),
)
get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
get_push_rules_enabled_for_user = (
PushRuleStore.__dict__["get_push_rules_enabled_for_user"]
)
have_push_rules_changed_for_user = (
DataStore.have_push_rules_changed_for_user.__func__
)
def get_push_rules_stream_token(self):
return (
self._push_rules_stream_id_gen.get_current_token(),
self._stream_id_gen.get_current_token(),
)
def stream_positions(self):
result = super(SlavedPushRuleStore, self).stream_positions()
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("push_rules")
if stream:
for row in stream["rows"]:
position = row[0]
user_id = row[2]
self.get_push_rules_for_user.invalidate((user_id,))
self.get_push_rules_enabled_for_user.invalidate((user_id,))
self.push_rules_stream_cache.entity_has_changed(
user_id, position
)
self._push_rules_stream_id_gen.advance(int(stream["position"]))
return super(SlavedPushRuleStore, self).process_replication(result)

View File

@@ -43,10 +43,10 @@ class SlavedPusherStore(BaseSlavedStore):
def process_replication(self, result):
stream = result.get("pushers")
if stream:
self._pushers_id_gen.advance(int(stream["position"]))
self._pushers_id_gen.advance(stream["position"])
stream = result.get("deleted_pushers")
if stream:
self._pushers_id_gen.advance(int(stream["position"]))
self._pushers_id_gen.advance(stream["position"])
return super(SlavedPusherStore, self).process_replication(result)

View File

@@ -18,7 +18,6 @@ from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.storage.receipts import ReceiptsStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
# So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the
@@ -38,28 +37,11 @@ class SlavedReceiptsStore(BaseSlavedStore):
db_conn, "receipts_linearized", "stream_id"
)
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
)
get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
get_linearized_receipts_for_room = (
ReceiptsStore.__dict__["get_linearized_receipts_for_room"]
)
_get_linearized_receipts_for_rooms = (
ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"]
)
get_last_receipt_event_id_for_user = (
ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"]
)
get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
get_linearized_receipts_for_rooms = (
DataStore.get_linearized_receipts_for_rooms.__func__
)
def stream_positions(self):
result = super(SlavedReceiptsStore, self).stream_positions()
result["receipts"] = self._receipts_id_gen.get_current_token()
@@ -68,17 +50,12 @@ class SlavedReceiptsStore(BaseSlavedStore):
def process_replication(self, result):
stream = result.get("receipts")
if stream:
self._receipts_id_gen.advance(int(stream["position"]))
self._receipts_id_gen.advance(stream["position"])
for row in stream["rows"]:
position, room_id, receipt_type, user_id = row[:4]
room_id, receipt_type, user_id = row[1:4]
self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
self._receipts_stream_cache.entity_has_changed(room_id, position)
return super(SlavedReceiptsStore, self).process_replication(result)
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self.get_linearized_receipts_for_room.invalidate_many((room_id,))
self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type)
)

View File

@@ -1,30 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.storage.registration import RegistrationStore
class SlavedRegistrationStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedRegistrationStore, self).__init__(db_conn, hs)
# TODO: use the cached version and invalidate deleted tokens
get_user_by_access_token = RegistrationStore.__dict__[
"get_user_by_access_token"
].orig
_query_for_auth = DataStore._query_for_auth.__func__

View File

@@ -44,8 +44,6 @@ from synapse.rest.client.v2_alpha import (
tokenrefresh,
tags,
account_data,
report_event,
openid,
)
from synapse.http.server import JsonResource
@@ -88,5 +86,3 @@ class ClientRestResource(JsonResource):
tokenrefresh.register_servlets(hs, client_resource)
tags.register_servlets(hs, client_resource)
account_data.register_servlets(hs, client_resource)
report_event.register_servlets(hs, client_resource)
openid.register_servlets(hs, client_resource)

View File

@@ -33,6 +33,9 @@ from saml2.client import Saml2Client
import xml.etree.ElementTree as ET
import jwt
from jwt.exceptions import InvalidTokenError
logger = logging.getLogger(__name__)
@@ -58,7 +61,6 @@ class LoginRestServlet(ClientV1RestServlet):
self.cas_required_attributes = hs.config.cas_required_attributes
self.servername = hs.config.server_name
self.http_client = hs.get_simple_http_client()
self.auth_handler = self.hs.get_auth_handler()
def on_GET(self, request):
flows = []
@@ -144,7 +146,7 @@ class LoginRestServlet(ClientV1RestServlet):
user_id, self.hs.hostname
).to_string()
auth_handler = self.auth_handler
auth_handler = self.handlers.auth_handler
user_id, access_token, refresh_token = yield auth_handler.login_with_password(
user_id=user_id,
password=login_submission["password"])
@@ -161,7 +163,7 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def do_token_login(self, login_submission):
token = login_submission['token']
auth_handler = self.auth_handler
auth_handler = self.handlers.auth_handler
user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
@@ -195,7 +197,7 @@ class LoginRestServlet(ClientV1RestServlet):
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
auth_handler = self.handlers.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = (
@@ -222,29 +224,21 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def do_jwt_login(self, login_submission):
token = login_submission.get("token", None)
token = login_submission['token']
if token is None:
raise LoginError(
401, "Token field for JWT is missing",
errcode=Codes.UNAUTHORIZED
)
import jwt
from jwt.exceptions import InvalidTokenError
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
try:
payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
except jwt.ExpiredSignatureError:
raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
except InvalidTokenError:
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user = payload.get("sub", None)
user = payload['user']
if user is None:
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
auth_handler = self.handlers.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = (
@@ -413,7 +407,7 @@ class CasTicketServlet(ClientV1RestServlet):
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
auth_handler = self.handlers.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if not user_exists:
user_id, _ = (

View File

@@ -30,24 +30,20 @@ logger = logging.getLogger(__name__)
class PresenceStatusRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status")
def __init__(self, hs):
super(PresenceStatusRestServlet, self).__init__(hs)
self.presence_handler = hs.get_presence_handler()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
if requester.user != user:
allowed = yield self.presence_handler.is_visible(
allowed = yield self.handlers.presence_handler.is_visible(
observed_user=user, observer_user=requester.user,
)
if not allowed:
raise AuthError(403, "You are not allowed to see their presence.")
state = yield self.presence_handler.get_state(target_user=user)
state = yield self.handlers.presence_handler.get_state(target_user=user)
defer.returnValue((200, state))
@@ -78,7 +74,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
except:
raise SynapseError(400, "Unable to parse state")
yield self.presence_handler.set_state(user, state)
yield self.handlers.presence_handler.set_state(user, state)
defer.returnValue((200, {}))
@@ -89,10 +85,6 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
class PresenceListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/presence/list/(?P<user_id>[^/]*)")
def __init__(self, hs):
super(PresenceListRestServlet, self).__init__(hs)
self.presence_handler = hs.get_presence_handler()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
requester = yield self.auth.get_user_by_req(request)
@@ -104,7 +96,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
if requester.user != user:
raise SynapseError(400, "Cannot get another user's presence list")
presence = yield self.presence_handler.get_presence_list(
presence = yield self.handlers.presence_handler.get_presence_list(
observer_user=user, accepted=True
)
@@ -131,7 +123,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
if len(u) == 0:
continue
invited_user = UserID.from_string(u)
yield self.presence_handler.send_presence_invite(
yield self.handlers.presence_handler.send_presence_invite(
observer_user=user, observed_user=invited_user
)
@@ -142,7 +134,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
if len(u) == 0:
continue
dropped_user = UserID.from_string(u)
yield self.presence_handler.drop(
yield self.handlers.presence_handler.drop(
observer_user=user, observed_user=dropped_user
)

View File

@@ -17,11 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes
from synapse.push import PusherConfigException
from synapse.http.servlet import (
parse_json_object_from_request, parse_string, RestServlet
)
from synapse.http.server import finish_request
from synapse.api.errors import StoreError
from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns
@@ -140,57 +136,6 @@ class PushersSetRestServlet(ClientV1RestServlet):
return 200, {}
class PushersRemoveRestServlet(RestServlet):
"""
To allow pusher to be delete by clicking a link (ie. GET request)
"""
PATTERNS = client_path_patterns("/pushers/remove$")
SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs):
super(RestServlet, self).__init__()
self.hs = hs
self.notifier = hs.get_notifier()
self.auth = hs.get_v1auth()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request, rights="delete_pusher")
user = requester.user
app_id = parse_string(request, "app_id", required=True)
pushkey = parse_string(request, "pushkey", required=True)
pusher_pool = self.hs.get_pusherpool()
try:
yield pusher_pool.remove_pusher(
app_id=app_id,
pushkey=pushkey,
user_id=user.to_string(),
)
except StoreError as se:
if se.code != 404:
# This is fine: they're already unsubscribed
raise
self.notifier.on_new_replication_data()
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (
len(PushersRemoveRestServlet.SUCCESS_HTML),
))
request.write(PushersRemoveRestServlet.SUCCESS_HTML)
finish_request(request)
defer.returnValue(None)
def on_OPTIONS(self, _):
return 200, {}
def register_servlets(hs, http_server):
PushersRestServlet(hs).register(http_server)
PushersSetRestServlet(hs).register(http_server)
PushersRemoveRestServlet(hs).register(http_server)

View File

@@ -355,76 +355,5 @@ class RegisterRestServlet(ClientV1RestServlet):
)
class CreateUserRestServlet(ClientV1RestServlet):
"""Handles user creation via a server-to-server interface
"""
PATTERNS = client_path_patterns("/createUser$", releases=())
def __init__(self, hs):
super(CreateUserRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
@defer.inlineCallbacks
def on_POST(self, request):
user_json = parse_json_object_from_request(request)
if "access_token" not in request.args:
raise SynapseError(400, "Expected application service token.")
app_service = yield self.store.get_app_service_by_token(
request.args["access_token"][0]
)
if not app_service:
raise SynapseError(403, "Invalid application service token.")
logger.debug("creating user: %s", user_json)
response = yield self._do_create(user_json)
defer.returnValue((200, response))
def on_OPTIONS(self, request):
return 403, {}
@defer.inlineCallbacks
def _do_create(self, user_json):
yield run_on_reactor()
if "localpart" not in user_json:
raise SynapseError(400, "Expected 'localpart' key.")
if "displayname" not in user_json:
raise SynapseError(400, "Expected 'displayname' key.")
if "duration_seconds" not in user_json:
raise SynapseError(400, "Expected 'duration_seconds' key.")
localpart = user_json["localpart"].encode("utf-8")
displayname = user_json["displayname"].encode("utf-8")
duration_seconds = 0
try:
duration_seconds = int(user_json["duration_seconds"])
except ValueError:
raise SynapseError(400, "Failed to parse 'duration_seconds'")
if duration_seconds > self.direct_user_creation_max_duration:
duration_seconds = self.direct_user_creation_max_duration
handler = self.handlers.registration_handler
user_id, token = yield handler.get_or_create_user(
localpart=localpart,
displayname=displayname,
duration_seconds=duration_seconds
)
defer.returnValue({
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
})
def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server)
CreateUserRestServlet(hs).register(http_server)

View File

@@ -232,10 +232,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
if RoomID.is_valid(room_identifier):
room_id = room_identifier
try:
remote_room_hosts = request.args["server_name"]
except:
remote_room_hosts = None
remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier):
handler = self.handlers.room_member_handler
room_alias = RoomAlias.from_string(room_identifier)
@@ -279,9 +276,8 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
handler = self.hs.get_room_list_handler()
data = yield handler.get_aggregated_public_room_list()
handler = self.handlers.room_list_handler
data = yield handler.get_public_room_list()
defer.returnValue((200, data))
@@ -574,8 +570,7 @@ class RoomTypingRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomTypingRestServlet, self).__init__(hs)
self.presence_handler = hs.get_presence_handler()
self.typing_handler = hs.get_typing_handler()
self.presence_handler = hs.get_handlers().presence_handler
@defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id):
@@ -586,17 +581,19 @@ class RoomTypingRestServlet(ClientV1RestServlet):
content = parse_json_object_from_request(request)
typing_handler = self.handlers.typing_notification_handler
yield self.presence_handler.bump_presence_active_time(requester.user)
if content["typing"]:
yield self.typing_handler.started_typing(
yield typing_handler.started_typing(
target_user=target_user,
auth_user=requester.user,
room_id=room_id,
timeout=content.get("timeout", 30000),
)
else:
yield self.typing_handler.stopped_typing(
yield typing_handler.stopped_typing(
target_user=target_user,
auth_user=requester.user,
room_id=room_id,

View File

@@ -35,7 +35,7 @@ class PasswordRestServlet(RestServlet):
super(PasswordRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.auth_handler = hs.get_handlers().auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
@@ -52,7 +52,6 @@ class PasswordRestServlet(RestServlet):
defer.returnValue((401, result))
user_id = None
requester = None
if LoginType.PASSWORD in result:
# if using password, they should also be logged in
@@ -97,7 +96,7 @@ class ThreepidRestServlet(RestServlet):
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.auth_handler = hs.get_handlers().auth_handler
@defer.inlineCallbacks
def on_GET(self, request):

View File

@@ -104,7 +104,7 @@ class AuthRestServlet(RestServlet):
super(AuthRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.auth_handler = hs.get_handlers().auth_handler
self.registration_handler = hs.get_handlers().registration_handler
@defer.inlineCallbacks

View File

@@ -1,96 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import client_v2_patterns
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.api.errors import AuthError
from synapse.util.stringutils import random_string
from twisted.internet import defer
import logging
logger = logging.getLogger(__name__)
class IdTokenServlet(RestServlet):
"""
Get a bearer token that may be passed to a third party to confirm ownership
of a matrix user id.
The format of the response could be made compatible with the format given
in http://openid.net/specs/openid-connect-core-1_0.html#TokenResponse
But instead of returning a signed "id_token" the response contains the
name of the issuing matrix homeserver. This means that for now the third
party will need to check the validity of the "id_token" against the
federation /openid/userinfo endpoint of the homeserver.
Request:
POST /user/{user_id}/openid/request_token?access_token=... HTTP/1.1
{}
Response:
HTTP/1.1 200 OK
{
"access_token": "ABDEFGH",
"token_type": "Bearer",
"matrix_server_name": "example.com",
"expires_in": 3600,
}
"""
PATTERNS = client_v2_patterns(
"/user/(?P<user_id>[^/]*)/openid/request_token"
)
EXPIRES_MS = 3600 * 1000
def __init__(self, hs):
super(IdTokenServlet, self).__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.server_name = hs.config.server_name
@defer.inlineCallbacks
def on_POST(self, request, user_id):
requester = yield self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot request tokens for other users.")
# Parse the request body to make sure it's JSON, but ignore the contents
# for now.
parse_json_object_from_request(request)
token = random_string(24)
ts_valid_until_ms = self.clock.time_msec() + self.EXPIRES_MS
yield self.store.insert_open_id_token(token, ts_valid_until_ms, user_id)
defer.returnValue((200, {
"access_token": token,
"token_type": "Bearer",
"matrix_server_name": self.server_name,
"expires_in": self.EXPIRES_MS / 1000,
}))
def register_servlets(hs, http_server):
IdTokenServlet(hs).register(http_server)

View File

@@ -37,7 +37,7 @@ class ReceiptRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.receipts_handler = hs.get_handlers().receipts_handler
self.presence_handler = hs.get_presence_handler()
self.presence_handler = hs.get_handlers().presence_handler
@defer.inlineCallbacks
def on_POST(self, request, room_id, receipt_type, event_id):

View File

@@ -48,8 +48,7 @@ class RegisterRestServlet(RestServlet):
super(RegisterRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.auth_handler = hs.get_auth_handler()
self.auth_handler = hs.get_handlers().auth_handler
self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler
@@ -215,34 +214,6 @@ class RegisterRestServlet(RestServlet):
threepid['validated_at'],
)
# And we add an email pusher for them by default, but only
# if email notifications are enabled (so people don't start
# getting mail spam where they weren't before if email
# notifs are set up on a home server)
if (
self.hs.config.email_enable_notifs and
self.hs.config.email_notif_for_new_users
):
# Pull the ID of the access token back out of the db
# It would really make more sense for this to be passed
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
user_tuple = yield self.store.get_user_by_access_token(
token
)
yield self.hs.get_pusherpool().add_pusher(
user_id=user_id,
access_token=user_tuple["token_id"],
kind="email",
app_id="m.email",
app_display_name="Email Notifications",
device_display_name=threepid["address"],
pushkey=threepid["address"],
lang=None, # We don't know a user's language here
data={},
)
if 'bind_email' in params and params['bind_email']:
logger.info("bind_email specified: binding")

View File

@@ -1,59 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns
import logging
logger = logging.getLogger(__name__)
class ReportEventRestServlet(RestServlet):
PATTERNS = client_v2_patterns(
"/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$"
)
def __init__(self, hs):
super(ReportEventRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
yield self.store.add_event_report(
room_id=room_id,
event_id=event_id,
user_id=user_id,
reason=body.get("reason"),
content=body,
received_ts=self.clock.time_msec(),
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
ReportEventRestServlet(hs).register(http_server)

View File

@@ -79,10 +79,11 @@ class SyncRestServlet(RestServlet):
def __init__(self, hs):
super(SyncRestServlet, self).__init__()
self.auth = hs.get_auth()
self.sync_handler = hs.get_sync_handler()
self.event_stream_handler = hs.get_handlers().event_stream_handler
self.sync_handler = hs.get_handlers().sync_handler
self.clock = hs.get_clock()
self.filtering = hs.get_filtering()
self.presence_handler = hs.get_presence_handler()
self.presence_handler = hs.get_handlers().presence_handler
@defer.inlineCallbacks
def on_GET(self, request):

View File

@@ -38,7 +38,7 @@ class TokenRefreshRestServlet(RestServlet):
body = parse_json_object_from_request(request)
try:
old_refresh_token = body["refresh_token"]
auth_handler = self.hs.get_auth_handler()
auth_handler = self.hs.get_handlers().auth_handler
(user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
old_refresh_token, auth_handler.generate_refresh_token)
new_access_token = yield auth_handler.issue_access_token(user_id)

View File

@@ -49,6 +49,7 @@ class LocalKey(Resource):
"""
def __init__(self, hs):
self.hs = hs
self.version_string = hs.version_string
self.response_body = encode_canonical_json(
self.response_json_object(hs.config)

View File

@@ -97,7 +97,7 @@ class RemoteKey(Resource):
self.async_render_GET(request)
return NOT_DONE_YET
@request_handler()
@request_handler
@defer.inlineCallbacks
def async_render_GET(self, request):
if len(request.postpath) == 1:
@@ -122,7 +122,7 @@ class RemoteKey(Resource):
self.async_render_POST(request)
return NOT_DONE_YET
@request_handler()
@request_handler
@defer.inlineCallbacks
def async_render_POST(self, request):
content = parse_json_object_from_request(request)

View File

@@ -36,13 +36,12 @@ class DownloadResource(Resource):
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.version_string = hs.version_string
self.clock = hs.get_clock()
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
@request_handler()
@request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
server_name, media_id, name = parse_media_id(request)

View File

@@ -56,7 +56,8 @@ class PreviewUrlResource(Resource):
self.client = SpiderHttpClient(hs)
self.media_repo = media_repo
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
if hasattr(hs.config, "url_preview_url_blacklist"):
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
# simple memory cache mapping urls to OG metadata
self.cache = ExpiringCache(
@@ -73,7 +74,7 @@ class PreviewUrlResource(Resource):
self._async_render_GET(request)
return NOT_DONE_YET
@request_handler()
@request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
@@ -85,37 +86,39 @@ class PreviewUrlResource(Resource):
else:
ts = self.clock.time_msec()
url_tuple = urlparse.urlsplit(url)
for entry in self.url_preview_url_blacklist:
match = True
for attrib in entry:
pattern = entry[attrib]
value = getattr(url_tuple, attrib)
logger.debug((
"Matching attrib '%s' with value '%s' against"
" pattern '%s'"
) % (attrib, value, pattern))
# impose the URL pattern blacklist
if hasattr(self, "url_preview_url_blacklist"):
url_tuple = urlparse.urlsplit(url)
for entry in self.url_preview_url_blacklist:
match = True
for attrib in entry:
pattern = entry[attrib]
value = getattr(url_tuple, attrib)
logger.debug((
"Matching attrib '%s' with value '%s' against"
" pattern '%s'"
) % (attrib, value, pattern))
if value is None:
match = False
continue
if pattern.startswith('^'):
if not re.match(pattern, getattr(url_tuple, attrib)):
if value is None:
match = False
continue
else:
if not fnmatch.fnmatch(getattr(url_tuple, attrib), pattern):
match = False
continue
if match:
logger.warn(
"URL %s blocked by url_blacklist entry %s", url, entry
)
raise SynapseError(
403, "URL blocked by url pattern blacklist entry",
Codes.UNKNOWN
)
if pattern.startswith('^'):
if not re.match(pattern, getattr(url_tuple, attrib)):
match = False
continue
else:
if not fnmatch.fnmatch(getattr(url_tuple, attrib), pattern):
match = False
continue
if match:
logger.warn(
"URL %s blocked by url_blacklist entry %s", url, entry
)
raise SynapseError(
403, "URL blocked by url pattern blacklist entry",
Codes.UNKNOWN
)
# first check the memory cache - good to handle all the clients on this
# HS thundering away to preview the same URL at the same time.

View File

@@ -39,13 +39,12 @@ class ThumbnailResource(Resource):
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname
self.version_string = hs.version_string
self.clock = hs.get_clock()
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
@request_handler()
@request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
server_name, media_id, _ = parse_media_id(request)

View File

@@ -41,7 +41,6 @@ class UploadResource(Resource):
self.auth = hs.get_auth()
self.max_upload_size = hs.config.max_upload_size
self.version_string = hs.version_string
self.clock = hs.get_clock()
def render_POST(self, request):
self._async_render_POST(request)
@@ -51,7 +50,7 @@ class UploadResource(Resource):
respond_with_json(request, 200, {}, send_cors=True)
return NOT_DONE_YET
@request_handler()
@request_handler
@defer.inlineCallbacks
def _async_render_POST(self, request):
requester = yield self.auth.get_user_by_req(request)

View File

@@ -22,19 +22,11 @@
from twisted.web.client import BrowserLikePolicyForHTTPS
from twisted.enterprise import adbapi
from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.appservice.api import ApplicationServiceApi
from synapse.federation import initialize_http_replication
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.notifier import Notifier
from synapse.api.auth import Auth
from synapse.handlers import Handlers
from synapse.handlers.presence import PresenceHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
from synapse.handlers.room import RoomListHandler
from synapse.handlers.auth import AuthHandler
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.state import StateHandler
from synapse.storage import DataStore
from synapse.util import Clock
@@ -86,14 +78,6 @@ class HomeServer(object):
'auth',
'rest_servlet_factory',
'state_handler',
'presence_handler',
'sync_handler',
'typing_handler',
'room_list_handler',
'auth_handler',
'application_service_api',
'application_service_scheduler',
'application_service_handler',
'notifier',
'distributor',
'client_resource',
@@ -180,30 +164,6 @@ class HomeServer(object):
def build_state_handler(self):
return StateHandler(self)
def build_presence_handler(self):
return PresenceHandler(self)
def build_typing_handler(self):
return TypingHandler(self)
def build_sync_handler(self):
return SyncHandler(self)
def build_room_list_handler(self):
return RoomListHandler(self)
def build_auth_handler(self):
return AuthHandler(self)
def build_application_service_api(self):
return ApplicationServiceApi(self)
def build_application_service_scheduler(self):
return ApplicationServiceScheduler(self)
def build_application_service_handler(self):
return ApplicationServicesHandler(self)
def build_event_sources(self):
return EventSources(self)

View File

@@ -17,7 +17,7 @@ from twisted.internet import defer
from .appservice import (
ApplicationServiceStore, ApplicationServiceTransactionStore
)
from ._base import Cache, LoggingTransaction
from ._base import Cache
from .directory import DirectoryStore
from .events import EventsStore
from .presence import PresenceStore, UserPresenceState
@@ -44,7 +44,6 @@ from .receipts import ReceiptsStore
from .search import SearchStore
from .tags import TagsStore
from .account_data import AccountDataStore
from .openid import OpenIdStore
from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
@@ -82,13 +81,11 @@ class DataStore(RoomMemberStore, RoomStore,
SearchStore,
TagsStore,
AccountDataStore,
EventPushActionsStore,
OpenIdStore,
EventPushActionsStore
):
def __init__(self, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = hs.database_engine
self.client_ip_last_seen = Cache(
@@ -117,7 +114,6 @@ class DataStore(RoomMemberStore, RoomStore,
self._state_groups_id_gen = StreamIdGenerator(db_conn, "state_groups", "id")
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
self._push_rules_stream_id_gen = ChainedIdGenerator(
@@ -149,7 +145,7 @@ class DataStore(RoomMemberStore, RoomStore,
"AccountDataAndTagsChangeCache", account_max,
)
self._presence_on_startup = self._get_active_presence(db_conn)
self.__presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self._get_cache_dict(
db_conn, "presence_stream",
@@ -174,24 +170,11 @@ class DataStore(RoomMemberStore, RoomStore,
prefilled_cache=push_rules_prefill,
)
cur = LoggingTransaction(
db_conn.cursor(),
name="_find_stream_orderings_for_times_txn",
database_engine=self.database_engine,
after_callbacks=[]
)
self._find_stream_orderings_for_times_txn(cur)
cur.close()
self.find_stream_orderings_looping_call = self._clock.looping_call(
self._find_stream_orderings_for_times, 60 * 60 * 1000
)
super(DataStore, self).__init__(hs)
def take_presence_startup_info(self):
active_on_startup = self._presence_on_startup
self._presence_on_startup = None
active_on_startup = self.__presence_on_startup
self.__presence_on_startup = None
return active_on_startup
def _get_active_presence(self, db_conn):

View File

@@ -152,8 +152,8 @@ class SQLBaseStore(object):
def __init__(self, hs):
self.hs = hs
self._clock = hs.get_clock()
self._db_pool = hs.get_db_pool()
self._clock = hs.get_clock()
self._previous_txn_total_time = 0
self._current_txn_total_time = 0
@@ -453,9 +453,7 @@ class SQLBaseStore(object):
keyvalues (dict): The unique key tables and their new values
values (dict): The nonunique columns and their new values
insertion_values (dict): key/values to use when inserting
Returns:
Deferred(bool): True if a new entry was created, False if an
existing one was updated.
Returns: A deferred
"""
return self.runInteraction(
desc,
@@ -500,10 +498,6 @@ class SQLBaseStore(object):
)
txn.execute(sql, allvalues.values())
return True
else:
return False
def _simple_select_one(self, table, keyvalues, retcols,
allow_none=False, desc="_simple_select_one"):
"""Executes a SELECT query on the named table, which is expected to

View File

@@ -16,8 +16,6 @@
from ._base import SQLBaseStore
from twisted.internet import defer
from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
import ujson as json
import logging
@@ -26,7 +24,6 @@ logger = logging.getLogger(__name__)
class AccountDataStore(SQLBaseStore):
@cached()
def get_account_data_for_user(self, user_id):
"""Get all the client account_data for a user.
@@ -63,47 +60,6 @@ class AccountDataStore(SQLBaseStore):
"get_account_data_for_user", get_account_data_for_user_txn
)
@cachedInlineCallbacks(num_args=2)
def get_global_account_data_by_type_for_user(self, data_type, user_id):
"""
Returns:
Deferred: A dict
"""
result = yield self._simple_select_one_onecol(
table="account_data",
keyvalues={
"user_id": user_id,
"account_data_type": data_type,
},
retcol="content",
desc="get_global_account_data_by_type_for_user",
allow_none=True,
)
if result:
defer.returnValue(json.loads(result))
else:
defer.returnValue(None)
@cachedList(cached_method_name="get_global_account_data_by_type_for_user",
num_args=2, list_name="user_ids", inlineCallbacks=True)
def get_global_account_data_by_type_for_users(self, data_type, user_ids):
rows = yield self._simple_select_many_batch(
table="account_data",
column="user_id",
iterable=user_ids,
keyvalues={
"account_data_type": data_type,
},
retcols=("user_id", "content",),
desc="get_global_account_data_by_type_for_users",
)
defer.returnValue({
row["user_id"]: json.loads(row["content"]) if row["content"] else None
for row in rows
})
def get_account_data_for_room(self, user_id, room_id):
"""Get all the client account_data for a user for a room.
@@ -237,7 +193,6 @@ class AccountDataStore(SQLBaseStore):
self._account_data_stream_cache.entity_has_changed,
user_id, next_id,
)
txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
self._update_max_stream_id(txn, next_id)
with self._account_data_id_gen.get_next() as next_id:
@@ -277,11 +232,6 @@ class AccountDataStore(SQLBaseStore):
self._account_data_stream_cache.entity_has_changed,
user_id, next_id,
)
txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
txn.call_after(
self.get_global_account_data_by_type_for_user.invalidate,
(account_data_type, user_id,)
)
self._update_max_stream_id(txn, next_id)
with self._account_data_id_gen.get_next() as next_id:

View File

@@ -13,13 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import urllib
import yaml
import simplejson as json
from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
from synapse.appservice import ApplicationService, AppServiceTransaction
from synapse.config._base import ConfigError
from synapse.storage.roommember import RoomsForUser
from synapse.types import UserID
from ._base import SQLBaseStore
@@ -31,7 +34,7 @@ class ApplicationServiceStore(SQLBaseStore):
def __init__(self, hs):
super(ApplicationServiceStore, self).__init__(hs)
self.hostname = hs.hostname
self.services_cache = load_appservices(
self.services_cache = ApplicationServiceStore.load_appservices(
hs.hostname,
hs.config.app_service_config_files
)
@@ -141,6 +144,102 @@ class ApplicationServiceStore(SQLBaseStore):
return rooms_for_user_matching_user_id
@classmethod
def _load_appservice(cls, hostname, as_info, config_filename):
required_string_fields = [
"id", "url", "as_token", "hs_token", "sender_localpart"
]
for field in required_string_fields:
if not isinstance(as_info.get(field), basestring):
raise KeyError("Required string field: '%s' (%s)" % (
field, config_filename,
))
localpart = as_info["sender_localpart"]
if urllib.quote(localpart) != localpart:
raise ValueError(
"sender_localpart needs characters which are not URL encoded."
)
user = UserID(localpart, hostname)
user_id = user.to_string()
# namespace checks
if not isinstance(as_info.get("namespaces"), dict):
raise KeyError("Requires 'namespaces' object.")
for ns in ApplicationService.NS_LIST:
# specific namespaces are optional
if ns in as_info["namespaces"]:
# expect a list of dicts with exclusive and regex keys
for regex_obj in as_info["namespaces"][ns]:
if not isinstance(regex_obj, dict):
raise ValueError(
"Expected namespace entry in %s to be an object,"
" but got %s", ns, regex_obj
)
if not isinstance(regex_obj.get("regex"), basestring):
raise ValueError(
"Missing/bad type 'regex' key in %s", regex_obj
)
if not isinstance(regex_obj.get("exclusive"), bool):
raise ValueError(
"Missing/bad type 'exclusive' key in %s", regex_obj
)
return ApplicationService(
token=as_info["as_token"],
url=as_info["url"],
namespaces=as_info["namespaces"],
hs_token=as_info["hs_token"],
sender=user_id,
id=as_info["id"],
)
@classmethod
def load_appservices(cls, hostname, config_files):
"""Returns a list of Application Services from the config files."""
if not isinstance(config_files, list):
logger.warning(
"Expected %s to be a list of AS config files.", config_files
)
return []
# Dicts of value -> filename
seen_as_tokens = {}
seen_ids = {}
appservices = []
for config_file in config_files:
try:
with open(config_file, 'r') as f:
appservice = ApplicationServiceStore._load_appservice(
hostname, yaml.load(f), config_file
)
if appservice.id in seen_ids:
raise ConfigError(
"Cannot reuse ID across application services: "
"%s (files: %s, %s)" % (
appservice.id, config_file, seen_ids[appservice.id],
)
)
seen_ids[appservice.id] = config_file
if appservice.token in seen_as_tokens:
raise ConfigError(
"Cannot reuse as_token across application services: "
"%s (files: %s, %s)" % (
appservice.token,
config_file,
seen_as_tokens[appservice.token],
)
)
seen_as_tokens[appservice.token] = config_file
logger.info("Loaded application service: %s", appservice)
appservices.append(appservice)
except Exception as e:
logger.error("Failed to load appservice from '%s'", config_file)
logger.exception(e)
raise
return appservices
class ApplicationServiceTransactionStore(SQLBaseStore):

View File

@@ -173,12 +173,11 @@ class BackgroundUpdateStore(SQLBaseStore):
logger.info(
"Updating %r. Updated %r items in %rms."
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r)",
update_name, items_updated, duration_ms,
performance.total_items_per_ms(),
performance.average_items_per_ms(),
performance.total_item_count,
batch_size,
)
performance.update(items_updated, duration_ms)

View File

@@ -24,10 +24,6 @@ logger = logging.getLogger(__name__)
class EventPushActionsStore(SQLBaseStore):
def __init__(self, hs):
self.stream_ordering_month_ago = None
super(EventPushActionsStore, self).__init__(hs)
def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
"""
Args:
@@ -119,23 +115,19 @@ class EventPushActionsStore(SQLBaseStore):
@defer.inlineCallbacks
def get_unread_push_actions_for_user_in_range(self, user_id,
min_stream_ordering,
max_stream_ordering=None,
limit=20):
max_stream_ordering=None):
def get_after_receipt(txn):
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, "
"e.received_ts "
"FROM ("
" SELECT room_id, user_id, "
" max(topological_ordering) as topological_ordering, "
" max(stream_ordering) as stream_ordering "
"SELECT ep.event_id, ep.stream_ordering, ep.actions "
"FROM event_push_actions AS ep, ("
" SELECT room_id, user_id,"
" max(topological_ordering) as topological_ordering,"
" max(stream_ordering) as stream_ordering"
" FROM events"
" NATURAL JOIN receipts_linearized WHERE receipt_type = 'm.read'"
" GROUP BY room_id, user_id"
") AS rl,"
" event_push_actions AS ep"
" INNER JOIN events AS e USING (room_id, event_id)"
" WHERE"
") AS rl "
"WHERE"
" ep.room_id = rl.room_id"
" AND ("
" ep.topological_ordering > rl.topological_ordering"
@@ -152,8 +144,7 @@ class EventPushActionsStore(SQLBaseStore):
if max_stream_ordering is not None:
sql += " AND ep.stream_ordering <= ?"
args.append(max_stream_ordering)
sql += " ORDER BY ep.stream_ordering ASC LIMIT ?"
args.append(limit)
sql += " ORDER BY ep.stream_ordering ASC"
txn.execute(sql, args)
return txn.fetchall()
after_read_receipt = yield self.runInteraction(
@@ -162,13 +153,11 @@ class EventPushActionsStore(SQLBaseStore):
def get_no_receipt(txn):
sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" e.received_ts"
" FROM event_push_actions AS ep"
" JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
" WHERE ep.room_id not in ("
"SELECT ep.event_id, ep.stream_ordering, ep.actions "
"FROM event_push_actions AS ep "
"WHERE ep.room_id not in ("
" SELECT room_id FROM events NATURAL JOIN receipts_linearized"
" WHERE receipt_type = 'm.read' AND user_id = ?"
" WHERE receipt_type = 'm.read' AND user_id = ? "
" GROUP BY room_id"
") AND ep.user_id = ? AND ep.stream_ordering > ?"
)
@@ -186,29 +175,11 @@ class EventPushActionsStore(SQLBaseStore):
defer.returnValue([
{
"event_id": row[0],
"room_id": row[1],
"stream_ordering": row[2],
"actions": json.loads(row[3]),
"received_ts": row[4],
"stream_ordering": row[1],
"actions": json.loads(row[2]),
} for row in after_read_receipt + no_read_receipt
])
@defer.inlineCallbacks
def get_time_of_last_push_action_before(self, stream_ordering):
def f(txn):
sql = (
"SELECT e.received_ts"
" FROM event_push_actions AS ep"
" JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
" WHERE ep.stream_ordering > ?"
" ORDER BY ep.stream_ordering ASC"
" LIMIT 1"
)
txn.execute(sql, (stream_ordering,))
return txn.fetchone()
result = yield self.runInteraction("get_time_of_last_push_action_before", f)
defer.returnValue(result[0] if result else None)
@defer.inlineCallbacks
def get_latest_push_action_stream_ordering(self):
def f(txn):
@@ -230,93 +201,6 @@ class EventPushActionsStore(SQLBaseStore):
(room_id, event_id)
)
def _remove_old_push_actions_before_txn(self, txn, room_id, user_id,
topological_ordering):
"""
Purges old, stale push actions for a user and room before a given
topological_ordering
Args:
txn: The transcation
room_id: Room ID to delete from
user_id: user ID to delete for
topological_ordering: The lowest topological ordering which will
not be deleted.
"""
txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(room_id, user_id, )
)
# We need to join on the events table to get the received_ts for
# event_push_actions and sqlite won't let us use a join in a delete so
# we can't just delete where received_ts < x. Furthermore we can
# only identify event_push_actions by a tuple of room_id, event_id
# we we can't use a subquery.
# Instead, we look up the stream ordering for the last event in that
# room received before the threshold time and delete event_push_actions
# in the room with a stream_odering before that.
txn.execute(
"DELETE FROM event_push_actions "
" WHERE user_id = ? AND room_id = ? AND "
" topological_ordering < ? AND stream_ordering < ?",
(user_id, room_id, topological_ordering, self.stream_ordering_month_ago)
)
@defer.inlineCallbacks
def _find_stream_orderings_for_times(self):
yield self.runInteraction(
"_find_stream_orderings_for_times",
self._find_stream_orderings_for_times_txn
)
def _find_stream_orderings_for_times_txn(self, txn):
logger.info("Searching for stream ordering 1 month ago")
self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
)
logger.info(
"Found stream ordering 1 month ago: it's %d",
self.stream_ordering_month_ago
)
def _find_first_stream_ordering_after_ts_txn(self, txn, ts):
"""
Find the stream_ordering of the first event that was received after
a given timestamp. This is relatively slow as there is no index on
received_ts but we can then use this to delete push actions before
this.
received_ts must necessarily be in the same order as stream_ordering
and stream_ordering is indexed, so we manually binary search using
stream_ordering
"""
txn.execute("SELECT MAX(stream_ordering) FROM events")
max_stream_ordering = txn.fetchone()[0]
if max_stream_ordering is None:
return 0
range_start = 0
range_end = max_stream_ordering
sql = (
"SELECT received_ts FROM events"
" WHERE stream_ordering > ?"
" ORDER BY stream_ordering"
" LIMIT 1"
)
while range_end - range_start > 1:
middle = int((range_end + range_start) / 2)
txn.execute(sql, (middle,))
middle_ts = txn.fetchone()[0]
if ts > middle_ts:
range_start = middle
else:
range_end = middle
return range_end
def _action_has_highlight(actions):
for action in actions:

Some files were not shown because too many files have changed in this diff Show More