Compare commits
9 Commits
markjh/syn
...
markjh/spl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3a676b8ee3 | ||
|
|
0d5622b088 | ||
|
|
712030aeef | ||
|
|
0b282d33af | ||
|
|
03c8df54f0 | ||
|
|
c214d3e36e | ||
|
|
1c1b2de975 | ||
|
|
f41b1a8723 | ||
|
|
1209d3174e |
@@ -11,7 +11,6 @@ recursive-include synapse/storage/schema *.sql
|
||||
recursive-include synapse/storage/schema *.py
|
||||
|
||||
recursive-include docs *
|
||||
recursive-include res *
|
||||
recursive-include scripts *
|
||||
recursive-include scripts-dev *
|
||||
recursive-include tests *.py
|
||||
|
||||
@@ -32,4 +32,5 @@ The format of the AS configuration file is as follows:
|
||||
|
||||
See the spec_ for further details on how application services work.
|
||||
|
||||
.. _spec: https://matrix.org/docs/spec/application_service/unstable.html
|
||||
.. _spec: https://github.com/matrix-org/matrix-doc/blob/master/specification/25_application_service_api.rst#application-service-api
|
||||
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
What do I do about "Unexpected logging context" debug log-lines everywhere?
|
||||
|
||||
<Mjark> The logging context lives in thread local storage
|
||||
<Mjark> Sometimes it gets out of sync with what it should actually be, usually because something scheduled something to run on the reactor without preserving the logging context.
|
||||
<Matthew> what is the impact of it getting out of sync? and how and when should we preserve log context?
|
||||
<Mjark> The impact is that some of the CPU and database metrics will be under-reported, and some log lines will be mis-attributed.
|
||||
<Mjark> It should happen auto-magically in all the APIs that do IO or otherwise defer to the reactor.
|
||||
<Erik> Mjark: the other place is if we branch, e.g. using defer.gatherResults
|
||||
|
||||
Unanswered: how and when should we preserve log context?
|
||||
@@ -1,7 +0,0 @@
|
||||
.header {
|
||||
border-bottom: 4px solid #e4f7ed ! important;
|
||||
}
|
||||
|
||||
.notif_link a, .footer a {
|
||||
color: #76CFA6 ! important;
|
||||
}
|
||||
@@ -1,156 +0,0 @@
|
||||
body {
|
||||
margin: 0px;
|
||||
}
|
||||
|
||||
pre, code {
|
||||
word-break: break-word;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
#page {
|
||||
font-family: 'Open Sans', Helvetica, Arial, Sans-Serif;
|
||||
font-color: #454545;
|
||||
font-size: 12pt;
|
||||
width: 100%;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
#inner {
|
||||
width: 640px;
|
||||
}
|
||||
|
||||
.header {
|
||||
width: 100%;
|
||||
height: 87px;
|
||||
color: #454545;
|
||||
border-bottom: 4px solid #e5e5e5;
|
||||
}
|
||||
|
||||
.logo {
|
||||
text-align: right;
|
||||
margin-left: 20px;
|
||||
}
|
||||
|
||||
.salutation {
|
||||
padding-top: 10px;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.summarytext {
|
||||
}
|
||||
|
||||
.room {
|
||||
width: 100%;
|
||||
color: #454545;
|
||||
border-bottom: 1px solid #e5e5e5;
|
||||
}
|
||||
|
||||
.room_header td {
|
||||
padding-top: 38px;
|
||||
padding-bottom: 10px;
|
||||
border-bottom: 1px solid #e5e5e5;
|
||||
}
|
||||
|
||||
.room_name {
|
||||
vertical-align: middle;
|
||||
font-size: 18px;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.room_header h2 {
|
||||
margin-top: 0px;
|
||||
margin-left: 75px;
|
||||
font-size: 20px;
|
||||
}
|
||||
|
||||
.room_avatar {
|
||||
width: 56px;
|
||||
line-height: 0px;
|
||||
text-align: center;
|
||||
vertical-align: middle;
|
||||
}
|
||||
|
||||
.room_avatar img {
|
||||
width: 48px;
|
||||
height: 48px;
|
||||
object-fit: cover;
|
||||
border-radius: 24px;
|
||||
}
|
||||
|
||||
.notif {
|
||||
border-bottom: 1px solid #e5e5e5;
|
||||
margin-top: 16px;
|
||||
padding-bottom: 16px;
|
||||
}
|
||||
|
||||
.historical_message .sender_avatar {
|
||||
opacity: 0.3;
|
||||
}
|
||||
|
||||
/* spell out opacity and historical_message class names for Outlook aka Word */
|
||||
.historical_message .sender_name {
|
||||
color: #e3e3e3;
|
||||
}
|
||||
|
||||
.historical_message .message_time {
|
||||
color: #e3e3e3;
|
||||
}
|
||||
|
||||
.historical_message .message_body {
|
||||
color: #c7c7c7;
|
||||
}
|
||||
|
||||
.historical_message td,
|
||||
.message td {
|
||||
padding-top: 10px;
|
||||
}
|
||||
|
||||
.sender_avatar {
|
||||
width: 56px;
|
||||
text-align: center;
|
||||
vertical-align: top;
|
||||
}
|
||||
|
||||
.sender_avatar img {
|
||||
margin-top: -2px;
|
||||
width: 32px;
|
||||
height: 32px;
|
||||
border-radius: 16px;
|
||||
}
|
||||
|
||||
.sender_name {
|
||||
display: inline;
|
||||
font-size: 13px;
|
||||
color: #a2a2a2;
|
||||
}
|
||||
|
||||
.message_time {
|
||||
text-align: right;
|
||||
width: 100px;
|
||||
font-size: 11px;
|
||||
color: #a2a2a2;
|
||||
}
|
||||
|
||||
.message_body {
|
||||
}
|
||||
|
||||
.notif_link td {
|
||||
padding-top: 10px;
|
||||
padding-bottom: 10px;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.notif_link a, .footer a {
|
||||
color: #454545;
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
.debug {
|
||||
font-size: 10px;
|
||||
color: #888;
|
||||
}
|
||||
|
||||
.footer {
|
||||
margin-top: 20px;
|
||||
text-align: center;
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
{% for message in notif.messages %}
|
||||
<tr class="{{ "historical_message" if message.is_historical else "message" }}">
|
||||
<td class="sender_avatar">
|
||||
{% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
|
||||
{% if message.sender_avatar_url %}
|
||||
<img alt="" class="sender_avatar" src="{{ message.sender_avatar_url|mxc_to_http(32,32) }}" />
|
||||
{% else %}
|
||||
{% if message.sender_hash % 3 == 0 %}
|
||||
<img class="sender_avatar" src="https://vector.im/beta/img/76cfa6.png" />
|
||||
{% elif message.sender_hash % 3 == 1 %}
|
||||
<img class="sender_avatar" src="https://vector.im/beta/img/50e2c2.png" />
|
||||
{% else %}
|
||||
<img class="sender_avatar" src="https://vector.im/beta/img/f4c371.png" />
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
</td>
|
||||
<td class="message_contents">
|
||||
{% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
|
||||
<div class="sender_name">{% if message.msgtype == "m.emote" %}*{% endif %} {{ message.sender_name }}</div>
|
||||
{% endif %}
|
||||
<div class="message_body">
|
||||
{% if message.msgtype == "m.text" %}
|
||||
{{ message.body_text_html }}
|
||||
{% elif message.msgtype == "m.emote" %}
|
||||
{{ message.body_text_html }}
|
||||
{% elif message.msgtype == "m.notice" %}
|
||||
{{ message.body_text_html }}
|
||||
{% elif message.msgtype == "m.image" %}
|
||||
<img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
|
||||
{% elif message.msgtype == "m.file" %}
|
||||
<span class="filename">{{ message.body_text_plain }}</span>
|
||||
{% endif %}
|
||||
</div>
|
||||
</td>
|
||||
<td class="message_time">{{ message.ts|format_ts("%H:%M") }}</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
<tr class="notif_link">
|
||||
<td></td>
|
||||
<td>
|
||||
<a href="{{ notif.link }}">View {{ room.title }}</a>
|
||||
</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
@@ -1,16 +0,0 @@
|
||||
{% for message in notif.messages %}
|
||||
{% if message.msgtype == "m.emote" %}* {% endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }})
|
||||
{% if message.msgtype == "m.text" %}
|
||||
{{ message.body_text_plain }}
|
||||
{% elif message.msgtype == "m.emote" %}
|
||||
{{ message.body_text_plain }}
|
||||
{% elif message.msgtype == "m.notice" %}
|
||||
{{ message.body_text_plain }}
|
||||
{% elif message.msgtype == "m.image" %}
|
||||
{{ message.body_text_plain }}
|
||||
{% elif message.msgtype == "m.file" %}
|
||||
{{ message.body_text_plain }}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
View {{ room.title }} at {{ notif.link }}
|
||||
@@ -1,53 +0,0 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<style type="text/css">
|
||||
{% include 'mail.css' without context %}
|
||||
{% include "mail-%s.css" % app_name ignore missing without context %}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<table id="page">
|
||||
<tr>
|
||||
<td> </td>
|
||||
<td id="inner">
|
||||
<table class="header">
|
||||
<tr>
|
||||
<td>
|
||||
<div class="salutation">Hi {{ user_display_name }},</div>
|
||||
<div class="summarytext">{{ summary_text }}</div>
|
||||
</td>
|
||||
<td class="logo">
|
||||
{% if app_name == "Vector" %}
|
||||
<img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
|
||||
{% else %}
|
||||
<img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>
|
||||
{% endif %}
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
{% for room in rooms %}
|
||||
{% include 'room.html' with context %}
|
||||
{% endfor %}
|
||||
<div class="footer">
|
||||
<a href="{{ unsubscribe_link }}">Unsubscribe</a>
|
||||
<br/>
|
||||
<br/>
|
||||
<div class="debug">
|
||||
Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
|
||||
an event was received at {{ reason.received_at|format_ts("%c") }}
|
||||
which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} (delay_before_mail_ms) mins ago,
|
||||
{% if reason.last_sent_ts %}
|
||||
and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
|
||||
which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
|
||||
{% else %}
|
||||
and we don't have a last time we sent a mail for this room.
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
</td>
|
||||
<td> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</body>
|
||||
</html>
|
||||
@@ -1,10 +0,0 @@
|
||||
Hi {{ user_display_name }},
|
||||
|
||||
{{ summary_text }}
|
||||
|
||||
{% for room in rooms %}
|
||||
{% include 'room.txt' with context %}
|
||||
{% endfor %}
|
||||
|
||||
You can disable these notifications at {{ unsubscribe_link }}
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
<table class="room">
|
||||
<tr class="room_header">
|
||||
<td class="room_avatar">
|
||||
{% if room.avatar_url %}
|
||||
<img alt="" src="{{ room.avatar_url|mxc_to_http(48,48) }}" />
|
||||
{% else %}
|
||||
{% if room.hash % 3 == 0 %}
|
||||
<img alt="" src="https://vector.im/beta/img/76cfa6.png" />
|
||||
{% elif room.hash % 3 == 1 %}
|
||||
<img alt="" src="https://vector.im/beta/img/50e2c2.png" />
|
||||
{% else %}
|
||||
<img alt="" src="https://vector.im/beta/img/f4c371.png" />
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
</td>
|
||||
<td class="room_name" colspan="2">
|
||||
{{ room.title }}
|
||||
</td>
|
||||
</tr>
|
||||
{% if room.invite %}
|
||||
<tr>
|
||||
<td></td>
|
||||
<td>
|
||||
<a href="{{ room.link }}">Join the conversation.</a>
|
||||
</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
{% else %}
|
||||
{% for notif in room.notifs %}
|
||||
{% include 'notif.html' with context %}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
</table>
|
||||
@@ -1,9 +0,0 @@
|
||||
{{ room.title }}
|
||||
|
||||
{% if room.invite %}
|
||||
You've been invited, join at {{ room.link }}
|
||||
{% else %}
|
||||
{% for notif in room.notifs %}
|
||||
{% include 'notif.txt' with context %}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
@@ -214,10 +214,6 @@ class Porter(object):
|
||||
|
||||
self.progress.add_table(table, postgres_size, table_size)
|
||||
|
||||
if table == "event_search":
|
||||
yield self.handle_search_table(postgres_size, table_size, next_chunk)
|
||||
return
|
||||
|
||||
select = (
|
||||
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
|
||||
% (table,)
|
||||
@@ -236,19 +232,51 @@ class Porter(object):
|
||||
if rows:
|
||||
next_chunk = rows[-1][0] + 1
|
||||
|
||||
self._convert_rows(table, headers, rows)
|
||||
if table == "event_search":
|
||||
# We have to treat event_search differently since it has a
|
||||
# different structure in the two different databases.
|
||||
def insert(txn):
|
||||
sql = (
|
||||
"INSERT INTO event_search (event_id, room_id, key, sender, vector)"
|
||||
" VALUES (?,?,?,?,to_tsvector('english', ?))"
|
||||
)
|
||||
|
||||
def insert(txn):
|
||||
self.postgres_store.insert_many_txn(
|
||||
txn, table, headers[1:], rows
|
||||
)
|
||||
rows_dict = [
|
||||
dict(zip(headers, row))
|
||||
for row in rows
|
||||
]
|
||||
|
||||
self.postgres_store._simple_update_one_txn(
|
||||
txn,
|
||||
table="port_from_sqlite3",
|
||||
keyvalues={"table_name": table},
|
||||
updatevalues={"rowid": next_chunk},
|
||||
)
|
||||
txn.executemany(sql, [
|
||||
(
|
||||
row["event_id"],
|
||||
row["room_id"],
|
||||
row["key"],
|
||||
row["sender"],
|
||||
row["value"],
|
||||
)
|
||||
for row in rows_dict
|
||||
])
|
||||
|
||||
self.postgres_store._simple_update_one_txn(
|
||||
txn,
|
||||
table="port_from_sqlite3",
|
||||
keyvalues={"table_name": table},
|
||||
updatevalues={"rowid": next_chunk},
|
||||
)
|
||||
else:
|
||||
self._convert_rows(table, headers, rows)
|
||||
|
||||
def insert(txn):
|
||||
self.postgres_store.insert_many_txn(
|
||||
txn, table, headers[1:], rows
|
||||
)
|
||||
|
||||
self.postgres_store._simple_update_one_txn(
|
||||
txn,
|
||||
table="port_from_sqlite3",
|
||||
keyvalues={"table_name": table},
|
||||
updatevalues={"rowid": next_chunk},
|
||||
)
|
||||
|
||||
yield self.postgres_store.execute(insert)
|
||||
|
||||
@@ -258,73 +286,6 @@ class Porter(object):
|
||||
else:
|
||||
return
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_search_table(self, postgres_size, table_size, next_chunk):
|
||||
select = (
|
||||
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
|
||||
" FROM event_search as es"
|
||||
" INNER JOIN events AS e USING (event_id, room_id)"
|
||||
" WHERE es.rowid >= ?"
|
||||
" ORDER BY es.rowid LIMIT ?"
|
||||
)
|
||||
|
||||
while True:
|
||||
def r(txn):
|
||||
txn.execute(select, (next_chunk, self.batch_size,))
|
||||
rows = txn.fetchall()
|
||||
headers = [column[0] for column in txn.description]
|
||||
|
||||
return headers, rows
|
||||
|
||||
headers, rows = yield self.sqlite_store.runInteraction("select", r)
|
||||
|
||||
if rows:
|
||||
next_chunk = rows[-1][0] + 1
|
||||
|
||||
# We have to treat event_search differently since it has a
|
||||
# different structure in the two different databases.
|
||||
def insert(txn):
|
||||
sql = (
|
||||
"INSERT INTO event_search (event_id, room_id, key,"
|
||||
" sender, vector, origin_server_ts, stream_ordering)"
|
||||
" VALUES (?,?,?,?,to_tsvector('english', ?),?,?)"
|
||||
)
|
||||
|
||||
rows_dict = [
|
||||
dict(zip(headers, row))
|
||||
for row in rows
|
||||
]
|
||||
|
||||
txn.executemany(sql, [
|
||||
(
|
||||
row["event_id"],
|
||||
row["room_id"],
|
||||
row["key"],
|
||||
row["sender"],
|
||||
row["value"],
|
||||
row["origin_server_ts"],
|
||||
row["stream_ordering"],
|
||||
)
|
||||
for row in rows_dict
|
||||
])
|
||||
|
||||
self.postgres_store._simple_update_one_txn(
|
||||
txn,
|
||||
table="port_from_sqlite3",
|
||||
keyvalues={"table_name": "event_search"},
|
||||
updatevalues={"rowid": next_chunk},
|
||||
)
|
||||
|
||||
yield self.postgres_store.execute(insert)
|
||||
|
||||
postgres_size += len(rows)
|
||||
|
||||
self.progress.update("event_search", postgres_size)
|
||||
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
def setup_db(self, db_config, database_engine):
|
||||
db_conn = database_engine.module.connect(
|
||||
**{
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""This module contains classes for authenticating the user."""
|
||||
from canonicaljson import encode_canonical_json
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
from signedjson.sign import verify_signed_json, SignatureVerifyException
|
||||
@@ -21,7 +22,7 @@ from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership, JoinRules
|
||||
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
|
||||
from synapse.types import Requester, UserID, get_domain_from_id
|
||||
from synapse.types import Requester, RoomID, UserID, EventID
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import preserve_context_over_fn
|
||||
from synapse.util.metrics import Measure
|
||||
@@ -41,20 +42,13 @@ AuthEventTypes = (
|
||||
|
||||
|
||||
class Auth(object):
|
||||
"""
|
||||
FIXME: This class contains a mix of functions for authenticating users
|
||||
of our client-server API and authenticating events added to room graphs.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.state = hs.get_state_handler()
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
|
||||
# Docs for these currently lives at
|
||||
# https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
|
||||
# In addition, we have type == delete_pusher which grants access only to
|
||||
# delete pushers.
|
||||
self._KNOWN_CAVEAT_PREFIXES = set([
|
||||
"gen = ",
|
||||
"guest = ",
|
||||
@@ -97,8 +91,8 @@ class Auth(object):
|
||||
"Room %r does not exist" % (event.room_id,)
|
||||
)
|
||||
|
||||
creating_domain = get_domain_from_id(event.room_id)
|
||||
originating_domain = get_domain_from_id(event.sender)
|
||||
creating_domain = RoomID.from_string(event.room_id).domain
|
||||
originating_domain = UserID.from_string(event.sender).domain
|
||||
if creating_domain != originating_domain:
|
||||
if not self.can_federate(event, auth_events):
|
||||
raise AuthError(
|
||||
@@ -126,24 +120,6 @@ class Auth(object):
|
||||
return allowed
|
||||
|
||||
self.check_event_sender_in_room(event, auth_events)
|
||||
|
||||
# Special case to allow m.room.third_party_invite events wherever
|
||||
# a user is allowed to issue invites. Fixes
|
||||
# https://github.com/vector-im/vector-web/issues/1208 hopefully
|
||||
if event.type == EventTypes.ThirdPartyInvite:
|
||||
user_level = self._get_user_power_level(event.user_id, auth_events)
|
||||
invite_level = self._get_named_level(auth_events, "invite", 0)
|
||||
|
||||
if user_level < invite_level:
|
||||
raise AuthError(
|
||||
403, (
|
||||
"You cannot issue a third party invite for %s." %
|
||||
(event.content.display_name,)
|
||||
)
|
||||
)
|
||||
else:
|
||||
return True
|
||||
|
||||
self._can_send_event(event, auth_events)
|
||||
|
||||
if event.type == EventTypes.PowerLevels:
|
||||
@@ -243,7 +219,7 @@ class Auth(object):
|
||||
for event in curr_state.values():
|
||||
if event.type == EventTypes.Member:
|
||||
try:
|
||||
if get_domain_from_id(event.state_key) != host:
|
||||
if UserID.from_string(event.state_key).domain != host:
|
||||
continue
|
||||
except:
|
||||
logger.warn("state_key not user_id: %s", event.state_key)
|
||||
@@ -290,8 +266,8 @@ class Auth(object):
|
||||
|
||||
target_user_id = event.state_key
|
||||
|
||||
creating_domain = get_domain_from_id(event.room_id)
|
||||
target_domain = get_domain_from_id(target_user_id)
|
||||
creating_domain = RoomID.from_string(event.room_id).domain
|
||||
target_domain = UserID.from_string(target_user_id).domain
|
||||
if creating_domain != target_domain:
|
||||
if not self.can_federate(event, auth_events):
|
||||
raise AuthError(
|
||||
@@ -531,7 +507,7 @@ class Auth(object):
|
||||
return default
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_by_req(self, request, allow_guest=False, rights="access"):
|
||||
def get_user_by_req(self, request, allow_guest=False):
|
||||
""" Get a registered user's ID.
|
||||
|
||||
Args:
|
||||
@@ -553,7 +529,7 @@ class Auth(object):
|
||||
)
|
||||
|
||||
access_token = request.args["access_token"][0]
|
||||
user_info = yield self.get_user_by_access_token(access_token, rights)
|
||||
user_info = yield self.get_user_by_access_token(access_token)
|
||||
user = user_info["user"]
|
||||
token_id = user_info["token_id"]
|
||||
is_guest = user_info["is_guest"]
|
||||
@@ -614,7 +590,7 @@ class Auth(object):
|
||||
defer.returnValue(user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_by_access_token(self, token, rights="access"):
|
||||
def get_user_by_access_token(self, token):
|
||||
""" Get a registered user's ID.
|
||||
|
||||
Args:
|
||||
@@ -625,7 +601,7 @@ class Auth(object):
|
||||
AuthError if no user by that token exists or the token is invalid.
|
||||
"""
|
||||
try:
|
||||
ret = yield self.get_user_from_macaroon(token, rights)
|
||||
ret = yield self.get_user_from_macaroon(token)
|
||||
except AuthError:
|
||||
# TODO(daniel): Remove this fallback when all existing access tokens
|
||||
# have been re-issued as macaroons.
|
||||
@@ -633,11 +609,10 @@ class Auth(object):
|
||||
defer.returnValue(ret)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_from_macaroon(self, macaroon_str, rights="access"):
|
||||
def get_user_from_macaroon(self, macaroon_str):
|
||||
try:
|
||||
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
|
||||
|
||||
self.validate_macaroon(macaroon, rights, self.hs.config.expire_access_token)
|
||||
self.validate_macaroon(macaroon, "access", False)
|
||||
|
||||
user_prefix = "user_id = "
|
||||
user = None
|
||||
@@ -660,13 +635,6 @@ class Auth(object):
|
||||
"is_guest": True,
|
||||
"token_id": None,
|
||||
}
|
||||
elif rights == "delete_pusher":
|
||||
# We don't store these tokens in the database
|
||||
ret = {
|
||||
"user": user,
|
||||
"is_guest": False,
|
||||
"token_id": None,
|
||||
}
|
||||
else:
|
||||
# This codepath exists so that we can actually return a
|
||||
# token ID, because we use token IDs in place of device
|
||||
@@ -698,8 +666,7 @@ class Auth(object):
|
||||
|
||||
Args:
|
||||
macaroon(pymacaroons.Macaroon): The macaroon to validate
|
||||
type_string(str): The kind of token required (e.g. "access", "refresh",
|
||||
"delete_pusher")
|
||||
type_string(str): The kind of token this is (e.g. "access", "refresh")
|
||||
verify_expiry(bool): Whether to verify whether the macaroon has expired.
|
||||
This should really always be True, but no clients currently implement
|
||||
token refresh, so we can't enforce expiry yet.
|
||||
@@ -922,8 +889,8 @@ class Auth(object):
|
||||
if user_level >= redact_level:
|
||||
return False
|
||||
|
||||
redacter_domain = get_domain_from_id(event.event_id)
|
||||
redactee_domain = get_domain_from_id(event.redacts)
|
||||
redacter_domain = EventID.from_string(event.event_id).domain
|
||||
redactee_domain = EventID.from_string(event.redacts).domain
|
||||
if redacter_domain == redactee_domain:
|
||||
return True
|
||||
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.types import UserID, RoomID
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import ujson as json
|
||||
|
||||
|
||||
@@ -26,10 +24,10 @@ class Filtering(object):
|
||||
super(Filtering, self).__init__()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_filter(self, user_localpart, filter_id):
|
||||
result = yield self.store.get_user_filter(user_localpart, filter_id)
|
||||
defer.returnValue(FilterCollection(result))
|
||||
result = self.store.get_user_filter(user_localpart, filter_id)
|
||||
result.addCallback(FilterCollection)
|
||||
return result
|
||||
|
||||
def add_user_filter(self, user_localpart, user_filter):
|
||||
self.check_valid_filter(user_filter)
|
||||
|
||||
@@ -16,9 +16,12 @@
|
||||
|
||||
import synapse
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from synapse.config._base import ConfigError
|
||||
|
||||
from synapse.python_dependencies import (
|
||||
@@ -32,11 +35,18 @@ from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_d
|
||||
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
from twisted.conch.manhole import ColoredManhole
|
||||
from twisted.conch.insults import insults
|
||||
from twisted.conch import manhole_ssh
|
||||
from twisted.cred import checkers, portal
|
||||
|
||||
|
||||
from twisted.internet import reactor, task, defer
|
||||
from twisted.application import service
|
||||
from twisted.web.resource import Resource, EncodingResourceWrapper
|
||||
from twisted.web.static import File
|
||||
from twisted.web.server import GzipEncoderFactory
|
||||
from twisted.web.server import Site, GzipEncoderFactory, Request
|
||||
from synapse.http.server import RootRedirect
|
||||
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
||||
@@ -56,10 +66,6 @@ from synapse.federation.transport.server import TransportLayerServer
|
||||
|
||||
from synapse.util.rlimit import change_resource_limit
|
||||
from synapse.util.versionstring import get_version_string
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.manhole import manhole
|
||||
|
||||
from synapse.http.site import SynapseSite
|
||||
|
||||
from synapse import events
|
||||
|
||||
@@ -68,6 +74,9 @@ from daemonize import Daemonize
|
||||
logger = logging.getLogger("synapse.app.homeserver")
|
||||
|
||||
|
||||
ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
|
||||
|
||||
|
||||
def gz_wrap(r):
|
||||
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
|
||||
|
||||
@@ -165,12 +174,7 @@ class SynapseHomeServer(HomeServer):
|
||||
if name == "replication":
|
||||
resources[REPLICATION_PREFIX] = ReplicationResource(self)
|
||||
|
||||
if WEB_CLIENT_PREFIX in resources:
|
||||
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
|
||||
else:
|
||||
root_resource = Resource()
|
||||
|
||||
root_resource = create_resource_tree(resources, root_resource)
|
||||
root_resource = create_resource_tree(resources)
|
||||
if tls:
|
||||
reactor.listenSSL(
|
||||
port,
|
||||
@@ -203,13 +207,24 @@ class SynapseHomeServer(HomeServer):
|
||||
if listener["type"] == "http":
|
||||
self._listener_http(config, listener)
|
||||
elif listener["type"] == "manhole":
|
||||
checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(
|
||||
matrix="rabbithole"
|
||||
)
|
||||
|
||||
rlm = manhole_ssh.TerminalRealm()
|
||||
rlm.chainedProtocolFactory = lambda: insults.ServerProtocol(
|
||||
ColoredManhole,
|
||||
{
|
||||
"__name__": "__console__",
|
||||
"hs": self,
|
||||
}
|
||||
)
|
||||
|
||||
f = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
|
||||
|
||||
reactor.listenTCP(
|
||||
listener["port"],
|
||||
manhole(
|
||||
username="matrix",
|
||||
password="rabbithole",
|
||||
globals={"hs": self},
|
||||
),
|
||||
f,
|
||||
interface=listener.get("bind_address", '127.0.0.1')
|
||||
)
|
||||
else:
|
||||
@@ -356,6 +371,210 @@ class SynapseService(service.Service):
|
||||
return self._port.stopListening()
|
||||
|
||||
|
||||
class SynapseRequest(Request):
|
||||
def __init__(self, site, *args, **kw):
|
||||
Request.__init__(self, *args, **kw)
|
||||
self.site = site
|
||||
self.authenticated_entity = None
|
||||
self.start_time = 0
|
||||
|
||||
def __repr__(self):
|
||||
# We overwrite this so that we don't log ``access_token``
|
||||
return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
|
||||
self.__class__.__name__,
|
||||
id(self),
|
||||
self.method,
|
||||
self.get_redacted_uri(),
|
||||
self.clientproto,
|
||||
self.site.site_tag,
|
||||
)
|
||||
|
||||
def get_redacted_uri(self):
|
||||
return ACCESS_TOKEN_RE.sub(
|
||||
r'\1<redacted>\3',
|
||||
self.uri
|
||||
)
|
||||
|
||||
def get_user_agent(self):
|
||||
return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1]
|
||||
|
||||
def started_processing(self):
|
||||
self.site.access_logger.info(
|
||||
"%s - %s - Received request: %s %s",
|
||||
self.getClientIP(),
|
||||
self.site.site_tag,
|
||||
self.method,
|
||||
self.get_redacted_uri()
|
||||
)
|
||||
self.start_time = int(time.time() * 1000)
|
||||
|
||||
def finished_processing(self):
|
||||
|
||||
try:
|
||||
context = LoggingContext.current_context()
|
||||
ru_utime, ru_stime = context.get_resource_usage()
|
||||
db_txn_count = context.db_txn_count
|
||||
db_txn_duration = context.db_txn_duration
|
||||
except:
|
||||
ru_utime, ru_stime = (0, 0)
|
||||
db_txn_count, db_txn_duration = (0, 0)
|
||||
|
||||
self.site.access_logger.info(
|
||||
"%s - %s - {%s}"
|
||||
" Processed request: %dms (%dms, %dms) (%dms/%d)"
|
||||
" %sB %s \"%s %s %s\" \"%s\"",
|
||||
self.getClientIP(),
|
||||
self.site.site_tag,
|
||||
self.authenticated_entity,
|
||||
int(time.time() * 1000) - self.start_time,
|
||||
int(ru_utime * 1000),
|
||||
int(ru_stime * 1000),
|
||||
int(db_txn_duration * 1000),
|
||||
int(db_txn_count),
|
||||
self.sentLength,
|
||||
self.code,
|
||||
self.method,
|
||||
self.get_redacted_uri(),
|
||||
self.clientproto,
|
||||
self.get_user_agent(),
|
||||
)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def processing(self):
|
||||
self.started_processing()
|
||||
yield
|
||||
self.finished_processing()
|
||||
|
||||
|
||||
class XForwardedForRequest(SynapseRequest):
|
||||
def __init__(self, *args, **kw):
|
||||
SynapseRequest.__init__(self, *args, **kw)
|
||||
|
||||
"""
|
||||
Add a layer on top of another request that only uses the value of an
|
||||
X-Forwarded-For header as the result of C{getClientIP}.
|
||||
"""
|
||||
def getClientIP(self):
|
||||
"""
|
||||
@return: The client address (the first address) in the value of the
|
||||
I{X-Forwarded-For header}. If the header is not present, return
|
||||
C{b"-"}.
|
||||
"""
|
||||
return self.requestHeaders.getRawHeaders(
|
||||
b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip()
|
||||
|
||||
|
||||
class SynapseRequestFactory(object):
|
||||
def __init__(self, site, x_forwarded_for):
|
||||
self.site = site
|
||||
self.x_forwarded_for = x_forwarded_for
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if self.x_forwarded_for:
|
||||
return XForwardedForRequest(self.site, *args, **kwargs)
|
||||
else:
|
||||
return SynapseRequest(self.site, *args, **kwargs)
|
||||
|
||||
|
||||
class SynapseSite(Site):
|
||||
"""
|
||||
Subclass of a twisted http Site that does access logging with python's
|
||||
standard logging
|
||||
"""
|
||||
def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs):
|
||||
Site.__init__(self, resource, *args, **kwargs)
|
||||
|
||||
self.site_tag = site_tag
|
||||
|
||||
proxied = config.get("x_forwarded", False)
|
||||
self.requestFactory = SynapseRequestFactory(self, proxied)
|
||||
self.access_logger = logging.getLogger(logger_name)
|
||||
|
||||
def log(self, request):
|
||||
pass
|
||||
|
||||
|
||||
def create_resource_tree(desired_tree, redirect_root_to_web_client=True):
|
||||
"""Create the resource tree for this Home Server.
|
||||
|
||||
This in unduly complicated because Twisted does not support putting
|
||||
child resources more than 1 level deep at a time.
|
||||
|
||||
Args:
|
||||
web_client (bool): True to enable the web client.
|
||||
redirect_root_to_web_client (bool): True to redirect '/' to the
|
||||
location of the web client. This does nothing if web_client is not
|
||||
True.
|
||||
"""
|
||||
if redirect_root_to_web_client and WEB_CLIENT_PREFIX in desired_tree:
|
||||
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
|
||||
else:
|
||||
root_resource = Resource()
|
||||
|
||||
# ideally we'd just use getChild and putChild but getChild doesn't work
|
||||
# unless you give it a Request object IN ADDITION to the name :/ So
|
||||
# instead, we'll store a copy of this mapping so we can actually add
|
||||
# extra resources to existing nodes. See self._resource_id for the key.
|
||||
resource_mappings = {}
|
||||
for full_path, res in desired_tree.items():
|
||||
logger.info("Attaching %s to path %s", res, full_path)
|
||||
last_resource = root_resource
|
||||
for path_seg in full_path.split('/')[1:-1]:
|
||||
if path_seg not in last_resource.listNames():
|
||||
# resource doesn't exist, so make a "dummy resource"
|
||||
child_resource = Resource()
|
||||
last_resource.putChild(path_seg, child_resource)
|
||||
res_id = _resource_id(last_resource, path_seg)
|
||||
resource_mappings[res_id] = child_resource
|
||||
last_resource = child_resource
|
||||
else:
|
||||
# we have an existing Resource, use that instead.
|
||||
res_id = _resource_id(last_resource, path_seg)
|
||||
last_resource = resource_mappings[res_id]
|
||||
|
||||
# ===========================
|
||||
# now attach the actual desired resource
|
||||
last_path_seg = full_path.split('/')[-1]
|
||||
|
||||
# if there is already a resource here, thieve its children and
|
||||
# replace it
|
||||
res_id = _resource_id(last_resource, last_path_seg)
|
||||
if res_id in resource_mappings:
|
||||
# there is a dummy resource at this path already, which needs
|
||||
# to be replaced with the desired resource.
|
||||
existing_dummy_resource = resource_mappings[res_id]
|
||||
for child_name in existing_dummy_resource.listNames():
|
||||
child_res_id = _resource_id(
|
||||
existing_dummy_resource, child_name
|
||||
)
|
||||
child_resource = resource_mappings[child_res_id]
|
||||
# steal the children
|
||||
res.putChild(child_name, child_resource)
|
||||
|
||||
# finally, insert the desired resource in the right place
|
||||
last_resource.putChild(last_path_seg, res)
|
||||
res_id = _resource_id(last_resource, last_path_seg)
|
||||
resource_mappings[res_id] = res
|
||||
|
||||
return root_resource
|
||||
|
||||
|
||||
def _resource_id(resource, path_seg):
|
||||
"""Construct an arbitrary resource ID so you can retrieve the mapping
|
||||
later.
|
||||
|
||||
If you want to represent resource A putChild resource B with path C,
|
||||
the mapping should looks like _resource_id(A,C) = B.
|
||||
|
||||
Args:
|
||||
resource (Resource): The *parent* Resourceb
|
||||
path_seg (str): The name of the child Resource to be attached.
|
||||
Returns:
|
||||
str: A unique string which can be a key to the child Resource.
|
||||
"""
|
||||
return "%s-%s" % (resource, path_seg)
|
||||
|
||||
|
||||
def run(hs):
|
||||
PROFILE_SYNAPSE = False
|
||||
if PROFILE_SYNAPSE:
|
||||
|
||||
@@ -17,31 +17,19 @@
|
||||
import synapse
|
||||
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util.versionstring import get_version_string
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.config.database import DatabaseConfig
|
||||
from synapse.config.logger import LoggingConfig
|
||||
from synapse.config.emailconfig import EmailConfig
|
||||
from synapse.config.key import KeyConfig
|
||||
from synapse.http.site import SynapseSite
|
||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||
from synapse.storage.roommember import RoomMemberStore
|
||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||
from synapse.replication.slave.storage.pushers import SlavedPusherStore
|
||||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.storage import DataStore
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.logcontext import LoggingContext, preserve_fn
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.rlimit import change_resource_limit
|
||||
from synapse.util.versionstring import get_version_string
|
||||
from synapse.util.logcontext import (LoggingContext, preserve_fn)
|
||||
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from daemonize import Daemonize
|
||||
|
||||
import sys
|
||||
import logging
|
||||
@@ -53,114 +41,30 @@ class SlaveConfig(DatabaseConfig):
|
||||
def read_config(self, config):
|
||||
self.replication_url = config["replication_url"]
|
||||
self.server_name = config["server_name"]
|
||||
self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get(
|
||||
"use_insecure_ssl_client_just_for_testing_do_not_use", False
|
||||
)
|
||||
self.use_insecure_ssl_client_just_for_testing_do_not_use = True
|
||||
self.user_agent_suffix = None
|
||||
self.start_pushers = True
|
||||
self.listeners = config["listeners"]
|
||||
self.soft_file_limit = config.get("soft_file_limit")
|
||||
self.daemonize = config.get("daemonize")
|
||||
self.pid_file = self.abspath(config.get("pid_file"))
|
||||
self.public_baseurl = config["public_baseurl"]
|
||||
|
||||
# some things used by the auth handler but not actually used in the
|
||||
# pusher codebase
|
||||
self.bcrypt_rounds = None
|
||||
self.ldap_enabled = None
|
||||
self.ldap_server = None
|
||||
self.ldap_port = None
|
||||
self.ldap_tls = None
|
||||
self.ldap_search_base = None
|
||||
self.ldap_search_property = None
|
||||
self.ldap_email_property = None
|
||||
self.ldap_full_name_property = None
|
||||
|
||||
# We would otherwise try to use the registration shared secret as the
|
||||
# macaroon shared secret if there was no macaroon_shared_secret, but
|
||||
# that means pulling in RegistrationConfig too. We don't need to be
|
||||
# backwards compaitible in the pusher codebase so just make people set
|
||||
# macaroon_shared_secret. We set this to None to prevent it referencing
|
||||
# an undefined key.
|
||||
self.registration_shared_secret = None
|
||||
|
||||
def default_config(self, server_name, **kwargs):
|
||||
pid_file = self.abspath("pusher.pid")
|
||||
def default_config(self, **kwargs):
|
||||
return """\
|
||||
# Slave configuration
|
||||
|
||||
# The replication listener on the synapse to talk to.
|
||||
## Slave ##
|
||||
#replication_url: https://localhost:{replication_port}/_synapse/replication
|
||||
|
||||
server_name: "%(server_name)s"
|
||||
|
||||
listeners: []
|
||||
# Enable a ssh manhole listener on the pusher.
|
||||
# - type: manhole
|
||||
# port: {manhole_port}
|
||||
# bind_address: 127.0.0.1
|
||||
# Enable a metric listener on the pusher.
|
||||
# - type: http
|
||||
# port: {metrics_port}
|
||||
# bind_address: 127.0.0.1
|
||||
# resources:
|
||||
# - names: ["metrics"]
|
||||
# compress: False
|
||||
|
||||
report_stats: False
|
||||
|
||||
daemonize: False
|
||||
|
||||
pid_file: %(pid_file)s
|
||||
|
||||
""" % locals()
|
||||
"""
|
||||
|
||||
|
||||
class PusherSlaveConfig(SlaveConfig, LoggingConfig, EmailConfig, KeyConfig):
|
||||
class PusherSlaveConfig(SlaveConfig, LoggingConfig):
|
||||
pass
|
||||
|
||||
|
||||
class PusherSlaveStore(
|
||||
SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore,
|
||||
SlavedAccountDataStore
|
||||
SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore
|
||||
):
|
||||
update_pusher_last_stream_ordering_and_success = (
|
||||
DataStore.update_pusher_last_stream_ordering_and_success.__func__
|
||||
)
|
||||
|
||||
update_pusher_failing_since = (
|
||||
DataStore.update_pusher_failing_since.__func__
|
||||
)
|
||||
|
||||
update_pusher_last_stream_ordering = (
|
||||
DataStore.update_pusher_last_stream_ordering.__func__
|
||||
)
|
||||
|
||||
get_throttle_params_by_room = (
|
||||
DataStore.get_throttle_params_by_room.__func__
|
||||
)
|
||||
|
||||
set_throttle_params = (
|
||||
DataStore.set_throttle_params.__func__
|
||||
)
|
||||
|
||||
get_time_of_last_push_action_before = (
|
||||
DataStore.get_time_of_last_push_action_before.__func__
|
||||
)
|
||||
|
||||
get_profile_displayname = (
|
||||
DataStore.get_profile_displayname.__func__
|
||||
)
|
||||
|
||||
# XXX: This is a bit broken because we don't persist forgotten rooms
|
||||
# in a way that they can be streamed. This means that we don't have a
|
||||
# way to invalidate the forgotten rooms cache correctly.
|
||||
# For now we expire the cache every 10 minutes.
|
||||
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
|
||||
who_forgot_in_room = (
|
||||
RoomMemberStore.__dict__["who_forgot_in_room"]
|
||||
)
|
||||
|
||||
|
||||
class PusherServer(HomeServer):
|
||||
|
||||
@@ -194,53 +98,12 @@ class PusherServer(HomeServer):
|
||||
}]
|
||||
})
|
||||
|
||||
def _listen_http(self, listener_config):
|
||||
port = listener_config["port"]
|
||||
bind_address = listener_config.get("bind_address", "")
|
||||
site_tag = listener_config.get("tag", port)
|
||||
resources = {}
|
||||
for res in listener_config["resources"]:
|
||||
for name in res["names"]:
|
||||
if name == "metrics":
|
||||
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||
|
||||
root_resource = create_resource_tree(resources, Resource())
|
||||
reactor.listenTCP(
|
||||
port,
|
||||
SynapseSite(
|
||||
"synapse.access.http.%s" % (site_tag,),
|
||||
site_tag,
|
||||
listener_config,
|
||||
root_resource,
|
||||
),
|
||||
interface=bind_address
|
||||
)
|
||||
logger.info("Synapse pusher now listening on port %d", port)
|
||||
|
||||
def start_listening(self):
|
||||
for listener in self.config.listeners:
|
||||
if listener["type"] == "http":
|
||||
self._listen_http(listener)
|
||||
elif listener["type"] == "manhole":
|
||||
reactor.listenTCP(
|
||||
listener["port"],
|
||||
manhole(
|
||||
username="matrix",
|
||||
password="rabbithole",
|
||||
globals={"hs": self},
|
||||
),
|
||||
interface=listener.get("bind_address", '127.0.0.1')
|
||||
)
|
||||
else:
|
||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def replicate(self):
|
||||
http_client = self.get_simple_http_client()
|
||||
store = self.get_datastore()
|
||||
replication_url = self.config.replication_url
|
||||
pusher_pool = self.get_pusherpool()
|
||||
clock = self.get_clock()
|
||||
|
||||
def stop_pusher(user_id, app_id, pushkey):
|
||||
key = "%s:%s" % (app_id, pushkey)
|
||||
@@ -292,21 +155,11 @@ class PusherServer(HomeServer):
|
||||
min_stream_id, max_stream_id, affected_room_ids
|
||||
)
|
||||
|
||||
def expire_broken_caches():
|
||||
store.who_forgot_in_room.invalidate_all()
|
||||
|
||||
next_expire_broken_caches_ms = 0
|
||||
while True:
|
||||
try:
|
||||
args = store.stream_positions()
|
||||
args["timeout"] = 30000
|
||||
result = yield http_client.get_json(replication_url, args=args)
|
||||
now_ms = clock.time_msec()
|
||||
if now_ms > next_expire_broken_caches_ms:
|
||||
expire_broken_caches()
|
||||
next_expire_broken_caches_ms = (
|
||||
now_ms + store.BROKEN_CACHE_EXPIRY_MS
|
||||
)
|
||||
yield store.process_replication(result)
|
||||
poke_pushers(result)
|
||||
except:
|
||||
@@ -323,9 +176,6 @@ def setup(config_options):
|
||||
sys.stderr.write("\n" + e.message + "\n")
|
||||
sys.exit(1)
|
||||
|
||||
if not config:
|
||||
sys.exit(0)
|
||||
|
||||
config.setup_logging()
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
@@ -339,9 +189,6 @@ def setup(config_options):
|
||||
)
|
||||
|
||||
ps.setup()
|
||||
ps.start_listening()
|
||||
|
||||
change_resource_limit(ps.config.soft_file_limit)
|
||||
|
||||
def start():
|
||||
ps.replicate()
|
||||
@@ -356,22 +203,4 @@ def setup(config_options):
|
||||
if __name__ == '__main__':
|
||||
with LoggingContext("main"):
|
||||
ps = setup(sys.argv[1:])
|
||||
|
||||
if ps.config.daemonize:
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
change_resource_limit(ps.config.soft_file_limit)
|
||||
reactor.run()
|
||||
|
||||
daemon = Daemonize(
|
||||
app="synapse-pusher",
|
||||
pid=ps.config.pid_file,
|
||||
action=run,
|
||||
auto_close_fds=False,
|
||||
verbose=True,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
daemon.start()
|
||||
else:
|
||||
reactor.run()
|
||||
reactor.run()
|
||||
|
||||
@@ -1,468 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import synapse
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.config.database import DatabaseConfig
|
||||
from synapse.config.logger import LoggingConfig
|
||||
from synapse.config.appservice import AppServiceConfig
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.handlers.presence import PresenceHandler
|
||||
from synapse.http.site import SynapseSite
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||
from synapse.rest.client.v2_alpha import sync
|
||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
|
||||
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.replication.slave.storage.filtering import SlavedFilteringStore
|
||||
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
|
||||
from synapse.replication.slave.storage.presence import SlavedPresenceStore
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.storage.presence import UserPresenceState
|
||||
from synapse.storage.roommember import RoomMemberStore
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.rlimit import change_resource_limit
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from daemonize import Daemonize
|
||||
|
||||
import sys
|
||||
import logging
|
||||
import contextlib
|
||||
import ujson as json
|
||||
|
||||
logger = logging.getLogger("synapse.app.synchrotron")
|
||||
|
||||
|
||||
class SynchrotronConfig(DatabaseConfig, LoggingConfig, AppServiceConfig):
|
||||
def read_config(self, config):
|
||||
self.replication_url = config["replication_url"]
|
||||
self.server_name = config["server_name"]
|
||||
self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get(
|
||||
"use_insecure_ssl_client_just_for_testing_do_not_use", False
|
||||
)
|
||||
self.user_agent_suffix = None
|
||||
self.listeners = config["listeners"]
|
||||
self.soft_file_limit = config.get("soft_file_limit")
|
||||
self.daemonize = config.get("daemonize")
|
||||
self.pid_file = self.abspath(config.get("pid_file"))
|
||||
self.macaroon_secret_key = config["macaroon_secret_key"]
|
||||
self.expire_access_token = config.get("expire_access_token", False)
|
||||
|
||||
def default_config(self, server_name, **kwargs):
|
||||
pid_file = self.abspath("synchroton.pid")
|
||||
return """\
|
||||
# Slave configuration
|
||||
|
||||
# The replication listener on the synapse to talk to.
|
||||
#replication_url: https://localhost:{replication_port}/_synapse/replication
|
||||
|
||||
server_name: "%(server_name)s"
|
||||
|
||||
listeners:
|
||||
# Enable a /sync listener on the synchrontron
|
||||
#- type: http
|
||||
# port: {http_port}
|
||||
# bind_address: ""
|
||||
# Enable a ssh manhole listener on the synchrotron
|
||||
# - type: manhole
|
||||
# port: {manhole_port}
|
||||
# bind_address: 127.0.0.1
|
||||
# Enable a metric listener on the synchrotron
|
||||
# - type: http
|
||||
# port: {metrics_port}
|
||||
# bind_address: 127.0.0.1
|
||||
# resources:
|
||||
# - names: ["metrics"]
|
||||
# compress: False
|
||||
|
||||
report_stats: False
|
||||
|
||||
daemonize: False
|
||||
|
||||
pid_file: %(pid_file)s
|
||||
""" % locals()
|
||||
|
||||
|
||||
class SynchrotronSlavedStore(
|
||||
SlavedPushRuleStore,
|
||||
SlavedEventStore,
|
||||
SlavedReceiptsStore,
|
||||
SlavedAccountDataStore,
|
||||
SlavedApplicationServiceStore,
|
||||
SlavedRegistrationStore,
|
||||
SlavedFilteringStore,
|
||||
SlavedPresenceStore,
|
||||
):
|
||||
def get_presence_list_accepted(self, user_localpart):
|
||||
return ()
|
||||
|
||||
def insert_client_ip(self, user, access_token, ip, user_agent):
|
||||
pass
|
||||
|
||||
# XXX: This is a bit broken because we don't persist forgotten rooms
|
||||
# in a way that they can be streamed. This means that we don't have a
|
||||
# way to invalidate the forgotten rooms cache correctly.
|
||||
# For now we expire the cache every 10 minutes.
|
||||
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
|
||||
who_forgot_in_room = (
|
||||
RoomMemberStore.__dict__["who_forgot_in_room"]
|
||||
)
|
||||
|
||||
|
||||
class SynchrotronPresence(object):
|
||||
def __init__(self, hs):
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
self.store = hs.get_datastore()
|
||||
self.user_to_num_current_syncs = {}
|
||||
self.syncing_users_url = hs.config.replication_url + "/syncing_users"
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
active_presence = self.store.take_presence_startup_info()
|
||||
self.user_to_current_state = {
|
||||
state.user_id: state
|
||||
for state in active_presence
|
||||
}
|
||||
|
||||
self.process_id = random_string(16)
|
||||
logger.info("Presence process_id is %r", self.process_id)
|
||||
|
||||
def set_state(self, user, state):
|
||||
# TODO Hows this supposed to work?
|
||||
pass
|
||||
|
||||
get_states = PresenceHandler.get_states.__func__
|
||||
current_state_for_users = PresenceHandler.current_state_for_users.__func__
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_syncing(self, user_id, affect_presence):
|
||||
if affect_presence:
|
||||
curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
|
||||
self.user_to_num_current_syncs[user_id] = curr_sync + 1
|
||||
# TODO: Send this less frequently.
|
||||
# TODO: Make sure this doesn't race. Currently we can lose updates
|
||||
# if two users come online in quick sucession and the second http
|
||||
# to the master completes before the first.
|
||||
# TODO: Don't block the sync request on this HTTP hit.
|
||||
yield self._send_syncing_users()
|
||||
|
||||
def _end():
|
||||
if affect_presence:
|
||||
self.user_to_num_current_syncs[user_id] -= 1
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _user_syncing():
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_end()
|
||||
|
||||
defer.returnValue(_user_syncing())
|
||||
|
||||
def _send_syncing_users(self):
|
||||
return self.http_client.post_json_get_json(self.syncing_users_url, {
|
||||
"process_id": self.process_id,
|
||||
"syncing_users": [
|
||||
user_id for user_id, count in self.user_to_num_current_syncs.items()
|
||||
if count > 0
|
||||
],
|
||||
})
|
||||
|
||||
def process_replication(self, result):
|
||||
stream = result.get("presence", {"rows": []})
|
||||
for row in stream["rows"]:
|
||||
(
|
||||
position, user_id, state, last_active_ts,
|
||||
last_federation_update_ts, last_user_sync_ts, status_msg,
|
||||
currently_active
|
||||
) = row
|
||||
self.user_to_current_state[user_id] = UserPresenceState(
|
||||
user_id, state, last_active_ts,
|
||||
last_federation_update_ts, last_user_sync_ts, status_msg,
|
||||
currently_active
|
||||
)
|
||||
|
||||
|
||||
class SynchrotronTyping(object):
|
||||
def __init__(self, hs):
|
||||
self._latest_room_serial = 0
|
||||
self._room_serials = {}
|
||||
self._room_typing = {}
|
||||
|
||||
def stream_positions(self):
|
||||
return {"typing": self._latest_room_serial}
|
||||
|
||||
def process_replication(self, result):
|
||||
stream = result.get("typing")
|
||||
if stream:
|
||||
self._latest_room_serial = int(stream["position"])
|
||||
|
||||
for row in stream["rows"]:
|
||||
position, room_id, typing_json = row
|
||||
typing = json.loads(typing_json)
|
||||
self._room_serials[room_id] = position
|
||||
self._room_typing[room_id] = typing
|
||||
|
||||
|
||||
class SynchrotronApplicationService(object):
|
||||
def notify_interested_services(self, event):
|
||||
pass
|
||||
|
||||
|
||||
class SynchrotronServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
|
||||
logger.info("Finished setting up.")
|
||||
|
||||
def _listen_http(self, listener_config):
|
||||
port = listener_config["port"]
|
||||
bind_address = listener_config.get("bind_address", "")
|
||||
site_tag = listener_config.get("tag", port)
|
||||
resources = {}
|
||||
for res in listener_config["resources"]:
|
||||
for name in res["names"]:
|
||||
if name == "metrics":
|
||||
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||
elif name == "client":
|
||||
resource = JsonResource(self, canonical_json=False)
|
||||
sync.register_servlets(self, resource)
|
||||
resources.update({
|
||||
"/_matrix/client/r0": resource,
|
||||
"/_matrix/client/unstable": resource,
|
||||
"/_matrix/client/v2_alpha": resource,
|
||||
})
|
||||
|
||||
root_resource = create_resource_tree(resources, Resource())
|
||||
reactor.listenTCP(
|
||||
port,
|
||||
SynapseSite(
|
||||
"synapse.access.http.%s" % (site_tag,),
|
||||
site_tag,
|
||||
listener_config,
|
||||
root_resource,
|
||||
),
|
||||
interface=bind_address
|
||||
)
|
||||
logger.info("Synapse synchrotron now listening on port %d", port)
|
||||
|
||||
def start_listening(self):
|
||||
for listener in self.config.listeners:
|
||||
if listener["type"] == "http":
|
||||
self._listen_http(listener)
|
||||
elif listener["type"] == "manhole":
|
||||
reactor.listenTCP(
|
||||
listener["port"],
|
||||
manhole(
|
||||
username="matrix",
|
||||
password="rabbithole",
|
||||
globals={"hs": self},
|
||||
),
|
||||
interface=listener.get("bind_address", '127.0.0.1')
|
||||
)
|
||||
else:
|
||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def replicate(self):
|
||||
http_client = self.get_simple_http_client()
|
||||
store = self.get_datastore()
|
||||
replication_url = self.config.replication_url
|
||||
clock = self.get_clock()
|
||||
notifier = self.get_notifier()
|
||||
presence_handler = self.get_presence_handler()
|
||||
typing_handler = self.get_typing_handler()
|
||||
|
||||
def expire_broken_caches():
|
||||
store.who_forgot_in_room.invalidate_all()
|
||||
|
||||
def notify_from_stream(
|
||||
result, stream_name, stream_key, room=None, user=None
|
||||
):
|
||||
stream = result.get(stream_name)
|
||||
if stream:
|
||||
position_index = stream["field_names"].index("position")
|
||||
if room:
|
||||
room_index = stream["field_names"].index(room)
|
||||
if user:
|
||||
user_index = stream["field_names"].index(user)
|
||||
|
||||
users = ()
|
||||
rooms = ()
|
||||
for row in stream["rows"]:
|
||||
position = row[position_index]
|
||||
|
||||
if user:
|
||||
users = (row[user_index],)
|
||||
|
||||
if room:
|
||||
rooms = (row[room_index],)
|
||||
|
||||
notifier.on_new_event(
|
||||
stream_key, position, users=users, rooms=rooms
|
||||
)
|
||||
|
||||
def notify(result):
|
||||
stream = result.get("events")
|
||||
if stream:
|
||||
max_position = stream["position"]
|
||||
for row in stream["rows"]:
|
||||
position = row[0]
|
||||
internal = json.loads(row[1])
|
||||
event_json = json.loads(row[2])
|
||||
event = FrozenEvent(event_json, internal_metadata_dict=internal)
|
||||
extra_users = ()
|
||||
if event.type == EventTypes.Member:
|
||||
extra_users = (event.state_key,)
|
||||
notifier.on_new_room_event(
|
||||
event, position, max_position, extra_users
|
||||
)
|
||||
|
||||
notify_from_stream(
|
||||
result, "push_rules", "push_rules_key", user="user_id"
|
||||
)
|
||||
notify_from_stream(
|
||||
result, "user_account_data", "account_data_key", user="user_id"
|
||||
)
|
||||
notify_from_stream(
|
||||
result, "room_account_data", "account_data_key", user="user_id"
|
||||
)
|
||||
notify_from_stream(
|
||||
result, "tag_account_data", "account_data_key", user="user_id"
|
||||
)
|
||||
notify_from_stream(
|
||||
result, "receipts", "receipt_key", room="room_id"
|
||||
)
|
||||
notify_from_stream(
|
||||
result, "typing", "typing_key", room="room_id"
|
||||
)
|
||||
|
||||
next_expire_broken_caches_ms = 0
|
||||
while True:
|
||||
try:
|
||||
args = store.stream_positions()
|
||||
args.update(typing_handler.stream_positions())
|
||||
args["timeout"] = 30000
|
||||
result = yield http_client.get_json(replication_url, args=args)
|
||||
logger.error("FENRIS %r", result)
|
||||
now_ms = clock.time_msec()
|
||||
if now_ms > next_expire_broken_caches_ms:
|
||||
expire_broken_caches()
|
||||
next_expire_broken_caches_ms = (
|
||||
now_ms + store.BROKEN_CACHE_EXPIRY_MS
|
||||
)
|
||||
yield store.process_replication(result)
|
||||
typing_handler.process_replication(result)
|
||||
presence_handler.process_replication(result)
|
||||
notify(result)
|
||||
except:
|
||||
logger.exception("Error replicating from %r", replication_url)
|
||||
sleep(5)
|
||||
|
||||
def build_presence_handler(self):
|
||||
return SynchrotronPresence(self)
|
||||
|
||||
def build_typing_handler(self):
|
||||
return SynchrotronTyping(self)
|
||||
|
||||
|
||||
def setup(config_options):
|
||||
try:
|
||||
config = SynchrotronConfig.load_config(
|
||||
"Synapse synchrotron", config_options
|
||||
)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + e.message + "\n")
|
||||
sys.exit(1)
|
||||
|
||||
if not config:
|
||||
sys.exit(0)
|
||||
|
||||
config.setup_logging()
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
ss = SynchrotronServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
config=config,
|
||||
version_string=get_version_string("Synapse", synapse),
|
||||
database_engine=database_engine,
|
||||
application_service_handler=SynchrotronApplicationService(),
|
||||
)
|
||||
|
||||
ss.setup()
|
||||
ss.start_listening()
|
||||
|
||||
change_resource_limit(ss.config.soft_file_limit)
|
||||
|
||||
def start():
|
||||
ss.get_datastore().start_profiling()
|
||||
ss.replicate()
|
||||
|
||||
reactor.callWhenRunning(start)
|
||||
|
||||
return ss
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with LoggingContext("main"):
|
||||
ps = setup(sys.argv[1:])
|
||||
|
||||
if ps.config.daemonize:
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
change_resource_limit(ps.config.soft_file_limit)
|
||||
reactor.run()
|
||||
|
||||
daemon = Daemonize(
|
||||
app="synapse-pusher",
|
||||
pid=ps.config.pid_file,
|
||||
action=run,
|
||||
auto_close_fds=False,
|
||||
verbose=True,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
daemon.start()
|
||||
else:
|
||||
reactor.run()
|
||||
@@ -66,10 +66,6 @@ def main():
|
||||
|
||||
config = yaml.load(open(configfile))
|
||||
pidfile = config["pid_file"]
|
||||
cache_factor = config.get("synctl_cache_factor", None)
|
||||
|
||||
if cache_factor:
|
||||
os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
|
||||
|
||||
action = sys.argv[1] if sys.argv[1:] else "usage"
|
||||
if action == "start":
|
||||
|
||||
@@ -56,22 +56,22 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApplicationServiceScheduler(object):
|
||||
class AppServiceScheduler(object):
|
||||
""" Public facing API for this module. Does the required DI to tie the
|
||||
components together. This also serves as the "event_pool", which in this
|
||||
case is a simple array.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.as_api = hs.get_application_service_api()
|
||||
def __init__(self, clock, store, as_api):
|
||||
self.clock = clock
|
||||
self.store = store
|
||||
self.as_api = as_api
|
||||
|
||||
def create_recoverer(service, callback):
|
||||
return _Recoverer(self.clock, self.store, self.as_api, service, callback)
|
||||
return _Recoverer(clock, store, as_api, service, callback)
|
||||
|
||||
self.txn_ctrl = _TransactionController(
|
||||
self.clock, self.store, self.as_api, create_recoverer
|
||||
clock, store, as_api, create_recoverer
|
||||
)
|
||||
self.queuer = _ServiceQueuer(self.txn_ctrl)
|
||||
|
||||
|
||||
@@ -12,16 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.types import UserID
|
||||
|
||||
import urllib
|
||||
import yaml
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from ._base import Config
|
||||
|
||||
|
||||
class AppServiceConfig(Config):
|
||||
@@ -34,99 +25,3 @@ class AppServiceConfig(Config):
|
||||
# A list of application service config file to use
|
||||
app_service_config_files: []
|
||||
"""
|
||||
|
||||
|
||||
def load_appservices(hostname, config_files):
|
||||
"""Returns a list of Application Services from the config files."""
|
||||
if not isinstance(config_files, list):
|
||||
logger.warning(
|
||||
"Expected %s to be a list of AS config files.", config_files
|
||||
)
|
||||
return []
|
||||
|
||||
# Dicts of value -> filename
|
||||
seen_as_tokens = {}
|
||||
seen_ids = {}
|
||||
|
||||
appservices = []
|
||||
|
||||
for config_file in config_files:
|
||||
try:
|
||||
with open(config_file, 'r') as f:
|
||||
appservice = _load_appservice(
|
||||
hostname, yaml.load(f), config_file
|
||||
)
|
||||
if appservice.id in seen_ids:
|
||||
raise ConfigError(
|
||||
"Cannot reuse ID across application services: "
|
||||
"%s (files: %s, %s)" % (
|
||||
appservice.id, config_file, seen_ids[appservice.id],
|
||||
)
|
||||
)
|
||||
seen_ids[appservice.id] = config_file
|
||||
if appservice.token in seen_as_tokens:
|
||||
raise ConfigError(
|
||||
"Cannot reuse as_token across application services: "
|
||||
"%s (files: %s, %s)" % (
|
||||
appservice.token,
|
||||
config_file,
|
||||
seen_as_tokens[appservice.token],
|
||||
)
|
||||
)
|
||||
seen_as_tokens[appservice.token] = config_file
|
||||
logger.info("Loaded application service: %s", appservice)
|
||||
appservices.append(appservice)
|
||||
except Exception as e:
|
||||
logger.error("Failed to load appservice from '%s'", config_file)
|
||||
logger.exception(e)
|
||||
raise
|
||||
return appservices
|
||||
|
||||
|
||||
def _load_appservice(hostname, as_info, config_filename):
|
||||
required_string_fields = [
|
||||
"id", "url", "as_token", "hs_token", "sender_localpart"
|
||||
]
|
||||
for field in required_string_fields:
|
||||
if not isinstance(as_info.get(field), basestring):
|
||||
raise KeyError("Required string field: '%s' (%s)" % (
|
||||
field, config_filename,
|
||||
))
|
||||
|
||||
localpart = as_info["sender_localpart"]
|
||||
if urllib.quote(localpart) != localpart:
|
||||
raise ValueError(
|
||||
"sender_localpart needs characters which are not URL encoded."
|
||||
)
|
||||
user = UserID(localpart, hostname)
|
||||
user_id = user.to_string()
|
||||
|
||||
# namespace checks
|
||||
if not isinstance(as_info.get("namespaces"), dict):
|
||||
raise KeyError("Requires 'namespaces' object.")
|
||||
for ns in ApplicationService.NS_LIST:
|
||||
# specific namespaces are optional
|
||||
if ns in as_info["namespaces"]:
|
||||
# expect a list of dicts with exclusive and regex keys
|
||||
for regex_obj in as_info["namespaces"][ns]:
|
||||
if not isinstance(regex_obj, dict):
|
||||
raise ValueError(
|
||||
"Expected namespace entry in %s to be an object,"
|
||||
" but got %s", ns, regex_obj
|
||||
)
|
||||
if not isinstance(regex_obj.get("regex"), basestring):
|
||||
raise ValueError(
|
||||
"Missing/bad type 'regex' key in %s", regex_obj
|
||||
)
|
||||
if not isinstance(regex_obj.get("exclusive"), bool):
|
||||
raise ValueError(
|
||||
"Missing/bad type 'exclusive' key in %s", regex_obj
|
||||
)
|
||||
return ApplicationService(
|
||||
token=as_info["as_token"],
|
||||
url=as_info["url"],
|
||||
namespaces=as_info["namespaces"],
|
||||
hs_token=as_info["hs_token"],
|
||||
sender=user_id,
|
||||
id=as_info["id"],
|
||||
)
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This file can't be called email.py because if it is, we cannot:
|
||||
import email.utils
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
||||
class EmailConfig(Config):
|
||||
def read_config(self, config):
|
||||
self.email_enable_notifs = False
|
||||
|
||||
email_config = config.get("email", {})
|
||||
self.email_enable_notifs = email_config.get("enable_notifs", False)
|
||||
|
||||
if self.email_enable_notifs:
|
||||
# make sure we can import the required deps
|
||||
import jinja2
|
||||
import bleach
|
||||
# prevent unused warnings
|
||||
jinja2
|
||||
bleach
|
||||
|
||||
required = [
|
||||
"smtp_host",
|
||||
"smtp_port",
|
||||
"notif_from",
|
||||
"template_dir",
|
||||
"notif_template_html",
|
||||
"notif_template_text",
|
||||
]
|
||||
|
||||
missing = []
|
||||
for k in required:
|
||||
if k not in email_config:
|
||||
missing.append(k)
|
||||
|
||||
if (len(missing) > 0):
|
||||
raise RuntimeError(
|
||||
"email.enable_notifs is True but required keys are missing: %s" %
|
||||
(", ".join(["email." + k for k in missing]),)
|
||||
)
|
||||
|
||||
if config.get("public_baseurl") is None:
|
||||
raise RuntimeError(
|
||||
"email.enable_notifs is True but no public_baseurl is set"
|
||||
)
|
||||
|
||||
self.email_smtp_host = email_config["smtp_host"]
|
||||
self.email_smtp_port = email_config["smtp_port"]
|
||||
self.email_notif_from = email_config["notif_from"]
|
||||
self.email_template_dir = email_config["template_dir"]
|
||||
self.email_notif_template_html = email_config["notif_template_html"]
|
||||
self.email_notif_template_text = email_config["notif_template_text"]
|
||||
self.email_notif_for_new_users = email_config.get(
|
||||
"notif_for_new_users", True
|
||||
)
|
||||
if "app_name" in email_config:
|
||||
self.email_app_name = email_config["app_name"]
|
||||
else:
|
||||
self.email_app_name = "Matrix"
|
||||
|
||||
# make sure it's valid
|
||||
parsed = email.utils.parseaddr(self.email_notif_from)
|
||||
if parsed[1] == '':
|
||||
raise RuntimeError("Invalid notif_from address")
|
||||
else:
|
||||
self.email_enable_notifs = False
|
||||
# Not much point setting defaults for the rest: it would be an
|
||||
# error for them to be used.
|
||||
|
||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||
return """
|
||||
# Enable sending emails for notification events
|
||||
#email:
|
||||
# enable_notifs: false
|
||||
# smtp_host: "localhost"
|
||||
# smtp_port: 25
|
||||
# notif_from: Your Friendly Matrix Home Server <noreply@example.com>
|
||||
# app_name: Matrix
|
||||
# template_dir: res/templates
|
||||
# notif_template_html: notif_mail.html
|
||||
# notif_template_text: notif_mail.txt
|
||||
# notif_for_new_users: True
|
||||
"""
|
||||
@@ -31,14 +31,13 @@ from .cas import CasConfig
|
||||
from .password import PasswordConfig
|
||||
from .jwt import JWTConfig
|
||||
from .ldap import LDAPConfig
|
||||
from .emailconfig import EmailConfig
|
||||
|
||||
|
||||
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
||||
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
|
||||
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
|
||||
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
|
||||
JWTConfig, LDAPConfig, PasswordConfig, EmailConfig,):
|
||||
JWTConfig, LDAPConfig, PasswordConfig,):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -13,16 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
|
||||
MISSING_JWT = (
|
||||
"""Missing jwt library. This is required for jwt login.
|
||||
|
||||
Install by running:
|
||||
pip install pyjwt
|
||||
"""
|
||||
)
|
||||
from ._base import Config
|
||||
|
||||
|
||||
class JWTConfig(Config):
|
||||
@@ -32,12 +23,6 @@ class JWTConfig(Config):
|
||||
self.jwt_enabled = jwt_config.get("enabled", False)
|
||||
self.jwt_secret = jwt_config["secret"]
|
||||
self.jwt_algorithm = jwt_config["algorithm"]
|
||||
|
||||
try:
|
||||
import jwt
|
||||
jwt # To stop unused lint.
|
||||
except ImportError:
|
||||
raise ConfigError(MISSING_JWT)
|
||||
else:
|
||||
self.jwt_enabled = False
|
||||
self.jwt_secret = None
|
||||
@@ -45,8 +30,6 @@ class JWTConfig(Config):
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
return """\
|
||||
# The JWT needs to contain a globally unique "sub" (subject) claim.
|
||||
#
|
||||
# jwt_config:
|
||||
# enabled: true
|
||||
# secret: "a secret"
|
||||
|
||||
@@ -57,8 +57,6 @@ class KeyConfig(Config):
|
||||
seed = self.signing_key[0].seed
|
||||
self.macaroon_secret_key = hashlib.sha256(seed)
|
||||
|
||||
self.expire_access_token = config.get("expire_access_token", False)
|
||||
|
||||
def default_config(self, config_dir_path, server_name, is_generating_file=False,
|
||||
**kwargs):
|
||||
base_key_name = os.path.join(config_dir_path, server_name)
|
||||
@@ -71,9 +69,6 @@ class KeyConfig(Config):
|
||||
return """\
|
||||
macaroon_secret_key: "%(macaroon_secret_key)s"
|
||||
|
||||
# Used to enable access token expiration.
|
||||
expire_access_token: False
|
||||
|
||||
## Signing Keys ##
|
||||
|
||||
# Path to the signing key to sign messages with
|
||||
|
||||
@@ -32,7 +32,6 @@ class RegistrationConfig(Config):
|
||||
)
|
||||
|
||||
self.registration_shared_secret = config.get("registration_shared_secret")
|
||||
self.user_creation_max_duration = int(config["user_creation_max_duration"])
|
||||
|
||||
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
||||
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
|
||||
@@ -55,11 +54,6 @@ class RegistrationConfig(Config):
|
||||
# secret, even if registration is otherwise disabled.
|
||||
registration_shared_secret: "%(registration_shared_secret)s"
|
||||
|
||||
# Sets the expiry for the short term user creation in
|
||||
# milliseconds. For instance the bellow duration is two weeks
|
||||
# in milliseconds.
|
||||
user_creation_max_duration: 1209600000
|
||||
|
||||
# Set the number of bcrypt rounds used to generate password hash.
|
||||
# Larger numbers increase the work factor needed to generate the hash.
|
||||
# The default number of rounds is 12.
|
||||
|
||||
@@ -100,13 +100,8 @@ class ContentRepositoryConfig(Config):
|
||||
"to work"
|
||||
)
|
||||
|
||||
self.url_preview_ip_range_whitelist = IPSet(
|
||||
config.get("url_preview_ip_range_whitelist", ())
|
||||
)
|
||||
|
||||
self.url_preview_url_blacklist = config.get(
|
||||
"url_preview_url_blacklist", ()
|
||||
)
|
||||
if "url_preview_url_blacklist" in config:
|
||||
self.url_preview_url_blacklist = config["url_preview_url_blacklist"]
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
media_store = self.default_path("media_store")
|
||||
@@ -167,15 +162,6 @@ class ContentRepositoryConfig(Config):
|
||||
# - '10.0.0.0/8'
|
||||
# - '172.16.0.0/12'
|
||||
# - '192.168.0.0/16'
|
||||
#
|
||||
# List of IP address CIDR ranges that the URL preview spider is allowed
|
||||
# to access even if they are specified in url_preview_ip_range_blacklist.
|
||||
# This is useful for specifying exceptions to wide-ranging blacklisted
|
||||
# target IP ranges - e.g. for enabling URL previews for a specific private
|
||||
# website only visible in your network.
|
||||
#
|
||||
# url_preview_ip_range_whitelist:
|
||||
# - '192.168.1.1'
|
||||
|
||||
# Optional list of URL matches that the URL preview spider is
|
||||
# denied from accessing. You should use url_preview_ip_range_blacklist
|
||||
|
||||
@@ -28,12 +28,6 @@ class ServerConfig(Config):
|
||||
self.print_pidfile = config.get("print_pidfile")
|
||||
self.user_agent_suffix = config.get("user_agent_suffix")
|
||||
self.use_frozen_dicts = config.get("use_frozen_dicts", True)
|
||||
self.public_baseurl = config.get("public_baseurl")
|
||||
self.secondary_directory_servers = config.get("secondary_directory_servers", [])
|
||||
|
||||
if self.public_baseurl is not None:
|
||||
if self.public_baseurl[-1] != '/':
|
||||
self.public_baseurl += '/'
|
||||
self.start_pushers = config.get("start_pushers", True)
|
||||
|
||||
self.listeners = config.get("listeners", [])
|
||||
@@ -149,23 +143,11 @@ class ServerConfig(Config):
|
||||
# Whether to serve a web client from the HTTP/HTTPS root resource.
|
||||
web_client: True
|
||||
|
||||
# The public-facing base URL for the client API (not including _matrix/...)
|
||||
# public_baseurl: https://example.com:8448/
|
||||
|
||||
# Set the soft limit on the number of file descriptors synapse can use
|
||||
# Zero is used to indicate synapse should set the soft limit to the
|
||||
# hard limit.
|
||||
soft_file_limit: 0
|
||||
|
||||
# A list of other Home Servers to fetch the public room directory from
|
||||
# and include in the public room directory of this home server
|
||||
# This is a temporary stopgap solution to populate new server with a
|
||||
# list of rooms until there exists a good solution of a decentralized
|
||||
# room directory.
|
||||
# secondary_directory_servers:
|
||||
# - matrix.org
|
||||
# - vector.im
|
||||
|
||||
# List of ports that Synapse should listen on, their purpose and their
|
||||
# configuration.
|
||||
listeners:
|
||||
|
||||
@@ -24,7 +24,6 @@ from synapse.api.errors import (
|
||||
CodeMessageException, HttpResponseException, SynapseError,
|
||||
)
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async import concurrently_execute
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.events import FrozenEvent
|
||||
@@ -551,25 +550,6 @@ class FederationClient(FederationBase):
|
||||
|
||||
raise RuntimeError("Failed to send to any server.")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_public_rooms(self, destinations):
|
||||
results_by_server = {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_result(s):
|
||||
if s == self.server_name:
|
||||
defer.returnValue()
|
||||
|
||||
try:
|
||||
result = yield self.transport_layer.get_public_rooms(s)
|
||||
results_by_server[s] = result
|
||||
except:
|
||||
logger.exception("Error getting room list from server %r", s)
|
||||
|
||||
yield concurrently_execute(_get_result, destinations, 3)
|
||||
|
||||
defer.returnValue(results_by_server)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_auth(self, destination, room_id, event_id, local_auth):
|
||||
"""
|
||||
|
||||
@@ -387,11 +387,6 @@ class FederationServer(FederationBase):
|
||||
"events": [ev.get_pdu_json(time_now) for ev in missing_events],
|
||||
})
|
||||
|
||||
@log_function
|
||||
def on_openid_userinfo(self, token):
|
||||
ts_now_ms = self._clock.time_msec()
|
||||
return self.store.get_user_id_for_open_id_token(token, ts_now_ms)
|
||||
|
||||
@log_function
|
||||
def _get_persisted_pdu(self, origin, event_id, do_auth=True):
|
||||
""" Get a PDU from the database with given origin and id.
|
||||
|
||||
@@ -20,7 +20,6 @@ from .persistence import TransactionActions
|
||||
from .units import Transaction
|
||||
|
||||
from synapse.api.errors import HttpResponseException
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
from synapse.util.retryutils import (
|
||||
@@ -200,8 +199,6 @@ class TransactionQueue(object):
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _attempt_new_transaction(self, destination):
|
||||
yield run_on_reactor()
|
||||
|
||||
# list of (pending_pdu, deferred, order)
|
||||
if destination in self.pending_transactions:
|
||||
# XXX: pending_transactions can get stuck on by a never-ending
|
||||
|
||||
@@ -224,18 +224,6 @@ class TransportLayerClient(object):
|
||||
|
||||
defer.returnValue(response)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_public_rooms(self, remote_server):
|
||||
path = PREFIX + "/publicRooms"
|
||||
|
||||
response = yield self.client.get_json(
|
||||
destination=remote_server,
|
||||
path=path,
|
||||
)
|
||||
|
||||
defer.returnValue(response)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def exchange_third_party_invite(self, destination, room_id, event_dict):
|
||||
|
||||
@@ -18,7 +18,7 @@ from twisted.internet import defer
|
||||
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.http.servlet import parse_json_object_from_request, parse_string
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
|
||||
import functools
|
||||
@@ -134,12 +134,10 @@ class Authenticator(object):
|
||||
|
||||
|
||||
class BaseFederationServlet(object):
|
||||
def __init__(self, handler, authenticator, ratelimiter, server_name,
|
||||
room_list_handler):
|
||||
def __init__(self, handler, authenticator, ratelimiter, server_name):
|
||||
self.handler = handler
|
||||
self.authenticator = authenticator
|
||||
self.ratelimiter = ratelimiter
|
||||
self.room_list_handler = room_list_handler
|
||||
|
||||
def _wrap(self, code):
|
||||
authenticator = self.authenticator
|
||||
@@ -325,7 +323,7 @@ class FederationSendLeaveServlet(BaseFederationServlet):
|
||||
|
||||
|
||||
class FederationEventAuthServlet(BaseFederationServlet):
|
||||
PATH = "/event_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
|
||||
PATH = "/event_auth(?P<context>[^/]*)/(?P<event_id>[^/]*)"
|
||||
|
||||
def on_GET(self, origin, content, query, context, event_id):
|
||||
return self.handler.on_event_auth(origin, context, event_id)
|
||||
@@ -450,94 +448,6 @@ class On3pidBindServlet(BaseFederationServlet):
|
||||
return code
|
||||
|
||||
|
||||
class OpenIdUserInfo(BaseFederationServlet):
|
||||
"""
|
||||
Exchange a bearer token for information about a user.
|
||||
|
||||
The response format should be compatible with:
|
||||
http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse
|
||||
|
||||
GET /openid/userinfo?access_token=ABDEFGH HTTP/1.1
|
||||
|
||||
HTTP/1.1 200 OK
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"sub": "@userpart:example.org",
|
||||
}
|
||||
"""
|
||||
|
||||
PATH = "/openid/userinfo"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
token = parse_string(request, "access_token")
|
||||
if token is None:
|
||||
defer.returnValue((401, {
|
||||
"errcode": "M_MISSING_TOKEN", "error": "Access Token required"
|
||||
}))
|
||||
return
|
||||
|
||||
user_id = yield self.handler.on_openid_userinfo(token)
|
||||
|
||||
if user_id is None:
|
||||
defer.returnValue((401, {
|
||||
"errcode": "M_UNKNOWN_TOKEN",
|
||||
"error": "Access Token unknown or expired"
|
||||
}))
|
||||
|
||||
defer.returnValue((200, {"sub": user_id}))
|
||||
|
||||
# Avoid doing remote HS authorization checks which are done by default by
|
||||
# BaseFederationServlet.
|
||||
def _wrap(self, code):
|
||||
return code
|
||||
|
||||
|
||||
class PublicRoomList(BaseFederationServlet):
|
||||
"""
|
||||
Fetch the public room list for this server.
|
||||
|
||||
This API returns information in the same format as /publicRooms on the
|
||||
client API, but will only ever include local public rooms and hence is
|
||||
intended for consumption by other home servers.
|
||||
|
||||
GET /publicRooms HTTP/1.1
|
||||
|
||||
HTTP/1.1 200 OK
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"chunk": [
|
||||
{
|
||||
"aliases": [
|
||||
"#test:localhost"
|
||||
],
|
||||
"guest_can_join": false,
|
||||
"name": "test room",
|
||||
"num_joined_members": 3,
|
||||
"room_id": "!whkydVegtvatLfXmPN:localhost",
|
||||
"world_readable": false
|
||||
}
|
||||
],
|
||||
"end": "END",
|
||||
"start": "START"
|
||||
}
|
||||
"""
|
||||
|
||||
PATH = "/publicRooms"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
data = yield self.room_list_handler.get_local_public_room_list()
|
||||
defer.returnValue((200, data))
|
||||
|
||||
# Avoid doing remote HS authorization checks which are done by default by
|
||||
# BaseFederationServlet.
|
||||
def _wrap(self, code):
|
||||
return code
|
||||
|
||||
|
||||
SERVLET_CLASSES = (
|
||||
FederationSendServlet,
|
||||
FederationPullServlet,
|
||||
@@ -558,8 +468,6 @@ SERVLET_CLASSES = (
|
||||
FederationClientKeysClaimServlet,
|
||||
FederationThirdPartyInviteExchangeServlet,
|
||||
On3pidBindServlet,
|
||||
OpenIdUserInfo,
|
||||
PublicRoomList,
|
||||
)
|
||||
|
||||
|
||||
@@ -570,5 +478,4 @@ def register_servlets(hs, resource, authenticator, ratelimiter):
|
||||
authenticator=authenticator,
|
||||
ratelimiter=ratelimiter,
|
||||
server_name=hs.hostname,
|
||||
room_list_handler=hs.get_room_list_handler(),
|
||||
).register(resource)
|
||||
|
||||
@@ -13,17 +13,24 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.appservice.scheduler import AppServiceScheduler
|
||||
from synapse.appservice.api import ApplicationServiceApi
|
||||
from .register import RegistrationHandler
|
||||
from .room import (
|
||||
RoomCreationHandler, RoomContextHandler,
|
||||
RoomCreationHandler, RoomListHandler, RoomContextHandler,
|
||||
)
|
||||
from .room_member import RoomMemberHandler
|
||||
from .message import MessageHandler
|
||||
from .events import EventStreamHandler, EventHandler
|
||||
from .federation import FederationHandler
|
||||
from .profile import ProfileHandler
|
||||
from .presence import PresenceHandler
|
||||
from .directory import DirectoryHandler
|
||||
from .typing import TypingNotificationHandler
|
||||
from .admin import AdminHandler
|
||||
from .appservice import ApplicationServicesHandler
|
||||
from .sync import SyncHandler
|
||||
from .auth import AuthHandler
|
||||
from .identity import IdentityHandler
|
||||
from .receipts import ReceiptsHandler
|
||||
from .search import SearchHandler
|
||||
@@ -46,9 +53,22 @@ class Handlers(object):
|
||||
self.event_handler = EventHandler(hs)
|
||||
self.federation_handler = FederationHandler(hs)
|
||||
self.profile_handler = ProfileHandler(hs)
|
||||
self.presence_handler = PresenceHandler(hs)
|
||||
self.room_list_handler = RoomListHandler(hs)
|
||||
self.directory_handler = DirectoryHandler(hs)
|
||||
self.typing_notification_handler = TypingNotificationHandler(hs)
|
||||
self.admin_handler = AdminHandler(hs)
|
||||
self.receipts_handler = ReceiptsHandler(hs)
|
||||
asapi = ApplicationServiceApi(hs)
|
||||
self.appservice_handler = ApplicationServicesHandler(
|
||||
hs, asapi, AppServiceScheduler(
|
||||
clock=hs.get_clock(),
|
||||
store=hs.get_datastore(),
|
||||
as_api=asapi
|
||||
)
|
||||
)
|
||||
self.sync_handler = SyncHandler(hs)
|
||||
self.auth_handler = AuthHandler(hs)
|
||||
self.identity_handler = IdentityHandler(hs)
|
||||
self.search_handler = SearchHandler(hs)
|
||||
self.room_context_handler = RoomContextHandler(hs)
|
||||
|
||||
@@ -15,10 +15,13 @@
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import LimitExceededError
|
||||
from synapse.api.errors import LimitExceededError, SynapseError, AuthError
|
||||
from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||
from synapse.api.constants import Membership, EventTypes
|
||||
from synapse.types import UserID, Requester
|
||||
from synapse.types import UserID, RoomAlias, Requester
|
||||
from synapse.push.action_generator import ActionGenerator
|
||||
|
||||
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
|
||||
|
||||
import logging
|
||||
|
||||
@@ -26,6 +29,23 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
VISIBILITY_PRIORITY = (
|
||||
"world_readable",
|
||||
"shared",
|
||||
"invited",
|
||||
"joined",
|
||||
)
|
||||
|
||||
|
||||
MEMBERSHIP_PRIORITY = (
|
||||
Membership.JOIN,
|
||||
Membership.INVITE,
|
||||
Membership.KNOCK,
|
||||
Membership.LEAVE,
|
||||
Membership.BAN,
|
||||
)
|
||||
|
||||
|
||||
class BaseHandler(object):
|
||||
"""
|
||||
Common base class for the event handlers.
|
||||
@@ -45,10 +65,161 @@ class BaseHandler(object):
|
||||
self.clock = hs.get_clock()
|
||||
self.hs = hs
|
||||
|
||||
self.signing_key = hs.config.signing_key[0]
|
||||
self.server_name = hs.hostname
|
||||
|
||||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def filter_events_for_clients(self, user_tuples, events, event_id_to_state):
|
||||
""" Returns dict of user_id -> list of events that user is allowed to
|
||||
see.
|
||||
|
||||
Args:
|
||||
user_tuples (str, bool): (user id, is_peeking) for each user to be
|
||||
checked. is_peeking should be true if:
|
||||
* the user is not currently a member of the room, and:
|
||||
* the user has not been a member of the room since the
|
||||
given events
|
||||
events ([synapse.events.EventBase]): list of events to filter
|
||||
"""
|
||||
forgotten = yield defer.gatherResults([
|
||||
self.store.who_forgot_in_room(
|
||||
room_id,
|
||||
)
|
||||
for room_id in frozenset(e.room_id for e in events)
|
||||
], consumeErrors=True)
|
||||
|
||||
# Set of membership event_ids that have been forgotten
|
||||
event_id_forgotten = frozenset(
|
||||
row["event_id"] for rows in forgotten for row in rows
|
||||
)
|
||||
|
||||
def allowed(event, user_id, is_peeking):
|
||||
"""
|
||||
Args:
|
||||
event (synapse.events.EventBase): event to check
|
||||
user_id (str)
|
||||
is_peeking (bool)
|
||||
"""
|
||||
state = event_id_to_state[event.event_id]
|
||||
|
||||
# get the room_visibility at the time of the event.
|
||||
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
|
||||
if visibility_event:
|
||||
visibility = visibility_event.content.get("history_visibility", "shared")
|
||||
else:
|
||||
visibility = "shared"
|
||||
|
||||
if visibility not in VISIBILITY_PRIORITY:
|
||||
visibility = "shared"
|
||||
|
||||
# if it was world_readable, it's easy: everyone can read it
|
||||
if visibility == "world_readable":
|
||||
return True
|
||||
|
||||
# Always allow history visibility events on boundaries. This is done
|
||||
# by setting the effective visibility to the least restrictive
|
||||
# of the old vs new.
|
||||
if event.type == EventTypes.RoomHistoryVisibility:
|
||||
prev_content = event.unsigned.get("prev_content", {})
|
||||
prev_visibility = prev_content.get("history_visibility", None)
|
||||
|
||||
if prev_visibility not in VISIBILITY_PRIORITY:
|
||||
prev_visibility = "shared"
|
||||
|
||||
new_priority = VISIBILITY_PRIORITY.index(visibility)
|
||||
old_priority = VISIBILITY_PRIORITY.index(prev_visibility)
|
||||
if old_priority < new_priority:
|
||||
visibility = prev_visibility
|
||||
|
||||
# likewise, if the event is the user's own membership event, use
|
||||
# the 'most joined' membership
|
||||
membership = None
|
||||
if event.type == EventTypes.Member and event.state_key == user_id:
|
||||
membership = event.content.get("membership", None)
|
||||
if membership not in MEMBERSHIP_PRIORITY:
|
||||
membership = "leave"
|
||||
|
||||
prev_content = event.unsigned.get("prev_content", {})
|
||||
prev_membership = prev_content.get("membership", None)
|
||||
if prev_membership not in MEMBERSHIP_PRIORITY:
|
||||
prev_membership = "leave"
|
||||
|
||||
new_priority = MEMBERSHIP_PRIORITY.index(membership)
|
||||
old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
|
||||
if old_priority < new_priority:
|
||||
membership = prev_membership
|
||||
|
||||
# otherwise, get the user's membership at the time of the event.
|
||||
if membership is None:
|
||||
membership_event = state.get((EventTypes.Member, user_id), None)
|
||||
if membership_event:
|
||||
if membership_event.event_id not in event_id_forgotten:
|
||||
membership = membership_event.membership
|
||||
|
||||
# if the user was a member of the room at the time of the event,
|
||||
# they can see it.
|
||||
if membership == Membership.JOIN:
|
||||
return True
|
||||
|
||||
if visibility == "joined":
|
||||
# we weren't a member at the time of the event, so we can't
|
||||
# see this event.
|
||||
return False
|
||||
|
||||
elif visibility == "invited":
|
||||
# user can also see the event if they were *invited* at the time
|
||||
# of the event.
|
||||
return membership == Membership.INVITE
|
||||
|
||||
else:
|
||||
# visibility is shared: user can also see the event if they have
|
||||
# become a member since the event
|
||||
#
|
||||
# XXX: if the user has subsequently joined and then left again,
|
||||
# ideally we would share history up to the point they left. But
|
||||
# we don't know when they left.
|
||||
return not is_peeking
|
||||
|
||||
defer.returnValue({
|
||||
user_id: [
|
||||
event
|
||||
for event in events
|
||||
if allowed(event, user_id, is_peeking)
|
||||
]
|
||||
for user_id, is_peeking in user_tuples
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _filter_events_for_client(self, user_id, events, is_peeking=False):
|
||||
"""
|
||||
Check which events a user is allowed to see
|
||||
|
||||
Args:
|
||||
user_id(str): user id to be checked
|
||||
events([synapse.events.EventBase]): list of events to be checked
|
||||
is_peeking(bool): should be True if:
|
||||
* the user is not currently a member of the room, and:
|
||||
* the user has not been a member of the room since the given
|
||||
events
|
||||
|
||||
Returns:
|
||||
[synapse.events.EventBase]
|
||||
"""
|
||||
types = (
|
||||
(EventTypes.RoomHistoryVisibility, ""),
|
||||
(EventTypes.Member, user_id),
|
||||
)
|
||||
event_id_to_state = yield self.store.get_state_for_events(
|
||||
frozenset(e.event_id for e in events),
|
||||
types=types
|
||||
)
|
||||
res = yield self.filter_events_for_clients(
|
||||
[(user_id, is_peeking)], events, event_id_to_state
|
||||
)
|
||||
defer.returnValue(res.get(user_id, []))
|
||||
|
||||
def ratelimit(self, requester):
|
||||
time_now = self.clock.time()
|
||||
allowed, time_allowed = self.ratelimiter.send_message(
|
||||
@@ -61,6 +232,56 @@ class BaseHandler(object):
|
||||
retry_after_ms=int(1000 * (time_allowed - time_now)),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _create_new_client_event(self, builder, prev_event_ids=None):
|
||||
if prev_event_ids:
|
||||
prev_events = yield self.store.add_event_hashes(prev_event_ids)
|
||||
prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
|
||||
depth = prev_max_depth + 1
|
||||
else:
|
||||
latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
|
||||
builder.room_id,
|
||||
)
|
||||
|
||||
if latest_ret:
|
||||
depth = max([d for _, _, d in latest_ret]) + 1
|
||||
else:
|
||||
depth = 1
|
||||
|
||||
prev_events = [
|
||||
(event_id, prev_hashes)
|
||||
for event_id, prev_hashes, _ in latest_ret
|
||||
]
|
||||
|
||||
builder.prev_events = prev_events
|
||||
builder.depth = depth
|
||||
|
||||
state_handler = self.state_handler
|
||||
|
||||
context = yield state_handler.compute_event_context(builder)
|
||||
|
||||
if builder.is_state():
|
||||
builder.prev_state = yield self.store.add_event_hashes(
|
||||
context.prev_state_events
|
||||
)
|
||||
|
||||
yield self.auth.add_auth_events(builder, context)
|
||||
|
||||
add_hashes_and_signatures(
|
||||
builder, self.server_name, self.signing_key
|
||||
)
|
||||
|
||||
event = builder.build()
|
||||
|
||||
logger.debug(
|
||||
"Created event %s with current state: %s",
|
||||
event.event_id, context.current_state,
|
||||
)
|
||||
|
||||
defer.returnValue(
|
||||
(event, context,)
|
||||
)
|
||||
|
||||
def is_host_in_room(self, current_state):
|
||||
room_members = [
|
||||
(state_key, event.membership)
|
||||
@@ -75,12 +296,153 @@ class BaseHandler(object):
|
||||
return True
|
||||
for (state_key, membership) in room_members:
|
||||
if (
|
||||
self.hs.is_mine_id(state_key)
|
||||
UserID.from_string(state_key).domain == self.hs.hostname
|
||||
and membership == Membership.JOIN
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_new_client_event(
|
||||
self,
|
||||
requester,
|
||||
event,
|
||||
context,
|
||||
ratelimit=True,
|
||||
extra_users=[]
|
||||
):
|
||||
# We now need to go and hit out to wherever we need to hit out to.
|
||||
|
||||
if ratelimit:
|
||||
self.ratelimit(requester)
|
||||
|
||||
try:
|
||||
self.auth.check(event, auth_events=context.current_state)
|
||||
except AuthError as err:
|
||||
logger.warn("Denying new event %r because %s", event, err)
|
||||
raise err
|
||||
|
||||
yield self.maybe_kick_guest_users(event, context.current_state.values())
|
||||
|
||||
if event.type == EventTypes.CanonicalAlias:
|
||||
# Check the alias is acually valid (at this time at least)
|
||||
room_alias_str = event.content.get("alias", None)
|
||||
if room_alias_str:
|
||||
room_alias = RoomAlias.from_string(room_alias_str)
|
||||
directory_handler = self.hs.get_handlers().directory_handler
|
||||
mapping = yield directory_handler.get_association(room_alias)
|
||||
|
||||
if mapping["room_id"] != event.room_id:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Room alias %s does not point to the room" % (
|
||||
room_alias_str,
|
||||
)
|
||||
)
|
||||
|
||||
federation_handler = self.hs.get_handlers().federation_handler
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
if event.content["membership"] == Membership.INVITE:
|
||||
def is_inviter_member_event(e):
|
||||
return (
|
||||
e.type == EventTypes.Member and
|
||||
e.sender == event.sender
|
||||
)
|
||||
|
||||
event.unsigned["invite_room_state"] = [
|
||||
{
|
||||
"type": e.type,
|
||||
"state_key": e.state_key,
|
||||
"content": e.content,
|
||||
"sender": e.sender,
|
||||
}
|
||||
for k, e in context.current_state.items()
|
||||
if e.type in self.hs.config.room_invite_state_types
|
||||
or is_inviter_member_event(e)
|
||||
]
|
||||
|
||||
invitee = UserID.from_string(event.state_key)
|
||||
if not self.hs.is_mine(invitee):
|
||||
# TODO: Can we add signature from remote server in a nicer
|
||||
# way? If we have been invited by a remote server, we need
|
||||
# to get them to sign the event.
|
||||
|
||||
returned_invite = yield federation_handler.send_invite(
|
||||
invitee.domain,
|
||||
event,
|
||||
)
|
||||
|
||||
event.unsigned.pop("room_state", None)
|
||||
|
||||
# TODO: Make sure the signatures actually are correct.
|
||||
event.signatures.update(
|
||||
returned_invite.signatures
|
||||
)
|
||||
|
||||
if event.type == EventTypes.Redaction:
|
||||
if self.auth.check_redaction(event, auth_events=context.current_state):
|
||||
original_event = yield self.store.get_event(
|
||||
event.redacts,
|
||||
check_redacted=False,
|
||||
get_prev_content=False,
|
||||
allow_rejected=False,
|
||||
allow_none=False
|
||||
)
|
||||
if event.user_id != original_event.user_id:
|
||||
raise AuthError(
|
||||
403,
|
||||
"You don't have permission to redact events"
|
||||
)
|
||||
|
||||
if event.type == EventTypes.Create and context.current_state:
|
||||
raise AuthError(
|
||||
403,
|
||||
"Changing the room create event is forbidden",
|
||||
)
|
||||
|
||||
action_generator = ActionGenerator(self.hs)
|
||||
yield action_generator.handle_push_actions_for_event(
|
||||
event, context, self
|
||||
)
|
||||
|
||||
(event_stream_id, max_stream_id) = yield self.store.persist_event(
|
||||
event, context=context
|
||||
)
|
||||
|
||||
# this intentionally does not yield: we don't care about the result
|
||||
# and don't need to wait for it.
|
||||
preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
|
||||
event_stream_id, max_stream_id
|
||||
)
|
||||
|
||||
destinations = set()
|
||||
for k, s in context.current_state.items():
|
||||
try:
|
||||
if k[0] == EventTypes.Member:
|
||||
if s.content["membership"] == Membership.JOIN:
|
||||
destinations.add(
|
||||
UserID.from_string(s.state_key).domain
|
||||
)
|
||||
except SynapseError:
|
||||
logger.warn(
|
||||
"Failed to get destination from event %s", s.event_id
|
||||
)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
# Don't block waiting on waking up all the listeners.
|
||||
self.notifier.on_new_room_event(
|
||||
event, event_stream_id, max_stream_id,
|
||||
extra_users=extra_users
|
||||
)
|
||||
|
||||
# If invite, remove room_state from unsigned before sending.
|
||||
event.unsigned.pop("invite_room_state", None)
|
||||
|
||||
federation_handler.handle_new_event(
|
||||
event, destinations=destinations,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def maybe_kick_guest_users(self, event, current_state):
|
||||
# Technically this function invalidates current_state by changing it.
|
||||
|
||||
@@ -17,6 +17,7 @@ from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.types import UserID
|
||||
|
||||
import logging
|
||||
|
||||
@@ -34,13 +35,16 @@ def log_failure(failure):
|
||||
)
|
||||
|
||||
|
||||
# NB: Purposefully not inheriting BaseHandler since that contains way too much
|
||||
# setup code which this handler does not need or use. This makes testing a lot
|
||||
# easier.
|
||||
class ApplicationServicesHandler(object):
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs, appservice_api, appservice_scheduler):
|
||||
self.store = hs.get_datastore()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.appservice_api = hs.get_application_service_api()
|
||||
self.scheduler = hs.get_application_service_scheduler()
|
||||
self.hs = hs
|
||||
self.appservice_api = appservice_api
|
||||
self.scheduler = appservice_scheduler
|
||||
self.started_scheduler = False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -165,7 +169,8 @@ class ApplicationServicesHandler(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _is_unknown_user(self, user_id):
|
||||
if not self.is_mine_id(user_id):
|
||||
user = UserID.from_string(user_id)
|
||||
if not self.hs.is_mine(user):
|
||||
# we don't know if they are unknown or not since it isn't one of our
|
||||
# users. We can't poke ASes.
|
||||
defer.returnValue(False)
|
||||
|
||||
@@ -18,7 +18,7 @@ from twisted.internet import defer
|
||||
from ._base import BaseHandler
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.types import UserID
|
||||
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
|
||||
from synapse.api.errors import AuthError, LoginError, Codes
|
||||
from synapse.util.async import run_on_reactor
|
||||
|
||||
from twisted.web.client import PartialDownloadError
|
||||
@@ -521,19 +521,14 @@ class AuthHandler(BaseHandler):
|
||||
))
|
||||
return m.serialize()
|
||||
|
||||
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
|
||||
def generate_short_term_login_token(self, user_id):
|
||||
macaroon = self._generate_base_macaroon(user_id)
|
||||
macaroon.add_first_party_caveat("type = login")
|
||||
now = self.hs.get_clock().time_msec()
|
||||
expiry = now + duration_in_ms
|
||||
expiry = now + (2 * 60 * 1000)
|
||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||
return macaroon.serialize()
|
||||
|
||||
def generate_delete_pusher_token(self, user_id):
|
||||
macaroon = self._generate_base_macaroon(user_id)
|
||||
macaroon.add_first_party_caveat("type = delete_pusher")
|
||||
return macaroon.serialize()
|
||||
|
||||
def validate_short_term_login_token_and_get_user_id(self, login_token):
|
||||
try:
|
||||
macaroon = pymacaroons.Macaroon.deserialize(login_token)
|
||||
@@ -568,12 +563,7 @@ class AuthHandler(BaseHandler):
|
||||
|
||||
except_access_token_ids = [requester.access_token_id] if requester else []
|
||||
|
||||
try:
|
||||
yield self.store.user_set_password_hash(user_id, password_hash)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
|
||||
raise e
|
||||
yield self.store.user_set_password_hash(user_id, password_hash)
|
||||
yield self.store.user_delete_access_tokens(
|
||||
user_id, except_access_token_ids
|
||||
)
|
||||
@@ -625,7 +615,4 @@ class AuthHandler(BaseHandler):
|
||||
Returns:
|
||||
Whether self.hash(password) == stored_hash (bool).
|
||||
"""
|
||||
if stored_hash:
|
||||
return bcrypt.hashpw(password, stored_hash) == stored_hash
|
||||
else:
|
||||
return False
|
||||
return bcrypt.hashpw(password, stored_hash) == stored_hash
|
||||
|
||||
@@ -33,7 +33,6 @@ class DirectoryHandler(BaseHandler):
|
||||
super(DirectoryHandler, self).__init__(hs)
|
||||
|
||||
self.state = hs.get_state_handler()
|
||||
self.appservice_handler = hs.get_application_service_handler()
|
||||
|
||||
self.federation = hs.get_replication_layer()
|
||||
self.federation.register_query_handler(
|
||||
@@ -282,7 +281,7 @@ class DirectoryHandler(BaseHandler):
|
||||
)
|
||||
if not result:
|
||||
# Query AS to see if it exists
|
||||
as_handler = self.appservice_handler
|
||||
as_handler = self.hs.get_handlers().appservice_handler
|
||||
result = yield as_handler.query_room_alias_exists(room_alias)
|
||||
defer.returnValue(result)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ class EventStreamHandler(BaseHandler):
|
||||
If `only_keys` is not None, events from keys will be sent down.
|
||||
"""
|
||||
auth_user = UserID.from_string(auth_user_id)
|
||||
presence_handler = self.hs.get_presence_handler()
|
||||
presence_handler = self.hs.get_handlers().presence_handler
|
||||
|
||||
context = yield presence_handler.user_syncing(
|
||||
auth_user_id, affect_presence=affect_presence,
|
||||
|
||||
@@ -33,7 +33,7 @@ from synapse.util.frozenutils import unfreeze
|
||||
from synapse.crypto.event_signing import (
|
||||
compute_event_signature, add_hashes_and_signatures,
|
||||
)
|
||||
from synapse.types import UserID, get_domain_from_id
|
||||
from synapse.types import UserID
|
||||
|
||||
from synapse.events.utils import prune_event
|
||||
|
||||
@@ -453,7 +453,7 @@ class FederationHandler(BaseHandler):
|
||||
joined_domains = {}
|
||||
for u, d in joined_users:
|
||||
try:
|
||||
dom = get_domain_from_id(u)
|
||||
dom = UserID.from_string(u).domain
|
||||
old_d = joined_domains.get(dom)
|
||||
if old_d:
|
||||
joined_domains[dom] = min(d, old_d)
|
||||
@@ -682,8 +682,7 @@ class FederationHandler(BaseHandler):
|
||||
})
|
||||
|
||||
try:
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
event, context = yield message_handler._create_new_client_event(
|
||||
event, context = yield self._create_new_client_event(
|
||||
builder=builder,
|
||||
)
|
||||
except AuthError as e:
|
||||
@@ -744,7 +743,9 @@ class FederationHandler(BaseHandler):
|
||||
try:
|
||||
if k[0] == EventTypes.Member:
|
||||
if s.content["membership"] == Membership.JOIN:
|
||||
destinations.add(get_domain_from_id(s.state_key))
|
||||
destinations.add(
|
||||
UserID.from_string(s.state_key).domain
|
||||
)
|
||||
except:
|
||||
logger.warn(
|
||||
"Failed to get destination from event %s", s.event_id
|
||||
@@ -914,8 +915,7 @@ class FederationHandler(BaseHandler):
|
||||
"state_key": user_id,
|
||||
})
|
||||
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
event, context = yield message_handler._create_new_client_event(
|
||||
event, context = yield self._create_new_client_event(
|
||||
builder=builder,
|
||||
)
|
||||
|
||||
@@ -970,7 +970,9 @@ class FederationHandler(BaseHandler):
|
||||
try:
|
||||
if k[0] == EventTypes.Member:
|
||||
if s.content["membership"] == Membership.LEAVE:
|
||||
destinations.add(get_domain_from_id(s.state_key))
|
||||
destinations.add(
|
||||
UserID.from_string(s.state_key).domain
|
||||
)
|
||||
except:
|
||||
logger.warn(
|
||||
"Failed to get destination from event %s", s.event_id
|
||||
@@ -1113,7 +1115,7 @@ class FederationHandler(BaseHandler):
|
||||
if not event.internal_metadata.is_outlier():
|
||||
action_generator = ActionGenerator(self.hs)
|
||||
yield action_generator.handle_push_actions_for_event(
|
||||
event, context
|
||||
event, context, self
|
||||
)
|
||||
|
||||
event_stream_id, max_stream_id = yield self.store.persist_event(
|
||||
@@ -1690,10 +1692,7 @@ class FederationHandler(BaseHandler):
|
||||
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
|
||||
builder = self.event_builder_factory.new(event_dict)
|
||||
EventValidator().validate_new(builder)
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
event, context = yield message_handler._create_new_client_event(
|
||||
builder=builder
|
||||
)
|
||||
event, context = yield self._create_new_client_event(builder=builder)
|
||||
|
||||
event, context = yield self.add_display_name_to_third_party_invite(
|
||||
event_dict, event, context
|
||||
@@ -1721,8 +1720,7 @@ class FederationHandler(BaseHandler):
|
||||
def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
|
||||
builder = self.event_builder_factory.new(event_dict)
|
||||
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
event, context = yield message_handler._create_new_client_event(
|
||||
event, context = yield self._create_new_client_event(
|
||||
builder=builder,
|
||||
)
|
||||
|
||||
@@ -1761,8 +1759,7 @@ class FederationHandler(BaseHandler):
|
||||
event_dict["content"]["third_party_invite"]["display_name"] = display_name
|
||||
builder = self.event_builder_factory.new(event_dict)
|
||||
EventValidator().validate_new(builder)
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
event, context = yield message_handler._create_new_client_event(builder=builder)
|
||||
event, context = yield self._create_new_client_event(builder=builder)
|
||||
defer.returnValue((event, context))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
||||
@@ -17,19 +17,13 @@ from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import AuthError, Codes, SynapseError
|
||||
from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.events.utils import serialize_event
|
||||
from synapse.events.validator import EventValidator
|
||||
from synapse.push.action_generator import ActionGenerator
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import (
|
||||
UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id
|
||||
)
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async import concurrently_execute
|
||||
from synapse.util.caches.snapshot_cache import SnapshotCache
|
||||
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
|
||||
from synapse.visibility import filter_events_for_client
|
||||
from synapse.types import UserID, RoomStreamToken, StreamToken
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
@@ -129,8 +123,7 @@ class MessageHandler(BaseHandler):
|
||||
"end": next_token.to_string(),
|
||||
})
|
||||
|
||||
events = yield filter_events_for_client(
|
||||
self.store,
|
||||
events = yield self._filter_events_for_client(
|
||||
user_id,
|
||||
events,
|
||||
is_peeking=(member_event_id is None),
|
||||
@@ -236,7 +229,7 @@ class MessageHandler(BaseHandler):
|
||||
)
|
||||
|
||||
if event.type == EventTypes.Message:
|
||||
presence = self.hs.get_presence_handler()
|
||||
presence = self.hs.get_handlers().presence_handler
|
||||
yield presence.bump_presence_active_time(user)
|
||||
|
||||
def deduplicate_state_event(self, event, context):
|
||||
@@ -490,8 +483,8 @@ class MessageHandler(BaseHandler):
|
||||
]
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
messages = yield filter_events_for_client(
|
||||
self.store, user_id, messages
|
||||
messages = yield self._filter_events_for_client(
|
||||
user_id, messages
|
||||
)
|
||||
|
||||
start_token = now_token.copy_and_replace("room_key", token[0])
|
||||
@@ -626,8 +619,8 @@ class MessageHandler(BaseHandler):
|
||||
end_token=stream_token
|
||||
)
|
||||
|
||||
messages = yield filter_events_for_client(
|
||||
self.store, user_id, messages, is_peeking=is_peeking
|
||||
messages = yield self._filter_events_for_client(
|
||||
user_id, messages, is_peeking=is_peeking
|
||||
)
|
||||
|
||||
start_token = StreamToken.START.copy_and_replace("room_key", token[0])
|
||||
@@ -674,7 +667,7 @@ class MessageHandler(BaseHandler):
|
||||
and m.content["membership"] == Membership.JOIN
|
||||
]
|
||||
|
||||
presence_handler = self.hs.get_presence_handler()
|
||||
presence_handler = self.hs.get_handlers().presence_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_presence():
|
||||
@@ -707,8 +700,8 @@ class MessageHandler(BaseHandler):
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
messages = yield filter_events_for_client(
|
||||
self.store, user_id, messages, is_peeking=is_peeking,
|
||||
messages = yield self._filter_events_for_client(
|
||||
user_id, messages, is_peeking=is_peeking,
|
||||
)
|
||||
|
||||
start_token = now_token.copy_and_replace("room_key", token[0])
|
||||
@@ -731,193 +724,3 @@ class MessageHandler(BaseHandler):
|
||||
ret["membership"] = membership
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _create_new_client_event(self, builder, prev_event_ids=None):
|
||||
if prev_event_ids:
|
||||
prev_events = yield self.store.add_event_hashes(prev_event_ids)
|
||||
prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
|
||||
depth = prev_max_depth + 1
|
||||
else:
|
||||
latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
|
||||
builder.room_id,
|
||||
)
|
||||
|
||||
if latest_ret:
|
||||
depth = max([d for _, _, d in latest_ret]) + 1
|
||||
else:
|
||||
depth = 1
|
||||
|
||||
prev_events = [
|
||||
(event_id, prev_hashes)
|
||||
for event_id, prev_hashes, _ in latest_ret
|
||||
]
|
||||
|
||||
builder.prev_events = prev_events
|
||||
builder.depth = depth
|
||||
|
||||
state_handler = self.state_handler
|
||||
|
||||
context = yield state_handler.compute_event_context(builder)
|
||||
|
||||
if builder.is_state():
|
||||
builder.prev_state = yield self.store.add_event_hashes(
|
||||
context.prev_state_events
|
||||
)
|
||||
|
||||
yield self.auth.add_auth_events(builder, context)
|
||||
|
||||
signing_key = self.hs.config.signing_key[0]
|
||||
add_hashes_and_signatures(
|
||||
builder, self.server_name, signing_key
|
||||
)
|
||||
|
||||
event = builder.build()
|
||||
|
||||
logger.debug(
|
||||
"Created event %s with current state: %s",
|
||||
event.event_id, context.current_state,
|
||||
)
|
||||
|
||||
defer.returnValue(
|
||||
(event, context,)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_new_client_event(
|
||||
self,
|
||||
requester,
|
||||
event,
|
||||
context,
|
||||
ratelimit=True,
|
||||
extra_users=[]
|
||||
):
|
||||
# We now need to go and hit out to wherever we need to hit out to.
|
||||
|
||||
if ratelimit:
|
||||
self.ratelimit(requester)
|
||||
|
||||
try:
|
||||
self.auth.check(event, auth_events=context.current_state)
|
||||
except AuthError as err:
|
||||
logger.warn("Denying new event %r because %s", event, err)
|
||||
raise err
|
||||
|
||||
yield self.maybe_kick_guest_users(event, context.current_state.values())
|
||||
|
||||
if event.type == EventTypes.CanonicalAlias:
|
||||
# Check the alias is acually valid (at this time at least)
|
||||
room_alias_str = event.content.get("alias", None)
|
||||
if room_alias_str:
|
||||
room_alias = RoomAlias.from_string(room_alias_str)
|
||||
directory_handler = self.hs.get_handlers().directory_handler
|
||||
mapping = yield directory_handler.get_association(room_alias)
|
||||
|
||||
if mapping["room_id"] != event.room_id:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Room alias %s does not point to the room" % (
|
||||
room_alias_str,
|
||||
)
|
||||
)
|
||||
|
||||
federation_handler = self.hs.get_handlers().federation_handler
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
if event.content["membership"] == Membership.INVITE:
|
||||
def is_inviter_member_event(e):
|
||||
return (
|
||||
e.type == EventTypes.Member and
|
||||
e.sender == event.sender
|
||||
)
|
||||
|
||||
event.unsigned["invite_room_state"] = [
|
||||
{
|
||||
"type": e.type,
|
||||
"state_key": e.state_key,
|
||||
"content": e.content,
|
||||
"sender": e.sender,
|
||||
}
|
||||
for k, e in context.current_state.items()
|
||||
if e.type in self.hs.config.room_invite_state_types
|
||||
or is_inviter_member_event(e)
|
||||
]
|
||||
|
||||
invitee = UserID.from_string(event.state_key)
|
||||
if not self.hs.is_mine(invitee):
|
||||
# TODO: Can we add signature from remote server in a nicer
|
||||
# way? If we have been invited by a remote server, we need
|
||||
# to get them to sign the event.
|
||||
|
||||
returned_invite = yield federation_handler.send_invite(
|
||||
invitee.domain,
|
||||
event,
|
||||
)
|
||||
|
||||
event.unsigned.pop("room_state", None)
|
||||
|
||||
# TODO: Make sure the signatures actually are correct.
|
||||
event.signatures.update(
|
||||
returned_invite.signatures
|
||||
)
|
||||
|
||||
if event.type == EventTypes.Redaction:
|
||||
if self.auth.check_redaction(event, auth_events=context.current_state):
|
||||
original_event = yield self.store.get_event(
|
||||
event.redacts,
|
||||
check_redacted=False,
|
||||
get_prev_content=False,
|
||||
allow_rejected=False,
|
||||
allow_none=False
|
||||
)
|
||||
if event.user_id != original_event.user_id:
|
||||
raise AuthError(
|
||||
403,
|
||||
"You don't have permission to redact events"
|
||||
)
|
||||
|
||||
if event.type == EventTypes.Create and context.current_state:
|
||||
raise AuthError(
|
||||
403,
|
||||
"Changing the room create event is forbidden",
|
||||
)
|
||||
|
||||
action_generator = ActionGenerator(self.hs)
|
||||
yield action_generator.handle_push_actions_for_event(
|
||||
event, context
|
||||
)
|
||||
|
||||
(event_stream_id, max_stream_id) = yield self.store.persist_event(
|
||||
event, context=context
|
||||
)
|
||||
|
||||
# this intentionally does not yield: we don't care about the result
|
||||
# and don't need to wait for it.
|
||||
preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
|
||||
event_stream_id, max_stream_id
|
||||
)
|
||||
|
||||
destinations = set()
|
||||
for k, s in context.current_state.items():
|
||||
try:
|
||||
if k[0] == EventTypes.Member:
|
||||
if s.content["membership"] == Membership.JOIN:
|
||||
destinations.add(get_domain_from_id(s.state_key))
|
||||
except SynapseError:
|
||||
logger.warn(
|
||||
"Failed to get destination from event %s", s.event_id
|
||||
)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
# Don't block waiting on waking up all the listeners.
|
||||
self.notifier.on_new_room_event(
|
||||
event, event_stream_id, max_stream_id,
|
||||
extra_users=extra_users
|
||||
)
|
||||
|
||||
# If invite, remove room_state from unsigned before sending.
|
||||
event.unsigned.pop("invite_room_state", None)
|
||||
|
||||
federation_handler.handle_new_event(
|
||||
event, destinations=destinations,
|
||||
)
|
||||
|
||||
@@ -33,9 +33,11 @@ from synapse.util.logcontext import preserve_fn
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.util.wheel_timer import WheelTimer
|
||||
from synapse.types import UserID, get_domain_from_id
|
||||
from synapse.types import UserID
|
||||
import synapse.metrics
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
@@ -68,18 +70,14 @@ FEDERATION_TIMEOUT = 30 * 60 * 1000
|
||||
# How often to resend presence to remote servers
|
||||
FEDERATION_PING_INTERVAL = 25 * 60 * 1000
|
||||
|
||||
# How long we will wait before assuming that the syncs from an external process
|
||||
# are dead.
|
||||
EXTERNAL_PROCESS_EXPIRY = 5 * 60 * 1000
|
||||
|
||||
assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
|
||||
|
||||
|
||||
class PresenceHandler(object):
|
||||
class PresenceHandler(BaseHandler):
|
||||
|
||||
def __init__(self, hs):
|
||||
self.is_mine = hs.is_mine
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
super(PresenceHandler, self).__init__(hs)
|
||||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.wheel_timer = WheelTimer()
|
||||
@@ -140,7 +138,7 @@ class PresenceHandler(object):
|
||||
obj=state.user_id,
|
||||
then=state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
|
||||
)
|
||||
if self.is_mine_id(state.user_id):
|
||||
if self.hs.is_mine_id(state.user_id):
|
||||
self.wheel_timer.insert(
|
||||
now=now,
|
||||
obj=state.user_id,
|
||||
@@ -162,26 +160,15 @@ class PresenceHandler(object):
|
||||
self.serial_to_user = {}
|
||||
self._next_serial = 1
|
||||
|
||||
# Keeps track of the number of *ongoing* syncs on this process. While
|
||||
# this is non zero a user will never go offline.
|
||||
# Keeps track of the number of *ongoing* syncs. While this is non zero
|
||||
# a user will never go offline.
|
||||
self.user_to_num_current_syncs = {}
|
||||
|
||||
# Keeps track of the number of *ongoing* syncs on other processes.
|
||||
# While any sync is ongoing on another process the user will never
|
||||
# go offline.
|
||||
# Each process has a unique identifier and an update frequency. If
|
||||
# no update is received from that process within the update period then
|
||||
# we assume that all the sync requests on that process have stopped.
|
||||
# Stored as a dict from process_id to set of user_id, and a dict of
|
||||
# process_id to millisecond timestamp last updated.
|
||||
self.external_process_to_current_syncs = {}
|
||||
self.external_process_last_updated_ms = {}
|
||||
|
||||
# Start a LoopingCall in 30s that fires every 5s.
|
||||
# The initial delay is to allow disconnected clients a chance to
|
||||
# reconnect before we treat them as offline.
|
||||
self.clock.call_later(
|
||||
30 * 1000,
|
||||
0 * 1000,
|
||||
self.clock.looping_call,
|
||||
self._handle_timeouts,
|
||||
5000,
|
||||
@@ -241,7 +228,7 @@ class PresenceHandler(object):
|
||||
|
||||
new_state, should_notify, should_ping = handle_update(
|
||||
prev_state, new_state,
|
||||
is_mine=self.is_mine_id(user_id),
|
||||
is_mine=self.hs.is_mine_id(user_id),
|
||||
wheel_timer=self.wheel_timer,
|
||||
now=now
|
||||
)
|
||||
@@ -287,34 +274,21 @@ class PresenceHandler(object):
|
||||
# Fetch the list of users that *may* have timed out. Things may have
|
||||
# changed since the timeout was set, so we won't necessarily have to
|
||||
# take any action.
|
||||
users_to_check = set(self.wheel_timer.fetch(now))
|
||||
|
||||
# Check whether the lists of syncing processes from an external
|
||||
# process have expired.
|
||||
expired_process_ids = [
|
||||
process_id for process_id, last_update
|
||||
in self.external_process_last_update.items()
|
||||
if now - last_update > EXTERNAL_PROCESS_EXPIRY
|
||||
]
|
||||
for process_id in expired_process_ids:
|
||||
users_to_check.update(
|
||||
self.external_process_to_current_syncs.pop(process_id, ())
|
||||
)
|
||||
self.external_process_last_update.pop(process_id)
|
||||
users_to_check = self.wheel_timer.fetch(now)
|
||||
|
||||
states = [
|
||||
self.user_to_current_state.get(
|
||||
user_id, UserPresenceState.default(user_id)
|
||||
)
|
||||
for user_id in users_to_check
|
||||
for user_id in set(users_to_check)
|
||||
]
|
||||
|
||||
timers_fired_counter.inc_by(len(states))
|
||||
|
||||
changes = handle_timeouts(
|
||||
states,
|
||||
is_mine_fn=self.is_mine_id,
|
||||
syncing_users=self.get_syncing_users(),
|
||||
is_mine_fn=self.hs.is_mine_id,
|
||||
user_to_num_current_syncs=self.user_to_num_current_syncs,
|
||||
now=now,
|
||||
)
|
||||
|
||||
@@ -391,73 +365,6 @@ class PresenceHandler(object):
|
||||
|
||||
defer.returnValue(_user_syncing())
|
||||
|
||||
def get_currently_syncing_users(self):
|
||||
"""Get the set of user ids that are currently syncing on this HS.
|
||||
Returns:
|
||||
set(str): A set of user_id strings.
|
||||
"""
|
||||
syncing_user_ids = {
|
||||
user_id for user_id, count in self.user_to_num_current_syncs.items()
|
||||
if count
|
||||
}
|
||||
syncing_user_ids.update(self.external_process_to_current_syncs.values())
|
||||
return syncing_user_ids
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_external_syncs(self, process_id, syncing_user_ids):
|
||||
"""Update the syncing users for an external process
|
||||
|
||||
Args:
|
||||
process_id(str): An identifier for the process the users are
|
||||
syncing against. This allows synapse to process updates
|
||||
as user start and stop syncing against a given process.
|
||||
syncing_user_ids(set(str)): The set of user_ids that are
|
||||
currently syncing on that server.
|
||||
"""
|
||||
|
||||
# Grab the previous list of user_ids that were syncing on that process
|
||||
prev_syncing_user_ids = (
|
||||
self.external_process_to_current_syncs.get(process_id, set())
|
||||
)
|
||||
# Grab the current presence state for both the users that are syncing
|
||||
# now and the users that were syncing before this update.
|
||||
prev_states = yield self.current_state_for_users(
|
||||
syncing_user_ids | prev_syncing_user_ids
|
||||
)
|
||||
updates = []
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
# For each new user that is syncing check if we need to mark them as
|
||||
# being online.
|
||||
for new_user_id in syncing_user_ids - prev_syncing_user_ids:
|
||||
prev_state = prev_states[new_user_id]
|
||||
if prev_state.state == PresenceState.OFFLINE:
|
||||
updates.append(prev_state.copy_and_replace(
|
||||
state=PresenceState.ONLINE,
|
||||
last_active_ts=time_now_ms,
|
||||
last_user_sync_ts=time_now_ms,
|
||||
))
|
||||
else:
|
||||
updates.append(prev_state.copy_and_replace(
|
||||
last_user_sync_ts=time_now_ms,
|
||||
))
|
||||
|
||||
# For each user that is still syncing or stopped syncing update the
|
||||
# last sync time so that we will correctly apply the grace period when
|
||||
# they stop syncing.
|
||||
for old_user_id in prev_syncing_user_ids:
|
||||
prev_state = prev_states[old_user_id]
|
||||
updates.append(prev_state.copy_and_replace(
|
||||
last_user_sync_ts=time_now_ms,
|
||||
))
|
||||
|
||||
yield self._update_states(updates)
|
||||
|
||||
# Update the last updated time for the process. We expire the entries
|
||||
# if we don't receive an update in the given timeframe.
|
||||
self.external_process_last_updated_ms[process_id] = self.clock.time_msec()
|
||||
self.external_process_to_current_syncs[process_id] = syncing_user_ids
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def current_state_for_user(self, user_id):
|
||||
"""Get the current presence state for a user.
|
||||
@@ -520,7 +427,7 @@ class PresenceHandler(object):
|
||||
|
||||
hosts_to_states = {}
|
||||
for room_id, states in room_ids_to_states.items():
|
||||
local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
|
||||
local_states = filter(lambda s: self.hs.is_mine_id(s.user_id), states)
|
||||
if not local_states:
|
||||
continue
|
||||
|
||||
@@ -529,11 +436,11 @@ class PresenceHandler(object):
|
||||
hosts_to_states.setdefault(host, []).extend(local_states)
|
||||
|
||||
for user_id, states in users_to_states.items():
|
||||
local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
|
||||
local_states = filter(lambda s: self.hs.is_mine_id(s.user_id), states)
|
||||
if not local_states:
|
||||
continue
|
||||
|
||||
host = get_domain_from_id(user_id)
|
||||
host = UserID.from_string(user_id).domain
|
||||
hosts_to_states.setdefault(host, []).extend(local_states)
|
||||
|
||||
# TODO: de-dup hosts_to_states, as a single host might have multiple
|
||||
@@ -704,14 +611,14 @@ class PresenceHandler(object):
|
||||
# don't need to send to local clients here, as that is done as part
|
||||
# of the event stream/sync.
|
||||
# TODO: Only send to servers not already in the room.
|
||||
if self.is_mine(user):
|
||||
if self.hs.is_mine(user):
|
||||
state = yield self.current_state_for_user(user.to_string())
|
||||
|
||||
hosts = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
self._push_to_remotes({host: (state,) for host in hosts})
|
||||
else:
|
||||
user_ids = yield self.store.get_users_in_room(room_id)
|
||||
user_ids = filter(self.is_mine_id, user_ids)
|
||||
user_ids = filter(self.hs.is_mine_id, user_ids)
|
||||
|
||||
states = yield self.current_state_for_users(user_ids)
|
||||
|
||||
@@ -721,7 +628,7 @@ class PresenceHandler(object):
|
||||
def get_presence_list(self, observer_user, accepted=None):
|
||||
"""Returns the presence for all users in their presence list.
|
||||
"""
|
||||
if not self.is_mine(observer_user):
|
||||
if not self.hs.is_mine(observer_user):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
presence_list = yield self.store.get_presence_list(
|
||||
@@ -752,7 +659,7 @@ class PresenceHandler(object):
|
||||
observer_user.localpart, observed_user.to_string()
|
||||
)
|
||||
|
||||
if self.is_mine(observed_user):
|
||||
if self.hs.is_mine(observed_user):
|
||||
yield self.invite_presence(observed_user, observer_user)
|
||||
else:
|
||||
yield self.federation.send_edu(
|
||||
@@ -768,11 +675,11 @@ class PresenceHandler(object):
|
||||
def invite_presence(self, observed_user, observer_user):
|
||||
"""Handles new presence invites.
|
||||
"""
|
||||
if not self.is_mine(observed_user):
|
||||
if not self.hs.is_mine(observed_user):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
# TODO: Don't auto accept
|
||||
if self.is_mine(observer_user):
|
||||
if self.hs.is_mine(observer_user):
|
||||
yield self.accept_presence(observed_user, observer_user)
|
||||
else:
|
||||
self.federation.send_edu(
|
||||
@@ -835,7 +742,7 @@ class PresenceHandler(object):
|
||||
Returns:
|
||||
A Deferred.
|
||||
"""
|
||||
if not self.is_mine(observer_user):
|
||||
if not self.hs.is_mine(observer_user):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
yield self.store.del_presence_list(
|
||||
@@ -927,11 +834,7 @@ def _format_user_presence_state(state, now):
|
||||
|
||||
class PresenceEventSource(object):
|
||||
def __init__(self, hs):
|
||||
# We can't call get_presence_handler here because there's a cycle:
|
||||
#
|
||||
# Presence -> Notifier -> PresenceEventSource -> Presence
|
||||
#
|
||||
self.get_presence_handler = hs.get_presence_handler
|
||||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@@ -957,7 +860,7 @@ class PresenceEventSource(object):
|
||||
from_key = int(from_key)
|
||||
room_ids = room_ids or []
|
||||
|
||||
presence = self.get_presence_handler()
|
||||
presence = self.hs.get_handlers().presence_handler
|
||||
stream_change_cache = self.store.presence_stream_cache
|
||||
|
||||
if not room_ids:
|
||||
@@ -1030,14 +933,15 @@ class PresenceEventSource(object):
|
||||
return self.get_new_events(user, from_key=None, include_offline=False)
|
||||
|
||||
|
||||
def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
|
||||
def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now):
|
||||
"""Checks the presence of users that have timed out and updates as
|
||||
appropriate.
|
||||
|
||||
Args:
|
||||
user_states(list): List of UserPresenceState's to check.
|
||||
is_mine_fn (fn): Function that returns if a user_id is ours
|
||||
syncing_user_ids (set): Set of user_ids with active syncs.
|
||||
user_to_num_current_syncs (dict): Mapping of user_id to number of currently
|
||||
active syncs.
|
||||
now (int): Current time in ms.
|
||||
|
||||
Returns:
|
||||
@@ -1048,20 +952,21 @@ def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
|
||||
for state in user_states:
|
||||
is_mine = is_mine_fn(state.user_id)
|
||||
|
||||
new_state = handle_timeout(state, is_mine, syncing_user_ids, now)
|
||||
new_state = handle_timeout(state, is_mine, user_to_num_current_syncs, now)
|
||||
if new_state:
|
||||
changes[state.user_id] = new_state
|
||||
|
||||
return changes.values()
|
||||
|
||||
|
||||
def handle_timeout(state, is_mine, syncing_user_ids, now):
|
||||
def handle_timeout(state, is_mine, user_to_num_current_syncs, now):
|
||||
"""Checks the presence of the user to see if any of the timers have elapsed
|
||||
|
||||
Args:
|
||||
state (UserPresenceState)
|
||||
is_mine (bool): Whether the user is ours
|
||||
syncing_user_ids (set): Set of user_ids with active syncs.
|
||||
user_to_num_current_syncs (dict): Mapping of user_id to number of currently
|
||||
active syncs.
|
||||
now (int): Current time in ms.
|
||||
|
||||
Returns:
|
||||
@@ -1095,7 +1000,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
|
||||
|
||||
# If there are have been no sync for a while (and none ongoing),
|
||||
# set presence to offline
|
||||
if user_id not in syncing_user_ids:
|
||||
if not user_to_num_current_syncs.get(user_id, 0):
|
||||
if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT:
|
||||
state = state.copy_and_replace(
|
||||
state=PresenceState.OFFLINE,
|
||||
|
||||
@@ -29,8 +29,6 @@ class ReceiptsHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
super(ReceiptsHandler, self).__init__(hs)
|
||||
|
||||
self.server_name = hs.config.server_name
|
||||
self.store = hs.get_datastore()
|
||||
self.hs = hs
|
||||
self.federation = hs.get_replication_layer()
|
||||
self.federation.register_edu_handler(
|
||||
@@ -133,9 +131,12 @@ class ReceiptsHandler(BaseHandler):
|
||||
event_ids = receipt["event_ids"]
|
||||
data = receipt["data"]
|
||||
|
||||
remotedomains = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
remotedomains = remotedomains.copy()
|
||||
remotedomains.discard(self.server_name)
|
||||
remotedomains = set()
|
||||
|
||||
rm_handler = self.hs.get_handlers().room_member_handler
|
||||
yield rm_handler.fetch_room_distributions_into(
|
||||
room_id, localusers=None, remotedomains=remotedomains
|
||||
)
|
||||
|
||||
logger.debug("Sending receipt to: %r", remotedomains)
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
"""Contains functions for registering clients."""
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.types import UserID, Requester
|
||||
from synapse.types import UserID
|
||||
from synapse.api.errors import (
|
||||
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
|
||||
)
|
||||
@@ -358,62 +358,8 @@ class RegistrationHandler(BaseHandler):
|
||||
)
|
||||
defer.returnValue(data)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_or_create_user(self, localpart, displayname, duration_seconds):
|
||||
"""Creates a new user if the user does not exist,
|
||||
else revokes all previous access tokens and generates a new one.
|
||||
|
||||
Args:
|
||||
localpart : The local part of the user ID to register. If None,
|
||||
one will be randomly generated.
|
||||
Returns:
|
||||
A tuple of (user_id, access_token).
|
||||
Raises:
|
||||
RegistrationError if there was a problem registering.
|
||||
"""
|
||||
yield run_on_reactor()
|
||||
|
||||
if localpart is None:
|
||||
raise SynapseError(400, "Request must include user id")
|
||||
|
||||
need_register = True
|
||||
|
||||
try:
|
||||
yield self.check_username(localpart)
|
||||
except SynapseError as e:
|
||||
if e.errcode == Codes.USER_IN_USE:
|
||||
need_register = False
|
||||
else:
|
||||
raise
|
||||
|
||||
user = UserID(localpart, self.hs.hostname)
|
||||
user_id = user.to_string()
|
||||
auth_handler = self.hs.get_handlers().auth_handler
|
||||
token = auth_handler.generate_short_term_login_token(user_id, duration_seconds)
|
||||
|
||||
if need_register:
|
||||
yield self.store.register(
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
password_hash=None
|
||||
)
|
||||
|
||||
yield registered_user(self.distributor, user)
|
||||
else:
|
||||
yield self.store.user_delete_access_tokens(user_id=user_id)
|
||||
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
|
||||
|
||||
if displayname is not None:
|
||||
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
||||
profile_handler = self.hs.get_handlers().profile_handler
|
||||
yield profile_handler.set_displayname(
|
||||
user, Requester(user, token, False), displayname
|
||||
)
|
||||
|
||||
defer.returnValue((user_id, token))
|
||||
|
||||
def auth_handler(self):
|
||||
return self.hs.get_auth_handler()
|
||||
return self.hs.get_handlers().auth_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def guest_access_token_for(self, medium, address, inviter_user_id):
|
||||
|
||||
@@ -26,7 +26,6 @@ from synapse.api.errors import AuthError, StoreError, SynapseError
|
||||
from synapse.util import stringutils
|
||||
from synapse.util.async import concurrently_execute
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
@@ -36,8 +35,6 @@ import string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
|
||||
|
||||
id_server_scheme = "https://"
|
||||
|
||||
|
||||
@@ -346,14 +343,8 @@ class RoomListHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
super(RoomListHandler, self).__init__(hs)
|
||||
self.response_cache = ResponseCache()
|
||||
self.remote_list_request_cache = ResponseCache()
|
||||
self.remote_list_cache = {}
|
||||
self.fetch_looping_call = hs.get_clock().looping_call(
|
||||
self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL
|
||||
)
|
||||
self.fetch_all_remote_lists()
|
||||
|
||||
def get_local_public_room_list(self):
|
||||
def get_public_room_list(self):
|
||||
result = self.response_cache.get(())
|
||||
if not result:
|
||||
result = self.response_cache.set((), self._get_public_room_list())
|
||||
@@ -435,55 +426,6 @@ class RoomListHandler(BaseHandler):
|
||||
# FIXME (erikj): START is no longer a valid value
|
||||
defer.returnValue({"start": "START", "end": "END", "chunk": results})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def fetch_all_remote_lists(self):
|
||||
deferred = self.hs.get_replication_layer().get_public_rooms(
|
||||
self.hs.config.secondary_directory_servers
|
||||
)
|
||||
self.remote_list_request_cache.set((), deferred)
|
||||
self.remote_list_cache = yield deferred
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_aggregated_public_room_list(self):
|
||||
"""
|
||||
Get the public room list from this server and the servers
|
||||
specified in the secondary_directory_servers config option.
|
||||
XXX: Pagination...
|
||||
"""
|
||||
# We return the results from out cache which is updated by a looping call,
|
||||
# unless we're missing a cache entry, in which case wait for the result
|
||||
# of the fetch if there's one in progress. If not, omit that server.
|
||||
wait = False
|
||||
for s in self.hs.config.secondary_directory_servers:
|
||||
if s not in self.remote_list_cache:
|
||||
logger.warn("No cached room list from %s: waiting for fetch", s)
|
||||
wait = True
|
||||
break
|
||||
|
||||
if wait and self.remote_list_request_cache.get(()):
|
||||
yield self.remote_list_request_cache.get(())
|
||||
|
||||
public_rooms = yield self.get_local_public_room_list()
|
||||
|
||||
# keep track of which room IDs we've seen so we can de-dup
|
||||
room_ids = set()
|
||||
|
||||
# tag all the ones in our list with our server name.
|
||||
# Also add the them to the de-deping set
|
||||
for room in public_rooms['chunk']:
|
||||
room["server_name"] = self.hs.hostname
|
||||
room_ids.add(room["room_id"])
|
||||
|
||||
# Now add the results from federation
|
||||
for server_name, server_result in self.remote_list_cache.items():
|
||||
for room in server_result["chunk"]:
|
||||
if room["room_id"] not in room_ids:
|
||||
room["server_name"] = server_name
|
||||
public_rooms["chunk"].append(room)
|
||||
room_ids.add(room["room_id"])
|
||||
|
||||
defer.returnValue(public_rooms)
|
||||
|
||||
|
||||
class RoomContextHandler(BaseHandler):
|
||||
@defer.inlineCallbacks
|
||||
@@ -507,12 +449,10 @@ class RoomContextHandler(BaseHandler):
|
||||
now_token = yield self.hs.get_event_sources().get_current_token()
|
||||
|
||||
def filter_evts(events):
|
||||
return filter_events_for_client(
|
||||
self.store,
|
||||
return self._filter_events_for_client(
|
||||
user.to_string(),
|
||||
events,
|
||||
is_peeking=is_guest
|
||||
)
|
||||
is_peeking=is_guest)
|
||||
|
||||
event = yield self.store.get_event(event_id, get_prev_content=True,
|
||||
allow_none=True)
|
||||
|
||||
@@ -55,6 +55,35 @@ class RoomMemberHandler(BaseHandler):
|
||||
self.distributor.declare("user_joined_room")
|
||||
self.distributor.declare("user_left_room")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_members(self, room_id):
|
||||
users = yield self.store.get_users_in_room(room_id)
|
||||
|
||||
defer.returnValue([UserID.from_string(u) for u in users])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def fetch_room_distributions_into(self, room_id, localusers=None,
|
||||
remotedomains=None, ignore_user=None):
|
||||
"""Fetch the distribution of a room, adding elements to either
|
||||
'localusers' or 'remotedomains', which should be a set() if supplied.
|
||||
If ignore_user is set, ignore that user.
|
||||
|
||||
This function returns nothing; its result is performed by the
|
||||
side-effect on the two passed sets. This allows easy accumulation of
|
||||
member lists of multiple rooms at once if required.
|
||||
"""
|
||||
members = yield self.get_room_members(room_id)
|
||||
for member in members:
|
||||
if ignore_user is not None and member == ignore_user:
|
||||
continue
|
||||
|
||||
if self.hs.is_mine(member):
|
||||
if localusers is not None:
|
||||
localusers.add(member)
|
||||
else:
|
||||
if remotedomains is not None:
|
||||
remotedomains.add(member.domain)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _local_membership_update(
|
||||
self, requester, target, room_id, membership,
|
||||
@@ -84,7 +113,7 @@ class RoomMemberHandler(BaseHandler):
|
||||
prev_event_ids=prev_event_ids,
|
||||
)
|
||||
|
||||
yield msg_handler.handle_new_client_event(
|
||||
yield self.handle_new_client_event(
|
||||
requester,
|
||||
event,
|
||||
context,
|
||||
@@ -203,7 +232,7 @@ class RoomMemberHandler(BaseHandler):
|
||||
if old_membership == "ban" and action != "unban":
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Cannot %s user who was banned" % (action,),
|
||||
"Cannot %s user who was is banned" % (action,),
|
||||
errcode=Codes.BAD_STATE
|
||||
)
|
||||
|
||||
@@ -328,7 +357,7 @@ class RoomMemberHandler(BaseHandler):
|
||||
# so don't really fit into the general auth process.
|
||||
raise AuthError(403, "Guest access not allowed")
|
||||
|
||||
yield message_handler.handle_new_client_event(
|
||||
yield self.handle_new_client_event(
|
||||
requester,
|
||||
event,
|
||||
context,
|
||||
@@ -397,6 +426,21 @@ class RoomMemberHandler(BaseHandler):
|
||||
if invite:
|
||||
defer.returnValue(UserID.from_string(invite.sender))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_joined_rooms_for_user(self, user):
|
||||
"""Returns a list of roomids that the user has any of the given
|
||||
membership states in."""
|
||||
|
||||
rooms = yield self.store.get_rooms_for_user(
|
||||
user.to_string(),
|
||||
)
|
||||
|
||||
# For some reason the list of events contains duplicates
|
||||
# TODO(paul): work out why because I really don't think it should
|
||||
room_ids = set(r.room_id for r in rooms)
|
||||
|
||||
defer.returnValue(room_ids)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_3pid_invite(
|
||||
self,
|
||||
@@ -413,7 +457,8 @@ class RoomMemberHandler(BaseHandler):
|
||||
)
|
||||
|
||||
if invitee:
|
||||
yield self.update_membership(
|
||||
handler = self.hs.get_handlers().room_member_handler
|
||||
yield handler.update_membership(
|
||||
requester,
|
||||
UserID.from_string(invitee),
|
||||
room_id,
|
||||
|
||||
@@ -21,7 +21,6 @@ from synapse.api.constants import Membership, EventTypes
|
||||
from synapse.api.filtering import Filter
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.events.utils import serialize_event
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
from unpaddedbase64 import decode_base64, encode_base64
|
||||
|
||||
@@ -173,8 +172,8 @@ class SearchHandler(BaseHandler):
|
||||
|
||||
filtered_events = search_filter.filter([r["event"] for r in results])
|
||||
|
||||
events = yield filter_events_for_client(
|
||||
self.store, user.to_string(), filtered_events
|
||||
events = yield self._filter_events_for_client(
|
||||
user.to_string(), filtered_events
|
||||
)
|
||||
|
||||
events.sort(key=lambda e: -rank_map[e.event_id])
|
||||
@@ -224,8 +223,8 @@ class SearchHandler(BaseHandler):
|
||||
r["event"] for r in results
|
||||
])
|
||||
|
||||
events = yield filter_events_for_client(
|
||||
self.store, user.to_string(), filtered_events
|
||||
events = yield self._filter_events_for_client(
|
||||
user.to_string(), filtered_events
|
||||
)
|
||||
|
||||
room_events.extend(events)
|
||||
@@ -282,12 +281,12 @@ class SearchHandler(BaseHandler):
|
||||
event.room_id, event.event_id, before_limit, after_limit
|
||||
)
|
||||
|
||||
res["events_before"] = yield filter_events_for_client(
|
||||
self.store, user.to_string(), res["events_before"]
|
||||
res["events_before"] = yield self._filter_events_for_client(
|
||||
user.to_string(), res["events_before"]
|
||||
)
|
||||
|
||||
res["events_after"] = yield filter_events_for_client(
|
||||
self.store, user.to_string(), res["events_after"]
|
||||
res["events_after"] = yield self._filter_events_for_client(
|
||||
user.to_string(), res["events_after"]
|
||||
)
|
||||
|
||||
res["start"] = now_token.copy_and_replace(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -15,6 +15,8 @@
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
from synapse.api.errors import SynapseError, AuthError
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
from synapse.util.metrics import Measure
|
||||
@@ -30,16 +32,14 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# A tiny object useful for storing a user's membership in a room, as a mapping
|
||||
# key
|
||||
RoomMember = namedtuple("RoomMember", ("room_id", "user_id"))
|
||||
RoomMember = namedtuple("RoomMember", ("room_id", "user"))
|
||||
|
||||
|
||||
class TypingHandler(object):
|
||||
class TypingNotificationHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.server_name = hs.config.server_name
|
||||
self.auth = hs.get_auth()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.notifier = hs.get_notifier()
|
||||
super(TypingNotificationHandler, self).__init__(hs)
|
||||
|
||||
self.homeserver = hs
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@@ -67,23 +67,20 @@ class TypingHandler(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def started_typing(self, target_user, auth_user, room_id, timeout):
|
||||
target_user_id = target_user.to_string()
|
||||
auth_user_id = auth_user.to_string()
|
||||
|
||||
if not self.is_mine_id(target_user_id):
|
||||
if not self.hs.is_mine(target_user):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
if target_user_id != auth_user_id:
|
||||
if target_user != auth_user:
|
||||
raise AuthError(400, "Cannot set another user's typing state")
|
||||
|
||||
yield self.auth.check_joined_room(room_id, target_user_id)
|
||||
yield self.auth.check_joined_room(room_id, target_user.to_string())
|
||||
|
||||
logger.debug(
|
||||
"%s has started typing in %s", target_user_id, room_id
|
||||
"%s has started typing in %s", target_user.to_string(), room_id
|
||||
)
|
||||
|
||||
until = self.clock.time_msec() + timeout
|
||||
member = RoomMember(room_id=room_id, user_id=target_user_id)
|
||||
member = RoomMember(room_id=room_id, user=target_user)
|
||||
|
||||
was_present = member in self._member_typing_until
|
||||
|
||||
@@ -107,28 +104,25 @@ class TypingHandler(object):
|
||||
|
||||
yield self._push_update(
|
||||
room_id=room_id,
|
||||
user_id=target_user_id,
|
||||
user=target_user,
|
||||
typing=True,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def stopped_typing(self, target_user, auth_user, room_id):
|
||||
target_user_id = target_user.to_string()
|
||||
auth_user_id = auth_user.to_string()
|
||||
|
||||
if not self.is_mine_id(target_user_id):
|
||||
if not self.hs.is_mine(target_user):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
if target_user_id != auth_user_id:
|
||||
if target_user != auth_user:
|
||||
raise AuthError(400, "Cannot set another user's typing state")
|
||||
|
||||
yield self.auth.check_joined_room(room_id, target_user_id)
|
||||
yield self.auth.check_joined_room(room_id, target_user.to_string())
|
||||
|
||||
logger.debug(
|
||||
"%s has stopped typing in %s", target_user_id, room_id
|
||||
"%s has stopped typing in %s", target_user.to_string(), room_id
|
||||
)
|
||||
|
||||
member = RoomMember(room_id=room_id, user_id=target_user_id)
|
||||
member = RoomMember(room_id=room_id, user=target_user)
|
||||
|
||||
if member in self._member_typing_timer:
|
||||
self.clock.cancel_call_later(self._member_typing_timer[member])
|
||||
@@ -138,9 +132,8 @@ class TypingHandler(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_left_room(self, user, room_id):
|
||||
user_id = user.to_string()
|
||||
if self.is_mine_id(user_id):
|
||||
member = RoomMember(room_id=room_id, user=user_id)
|
||||
if self.hs.is_mine(user):
|
||||
member = RoomMember(room_id=room_id, user=user)
|
||||
yield self._stopped_typing(member)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -151,7 +144,7 @@ class TypingHandler(object):
|
||||
|
||||
yield self._push_update(
|
||||
room_id=member.room_id,
|
||||
user_id=member.user_id,
|
||||
user=member.user,
|
||||
typing=False,
|
||||
)
|
||||
|
||||
@@ -163,53 +156,61 @@ class TypingHandler(object):
|
||||
del self._member_typing_timer[member]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _push_update(self, room_id, user_id, typing):
|
||||
domains = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
def _push_update(self, room_id, user, typing):
|
||||
localusers = set()
|
||||
remotedomains = set()
|
||||
|
||||
rm_handler = self.homeserver.get_handlers().room_member_handler
|
||||
yield rm_handler.fetch_room_distributions_into(
|
||||
room_id, localusers=localusers, remotedomains=remotedomains
|
||||
)
|
||||
|
||||
if localusers:
|
||||
self._push_update_local(
|
||||
room_id=room_id,
|
||||
user=user,
|
||||
typing=typing
|
||||
)
|
||||
|
||||
deferreds = []
|
||||
for domain in domains:
|
||||
if domain == self.server_name:
|
||||
self._push_update_local(
|
||||
room_id=room_id,
|
||||
user_id=user_id,
|
||||
typing=typing
|
||||
)
|
||||
else:
|
||||
deferreds.append(self.federation.send_edu(
|
||||
destination=domain,
|
||||
edu_type="m.typing",
|
||||
content={
|
||||
"room_id": room_id,
|
||||
"user_id": user_id,
|
||||
"typing": typing,
|
||||
},
|
||||
))
|
||||
for domain in remotedomains:
|
||||
deferreds.append(self.federation.send_edu(
|
||||
destination=domain,
|
||||
edu_type="m.typing",
|
||||
content={
|
||||
"room_id": room_id,
|
||||
"user_id": user.to_string(),
|
||||
"typing": typing,
|
||||
},
|
||||
))
|
||||
|
||||
yield defer.DeferredList(deferreds, consumeErrors=True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _recv_edu(self, origin, content):
|
||||
room_id = content["room_id"]
|
||||
user_id = content["user_id"]
|
||||
user = UserID.from_string(content["user_id"])
|
||||
|
||||
# Check that the string is a valid user id
|
||||
UserID.from_string(user_id)
|
||||
localusers = set()
|
||||
|
||||
domains = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
rm_handler = self.homeserver.get_handlers().room_member_handler
|
||||
yield rm_handler.fetch_room_distributions_into(
|
||||
room_id, localusers=localusers
|
||||
)
|
||||
|
||||
if self.server_name in domains:
|
||||
if localusers:
|
||||
self._push_update_local(
|
||||
room_id=room_id,
|
||||
user_id=user_id,
|
||||
user=user,
|
||||
typing=content["typing"]
|
||||
)
|
||||
|
||||
def _push_update_local(self, room_id, user_id, typing):
|
||||
def _push_update_local(self, room_id, user, typing):
|
||||
room_set = self._room_typing.setdefault(room_id, set())
|
||||
if typing:
|
||||
room_set.add(user_id)
|
||||
room_set.add(user)
|
||||
else:
|
||||
room_set.discard(user_id)
|
||||
room_set.discard(user)
|
||||
|
||||
self._latest_room_serial += 1
|
||||
self._room_serials[room_id] = self._latest_room_serial
|
||||
@@ -225,7 +226,9 @@ class TypingHandler(object):
|
||||
for room_id, serial in self._room_serials.items():
|
||||
if last_id < serial and serial <= current_id:
|
||||
typing = self._room_typing[room_id]
|
||||
typing_bytes = json.dumps(list(typing), ensure_ascii=False)
|
||||
typing_bytes = json.dumps([
|
||||
u.to_string() for u in typing
|
||||
], ensure_ascii=False)
|
||||
rows.append((serial, room_id, typing_bytes))
|
||||
rows.sort()
|
||||
return rows
|
||||
@@ -235,26 +238,34 @@ class TypingNotificationEventSource(object):
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
# We can't call get_typing_handler here because there's a cycle:
|
||||
#
|
||||
# Typing -> Notifier -> TypingNotificationEventSource -> Typing
|
||||
#
|
||||
self.get_typing_handler = hs.get_typing_handler
|
||||
self._handler = None
|
||||
self._room_member_handler = None
|
||||
|
||||
def handler(self):
|
||||
# Avoid cyclic dependency in handler setup
|
||||
if not self._handler:
|
||||
self._handler = self.hs.get_handlers().typing_notification_handler
|
||||
return self._handler
|
||||
|
||||
def room_member_handler(self):
|
||||
if not self._room_member_handler:
|
||||
self._room_member_handler = self.hs.get_handlers().room_member_handler
|
||||
return self._room_member_handler
|
||||
|
||||
def _make_event_for(self, room_id):
|
||||
typing = self.get_typing_handler()._room_typing[room_id]
|
||||
typing = self.handler()._room_typing[room_id]
|
||||
return {
|
||||
"type": "m.typing",
|
||||
"room_id": room_id,
|
||||
"content": {
|
||||
"user_ids": list(typing),
|
||||
"user_ids": [u.to_string() for u in typing],
|
||||
},
|
||||
}
|
||||
|
||||
def get_new_events(self, from_key, room_ids, **kwargs):
|
||||
with Measure(self.clock, "typing.get_new_events"):
|
||||
from_key = int(from_key)
|
||||
handler = self.get_typing_handler()
|
||||
handler = self.handler()
|
||||
|
||||
events = []
|
||||
for room_id in room_ids:
|
||||
@@ -268,7 +279,7 @@ class TypingNotificationEventSource(object):
|
||||
return events, handler._latest_room_serial
|
||||
|
||||
def get_current_key(self):
|
||||
return self.get_typing_handler()._latest_room_serial
|
||||
return self.handler()._latest_room_serial
|
||||
|
||||
def get_pagination_rows(self, user, pagination_config, key):
|
||||
return ([], pagination_config.from_key)
|
||||
|
||||
@@ -380,14 +380,13 @@ class CaptchaServerHttpClient(SimpleHttpClient):
|
||||
class SpiderEndpointFactory(object):
|
||||
def __init__(self, hs):
|
||||
self.blacklist = hs.config.url_preview_ip_range_blacklist
|
||||
self.whitelist = hs.config.url_preview_ip_range_whitelist
|
||||
self.policyForHTTPS = hs.get_http_client_context_factory()
|
||||
|
||||
def endpointForURI(self, uri):
|
||||
logger.info("Getting endpoint for %s", uri.toBytes())
|
||||
if uri.scheme == "http":
|
||||
return SpiderEndpoint(
|
||||
reactor, uri.host, uri.port, self.blacklist, self.whitelist,
|
||||
reactor, uri.host, uri.port, self.blacklist,
|
||||
endpoint=TCP4ClientEndpoint,
|
||||
endpoint_kw_args={
|
||||
'timeout': 15
|
||||
@@ -396,7 +395,7 @@ class SpiderEndpointFactory(object):
|
||||
elif uri.scheme == "https":
|
||||
tlsPolicy = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port)
|
||||
return SpiderEndpoint(
|
||||
reactor, uri.host, uri.port, self.blacklist, self.whitelist,
|
||||
reactor, uri.host, uri.port, self.blacklist,
|
||||
endpoint=SSL4ClientEndpoint,
|
||||
endpoint_kw_args={
|
||||
'sslContextFactory': tlsPolicy,
|
||||
|
||||
@@ -79,13 +79,12 @@ class SpiderEndpoint(object):
|
||||
"""An endpoint which refuses to connect to blacklisted IP addresses
|
||||
Implements twisted.internet.interfaces.IStreamClientEndpoint.
|
||||
"""
|
||||
def __init__(self, reactor, host, port, blacklist, whitelist,
|
||||
def __init__(self, reactor, host, port, blacklist,
|
||||
endpoint=TCP4ClientEndpoint, endpoint_kw_args={}):
|
||||
self.reactor = reactor
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.blacklist = blacklist
|
||||
self.whitelist = whitelist
|
||||
self.endpoint = endpoint
|
||||
self.endpoint_kw_args = endpoint_kw_args
|
||||
|
||||
@@ -94,13 +93,10 @@ class SpiderEndpoint(object):
|
||||
address = yield self.reactor.resolve(self.host)
|
||||
|
||||
from netaddr import IPAddress
|
||||
ip_address = IPAddress(address)
|
||||
|
||||
if ip_address in self.blacklist:
|
||||
if self.whitelist is None or ip_address not in self.whitelist:
|
||||
raise ConnectError(
|
||||
"Refusing to spider blacklisted IP address %s" % address
|
||||
)
|
||||
if IPAddress(address) in self.blacklist:
|
||||
raise ConnectError(
|
||||
"Refusing to spider blacklisted IP address %s" % address
|
||||
)
|
||||
|
||||
logger.info("Connecting to %s:%s", address, self.port)
|
||||
endpoint = self.endpoint(
|
||||
|
||||
@@ -74,12 +74,7 @@ response_db_txn_duration = metrics.register_distribution(
|
||||
_next_request_id = 0
|
||||
|
||||
|
||||
def request_handler(report_metrics=True):
|
||||
"""Decorator for ``wrap_request_handler``"""
|
||||
return lambda request_handler: wrap_request_handler(request_handler, report_metrics)
|
||||
|
||||
|
||||
def wrap_request_handler(request_handler, report_metrics):
|
||||
def request_handler(request_handler):
|
||||
"""Wraps a method that acts as a request handler with the necessary logging
|
||||
and exception handling.
|
||||
|
||||
@@ -101,12 +96,7 @@ def wrap_request_handler(request_handler, report_metrics):
|
||||
global _next_request_id
|
||||
request_id = "%s-%s" % (request.method, _next_request_id)
|
||||
_next_request_id += 1
|
||||
|
||||
with LoggingContext(request_id) as request_context:
|
||||
if report_metrics:
|
||||
request_metrics = RequestMetrics()
|
||||
request_metrics.start(self.clock)
|
||||
|
||||
request_context.request = request_id
|
||||
with request.processing():
|
||||
try:
|
||||
@@ -143,14 +133,6 @@ def wrap_request_handler(request_handler, report_metrics):
|
||||
},
|
||||
send_cors=True
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
if report_metrics:
|
||||
request_metrics.stop(
|
||||
self.clock, request, self.__class__.__name__
|
||||
)
|
||||
except:
|
||||
pass
|
||||
return wrapped_request_handler
|
||||
|
||||
|
||||
@@ -215,23 +197,19 @@ class JsonResource(HttpServer, resource.Resource):
|
||||
self._async_render(request)
|
||||
return server.NOT_DONE_YET
|
||||
|
||||
# Disable metric reporting because _async_render does its own metrics.
|
||||
# It does its own metric reporting because _async_render dispatches to
|
||||
# a callback and it's the class name of that callback we want to report
|
||||
# against rather than the JsonResource itself.
|
||||
@request_handler(report_metrics=False)
|
||||
@request_handler
|
||||
@defer.inlineCallbacks
|
||||
def _async_render(self, request):
|
||||
""" This gets called from render() every time someone sends us a request.
|
||||
This checks if anyone has registered a callback for that method and
|
||||
path.
|
||||
"""
|
||||
start = self.clock.time_msec()
|
||||
if request.method == "OPTIONS":
|
||||
self._send_response(request, 200, {})
|
||||
return
|
||||
|
||||
request_metrics = RequestMetrics()
|
||||
request_metrics.start(self.clock)
|
||||
start_context = LoggingContext.current_context()
|
||||
|
||||
# Loop through all the registered callbacks to check if the method
|
||||
# and path regex match
|
||||
@@ -263,7 +241,40 @@ class JsonResource(HttpServer, resource.Resource):
|
||||
self._send_response(request, code, response)
|
||||
|
||||
try:
|
||||
request_metrics.stop(self.clock, request, servlet_classname)
|
||||
context = LoggingContext.current_context()
|
||||
|
||||
tag = ""
|
||||
if context:
|
||||
tag = context.tag
|
||||
|
||||
if context != start_context:
|
||||
logger.warn(
|
||||
"Context have unexpectedly changed %r, %r",
|
||||
context, self.start_context
|
||||
)
|
||||
return
|
||||
|
||||
incoming_requests_counter.inc(request.method, servlet_classname, tag)
|
||||
|
||||
response_timer.inc_by(
|
||||
self.clock.time_msec() - start, request.method,
|
||||
servlet_classname, tag
|
||||
)
|
||||
|
||||
ru_utime, ru_stime = context.get_resource_usage()
|
||||
|
||||
response_ru_utime.inc_by(
|
||||
ru_utime, request.method, servlet_classname, tag
|
||||
)
|
||||
response_ru_stime.inc_by(
|
||||
ru_stime, request.method, servlet_classname, tag
|
||||
)
|
||||
response_db_txn_count.inc_by(
|
||||
context.db_txn_count, request.method, servlet_classname, tag
|
||||
)
|
||||
response_db_txn_duration.inc_by(
|
||||
context.db_txn_duration, request.method, servlet_classname, tag
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
@@ -296,48 +307,6 @@ class JsonResource(HttpServer, resource.Resource):
|
||||
)
|
||||
|
||||
|
||||
class RequestMetrics(object):
|
||||
def start(self, clock):
|
||||
self.start = clock.time_msec()
|
||||
self.start_context = LoggingContext.current_context()
|
||||
|
||||
def stop(self, clock, request, servlet_classname):
|
||||
context = LoggingContext.current_context()
|
||||
|
||||
tag = ""
|
||||
if context:
|
||||
tag = context.tag
|
||||
|
||||
if context != self.start_context:
|
||||
logger.warn(
|
||||
"Context have unexpectedly changed %r, %r",
|
||||
context, self.start_context
|
||||
)
|
||||
return
|
||||
|
||||
incoming_requests_counter.inc(request.method, servlet_classname, tag)
|
||||
|
||||
response_timer.inc_by(
|
||||
clock.time_msec() - self.start, request.method,
|
||||
servlet_classname, tag
|
||||
)
|
||||
|
||||
ru_utime, ru_stime = context.get_resource_usage()
|
||||
|
||||
response_ru_utime.inc_by(
|
||||
ru_utime, request.method, servlet_classname, tag
|
||||
)
|
||||
response_ru_stime.inc_by(
|
||||
ru_stime, request.method, servlet_classname, tag
|
||||
)
|
||||
response_db_txn_count.inc_by(
|
||||
context.db_txn_count, request.method, servlet_classname, tag
|
||||
)
|
||||
response_db_txn_duration.inc_by(
|
||||
context.db_txn_duration, request.method, servlet_classname, tag
|
||||
)
|
||||
|
||||
|
||||
class RootRedirect(resource.Resource):
|
||||
"""Redirects the root '/' path to another path."""
|
||||
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from twisted.web.server import Site, Request
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
|
||||
ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
|
||||
|
||||
|
||||
class SynapseRequest(Request):
|
||||
def __init__(self, site, *args, **kw):
|
||||
Request.__init__(self, *args, **kw)
|
||||
self.site = site
|
||||
self.authenticated_entity = None
|
||||
self.start_time = 0
|
||||
|
||||
def __repr__(self):
|
||||
# We overwrite this so that we don't log ``access_token``
|
||||
return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
|
||||
self.__class__.__name__,
|
||||
id(self),
|
||||
self.method,
|
||||
self.get_redacted_uri(),
|
||||
self.clientproto,
|
||||
self.site.site_tag,
|
||||
)
|
||||
|
||||
def get_redacted_uri(self):
|
||||
return ACCESS_TOKEN_RE.sub(
|
||||
r'\1<redacted>\3',
|
||||
self.uri
|
||||
)
|
||||
|
||||
def get_user_agent(self):
|
||||
return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1]
|
||||
|
||||
def started_processing(self):
|
||||
self.site.access_logger.info(
|
||||
"%s - %s - Received request: %s %s",
|
||||
self.getClientIP(),
|
||||
self.site.site_tag,
|
||||
self.method,
|
||||
self.get_redacted_uri()
|
||||
)
|
||||
self.start_time = int(time.time() * 1000)
|
||||
|
||||
def finished_processing(self):
|
||||
|
||||
try:
|
||||
context = LoggingContext.current_context()
|
||||
ru_utime, ru_stime = context.get_resource_usage()
|
||||
db_txn_count = context.db_txn_count
|
||||
db_txn_duration = context.db_txn_duration
|
||||
except:
|
||||
ru_utime, ru_stime = (0, 0)
|
||||
db_txn_count, db_txn_duration = (0, 0)
|
||||
|
||||
self.site.access_logger.info(
|
||||
"%s - %s - {%s}"
|
||||
" Processed request: %dms (%dms, %dms) (%dms/%d)"
|
||||
" %sB %s \"%s %s %s\" \"%s\"",
|
||||
self.getClientIP(),
|
||||
self.site.site_tag,
|
||||
self.authenticated_entity,
|
||||
int(time.time() * 1000) - self.start_time,
|
||||
int(ru_utime * 1000),
|
||||
int(ru_stime * 1000),
|
||||
int(db_txn_duration * 1000),
|
||||
int(db_txn_count),
|
||||
self.sentLength,
|
||||
self.code,
|
||||
self.method,
|
||||
self.get_redacted_uri(),
|
||||
self.clientproto,
|
||||
self.get_user_agent(),
|
||||
)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def processing(self):
|
||||
self.started_processing()
|
||||
yield
|
||||
self.finished_processing()
|
||||
|
||||
|
||||
class XForwardedForRequest(SynapseRequest):
|
||||
def __init__(self, *args, **kw):
|
||||
SynapseRequest.__init__(self, *args, **kw)
|
||||
|
||||
"""
|
||||
Add a layer on top of another request that only uses the value of an
|
||||
X-Forwarded-For header as the result of C{getClientIP}.
|
||||
"""
|
||||
def getClientIP(self):
|
||||
"""
|
||||
@return: The client address (the first address) in the value of the
|
||||
I{X-Forwarded-For header}. If the header is not present, return
|
||||
C{b"-"}.
|
||||
"""
|
||||
return self.requestHeaders.getRawHeaders(
|
||||
b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip()
|
||||
|
||||
|
||||
class SynapseRequestFactory(object):
|
||||
def __init__(self, site, x_forwarded_for):
|
||||
self.site = site
|
||||
self.x_forwarded_for = x_forwarded_for
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if self.x_forwarded_for:
|
||||
return XForwardedForRequest(self.site, *args, **kwargs)
|
||||
else:
|
||||
return SynapseRequest(self.site, *args, **kwargs)
|
||||
|
||||
|
||||
class SynapseSite(Site):
|
||||
"""
|
||||
Subclass of a twisted http Site that does access logging with python's
|
||||
standard logging
|
||||
"""
|
||||
def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs):
|
||||
Site.__init__(self, resource, *args, **kwargs)
|
||||
|
||||
self.site_tag = site_tag
|
||||
|
||||
proxied = config.get("x_forwarded", False)
|
||||
self.requestFactory = SynapseRequestFactory(self, proxied)
|
||||
self.access_logger = logging.getLogger(logger_name)
|
||||
|
||||
def log(self, request):
|
||||
pass
|
||||
@@ -21,7 +21,6 @@ from synapse.util.logutils import log_function
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
from synapse.types import StreamToken
|
||||
from synapse.visibility import filter_events_for_client
|
||||
import synapse.metrics
|
||||
|
||||
from collections import namedtuple
|
||||
@@ -140,6 +139,8 @@ class Notifier(object):
|
||||
UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
|
||||
self.user_to_user_stream = {}
|
||||
self.room_to_user_streams = {}
|
||||
self.appservice_to_user_streams = {}
|
||||
@@ -149,8 +150,6 @@ class Notifier(object):
|
||||
self.pending_new_room_events = []
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
self.appservice_handler = hs.get_application_service_handler()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
|
||||
hs.get_distributor().observe(
|
||||
"user_joined_room", self._user_joined_room
|
||||
@@ -232,7 +231,9 @@ class Notifier(object):
|
||||
def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
|
||||
"""Notify any user streams that are interested in this room event"""
|
||||
# poke any interested application service.
|
||||
self.appservice_handler.notify_interested_services(event)
|
||||
self.hs.get_handlers().appservice_handler.notify_interested_services(
|
||||
event
|
||||
)
|
||||
|
||||
app_streams = set()
|
||||
|
||||
@@ -397,8 +398,8 @@ class Notifier(object):
|
||||
)
|
||||
|
||||
if name == "room":
|
||||
new_events = yield filter_events_for_client(
|
||||
self.store,
|
||||
room_member_handler = self.hs.get_handlers().room_member_handler
|
||||
new_events = yield room_member_handler._filter_events_for_client(
|
||||
user.to_string(),
|
||||
new_events,
|
||||
is_peeking=is_peeking,
|
||||
@@ -447,7 +448,7 @@ class Notifier(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _is_world_readable(self, room_id):
|
||||
state = yield self.state_handler.get_current_state(
|
||||
state = yield self.hs.get_state_handler().get_current_state(
|
||||
room_id,
|
||||
EventTypes.RoomHistoryVisibility
|
||||
)
|
||||
|
||||
@@ -37,14 +37,14 @@ class ActionGenerator:
|
||||
# tag (ie. we just need all the users).
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_push_actions_for_event(self, event, context):
|
||||
def handle_push_actions_for_event(self, event, context, handler):
|
||||
with Measure(self.clock, "handle_push_actions_for_event"):
|
||||
bulk_evaluator = yield evaluator_for_event(
|
||||
event, self.hs, self.store
|
||||
)
|
||||
|
||||
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
|
||||
event, context.current_state
|
||||
event, handler, context.current_state
|
||||
)
|
||||
|
||||
context.push_actions = [
|
||||
|
||||
@@ -22,14 +22,12 @@ from .baserules import list_with_base_rules
|
||||
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.visibility import filter_events_for_clients
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def decode_rule_json(rule):
|
||||
rule = dict(rule)
|
||||
rule['conditions'] = json.loads(rule['conditions'])
|
||||
rule['actions'] = json.loads(rule['actions'])
|
||||
return rule
|
||||
@@ -40,8 +38,6 @@ def _get_rules(room_id, user_ids, store):
|
||||
rules_by_user = yield store.bulk_get_push_rules(user_ids)
|
||||
rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids)
|
||||
|
||||
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
|
||||
|
||||
rules_by_user = {
|
||||
uid: list_with_base_rules([
|
||||
decode_rule_json(rule_list)
|
||||
@@ -54,10 +50,11 @@ def _get_rules(room_id, user_ids, store):
|
||||
# fetch disabled rules, but this won't account for any server default
|
||||
# rules the user has disabled, so we need to do this too.
|
||||
for uid in user_ids:
|
||||
user_enabled_map = rules_enabled_by_user.get(uid)
|
||||
if not user_enabled_map:
|
||||
if uid not in rules_enabled_by_user:
|
||||
continue
|
||||
|
||||
user_enabled_map = rules_enabled_by_user[uid]
|
||||
|
||||
for i, rule in enumerate(rules_by_user[uid]):
|
||||
rule_id = rule['rule_id']
|
||||
|
||||
@@ -129,7 +126,7 @@ class BulkPushRuleEvaluator:
|
||||
self.store = store
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def action_for_event_by_user(self, event, current_state):
|
||||
def action_for_event_by_user(self, event, handler, current_state):
|
||||
actions_by_user = {}
|
||||
|
||||
# None of these users can be peeking since this list of users comes
|
||||
@@ -139,8 +136,8 @@ class BulkPushRuleEvaluator:
|
||||
(u, False) for u in self.rules_by_user.keys()
|
||||
]
|
||||
|
||||
filtered_by_user = yield filter_events_for_clients(
|
||||
self.store, user_tuples, [event], {event.event_id: current_state}
|
||||
filtered_by_user = yield handler.filter_events_for_clients(
|
||||
user_tuples, [event], {event.event_id: current_state}
|
||||
)
|
||||
|
||||
room_members = yield self.store.get_users_in_room(self.room_id)
|
||||
|
||||
@@ -1,283 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
|
||||
import logging
|
||||
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
|
||||
from mailer import Mailer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# The amount of time we always wait before ever emailing about a notification
|
||||
# (to give the user a chance to respond to other push or notice the window)
|
||||
DELAY_BEFORE_MAIL_MS = 10 * 60 * 1000
|
||||
|
||||
# THROTTLE is the minimum time between mail notifications sent for a given room.
|
||||
# Each room maintains its own throttle counter, but each new mail notification
|
||||
# sends the pending notifications for all rooms.
|
||||
THROTTLE_START_MS = 10 * 60 * 1000
|
||||
THROTTLE_MAX_MS = 24 * 60 * 60 * 1000 # 24h
|
||||
# THROTTLE_MULTIPLIER = 6 # 10 mins, 1 hour, 6 hours, 24 hours
|
||||
THROTTLE_MULTIPLIER = 144 # 10 mins, 24 hours - i.e. jump straight to 1 day
|
||||
|
||||
# If no event triggers a notification for this long after the previous,
|
||||
# the throttle is released.
|
||||
# 12 hours - a gap of 12 hours in conversation is surely enough to merit a new
|
||||
# notification when things get going again...
|
||||
THROTTLE_RESET_AFTER_MS = (12 * 60 * 60 * 1000)
|
||||
|
||||
# does each email include all unread notifs, or just the ones which have happened
|
||||
# since the last mail?
|
||||
# XXX: this is currently broken as it includes ones from parted rooms(!)
|
||||
INCLUDE_ALL_UNREAD_NOTIFS = False
|
||||
|
||||
|
||||
class EmailPusher(object):
|
||||
"""
|
||||
A pusher that sends email notifications about events (approximately)
|
||||
when they happen.
|
||||
This shares quite a bit of code with httpusher: it would be good to
|
||||
factor out the common parts
|
||||
"""
|
||||
def __init__(self, hs, pusherdict):
|
||||
self.hs = hs
|
||||
self.store = self.hs.get_datastore()
|
||||
self.clock = self.hs.get_clock()
|
||||
self.pusher_id = pusherdict['id']
|
||||
self.user_id = pusherdict['user_name']
|
||||
self.app_id = pusherdict['app_id']
|
||||
self.email = pusherdict['pushkey']
|
||||
self.last_stream_ordering = pusherdict['last_stream_ordering']
|
||||
self.timed_call = None
|
||||
self.throttle_params = None
|
||||
|
||||
# See httppusher
|
||||
self.max_stream_ordering = None
|
||||
|
||||
self.processing = False
|
||||
|
||||
if self.hs.config.email_enable_notifs:
|
||||
if 'data' in pusherdict and 'brand' in pusherdict['data']:
|
||||
app_name = pusherdict['data']['brand']
|
||||
else:
|
||||
app_name = self.hs.config.email_app_name
|
||||
|
||||
self.mailer = Mailer(self.hs, app_name)
|
||||
else:
|
||||
self.mailer = None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_started(self):
|
||||
if self.mailer is not None:
|
||||
self.throttle_params = yield self.store.get_throttle_params_by_room(
|
||||
self.pusher_id
|
||||
)
|
||||
yield self._process()
|
||||
|
||||
def on_stop(self):
|
||||
if self.timed_call:
|
||||
self.timed_call.cancel()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
|
||||
self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering)
|
||||
yield self._process()
|
||||
|
||||
def on_new_receipts(self, min_stream_id, max_stream_id):
|
||||
# We could wake up and cancel the timer but there tend to be quite a
|
||||
# lot of read receipts so it's probably less work to just let the
|
||||
# timer fire
|
||||
return defer.succeed(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_timer(self):
|
||||
self.timed_call = None
|
||||
yield self._process()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _process(self):
|
||||
if self.processing:
|
||||
return
|
||||
|
||||
with LoggingContext("emailpush._process"):
|
||||
with Measure(self.clock, "emailpush._process"):
|
||||
try:
|
||||
self.processing = True
|
||||
# if the max ordering changes while we're running _unsafe_process,
|
||||
# call it again, and so on until we've caught up.
|
||||
while True:
|
||||
starting_max_ordering = self.max_stream_ordering
|
||||
try:
|
||||
yield self._unsafe_process()
|
||||
except:
|
||||
logger.exception("Exception processing notifs")
|
||||
if self.max_stream_ordering == starting_max_ordering:
|
||||
break
|
||||
finally:
|
||||
self.processing = False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _unsafe_process(self):
|
||||
"""
|
||||
Main logic of the push loop without the wrapper function that sets
|
||||
up logging, measures and guards against multiple instances of it
|
||||
being run.
|
||||
"""
|
||||
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
|
||||
unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
|
||||
self.user_id, start, self.max_stream_ordering
|
||||
)
|
||||
|
||||
soonest_due_at = None
|
||||
|
||||
for push_action in unprocessed:
|
||||
received_at = push_action['received_ts']
|
||||
if received_at is None:
|
||||
received_at = 0
|
||||
notif_ready_at = received_at + DELAY_BEFORE_MAIL_MS
|
||||
|
||||
room_ready_at = self.room_ready_to_notify_at(
|
||||
push_action['room_id']
|
||||
)
|
||||
|
||||
should_notify_at = max(notif_ready_at, room_ready_at)
|
||||
|
||||
if should_notify_at < self.clock.time_msec():
|
||||
# one of our notifications is ready for sending, so we send
|
||||
# *one* email updating the user on their notifications,
|
||||
# we then consider all previously outstanding notifications
|
||||
# to be delivered.
|
||||
|
||||
reason = {
|
||||
'room_id': push_action['room_id'],
|
||||
'now': self.clock.time_msec(),
|
||||
'received_at': received_at,
|
||||
'delay_before_mail_ms': DELAY_BEFORE_MAIL_MS,
|
||||
'last_sent_ts': self.get_room_last_sent_ts(push_action['room_id']),
|
||||
'throttle_ms': self.get_room_throttle_ms(push_action['room_id']),
|
||||
}
|
||||
|
||||
yield self.send_notification(unprocessed, reason)
|
||||
|
||||
yield self.save_last_stream_ordering_and_success(max([
|
||||
ea['stream_ordering'] for ea in unprocessed
|
||||
]))
|
||||
|
||||
# we update the throttle on all the possible unprocessed push actions
|
||||
for ea in unprocessed:
|
||||
yield self.sent_notif_update_throttle(
|
||||
ea['room_id'], ea
|
||||
)
|
||||
break
|
||||
else:
|
||||
if soonest_due_at is None or should_notify_at < soonest_due_at:
|
||||
soonest_due_at = should_notify_at
|
||||
|
||||
if self.timed_call is not None:
|
||||
self.timed_call.cancel()
|
||||
self.timed_call = None
|
||||
|
||||
if soonest_due_at is not None:
|
||||
self.timed_call = reactor.callLater(
|
||||
self.seconds_until(soonest_due_at), self.on_timer
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def save_last_stream_ordering_and_success(self, last_stream_ordering):
|
||||
self.last_stream_ordering = last_stream_ordering
|
||||
yield self.store.update_pusher_last_stream_ordering_and_success(
|
||||
self.app_id, self.email, self.user_id,
|
||||
last_stream_ordering, self.clock.time_msec()
|
||||
)
|
||||
|
||||
def seconds_until(self, ts_msec):
|
||||
return (ts_msec - self.clock.time_msec()) / 1000
|
||||
|
||||
def get_room_throttle_ms(self, room_id):
|
||||
if room_id in self.throttle_params:
|
||||
return self.throttle_params[room_id]["throttle_ms"]
|
||||
else:
|
||||
return 0
|
||||
|
||||
def get_room_last_sent_ts(self, room_id):
|
||||
if room_id in self.throttle_params:
|
||||
return self.throttle_params[room_id]["last_sent_ts"]
|
||||
else:
|
||||
return 0
|
||||
|
||||
def room_ready_to_notify_at(self, room_id):
|
||||
"""
|
||||
Determines whether throttling should prevent us from sending an email
|
||||
for the given room
|
||||
Returns: The timestamp when we are next allowed to send an email notif
|
||||
for this room
|
||||
"""
|
||||
last_sent_ts = self.get_room_last_sent_ts(room_id)
|
||||
throttle_ms = self.get_room_throttle_ms(room_id)
|
||||
|
||||
may_send_at = last_sent_ts + throttle_ms
|
||||
return may_send_at
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def sent_notif_update_throttle(self, room_id, notified_push_action):
|
||||
# We have sent a notification, so update the throttle accordingly.
|
||||
# If the event that triggered the notif happened more than
|
||||
# THROTTLE_RESET_AFTER_MS after the previous one that triggered a
|
||||
# notif, we release the throttle. Otherwise, the throttle is increased.
|
||||
time_of_previous_notifs = yield self.store.get_time_of_last_push_action_before(
|
||||
notified_push_action['stream_ordering']
|
||||
)
|
||||
|
||||
time_of_this_notifs = notified_push_action['received_ts']
|
||||
|
||||
if time_of_previous_notifs is not None and time_of_this_notifs is not None:
|
||||
gap = time_of_this_notifs - time_of_previous_notifs
|
||||
else:
|
||||
# if we don't know the arrival time of one of the notifs (it was not
|
||||
# stored prior to email notification code) then assume a gap of
|
||||
# zero which will just not reset the throttle
|
||||
gap = 0
|
||||
|
||||
current_throttle_ms = self.get_room_throttle_ms(room_id)
|
||||
|
||||
if gap > THROTTLE_RESET_AFTER_MS:
|
||||
new_throttle_ms = THROTTLE_START_MS
|
||||
else:
|
||||
if current_throttle_ms == 0:
|
||||
new_throttle_ms = THROTTLE_START_MS
|
||||
else:
|
||||
new_throttle_ms = min(
|
||||
current_throttle_ms * THROTTLE_MULTIPLIER,
|
||||
THROTTLE_MAX_MS
|
||||
)
|
||||
self.throttle_params[room_id] = {
|
||||
"last_sent_ts": self.clock.time_msec(),
|
||||
"throttle_ms": new_throttle_ms
|
||||
}
|
||||
yield self.store.set_throttle_params(
|
||||
self.pusher_id, room_id, self.throttle_params[room_id]
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_notification(self, push_actions, reason):
|
||||
logger.info("Sending notif email for user %r", self.user_id)
|
||||
|
||||
yield self.mailer.send_notification_mail(
|
||||
self.app_id, self.user_id, self.email, push_actions, reason
|
||||
)
|
||||
@@ -1,507 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.mail.smtp import sendmail
|
||||
|
||||
import email.utils
|
||||
import email.mime.multipart
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
|
||||
from synapse.util.async import concurrently_execute
|
||||
from synapse.util.presentable_names import (
|
||||
calculate_room_name, name_from_member_event, descriptor_from_member_events
|
||||
)
|
||||
from synapse.types import UserID
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
import jinja2
|
||||
import bleach
|
||||
|
||||
import time
|
||||
import urllib
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MESSAGE_FROM_PERSON_IN_ROOM = "You have a message on %(app)s from %(person)s " \
|
||||
"in the %(room)s room..."
|
||||
MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..."
|
||||
MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..."
|
||||
MESSAGES_IN_ROOM = "You have messages on %(app)s in the %(room)s room..."
|
||||
MESSAGES_IN_ROOM_AND_OTHERS = \
|
||||
"You have messages on %(app)s in the %(room)s room and others..."
|
||||
MESSAGES_FROM_PERSON_AND_OTHERS = \
|
||||
"You have messages on %(app)s from %(person)s and others..."
|
||||
INVITE_FROM_PERSON_TO_ROOM = "%(person)s has invited you to join the " \
|
||||
"%(room)s room on %(app)s..."
|
||||
INVITE_FROM_PERSON = "%(person)s has invited you to chat on %(app)s..."
|
||||
|
||||
CONTEXT_BEFORE = 1
|
||||
CONTEXT_AFTER = 1
|
||||
|
||||
# From https://github.com/matrix-org/matrix-react-sdk/blob/master/src/HtmlUtils.js
|
||||
ALLOWED_TAGS = [
|
||||
'font', # custom to matrix for IRC-style font coloring
|
||||
'del', # for markdown
|
||||
# deliberately no h1/h2 to stop people shouting.
|
||||
'h3', 'h4', 'h5', 'h6', 'blockquote', 'p', 'a', 'ul', 'ol',
|
||||
'nl', 'li', 'b', 'i', 'u', 'strong', 'em', 'strike', 'code', 'hr', 'br', 'div',
|
||||
'table', 'thead', 'caption', 'tbody', 'tr', 'th', 'td', 'pre'
|
||||
]
|
||||
ALLOWED_ATTRS = {
|
||||
# custom ones first:
|
||||
"font": ["color"], # custom to matrix
|
||||
"a": ["href", "name", "target"], # remote target: custom to matrix
|
||||
# We don't currently allow img itself by default, but this
|
||||
# would make sense if we did
|
||||
"img": ["src"],
|
||||
}
|
||||
# When bleach release a version with this option, we can specify schemes
|
||||
# ALLOWED_SCHEMES = ["http", "https", "ftp", "mailto"]
|
||||
|
||||
|
||||
class Mailer(object):
|
||||
def __init__(self, hs, app_name):
|
||||
self.hs = hs
|
||||
self.store = self.hs.get_datastore()
|
||||
self.auth_handler = self.hs.get_auth_handler()
|
||||
self.state_handler = self.hs.get_state_handler()
|
||||
loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir)
|
||||
self.app_name = app_name
|
||||
logger.info("Created Mailer for app_name %s" % app_name)
|
||||
env = jinja2.Environment(loader=loader)
|
||||
env.filters["format_ts"] = format_ts_filter
|
||||
env.filters["mxc_to_http"] = self.mxc_to_http_filter
|
||||
self.notif_template_html = env.get_template(
|
||||
self.hs.config.email_notif_template_html
|
||||
)
|
||||
self.notif_template_text = env.get_template(
|
||||
self.hs.config.email_notif_template_text
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_notification_mail(self, app_id, user_id, email_address,
|
||||
push_actions, reason):
|
||||
raw_from = email.utils.parseaddr(self.hs.config.email_notif_from)[1]
|
||||
raw_to = email.utils.parseaddr(email_address)[1]
|
||||
|
||||
if raw_to == '':
|
||||
raise RuntimeError("Invalid 'to' address")
|
||||
|
||||
rooms_in_order = deduped_ordered_list(
|
||||
[pa['room_id'] for pa in push_actions]
|
||||
)
|
||||
|
||||
notif_events = yield self.store.get_events(
|
||||
[pa['event_id'] for pa in push_actions]
|
||||
)
|
||||
|
||||
notifs_by_room = {}
|
||||
for pa in push_actions:
|
||||
notifs_by_room.setdefault(pa["room_id"], []).append(pa)
|
||||
|
||||
# collect the current state for all the rooms in which we have
|
||||
# notifications
|
||||
state_by_room = {}
|
||||
|
||||
try:
|
||||
user_display_name = yield self.store.get_profile_displayname(
|
||||
UserID.from_string(user_id).localpart
|
||||
)
|
||||
if user_display_name is None:
|
||||
user_display_name = user_id
|
||||
except StoreError:
|
||||
user_display_name = user_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _fetch_room_state(room_id):
|
||||
room_state = yield self.state_handler.get_current_state(room_id)
|
||||
state_by_room[room_id] = room_state
|
||||
|
||||
# Run at most 3 of these at once: sync does 10 at a time but email
|
||||
# notifs are much less realtime than sync so we can afford to wait a bit.
|
||||
yield concurrently_execute(_fetch_room_state, rooms_in_order, 3)
|
||||
|
||||
# actually sort our so-called rooms_in_order list, most recent room first
|
||||
rooms_in_order.sort(
|
||||
key=lambda r: -(notifs_by_room[r][-1]['received_ts'] or 0)
|
||||
)
|
||||
|
||||
rooms = []
|
||||
|
||||
for r in rooms_in_order:
|
||||
roomvars = yield self.get_room_vars(
|
||||
r, user_id, notifs_by_room[r], notif_events, state_by_room[r]
|
||||
)
|
||||
rooms.append(roomvars)
|
||||
|
||||
reason['room_name'] = calculate_room_name(
|
||||
state_by_room[reason['room_id']], user_id, fallback_to_members=True
|
||||
)
|
||||
|
||||
summary_text = self.make_summary_text(
|
||||
notifs_by_room, state_by_room, notif_events, user_id, reason
|
||||
)
|
||||
|
||||
template_vars = {
|
||||
"user_display_name": user_display_name,
|
||||
"unsubscribe_link": self.make_unsubscribe_link(
|
||||
user_id, app_id, email_address
|
||||
),
|
||||
"summary_text": summary_text,
|
||||
"app_name": self.app_name,
|
||||
"rooms": rooms,
|
||||
"reason": reason,
|
||||
}
|
||||
|
||||
html_text = self.notif_template_html.render(**template_vars)
|
||||
html_part = MIMEText(html_text, "html", "utf8")
|
||||
|
||||
plain_text = self.notif_template_text.render(**template_vars)
|
||||
text_part = MIMEText(plain_text, "plain", "utf8")
|
||||
|
||||
multipart_msg = MIMEMultipart('alternative')
|
||||
multipart_msg['Subject'] = "[%s] %s" % (self.app_name, summary_text)
|
||||
multipart_msg['From'] = self.hs.config.email_notif_from
|
||||
multipart_msg['To'] = email_address
|
||||
multipart_msg['Date'] = email.utils.formatdate()
|
||||
multipart_msg['Message-ID'] = email.utils.make_msgid()
|
||||
multipart_msg.attach(text_part)
|
||||
multipart_msg.attach(html_part)
|
||||
|
||||
logger.info("Sending email push notification to %s" % email_address)
|
||||
# logger.debug(html_text)
|
||||
|
||||
yield sendmail(
|
||||
self.hs.config.email_smtp_host,
|
||||
raw_from, raw_to, multipart_msg.as_string(),
|
||||
port=self.hs.config.email_smtp_port
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state):
|
||||
my_member_event = room_state[("m.room.member", user_id)]
|
||||
is_invite = my_member_event.content["membership"] == "invite"
|
||||
|
||||
room_vars = {
|
||||
"title": calculate_room_name(room_state, user_id),
|
||||
"hash": string_ordinal_total(room_id), # See sender avatar hash
|
||||
"notifs": [],
|
||||
"invite": is_invite,
|
||||
"link": self.make_room_link(room_id),
|
||||
}
|
||||
|
||||
if not is_invite:
|
||||
for n in notifs:
|
||||
notifvars = yield self.get_notif_vars(
|
||||
n, user_id, notif_events[n['event_id']], room_state
|
||||
)
|
||||
|
||||
# merge overlapping notifs together.
|
||||
# relies on the notifs being in chronological order.
|
||||
merge = False
|
||||
if room_vars['notifs'] and 'messages' in room_vars['notifs'][-1]:
|
||||
prev_messages = room_vars['notifs'][-1]['messages']
|
||||
for message in notifvars['messages']:
|
||||
pm = filter(lambda pm: pm['id'] == message['id'], prev_messages)
|
||||
if pm:
|
||||
if not message["is_historical"]:
|
||||
pm[0]["is_historical"] = False
|
||||
merge = True
|
||||
elif merge:
|
||||
# we're merging, so append any remaining messages
|
||||
# in this notif to the previous one
|
||||
prev_messages.append(message)
|
||||
|
||||
if not merge:
|
||||
room_vars['notifs'].append(notifvars)
|
||||
|
||||
defer.returnValue(room_vars)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_notif_vars(self, notif, user_id, notif_event, room_state):
|
||||
results = yield self.store.get_events_around(
|
||||
notif['room_id'], notif['event_id'],
|
||||
before_limit=CONTEXT_BEFORE, after_limit=CONTEXT_AFTER
|
||||
)
|
||||
|
||||
ret = {
|
||||
"link": self.make_notif_link(notif),
|
||||
"ts": notif['received_ts'],
|
||||
"messages": [],
|
||||
}
|
||||
|
||||
the_events = yield filter_events_for_client(
|
||||
self.store, user_id, results["events_before"]
|
||||
)
|
||||
the_events.append(notif_event)
|
||||
|
||||
for event in the_events:
|
||||
messagevars = self.get_message_vars(notif, event, room_state)
|
||||
if messagevars is not None:
|
||||
ret['messages'].append(messagevars)
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
||||
def get_message_vars(self, notif, event, room_state):
|
||||
if event.type != EventTypes.Message:
|
||||
return None
|
||||
|
||||
sender_state_event = room_state[("m.room.member", event.sender)]
|
||||
sender_name = name_from_member_event(sender_state_event)
|
||||
sender_avatar_url = None
|
||||
if "avatar_url" in sender_state_event.content:
|
||||
sender_avatar_url = sender_state_event.content["avatar_url"]
|
||||
|
||||
# 'hash' for deterministically picking default images: use
|
||||
# sender_hash % the number of default images to choose from
|
||||
sender_hash = string_ordinal_total(event.sender)
|
||||
|
||||
ret = {
|
||||
"msgtype": event.content["msgtype"],
|
||||
"is_historical": event.event_id != notif['event_id'],
|
||||
"id": event.event_id,
|
||||
"ts": event.origin_server_ts,
|
||||
"sender_name": sender_name,
|
||||
"sender_avatar_url": sender_avatar_url,
|
||||
"sender_hash": sender_hash,
|
||||
}
|
||||
|
||||
if event.content["msgtype"] == "m.text":
|
||||
self.add_text_message_vars(ret, event)
|
||||
elif event.content["msgtype"] == "m.image":
|
||||
self.add_image_message_vars(ret, event)
|
||||
|
||||
if "body" in event.content:
|
||||
ret["body_text_plain"] = event.content["body"]
|
||||
|
||||
return ret
|
||||
|
||||
def add_text_message_vars(self, messagevars, event):
|
||||
if "format" in event.content:
|
||||
msgformat = event.content["format"]
|
||||
else:
|
||||
msgformat = None
|
||||
messagevars["format"] = msgformat
|
||||
|
||||
if msgformat == "org.matrix.custom.html":
|
||||
messagevars["body_text_html"] = safe_markup(event.content["formatted_body"])
|
||||
else:
|
||||
messagevars["body_text_html"] = safe_text(event.content["body"])
|
||||
|
||||
return messagevars
|
||||
|
||||
def add_image_message_vars(self, messagevars, event):
|
||||
messagevars["image_url"] = event.content["url"]
|
||||
|
||||
return messagevars
|
||||
|
||||
def make_summary_text(self, notifs_by_room, state_by_room,
|
||||
notif_events, user_id, reason):
|
||||
if len(notifs_by_room) == 1:
|
||||
# Only one room has new stuff
|
||||
room_id = notifs_by_room.keys()[0]
|
||||
|
||||
# If the room has some kind of name, use it, but we don't
|
||||
# want the generated-from-names one here otherwise we'll
|
||||
# end up with, "new message from Bob in the Bob room"
|
||||
room_name = calculate_room_name(
|
||||
state_by_room[room_id], user_id, fallback_to_members=False
|
||||
)
|
||||
|
||||
my_member_event = state_by_room[room_id][("m.room.member", user_id)]
|
||||
if my_member_event.content["membership"] == "invite":
|
||||
inviter_member_event = state_by_room[room_id][
|
||||
("m.room.member", my_member_event.sender)
|
||||
]
|
||||
inviter_name = name_from_member_event(inviter_member_event)
|
||||
|
||||
if room_name is None:
|
||||
return INVITE_FROM_PERSON % {
|
||||
"person": inviter_name,
|
||||
"app": self.app_name
|
||||
}
|
||||
else:
|
||||
return INVITE_FROM_PERSON_TO_ROOM % {
|
||||
"person": inviter_name,
|
||||
"room": room_name,
|
||||
"app": self.app_name,
|
||||
}
|
||||
|
||||
sender_name = None
|
||||
if len(notifs_by_room[room_id]) == 1:
|
||||
# There is just the one notification, so give some detail
|
||||
event = notif_events[notifs_by_room[room_id][0]["event_id"]]
|
||||
if ("m.room.member", event.sender) in state_by_room[room_id]:
|
||||
state_event = state_by_room[room_id][("m.room.member", event.sender)]
|
||||
sender_name = name_from_member_event(state_event)
|
||||
|
||||
if sender_name is not None and room_name is not None:
|
||||
return MESSAGE_FROM_PERSON_IN_ROOM % {
|
||||
"person": sender_name,
|
||||
"room": room_name,
|
||||
"app": self.app_name,
|
||||
}
|
||||
elif sender_name is not None:
|
||||
return MESSAGE_FROM_PERSON % {
|
||||
"person": sender_name,
|
||||
"app": self.app_name,
|
||||
}
|
||||
else:
|
||||
# There's more than one notification for this room, so just
|
||||
# say there are several
|
||||
if room_name is not None:
|
||||
return MESSAGES_IN_ROOM % {
|
||||
"room": room_name,
|
||||
"app": self.app_name,
|
||||
}
|
||||
else:
|
||||
# If the room doesn't have a name, say who the messages
|
||||
# are from explicitly to avoid, "messages in the Bob room"
|
||||
sender_ids = list(set([
|
||||
notif_events[n['event_id']].sender
|
||||
for n in notifs_by_room[room_id]
|
||||
]))
|
||||
|
||||
return MESSAGES_FROM_PERSON % {
|
||||
"person": descriptor_from_member_events([
|
||||
state_by_room[room_id][("m.room.member", s)]
|
||||
for s in sender_ids
|
||||
]),
|
||||
"app": self.app_name,
|
||||
}
|
||||
else:
|
||||
# Stuff's happened in multiple different rooms
|
||||
|
||||
# ...but we still refer to the 'reason' room which triggered the mail
|
||||
if reason['room_name'] is not None:
|
||||
return MESSAGES_IN_ROOM_AND_OTHERS % {
|
||||
"room": reason['room_name'],
|
||||
"app": self.app_name,
|
||||
}
|
||||
else:
|
||||
# If the reason room doesn't have a name, say who the messages
|
||||
# are from explicitly to avoid, "messages in the Bob room"
|
||||
sender_ids = list(set([
|
||||
notif_events[n['event_id']].sender
|
||||
for n in notifs_by_room[reason['room_id']]
|
||||
]))
|
||||
|
||||
return MESSAGES_FROM_PERSON_AND_OTHERS % {
|
||||
"person": descriptor_from_member_events([
|
||||
state_by_room[reason['room_id']][("m.room.member", s)]
|
||||
for s in sender_ids
|
||||
]),
|
||||
"app": self.app_name,
|
||||
}
|
||||
|
||||
def make_room_link(self, room_id):
|
||||
# need /beta for Universal Links to work on iOS
|
||||
if self.app_name == "Vector":
|
||||
return "https://vector.im/beta/#/room/%s" % (room_id,)
|
||||
else:
|
||||
return "https://matrix.to/#/%s" % (room_id,)
|
||||
|
||||
def make_notif_link(self, notif):
|
||||
# need /beta for Universal Links to work on iOS
|
||||
if self.app_name == "Vector":
|
||||
return "https://vector.im/beta/#/room/%s/%s" % (
|
||||
notif['room_id'], notif['event_id']
|
||||
)
|
||||
else:
|
||||
return "https://matrix.to/#/%s/%s" % (
|
||||
notif['room_id'], notif['event_id']
|
||||
)
|
||||
|
||||
def make_unsubscribe_link(self, user_id, app_id, email_address):
|
||||
params = {
|
||||
"access_token": self.auth_handler.generate_delete_pusher_token(user_id),
|
||||
"app_id": app_id,
|
||||
"pushkey": email_address,
|
||||
}
|
||||
|
||||
# XXX: make r0 once API is stable
|
||||
return "%s_matrix/client/unstable/pushers/remove?%s" % (
|
||||
self.hs.config.public_baseurl,
|
||||
urllib.urlencode(params),
|
||||
)
|
||||
|
||||
def mxc_to_http_filter(self, value, width, height, resize_method="crop"):
|
||||
if value[0:6] != "mxc://":
|
||||
return ""
|
||||
|
||||
serverAndMediaId = value[6:]
|
||||
fragment = None
|
||||
if '#' in serverAndMediaId:
|
||||
(serverAndMediaId, fragment) = serverAndMediaId.split('#', 1)
|
||||
fragment = "#" + fragment
|
||||
|
||||
params = {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"method": resize_method,
|
||||
}
|
||||
return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
|
||||
self.hs.config.public_baseurl,
|
||||
serverAndMediaId,
|
||||
urllib.urlencode(params),
|
||||
fragment or "",
|
||||
)
|
||||
|
||||
|
||||
def safe_markup(raw_html):
|
||||
return jinja2.Markup(bleach.linkify(bleach.clean(
|
||||
raw_html, tags=ALLOWED_TAGS, attributes=ALLOWED_ATTRS,
|
||||
# bleach master has this, but it isn't released yet
|
||||
# protocols=ALLOWED_SCHEMES,
|
||||
strip=True
|
||||
)))
|
||||
|
||||
|
||||
def safe_text(raw_text):
|
||||
"""
|
||||
Process text: treat it as HTML but escape any tags (ie. just escape the
|
||||
HTML) then linkify it.
|
||||
"""
|
||||
return jinja2.Markup(bleach.linkify(bleach.clean(
|
||||
raw_text, tags=[], attributes={},
|
||||
strip=False
|
||||
)))
|
||||
|
||||
|
||||
def deduped_ordered_list(l):
|
||||
seen = set()
|
||||
ret = []
|
||||
for item in l:
|
||||
if item not in seen:
|
||||
seen.add(item)
|
||||
ret.append(item)
|
||||
return ret
|
||||
|
||||
|
||||
def string_ordinal_total(s):
|
||||
tot = 0
|
||||
for c in s:
|
||||
tot += ord(c)
|
||||
return tot
|
||||
|
||||
|
||||
def format_ts_filter(value, format):
|
||||
return time.strftime(format, time.localtime(value / 1000))
|
||||
@@ -38,9 +38,7 @@ def get_badge_count(store, user_id):
|
||||
r.room_id, user_id, last_unread_event_id
|
||||
)
|
||||
)
|
||||
# return one badge count per conversation, as count per
|
||||
# message is so noisy as to be almost useless
|
||||
badge += 1 if notifs["notify_count"] else 0
|
||||
badge += notifs["notify_count"]
|
||||
defer.returnValue(badge)
|
||||
|
||||
|
||||
|
||||
@@ -1,47 +1,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from httppusher import HttpPusher
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# We try importing this if we can (it will fail if we don't
|
||||
# have the optional email dependencies installed). We don't
|
||||
# yet have the config to know if we need the email pusher,
|
||||
# but importing this after daemonizing seems to fail
|
||||
# (even though a simple test of importing from a daemonized
|
||||
# process works fine)
|
||||
try:
|
||||
from synapse.push.emailpusher import EmailPusher
|
||||
except:
|
||||
pass
|
||||
PUSHER_TYPES = {
|
||||
'http': HttpPusher
|
||||
}
|
||||
|
||||
|
||||
def create_pusher(hs, pusherdict):
|
||||
logger.info("trying to create_pusher for %r", pusherdict)
|
||||
|
||||
PUSHER_TYPES = {
|
||||
"http": HttpPusher,
|
||||
}
|
||||
|
||||
logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
|
||||
if hs.config.email_enable_notifs:
|
||||
PUSHER_TYPES["email"] = EmailPusher
|
||||
logger.info("defined email pusher type")
|
||||
|
||||
if pusherdict['kind'] in PUSHER_TYPES:
|
||||
logger.info("found pusher")
|
||||
return PUSHER_TYPES[pusherdict['kind']](hs, pusherdict)
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
from twisted.internet import defer
|
||||
|
||||
import pusher
|
||||
from synapse.push import PusherConfigException
|
||||
from synapse.util.logcontext import preserve_fn
|
||||
from synapse.util.async import run_on_reactor
|
||||
|
||||
@@ -49,7 +50,6 @@ class PusherPool:
|
||||
# recreated, added and started: this means we have only one
|
||||
# code path adding pushers.
|
||||
pusher.create_pusher(self.hs, {
|
||||
"id": None,
|
||||
"user_name": user_id,
|
||||
"kind": kind,
|
||||
"app_id": app_id,
|
||||
@@ -185,8 +185,8 @@ class PusherPool:
|
||||
for pusherdict in pushers:
|
||||
try:
|
||||
p = pusher.create_pusher(self.hs, pusherdict)
|
||||
except:
|
||||
logger.exception("Couldn't start a pusher: caught Exception")
|
||||
except PusherConfigException:
|
||||
logger.exception("Couldn't start a pusher: caught PusherConfigException")
|
||||
continue
|
||||
if p:
|
||||
appid_pushkey = "%s:%s" % (
|
||||
|
||||
@@ -36,6 +36,7 @@ REQUIREMENTS = {
|
||||
"blist": ["blist"],
|
||||
"pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
|
||||
"pymacaroons-pynacl": ["pymacaroons"],
|
||||
"pyjwt": ["jwt"],
|
||||
}
|
||||
CONDITIONAL_REQUIREMENTS = {
|
||||
"web_client": {
|
||||
@@ -44,10 +45,6 @@ CONDITIONAL_REQUIREMENTS = {
|
||||
"preview_url": {
|
||||
"netaddr>=0.7.18": ["netaddr"],
|
||||
},
|
||||
"email.enable_notifs": {
|
||||
"Jinja2>=2.8": ["Jinja2>=2.8"],
|
||||
"bleach>=1.4.2": ["bleach>=1.4.2"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.http.server import respond_with_json_bytes, request_handler
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
from twisted.internet import defer
|
||||
|
||||
|
||||
class PresenceResource(Resource):
|
||||
"""
|
||||
HTTP endpoint for marking users as syncing.
|
||||
|
||||
POST /_synapse/replication/presence HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"process_id": "<process_id>",
|
||||
"syncing_users": ["<user_id>"]
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
Resource.__init__(self) # Resource is old-style, so no super()
|
||||
|
||||
self.version_string = hs.version_string
|
||||
self.clock = hs.get_clock()
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
|
||||
def render_POST(self, request):
|
||||
self._async_render_POST(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@request_handler()
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_POST(self, request):
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
process_id = content["process_id"]
|
||||
syncing_user_ids = content["syncing_users"]
|
||||
|
||||
yield self.presence_handler.update_external_syncs(
|
||||
process_id, set(syncing_user_ids)
|
||||
)
|
||||
|
||||
respond_with_json_bytes(request, 200, "{}")
|
||||
@@ -31,13 +31,12 @@ class PusherResource(Resource):
|
||||
self.version_string = hs.version_string
|
||||
self.store = hs.get_datastore()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
def render_POST(self, request):
|
||||
self._async_render_POST(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@request_handler()
|
||||
@request_handler
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_POST(self, request):
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
from synapse.http.servlet import parse_integer, parse_string
|
||||
from synapse.http.server import request_handler, finish_request
|
||||
from synapse.replication.pusher_resource import PusherResource
|
||||
from synapse.replication.presence_resource import PresenceResource
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
@@ -110,13 +109,11 @@ class ReplicationResource(Resource):
|
||||
self.version_string = hs.version_string
|
||||
self.store = hs.get_datastore()
|
||||
self.sources = hs.get_event_sources()
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
self.typing_handler = hs.get_typing_handler()
|
||||
self.presence_handler = hs.get_handlers().presence_handler
|
||||
self.typing_handler = hs.get_handlers().typing_notification_handler
|
||||
self.notifier = hs.notifier
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
self.putChild("remove_pushers", PusherResource(hs))
|
||||
self.putChild("syncing_users", PresenceResource(hs))
|
||||
|
||||
def render_GET(self, request):
|
||||
self._async_render_GET(request)
|
||||
@@ -142,7 +139,7 @@ class ReplicationResource(Resource):
|
||||
state_token,
|
||||
))
|
||||
|
||||
@request_handler()
|
||||
@request_handler
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_GET(self, request):
|
||||
limit = parse_integer(request, "limit", 100)
|
||||
@@ -161,15 +158,6 @@ class ReplicationResource(Resource):
|
||||
|
||||
result = yield self.notifier.wait_for_replication(replicate, timeout)
|
||||
|
||||
for stream_name, stream_content in result.items():
|
||||
logger.info(
|
||||
"Replicating %d rows of %s from %s -> %s",
|
||||
len(stream_content["rows"]),
|
||||
stream_name,
|
||||
request_streams.get(stream_name),
|
||||
stream_content["position"],
|
||||
)
|
||||
|
||||
request.write(json.dumps(result, ensure_ascii=False))
|
||||
finish_request(request)
|
||||
|
||||
@@ -394,7 +382,7 @@ class _Writer(object):
|
||||
position = rows[-1][0]
|
||||
|
||||
self.streams[name] = {
|
||||
"position": position if type(position) is int else str(position),
|
||||
"position": str(position),
|
||||
"field_names": fields,
|
||||
"rows": rows,
|
||||
}
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.account_data import AccountDataStore
|
||||
from synapse.storage.tags import TagsStore
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
|
||||
class SlavedAccountDataStore(BaseSlavedStore):
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedAccountDataStore, self).__init__(db_conn, hs)
|
||||
self._account_data_id_gen = SlavedIdTracker(
|
||||
db_conn, "account_data_max_stream_id", "stream_id",
|
||||
)
|
||||
self._account_data_stream_cache = StreamChangeCache(
|
||||
"AccountDataAndTagsChangeCache",
|
||||
self._account_data_id_gen.get_current_token(),
|
||||
)
|
||||
|
||||
get_account_data_for_user = (
|
||||
AccountDataStore.__dict__["get_account_data_for_user"]
|
||||
)
|
||||
|
||||
get_global_account_data_by_type_for_users = (
|
||||
AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
|
||||
)
|
||||
|
||||
get_global_account_data_by_type_for_user = (
|
||||
AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
|
||||
)
|
||||
|
||||
get_tags_for_user = TagsStore.__dict__["get_tags_for_user"]
|
||||
|
||||
get_updated_tags = DataStore.get_updated_tags.__func__
|
||||
get_updated_account_data_for_user = (
|
||||
DataStore.get_updated_account_data_for_user.__func__
|
||||
)
|
||||
|
||||
def get_max_account_data_stream_id(self):
|
||||
return self._account_data_id_gen.get_current_token()
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedAccountDataStore, self).stream_positions()
|
||||
position = self._account_data_id_gen.get_current_token()
|
||||
result["user_account_data"] = position
|
||||
result["room_account_data"] = position
|
||||
result["tag_account_data"] = position
|
||||
return result
|
||||
|
||||
def process_replication(self, result):
|
||||
stream = result.get("user_account_data")
|
||||
if stream:
|
||||
self._account_data_id_gen.advance(int(stream["position"]))
|
||||
for row in stream["rows"]:
|
||||
position, user_id, data_type = row[:3]
|
||||
self.get_global_account_data_by_type_for_user.invalidate(
|
||||
(data_type, user_id,)
|
||||
)
|
||||
self.get_account_data_for_user.invalidate((user_id,))
|
||||
self._account_data_stream_cache.entity_has_changed(
|
||||
user_id, position
|
||||
)
|
||||
|
||||
stream = result.get("room_account_data")
|
||||
if stream:
|
||||
self._account_data_id_gen.advance(int(stream["position"]))
|
||||
for row in stream["rows"]:
|
||||
position, user_id = row[:2]
|
||||
self.get_account_data_for_user.invalidate((user_id,))
|
||||
self._account_data_stream_cache.entity_has_changed(
|
||||
user_id, position
|
||||
)
|
||||
|
||||
stream = result.get("tag_account_data")
|
||||
if stream:
|
||||
self._account_data_id_gen.advance(int(stream["position"]))
|
||||
for row in stream["rows"]:
|
||||
position, user_id = row[:2]
|
||||
self.get_tags_for_user.invalidate((user_id,))
|
||||
self._account_data_stream_cache.entity_has_changed(
|
||||
user_id, position
|
||||
)
|
||||
|
||||
return super(SlavedAccountDataStore, self).process_replication(result)
|
||||
@@ -1,30 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from synapse.storage import DataStore
|
||||
from synapse.config.appservice import load_appservices
|
||||
|
||||
|
||||
class SlavedApplicationServiceStore(BaseSlavedStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedApplicationServiceStore, self).__init__(db_conn, hs)
|
||||
self.services_cache = load_appservices(
|
||||
hs.config.server_name,
|
||||
hs.config.app_service_config_files
|
||||
)
|
||||
|
||||
get_app_service_by_token = DataStore.get_app_service_by_token.__func__
|
||||
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
|
||||
@@ -23,7 +23,6 @@ from synapse.storage.roommember import RoomMemberStore
|
||||
from synapse.storage.event_federation import EventFederationStore
|
||||
from synapse.storage.event_push_actions import EventPushActionsStore
|
||||
from synapse.storage.state import StateStore
|
||||
from synapse.storage.stream import StreamStore
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
import ujson as json
|
||||
@@ -58,9 +57,6 @@ class SlavedEventStore(BaseSlavedStore):
|
||||
"EventsRoomStreamChangeCache", min_event_val,
|
||||
prefilled_cache=event_cache_prefill,
|
||||
)
|
||||
self._membership_stream_cache = StreamChangeCache(
|
||||
"MembershipStreamChangeCache", events_max,
|
||||
)
|
||||
|
||||
# Cached functions can't be accessed through a class instance so we need
|
||||
# to reach inside the __dict__ to extract them.
|
||||
@@ -79,21 +75,6 @@ class SlavedEventStore(BaseSlavedStore):
|
||||
get_unread_event_push_actions_by_room_for_user = (
|
||||
EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"]
|
||||
)
|
||||
_get_state_group_for_events = (
|
||||
StateStore.__dict__["_get_state_group_for_events"]
|
||||
)
|
||||
_get_state_group_for_event = (
|
||||
StateStore.__dict__["_get_state_group_for_event"]
|
||||
)
|
||||
_get_state_groups_from_groups = (
|
||||
StateStore.__dict__["_get_state_groups_from_groups"]
|
||||
)
|
||||
_get_state_group_from_group = (
|
||||
StateStore.__dict__["_get_state_group_from_group"]
|
||||
)
|
||||
get_recent_event_ids_for_room = (
|
||||
StreamStore.__dict__["get_recent_event_ids_for_room"]
|
||||
)
|
||||
|
||||
get_unread_push_actions_for_user_in_range = (
|
||||
DataStore.get_unread_push_actions_for_user_in_range.__func__
|
||||
@@ -102,7 +83,6 @@ class SlavedEventStore(BaseSlavedStore):
|
||||
DataStore.get_push_action_users_in_range.__func__
|
||||
)
|
||||
get_event = DataStore.get_event.__func__
|
||||
get_events = DataStore.get_events.__func__
|
||||
get_current_state = DataStore.get_current_state.__func__
|
||||
get_current_state_for_key = DataStore.get_current_state_for_key.__func__
|
||||
get_rooms_for_user_where_membership_is = (
|
||||
@@ -115,17 +95,8 @@ class SlavedEventStore(BaseSlavedStore):
|
||||
get_room_events_stream_for_room = (
|
||||
DataStore.get_room_events_stream_for_room.__func__
|
||||
)
|
||||
get_events_around = DataStore.get_events_around.__func__
|
||||
get_state_for_event = DataStore.get_state_for_event.__func__
|
||||
get_state_for_events = DataStore.get_state_for_events.__func__
|
||||
get_state_groups = DataStore.get_state_groups.__func__
|
||||
get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
|
||||
get_room_events_stream_for_rooms = (
|
||||
DataStore.get_room_events_stream_for_rooms.__func__
|
||||
)
|
||||
get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
|
||||
|
||||
_set_before_and_after = staticmethod(DataStore._set_before_and_after)
|
||||
_set_before_and_after = DataStore._set_before_and_after
|
||||
|
||||
_get_events = DataStore._get_events.__func__
|
||||
_get_events_from_cache = DataStore._get_events_from_cache.__func__
|
||||
@@ -133,7 +104,6 @@ class SlavedEventStore(BaseSlavedStore):
|
||||
_invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__
|
||||
_parse_events_txn = DataStore._parse_events_txn.__func__
|
||||
_get_events_txn = DataStore._get_events_txn.__func__
|
||||
_get_event_txn = DataStore._get_event_txn.__func__
|
||||
_enqueue_events = DataStore._enqueue_events.__func__
|
||||
_do_fetch = DataStore._do_fetch.__func__
|
||||
_fetch_events_txn = DataStore._fetch_events_txn.__func__
|
||||
@@ -144,15 +114,11 @@ class SlavedEventStore(BaseSlavedStore):
|
||||
DataStore._get_rooms_for_user_where_membership_is_txn.__func__
|
||||
)
|
||||
_get_members_rows_txn = DataStore._get_members_rows_txn.__func__
|
||||
_get_state_for_groups = DataStore._get_state_for_groups.__func__
|
||||
_get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__
|
||||
_get_events_around_txn = DataStore._get_events_around_txn.__func__
|
||||
_get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedEventStore, self).stream_positions()
|
||||
result["events"] = self._stream_id_gen.get_current_token()
|
||||
result["backfill"] = -self._backfill_id_gen.get_current_token()
|
||||
result["backfill"] = self._backfill_id_gen.get_current_token()
|
||||
return result
|
||||
|
||||
def process_replication(self, result):
|
||||
@@ -162,7 +128,7 @@ class SlavedEventStore(BaseSlavedStore):
|
||||
|
||||
stream = result.get("events")
|
||||
if stream:
|
||||
self._stream_id_gen.advance(int(stream["position"]))
|
||||
self._stream_id_gen.advance(stream["position"])
|
||||
for row in stream["rows"]:
|
||||
self._process_replication_row(
|
||||
row, backfilled=False, state_resets=state_resets
|
||||
@@ -170,7 +136,7 @@ class SlavedEventStore(BaseSlavedStore):
|
||||
|
||||
stream = result.get("backfill")
|
||||
if stream:
|
||||
self._backfill_id_gen.advance(-int(stream["position"]))
|
||||
self._backfill_id_gen.advance(stream["position"])
|
||||
for row in stream["rows"]:
|
||||
self._process_replication_row(
|
||||
row, backfilled=True, state_resets=state_resets
|
||||
@@ -178,14 +144,12 @@ class SlavedEventStore(BaseSlavedStore):
|
||||
|
||||
stream = result.get("forward_ex_outliers")
|
||||
if stream:
|
||||
self._stream_id_gen.advance(int(stream["position"]))
|
||||
for row in stream["rows"]:
|
||||
event_id = row[1]
|
||||
self._invalidate_get_event_cache(event_id)
|
||||
|
||||
stream = result.get("backward_ex_outliers")
|
||||
if stream:
|
||||
self._backfill_id_gen.advance(-int(stream["position"]))
|
||||
for row in stream["rows"]:
|
||||
event_id = row[1]
|
||||
self._invalidate_get_event_cache(event_id)
|
||||
@@ -233,9 +197,9 @@ class SlavedEventStore(BaseSlavedStore):
|
||||
self.get_rooms_for_user.invalidate((event.state_key,))
|
||||
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
|
||||
self.get_users_in_room.invalidate((event.room_id,))
|
||||
self._membership_stream_cache.entity_has_changed(
|
||||
event.state_key, event.internal_metadata.stream_ordering
|
||||
)
|
||||
# self._membership_stream_cache.entity_has_changed(
|
||||
# event.state_key, event.internal_metadata.stream_ordering
|
||||
# )
|
||||
self.get_invited_rooms_for_user.invalidate((event.state_key,))
|
||||
|
||||
if not event.is_state():
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from synapse.storage.filtering import FilteringStore
|
||||
|
||||
|
||||
class SlavedFilteringStore(BaseSlavedStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedFilteringStore, self).__init__(db_conn, hs)
|
||||
|
||||
get_user_filter = FilteringStore.__dict__["get_user_filter"]
|
||||
@@ -1,59 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
from synapse.storage import DataStore
|
||||
|
||||
|
||||
class SlavedPresenceStore(BaseSlavedStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedPresenceStore, self).__init__(db_conn, hs)
|
||||
self._presence_id_gen = SlavedIdTracker(
|
||||
db_conn, "presence_stream", "stream_id",
|
||||
)
|
||||
|
||||
self._presence_on_startup = self._get_active_presence(db_conn)
|
||||
|
||||
self.presence_stream_cache = self.presence_stream_cache = StreamChangeCache(
|
||||
"PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
|
||||
)
|
||||
|
||||
_get_active_presence = DataStore._get_active_presence.__func__
|
||||
take_presence_startup_info = DataStore.take_presence_startup_info.__func__
|
||||
get_presence_for_users = DataStore.get_presence_for_users.__func__
|
||||
|
||||
def get_current_presence_token(self):
|
||||
return self._presence_id_gen.get_current_token()
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedPresenceStore, self).stream_positions()
|
||||
position = self._presence_id_gen.get_current_token()
|
||||
result["presence"] = position
|
||||
return result
|
||||
|
||||
def process_replication(self, result):
|
||||
stream = result.get("presence")
|
||||
if stream:
|
||||
self._presence_id_gen.advance(int(stream["position"]))
|
||||
for row in stream["rows"]:
|
||||
position, user_id = row[:2]
|
||||
self.presence_stream_cache.entity_has_changed(
|
||||
user_id, position
|
||||
)
|
||||
|
||||
return super(SlavedPresenceStore, self).process_replication(result)
|
||||
@@ -1,67 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .events import SlavedEventStore
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.push_rule import PushRuleStore
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
|
||||
class SlavedPushRuleStore(SlavedEventStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedPushRuleStore, self).__init__(db_conn, hs)
|
||||
self._push_rules_stream_id_gen = SlavedIdTracker(
|
||||
db_conn, "push_rules_stream", "stream_id",
|
||||
)
|
||||
self.push_rules_stream_cache = StreamChangeCache(
|
||||
"PushRulesStreamChangeCache",
|
||||
self._push_rules_stream_id_gen.get_current_token(),
|
||||
)
|
||||
|
||||
get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
|
||||
get_push_rules_enabled_for_user = (
|
||||
PushRuleStore.__dict__["get_push_rules_enabled_for_user"]
|
||||
)
|
||||
have_push_rules_changed_for_user = (
|
||||
DataStore.have_push_rules_changed_for_user.__func__
|
||||
)
|
||||
|
||||
def get_push_rules_stream_token(self):
|
||||
return (
|
||||
self._push_rules_stream_id_gen.get_current_token(),
|
||||
self._stream_id_gen.get_current_token(),
|
||||
)
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedPushRuleStore, self).stream_positions()
|
||||
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
|
||||
return result
|
||||
|
||||
def process_replication(self, result):
|
||||
stream = result.get("push_rules")
|
||||
if stream:
|
||||
for row in stream["rows"]:
|
||||
position = row[0]
|
||||
user_id = row[2]
|
||||
self.get_push_rules_for_user.invalidate((user_id,))
|
||||
self.get_push_rules_enabled_for_user.invalidate((user_id,))
|
||||
self.push_rules_stream_cache.entity_has_changed(
|
||||
user_id, position
|
||||
)
|
||||
|
||||
self._push_rules_stream_id_gen.advance(int(stream["position"]))
|
||||
|
||||
return super(SlavedPushRuleStore, self).process_replication(result)
|
||||
@@ -43,10 +43,10 @@ class SlavedPusherStore(BaseSlavedStore):
|
||||
def process_replication(self, result):
|
||||
stream = result.get("pushers")
|
||||
if stream:
|
||||
self._pushers_id_gen.advance(int(stream["position"]))
|
||||
self._pushers_id_gen.advance(stream["position"])
|
||||
|
||||
stream = result.get("deleted_pushers")
|
||||
if stream:
|
||||
self._pushers_id_gen.advance(int(stream["position"]))
|
||||
self._pushers_id_gen.advance(stream["position"])
|
||||
|
||||
return super(SlavedPusherStore, self).process_replication(result)
|
||||
|
||||
@@ -18,7 +18,6 @@ from ._slaved_id_tracker import SlavedIdTracker
|
||||
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.receipts import ReceiptsStore
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
# So, um, we want to borrow a load of functions intended for reading from
|
||||
# a DataStore, but we don't want to take functions that either write to the
|
||||
@@ -38,28 +37,11 @@ class SlavedReceiptsStore(BaseSlavedStore):
|
||||
db_conn, "receipts_linearized", "stream_id"
|
||||
)
|
||||
|
||||
self._receipts_stream_cache = StreamChangeCache(
|
||||
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
|
||||
)
|
||||
|
||||
get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
|
||||
get_linearized_receipts_for_room = (
|
||||
ReceiptsStore.__dict__["get_linearized_receipts_for_room"]
|
||||
)
|
||||
_get_linearized_receipts_for_rooms = (
|
||||
ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"]
|
||||
)
|
||||
get_last_receipt_event_id_for_user = (
|
||||
ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"]
|
||||
)
|
||||
|
||||
get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
|
||||
get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
|
||||
|
||||
get_linearized_receipts_for_rooms = (
|
||||
DataStore.get_linearized_receipts_for_rooms.__func__
|
||||
)
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedReceiptsStore, self).stream_positions()
|
||||
result["receipts"] = self._receipts_id_gen.get_current_token()
|
||||
@@ -68,17 +50,12 @@ class SlavedReceiptsStore(BaseSlavedStore):
|
||||
def process_replication(self, result):
|
||||
stream = result.get("receipts")
|
||||
if stream:
|
||||
self._receipts_id_gen.advance(int(stream["position"]))
|
||||
self._receipts_id_gen.advance(stream["position"])
|
||||
for row in stream["rows"]:
|
||||
position, room_id, receipt_type, user_id = row[:4]
|
||||
room_id, receipt_type, user_id = row[1:4]
|
||||
self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
|
||||
self._receipts_stream_cache.entity_has_changed(room_id, position)
|
||||
|
||||
return super(SlavedReceiptsStore, self).process_replication(result)
|
||||
|
||||
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
|
||||
self.get_receipts_for_user.invalidate((user_id, receipt_type))
|
||||
self.get_linearized_receipts_for_room.invalidate_many((room_id,))
|
||||
self.get_last_receipt_event_id_for_user.invalidate(
|
||||
(user_id, room_id, receipt_type)
|
||||
)
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.registration import RegistrationStore
|
||||
|
||||
|
||||
class SlavedRegistrationStore(BaseSlavedStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedRegistrationStore, self).__init__(db_conn, hs)
|
||||
|
||||
# TODO: use the cached version and invalidate deleted tokens
|
||||
get_user_by_access_token = RegistrationStore.__dict__[
|
||||
"get_user_by_access_token"
|
||||
].orig
|
||||
|
||||
_query_for_auth = DataStore._query_for_auth.__func__
|
||||
@@ -44,8 +44,6 @@ from synapse.rest.client.v2_alpha import (
|
||||
tokenrefresh,
|
||||
tags,
|
||||
account_data,
|
||||
report_event,
|
||||
openid,
|
||||
)
|
||||
|
||||
from synapse.http.server import JsonResource
|
||||
@@ -88,5 +86,3 @@ class ClientRestResource(JsonResource):
|
||||
tokenrefresh.register_servlets(hs, client_resource)
|
||||
tags.register_servlets(hs, client_resource)
|
||||
account_data.register_servlets(hs, client_resource)
|
||||
report_event.register_servlets(hs, client_resource)
|
||||
openid.register_servlets(hs, client_resource)
|
||||
|
||||
@@ -33,6 +33,9 @@ from saml2.client import Saml2Client
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
import jwt
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -58,7 +61,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||
self.cas_required_attributes = hs.config.cas_required_attributes
|
||||
self.servername = hs.config.server_name
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
self.auth_handler = self.hs.get_auth_handler()
|
||||
|
||||
def on_GET(self, request):
|
||||
flows = []
|
||||
@@ -144,7 +146,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||
user_id, self.hs.hostname
|
||||
).to_string()
|
||||
|
||||
auth_handler = self.auth_handler
|
||||
auth_handler = self.handlers.auth_handler
|
||||
user_id, access_token, refresh_token = yield auth_handler.login_with_password(
|
||||
user_id=user_id,
|
||||
password=login_submission["password"])
|
||||
@@ -161,7 +163,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||
@defer.inlineCallbacks
|
||||
def do_token_login(self, login_submission):
|
||||
token = login_submission['token']
|
||||
auth_handler = self.auth_handler
|
||||
auth_handler = self.handlers.auth_handler
|
||||
user_id = (
|
||||
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
||||
)
|
||||
@@ -195,7 +197,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||
auth_handler = self.auth_handler
|
||||
auth_handler = self.handlers.auth_handler
|
||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||
if user_exists:
|
||||
user_id, access_token, refresh_token = (
|
||||
@@ -222,29 +224,21 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_jwt_login(self, login_submission):
|
||||
token = login_submission.get("token", None)
|
||||
token = login_submission['token']
|
||||
if token is None:
|
||||
raise LoginError(
|
||||
401, "Token field for JWT is missing",
|
||||
errcode=Codes.UNAUTHORIZED
|
||||
)
|
||||
|
||||
import jwt
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
|
||||
except InvalidTokenError:
|
||||
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
user = payload.get("sub", None)
|
||||
user = payload['user']
|
||||
if user is None:
|
||||
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||
auth_handler = self.auth_handler
|
||||
auth_handler = self.handlers.auth_handler
|
||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||
if user_exists:
|
||||
user_id, access_token, refresh_token = (
|
||||
@@ -413,7 +407,7 @@ class CasTicketServlet(ClientV1RestServlet):
|
||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||
auth_handler = self.auth_handler
|
||||
auth_handler = self.handlers.auth_handler
|
||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||
if not user_exists:
|
||||
user_id, _ = (
|
||||
|
||||
@@ -30,24 +30,20 @@ logger = logging.getLogger(__name__)
|
||||
class PresenceStatusRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(PresenceStatusRestServlet, self).__init__(hs)
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
if requester.user != user:
|
||||
allowed = yield self.presence_handler.is_visible(
|
||||
allowed = yield self.handlers.presence_handler.is_visible(
|
||||
observed_user=user, observer_user=requester.user,
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
raise AuthError(403, "You are not allowed to see their presence.")
|
||||
|
||||
state = yield self.presence_handler.get_state(target_user=user)
|
||||
state = yield self.handlers.presence_handler.get_state(target_user=user)
|
||||
|
||||
defer.returnValue((200, state))
|
||||
|
||||
@@ -78,7 +74,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
|
||||
except:
|
||||
raise SynapseError(400, "Unable to parse state")
|
||||
|
||||
yield self.presence_handler.set_state(user, state)
|
||||
yield self.handlers.presence_handler.set_state(user, state)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
@@ -89,10 +85,6 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
|
||||
class PresenceListRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/presence/list/(?P<user_id>[^/]*)")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(PresenceListRestServlet, self).__init__(hs)
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
@@ -104,7 +96,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
|
||||
if requester.user != user:
|
||||
raise SynapseError(400, "Cannot get another user's presence list")
|
||||
|
||||
presence = yield self.presence_handler.get_presence_list(
|
||||
presence = yield self.handlers.presence_handler.get_presence_list(
|
||||
observer_user=user, accepted=True
|
||||
)
|
||||
|
||||
@@ -131,7 +123,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
|
||||
if len(u) == 0:
|
||||
continue
|
||||
invited_user = UserID.from_string(u)
|
||||
yield self.presence_handler.send_presence_invite(
|
||||
yield self.handlers.presence_handler.send_presence_invite(
|
||||
observer_user=user, observed_user=invited_user
|
||||
)
|
||||
|
||||
@@ -142,7 +134,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
|
||||
if len(u) == 0:
|
||||
continue
|
||||
dropped_user = UserID.from_string(u)
|
||||
yield self.presence_handler.drop(
|
||||
yield self.handlers.presence_handler.drop(
|
||||
observer_user=user, observed_user=dropped_user
|
||||
)
|
||||
|
||||
|
||||
@@ -17,11 +17,7 @@ from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError, Codes
|
||||
from synapse.push import PusherConfigException
|
||||
from synapse.http.servlet import (
|
||||
parse_json_object_from_request, parse_string, RestServlet
|
||||
)
|
||||
from synapse.http.server import finish_request
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
|
||||
from .base import ClientV1RestServlet, client_path_patterns
|
||||
|
||||
@@ -140,57 +136,6 @@ class PushersSetRestServlet(ClientV1RestServlet):
|
||||
return 200, {}
|
||||
|
||||
|
||||
class PushersRemoveRestServlet(RestServlet):
|
||||
"""
|
||||
To allow pusher to be delete by clicking a link (ie. GET request)
|
||||
"""
|
||||
PATTERNS = client_path_patterns("/pushers/remove$")
|
||||
SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>"
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.notifier = hs.get_notifier()
|
||||
self.auth = hs.get_v1auth()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request, rights="delete_pusher")
|
||||
user = requester.user
|
||||
|
||||
app_id = parse_string(request, "app_id", required=True)
|
||||
pushkey = parse_string(request, "pushkey", required=True)
|
||||
|
||||
pusher_pool = self.hs.get_pusherpool()
|
||||
|
||||
try:
|
||||
yield pusher_pool.remove_pusher(
|
||||
app_id=app_id,
|
||||
pushkey=pushkey,
|
||||
user_id=user.to_string(),
|
||||
)
|
||||
except StoreError as se:
|
||||
if se.code != 404:
|
||||
# This is fine: they're already unsubscribed
|
||||
raise
|
||||
|
||||
self.notifier.on_new_replication_data()
|
||||
|
||||
request.setResponseCode(200)
|
||||
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
|
||||
request.setHeader(b"Server", self.hs.version_string)
|
||||
request.setHeader(b"Content-Length", b"%d" % (
|
||||
len(PushersRemoveRestServlet.SUCCESS_HTML),
|
||||
))
|
||||
request.write(PushersRemoveRestServlet.SUCCESS_HTML)
|
||||
finish_request(request)
|
||||
defer.returnValue(None)
|
||||
|
||||
def on_OPTIONS(self, _):
|
||||
return 200, {}
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
PushersRestServlet(hs).register(http_server)
|
||||
PushersSetRestServlet(hs).register(http_server)
|
||||
PushersRemoveRestServlet(hs).register(http_server)
|
||||
|
||||
@@ -355,76 +355,5 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||
)
|
||||
|
||||
|
||||
class CreateUserRestServlet(ClientV1RestServlet):
|
||||
"""Handles user creation via a server-to-server interface
|
||||
"""
|
||||
|
||||
PATTERNS = client_path_patterns("/createUser$", releases=())
|
||||
|
||||
def __init__(self, hs):
|
||||
super(CreateUserRestServlet, self).__init__(hs)
|
||||
self.store = hs.get_datastore()
|
||||
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
user_json = parse_json_object_from_request(request)
|
||||
|
||||
if "access_token" not in request.args:
|
||||
raise SynapseError(400, "Expected application service token.")
|
||||
|
||||
app_service = yield self.store.get_app_service_by_token(
|
||||
request.args["access_token"][0]
|
||||
)
|
||||
if not app_service:
|
||||
raise SynapseError(403, "Invalid application service token.")
|
||||
|
||||
logger.debug("creating user: %s", user_json)
|
||||
|
||||
response = yield self._do_create(user_json)
|
||||
|
||||
defer.returnValue((200, response))
|
||||
|
||||
def on_OPTIONS(self, request):
|
||||
return 403, {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_create(self, user_json):
|
||||
yield run_on_reactor()
|
||||
|
||||
if "localpart" not in user_json:
|
||||
raise SynapseError(400, "Expected 'localpart' key.")
|
||||
|
||||
if "displayname" not in user_json:
|
||||
raise SynapseError(400, "Expected 'displayname' key.")
|
||||
|
||||
if "duration_seconds" not in user_json:
|
||||
raise SynapseError(400, "Expected 'duration_seconds' key.")
|
||||
|
||||
localpart = user_json["localpart"].encode("utf-8")
|
||||
displayname = user_json["displayname"].encode("utf-8")
|
||||
duration_seconds = 0
|
||||
try:
|
||||
duration_seconds = int(user_json["duration_seconds"])
|
||||
except ValueError:
|
||||
raise SynapseError(400, "Failed to parse 'duration_seconds'")
|
||||
if duration_seconds > self.direct_user_creation_max_duration:
|
||||
duration_seconds = self.direct_user_creation_max_duration
|
||||
|
||||
handler = self.handlers.registration_handler
|
||||
user_id, token = yield handler.get_or_create_user(
|
||||
localpart=localpart,
|
||||
displayname=displayname,
|
||||
duration_seconds=duration_seconds
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname,
|
||||
})
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
RegisterRestServlet(hs).register(http_server)
|
||||
CreateUserRestServlet(hs).register(http_server)
|
||||
|
||||
@@ -232,10 +232,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
|
||||
|
||||
if RoomID.is_valid(room_identifier):
|
||||
room_id = room_identifier
|
||||
try:
|
||||
remote_room_hosts = request.args["server_name"]
|
||||
except:
|
||||
remote_room_hosts = None
|
||||
remote_room_hosts = None
|
||||
elif RoomAlias.is_valid(room_identifier):
|
||||
handler = self.handlers.room_member_handler
|
||||
room_alias = RoomAlias.from_string(room_identifier)
|
||||
@@ -279,9 +276,8 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
handler = self.hs.get_room_list_handler()
|
||||
data = yield handler.get_aggregated_public_room_list()
|
||||
|
||||
handler = self.handlers.room_list_handler
|
||||
data = yield handler.get_public_room_list()
|
||||
defer.returnValue((200, data))
|
||||
|
||||
|
||||
@@ -574,8 +570,7 @@ class RoomTypingRestServlet(ClientV1RestServlet):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomTypingRestServlet, self).__init__(hs)
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
self.typing_handler = hs.get_typing_handler()
|
||||
self.presence_handler = hs.get_handlers().presence_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_id, user_id):
|
||||
@@ -586,17 +581,19 @@ class RoomTypingRestServlet(ClientV1RestServlet):
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
typing_handler = self.handlers.typing_notification_handler
|
||||
|
||||
yield self.presence_handler.bump_presence_active_time(requester.user)
|
||||
|
||||
if content["typing"]:
|
||||
yield self.typing_handler.started_typing(
|
||||
yield typing_handler.started_typing(
|
||||
target_user=target_user,
|
||||
auth_user=requester.user,
|
||||
room_id=room_id,
|
||||
timeout=content.get("timeout", 30000),
|
||||
)
|
||||
else:
|
||||
yield self.typing_handler.stopped_typing(
|
||||
yield typing_handler.stopped_typing(
|
||||
target_user=target_user,
|
||||
auth_user=requester.user,
|
||||
room_id=room_id,
|
||||
|
||||
@@ -35,7 +35,7 @@ class PasswordRestServlet(RestServlet):
|
||||
super(PasswordRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.auth_handler = hs.get_handlers().auth_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
@@ -52,7 +52,6 @@ class PasswordRestServlet(RestServlet):
|
||||
defer.returnValue((401, result))
|
||||
|
||||
user_id = None
|
||||
requester = None
|
||||
|
||||
if LoginType.PASSWORD in result:
|
||||
# if using password, they should also be logged in
|
||||
@@ -97,7 +96,7 @@ class ThreepidRestServlet(RestServlet):
|
||||
self.hs = hs
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.auth_handler = hs.get_handlers().auth_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
|
||||
@@ -104,7 +104,7 @@ class AuthRestServlet(RestServlet):
|
||||
super(AuthRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.auth_handler = hs.get_handlers().auth_handler
|
||||
self.registration_handler = hs.get_handlers().registration_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IdTokenServlet(RestServlet):
|
||||
"""
|
||||
Get a bearer token that may be passed to a third party to confirm ownership
|
||||
of a matrix user id.
|
||||
|
||||
The format of the response could be made compatible with the format given
|
||||
in http://openid.net/specs/openid-connect-core-1_0.html#TokenResponse
|
||||
|
||||
But instead of returning a signed "id_token" the response contains the
|
||||
name of the issuing matrix homeserver. This means that for now the third
|
||||
party will need to check the validity of the "id_token" against the
|
||||
federation /openid/userinfo endpoint of the homeserver.
|
||||
|
||||
Request:
|
||||
|
||||
POST /user/{user_id}/openid/request_token?access_token=... HTTP/1.1
|
||||
|
||||
{}
|
||||
|
||||
Response:
|
||||
|
||||
HTTP/1.1 200 OK
|
||||
{
|
||||
"access_token": "ABDEFGH",
|
||||
"token_type": "Bearer",
|
||||
"matrix_server_name": "example.com",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/user/(?P<user_id>[^/]*)/openid/request_token"
|
||||
)
|
||||
|
||||
EXPIRES_MS = 3600 * 1000
|
||||
|
||||
def __init__(self, hs):
|
||||
super(IdTokenServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
self.server_name = hs.config.server_name
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, user_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
if user_id != requester.user.to_string():
|
||||
raise AuthError(403, "Cannot request tokens for other users.")
|
||||
|
||||
# Parse the request body to make sure it's JSON, but ignore the contents
|
||||
# for now.
|
||||
parse_json_object_from_request(request)
|
||||
|
||||
token = random_string(24)
|
||||
ts_valid_until_ms = self.clock.time_msec() + self.EXPIRES_MS
|
||||
|
||||
yield self.store.insert_open_id_token(token, ts_valid_until_ms, user_id)
|
||||
|
||||
defer.returnValue((200, {
|
||||
"access_token": token,
|
||||
"token_type": "Bearer",
|
||||
"matrix_server_name": self.server_name,
|
||||
"expires_in": self.EXPIRES_MS / 1000,
|
||||
}))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
IdTokenServlet(hs).register(http_server)
|
||||
@@ -37,7 +37,7 @@ class ReceiptRestServlet(RestServlet):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.receipts_handler = hs.get_handlers().receipts_handler
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
self.presence_handler = hs.get_handlers().presence_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, receipt_type, event_id):
|
||||
|
||||
@@ -48,8 +48,7 @@ class RegisterRestServlet(RestServlet):
|
||||
super(RegisterRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.auth_handler = hs.get_handlers().auth_handler
|
||||
self.registration_handler = hs.get_handlers().registration_handler
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
|
||||
@@ -215,34 +214,6 @@ class RegisterRestServlet(RestServlet):
|
||||
threepid['validated_at'],
|
||||
)
|
||||
|
||||
# And we add an email pusher for them by default, but only
|
||||
# if email notifications are enabled (so people don't start
|
||||
# getting mail spam where they weren't before if email
|
||||
# notifs are set up on a home server)
|
||||
if (
|
||||
self.hs.config.email_enable_notifs and
|
||||
self.hs.config.email_notif_for_new_users
|
||||
):
|
||||
# Pull the ID of the access token back out of the db
|
||||
# It would really make more sense for this to be passed
|
||||
# up when the access token is saved, but that's quite an
|
||||
# invasive change I'd rather do separately.
|
||||
user_tuple = yield self.store.get_user_by_access_token(
|
||||
token
|
||||
)
|
||||
|
||||
yield self.hs.get_pusherpool().add_pusher(
|
||||
user_id=user_id,
|
||||
access_token=user_tuple["token_id"],
|
||||
kind="email",
|
||||
app_id="m.email",
|
||||
app_display_name="Email Notifications",
|
||||
device_display_name=threepid["address"],
|
||||
pushkey=threepid["address"],
|
||||
lang=None, # We don't know a user's language here
|
||||
data={},
|
||||
)
|
||||
|
||||
if 'bind_email' in params and params['bind_email']:
|
||||
logger.info("bind_email specified: binding")
|
||||
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReportEventRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ReportEventRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, event_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
yield self.store.add_event_report(
|
||||
room_id=room_id,
|
||||
event_id=event_id,
|
||||
user_id=user_id,
|
||||
reason=body.get("reason"),
|
||||
content=body,
|
||||
received_ts=self.clock.time_msec(),
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
ReportEventRestServlet(hs).register(http_server)
|
||||
@@ -79,10 +79,11 @@ class SyncRestServlet(RestServlet):
|
||||
def __init__(self, hs):
|
||||
super(SyncRestServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.sync_handler = hs.get_sync_handler()
|
||||
self.event_stream_handler = hs.get_handlers().event_stream_handler
|
||||
self.sync_handler = hs.get_handlers().sync_handler
|
||||
self.clock = hs.get_clock()
|
||||
self.filtering = hs.get_filtering()
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
self.presence_handler = hs.get_handlers().presence_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
|
||||
@@ -38,7 +38,7 @@ class TokenRefreshRestServlet(RestServlet):
|
||||
body = parse_json_object_from_request(request)
|
||||
try:
|
||||
old_refresh_token = body["refresh_token"]
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler = self.hs.get_handlers().auth_handler
|
||||
(user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
|
||||
old_refresh_token, auth_handler.generate_refresh_token)
|
||||
new_access_token = yield auth_handler.issue_access_token(user_id)
|
||||
|
||||
@@ -49,6 +49,7 @@ class LocalKey(Resource):
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.version_string = hs.version_string
|
||||
self.response_body = encode_canonical_json(
|
||||
self.response_json_object(hs.config)
|
||||
|
||||
@@ -97,7 +97,7 @@ class RemoteKey(Resource):
|
||||
self.async_render_GET(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@request_handler()
|
||||
@request_handler
|
||||
@defer.inlineCallbacks
|
||||
def async_render_GET(self, request):
|
||||
if len(request.postpath) == 1:
|
||||
@@ -122,7 +122,7 @@ class RemoteKey(Resource):
|
||||
self.async_render_POST(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@request_handler()
|
||||
@request_handler
|
||||
@defer.inlineCallbacks
|
||||
def async_render_POST(self, request):
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
@@ -36,13 +36,12 @@ class DownloadResource(Resource):
|
||||
self.server_name = hs.hostname
|
||||
self.store = hs.get_datastore()
|
||||
self.version_string = hs.version_string
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
def render_GET(self, request):
|
||||
self._async_render_GET(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@request_handler()
|
||||
@request_handler
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_GET(self, request):
|
||||
server_name, media_id, name = parse_media_id(request)
|
||||
|
||||
@@ -56,7 +56,8 @@ class PreviewUrlResource(Resource):
|
||||
self.client = SpiderHttpClient(hs)
|
||||
self.media_repo = media_repo
|
||||
|
||||
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
|
||||
if hasattr(hs.config, "url_preview_url_blacklist"):
|
||||
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
|
||||
|
||||
# simple memory cache mapping urls to OG metadata
|
||||
self.cache = ExpiringCache(
|
||||
@@ -73,7 +74,7 @@ class PreviewUrlResource(Resource):
|
||||
self._async_render_GET(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@request_handler()
|
||||
@request_handler
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_GET(self, request):
|
||||
|
||||
@@ -85,37 +86,39 @@ class PreviewUrlResource(Resource):
|
||||
else:
|
||||
ts = self.clock.time_msec()
|
||||
|
||||
url_tuple = urlparse.urlsplit(url)
|
||||
for entry in self.url_preview_url_blacklist:
|
||||
match = True
|
||||
for attrib in entry:
|
||||
pattern = entry[attrib]
|
||||
value = getattr(url_tuple, attrib)
|
||||
logger.debug((
|
||||
"Matching attrib '%s' with value '%s' against"
|
||||
" pattern '%s'"
|
||||
) % (attrib, value, pattern))
|
||||
# impose the URL pattern blacklist
|
||||
if hasattr(self, "url_preview_url_blacklist"):
|
||||
url_tuple = urlparse.urlsplit(url)
|
||||
for entry in self.url_preview_url_blacklist:
|
||||
match = True
|
||||
for attrib in entry:
|
||||
pattern = entry[attrib]
|
||||
value = getattr(url_tuple, attrib)
|
||||
logger.debug((
|
||||
"Matching attrib '%s' with value '%s' against"
|
||||
" pattern '%s'"
|
||||
) % (attrib, value, pattern))
|
||||
|
||||
if value is None:
|
||||
match = False
|
||||
continue
|
||||
|
||||
if pattern.startswith('^'):
|
||||
if not re.match(pattern, getattr(url_tuple, attrib)):
|
||||
if value is None:
|
||||
match = False
|
||||
continue
|
||||
else:
|
||||
if not fnmatch.fnmatch(getattr(url_tuple, attrib), pattern):
|
||||
match = False
|
||||
continue
|
||||
if match:
|
||||
logger.warn(
|
||||
"URL %s blocked by url_blacklist entry %s", url, entry
|
||||
)
|
||||
raise SynapseError(
|
||||
403, "URL blocked by url pattern blacklist entry",
|
||||
Codes.UNKNOWN
|
||||
)
|
||||
|
||||
if pattern.startswith('^'):
|
||||
if not re.match(pattern, getattr(url_tuple, attrib)):
|
||||
match = False
|
||||
continue
|
||||
else:
|
||||
if not fnmatch.fnmatch(getattr(url_tuple, attrib), pattern):
|
||||
match = False
|
||||
continue
|
||||
if match:
|
||||
logger.warn(
|
||||
"URL %s blocked by url_blacklist entry %s", url, entry
|
||||
)
|
||||
raise SynapseError(
|
||||
403, "URL blocked by url pattern blacklist entry",
|
||||
Codes.UNKNOWN
|
||||
)
|
||||
|
||||
# first check the memory cache - good to handle all the clients on this
|
||||
# HS thundering away to preview the same URL at the same time.
|
||||
|
||||
@@ -39,13 +39,12 @@ class ThumbnailResource(Resource):
|
||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||
self.server_name = hs.hostname
|
||||
self.version_string = hs.version_string
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
def render_GET(self, request):
|
||||
self._async_render_GET(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@request_handler()
|
||||
@request_handler
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_GET(self, request):
|
||||
server_name, media_id, _ = parse_media_id(request)
|
||||
|
||||
@@ -41,7 +41,6 @@ class UploadResource(Resource):
|
||||
self.auth = hs.get_auth()
|
||||
self.max_upload_size = hs.config.max_upload_size
|
||||
self.version_string = hs.version_string
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
def render_POST(self, request):
|
||||
self._async_render_POST(request)
|
||||
@@ -51,7 +50,7 @@ class UploadResource(Resource):
|
||||
respond_with_json(request, 200, {}, send_cors=True)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@request_handler()
|
||||
@request_handler
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_POST(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
@@ -22,19 +22,11 @@
|
||||
from twisted.web.client import BrowserLikePolicyForHTTPS
|
||||
from twisted.enterprise import adbapi
|
||||
|
||||
from synapse.appservice.scheduler import ApplicationServiceScheduler
|
||||
from synapse.appservice.api import ApplicationServiceApi
|
||||
from synapse.federation import initialize_http_replication
|
||||
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
|
||||
from synapse.notifier import Notifier
|
||||
from synapse.api.auth import Auth
|
||||
from synapse.handlers import Handlers
|
||||
from synapse.handlers.presence import PresenceHandler
|
||||
from synapse.handlers.sync import SyncHandler
|
||||
from synapse.handlers.typing import TypingHandler
|
||||
from synapse.handlers.room import RoomListHandler
|
||||
from synapse.handlers.auth import AuthHandler
|
||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||
from synapse.state import StateHandler
|
||||
from synapse.storage import DataStore
|
||||
from synapse.util import Clock
|
||||
@@ -86,14 +78,6 @@ class HomeServer(object):
|
||||
'auth',
|
||||
'rest_servlet_factory',
|
||||
'state_handler',
|
||||
'presence_handler',
|
||||
'sync_handler',
|
||||
'typing_handler',
|
||||
'room_list_handler',
|
||||
'auth_handler',
|
||||
'application_service_api',
|
||||
'application_service_scheduler',
|
||||
'application_service_handler',
|
||||
'notifier',
|
||||
'distributor',
|
||||
'client_resource',
|
||||
@@ -180,30 +164,6 @@ class HomeServer(object):
|
||||
def build_state_handler(self):
|
||||
return StateHandler(self)
|
||||
|
||||
def build_presence_handler(self):
|
||||
return PresenceHandler(self)
|
||||
|
||||
def build_typing_handler(self):
|
||||
return TypingHandler(self)
|
||||
|
||||
def build_sync_handler(self):
|
||||
return SyncHandler(self)
|
||||
|
||||
def build_room_list_handler(self):
|
||||
return RoomListHandler(self)
|
||||
|
||||
def build_auth_handler(self):
|
||||
return AuthHandler(self)
|
||||
|
||||
def build_application_service_api(self):
|
||||
return ApplicationServiceApi(self)
|
||||
|
||||
def build_application_service_scheduler(self):
|
||||
return ApplicationServiceScheduler(self)
|
||||
|
||||
def build_application_service_handler(self):
|
||||
return ApplicationServicesHandler(self)
|
||||
|
||||
def build_event_sources(self):
|
||||
return EventSources(self)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from twisted.internet import defer
|
||||
from .appservice import (
|
||||
ApplicationServiceStore, ApplicationServiceTransactionStore
|
||||
)
|
||||
from ._base import Cache, LoggingTransaction
|
||||
from ._base import Cache
|
||||
from .directory import DirectoryStore
|
||||
from .events import EventsStore
|
||||
from .presence import PresenceStore, UserPresenceState
|
||||
@@ -44,7 +44,6 @@ from .receipts import ReceiptsStore
|
||||
from .search import SearchStore
|
||||
from .tags import TagsStore
|
||||
from .account_data import AccountDataStore
|
||||
from .openid import OpenIdStore
|
||||
|
||||
from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
|
||||
|
||||
@@ -82,13 +81,11 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
SearchStore,
|
||||
TagsStore,
|
||||
AccountDataStore,
|
||||
EventPushActionsStore,
|
||||
OpenIdStore,
|
||||
EventPushActionsStore
|
||||
):
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
self.hs = hs
|
||||
self._clock = hs.get_clock()
|
||||
self.database_engine = hs.database_engine
|
||||
|
||||
self.client_ip_last_seen = Cache(
|
||||
@@ -117,7 +114,6 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
self._state_groups_id_gen = StreamIdGenerator(db_conn, "state_groups", "id")
|
||||
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
||||
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
|
||||
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
|
||||
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
||||
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
|
||||
self._push_rules_stream_id_gen = ChainedIdGenerator(
|
||||
@@ -149,7 +145,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
"AccountDataAndTagsChangeCache", account_max,
|
||||
)
|
||||
|
||||
self._presence_on_startup = self._get_active_presence(db_conn)
|
||||
self.__presence_on_startup = self._get_active_presence(db_conn)
|
||||
|
||||
presence_cache_prefill, min_presence_val = self._get_cache_dict(
|
||||
db_conn, "presence_stream",
|
||||
@@ -174,24 +170,11 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
prefilled_cache=push_rules_prefill,
|
||||
)
|
||||
|
||||
cur = LoggingTransaction(
|
||||
db_conn.cursor(),
|
||||
name="_find_stream_orderings_for_times_txn",
|
||||
database_engine=self.database_engine,
|
||||
after_callbacks=[]
|
||||
)
|
||||
self._find_stream_orderings_for_times_txn(cur)
|
||||
cur.close()
|
||||
|
||||
self.find_stream_orderings_looping_call = self._clock.looping_call(
|
||||
self._find_stream_orderings_for_times, 60 * 60 * 1000
|
||||
)
|
||||
|
||||
super(DataStore, self).__init__(hs)
|
||||
|
||||
def take_presence_startup_info(self):
|
||||
active_on_startup = self._presence_on_startup
|
||||
self._presence_on_startup = None
|
||||
active_on_startup = self.__presence_on_startup
|
||||
self.__presence_on_startup = None
|
||||
return active_on_startup
|
||||
|
||||
def _get_active_presence(self, db_conn):
|
||||
|
||||
@@ -152,8 +152,8 @@ class SQLBaseStore(object):
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self._clock = hs.get_clock()
|
||||
self._db_pool = hs.get_db_pool()
|
||||
self._clock = hs.get_clock()
|
||||
|
||||
self._previous_txn_total_time = 0
|
||||
self._current_txn_total_time = 0
|
||||
@@ -453,9 +453,7 @@ class SQLBaseStore(object):
|
||||
keyvalues (dict): The unique key tables and their new values
|
||||
values (dict): The nonunique columns and their new values
|
||||
insertion_values (dict): key/values to use when inserting
|
||||
Returns:
|
||||
Deferred(bool): True if a new entry was created, False if an
|
||||
existing one was updated.
|
||||
Returns: A deferred
|
||||
"""
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
@@ -500,10 +498,6 @@ class SQLBaseStore(object):
|
||||
)
|
||||
txn.execute(sql, allvalues.values())
|
||||
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _simple_select_one(self, table, keyvalues, retcols,
|
||||
allow_none=False, desc="_simple_select_one"):
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
|
||||
@@ -16,8 +16,6 @@
|
||||
from ._base import SQLBaseStore
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
|
||||
|
||||
import ujson as json
|
||||
import logging
|
||||
|
||||
@@ -26,7 +24,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class AccountDataStore(SQLBaseStore):
|
||||
|
||||
@cached()
|
||||
def get_account_data_for_user(self, user_id):
|
||||
"""Get all the client account_data for a user.
|
||||
|
||||
@@ -63,47 +60,6 @@ class AccountDataStore(SQLBaseStore):
|
||||
"get_account_data_for_user", get_account_data_for_user_txn
|
||||
)
|
||||
|
||||
@cachedInlineCallbacks(num_args=2)
|
||||
def get_global_account_data_by_type_for_user(self, data_type, user_id):
|
||||
"""
|
||||
Returns:
|
||||
Deferred: A dict
|
||||
"""
|
||||
result = yield self._simple_select_one_onecol(
|
||||
table="account_data",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"account_data_type": data_type,
|
||||
},
|
||||
retcol="content",
|
||||
desc="get_global_account_data_by_type_for_user",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
if result:
|
||||
defer.returnValue(json.loads(result))
|
||||
else:
|
||||
defer.returnValue(None)
|
||||
|
||||
@cachedList(cached_method_name="get_global_account_data_by_type_for_user",
|
||||
num_args=2, list_name="user_ids", inlineCallbacks=True)
|
||||
def get_global_account_data_by_type_for_users(self, data_type, user_ids):
|
||||
rows = yield self._simple_select_many_batch(
|
||||
table="account_data",
|
||||
column="user_id",
|
||||
iterable=user_ids,
|
||||
keyvalues={
|
||||
"account_data_type": data_type,
|
||||
},
|
||||
retcols=("user_id", "content",),
|
||||
desc="get_global_account_data_by_type_for_users",
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
row["user_id"]: json.loads(row["content"]) if row["content"] else None
|
||||
for row in rows
|
||||
})
|
||||
|
||||
def get_account_data_for_room(self, user_id, room_id):
|
||||
"""Get all the client account_data for a user for a room.
|
||||
|
||||
@@ -237,7 +193,6 @@ class AccountDataStore(SQLBaseStore):
|
||||
self._account_data_stream_cache.entity_has_changed,
|
||||
user_id, next_id,
|
||||
)
|
||||
txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
|
||||
self._update_max_stream_id(txn, next_id)
|
||||
|
||||
with self._account_data_id_gen.get_next() as next_id:
|
||||
@@ -277,11 +232,6 @@ class AccountDataStore(SQLBaseStore):
|
||||
self._account_data_stream_cache.entity_has_changed,
|
||||
user_id, next_id,
|
||||
)
|
||||
txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
|
||||
txn.call_after(
|
||||
self.get_global_account_data_by_type_for_user.invalidate,
|
||||
(account_data_type, user_id,)
|
||||
)
|
||||
self._update_max_stream_id(txn, next_id)
|
||||
|
||||
with self._account_data_id_gen.get_next() as next_id:
|
||||
|
||||
@@ -13,13 +13,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import urllib
|
||||
import yaml
|
||||
import simplejson as json
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.appservice import AppServiceTransaction
|
||||
from synapse.config.appservice import load_appservices
|
||||
from synapse.appservice import ApplicationService, AppServiceTransaction
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.storage.roommember import RoomsForUser
|
||||
from synapse.types import UserID
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
|
||||
@@ -31,7 +34,7 @@ class ApplicationServiceStore(SQLBaseStore):
|
||||
def __init__(self, hs):
|
||||
super(ApplicationServiceStore, self).__init__(hs)
|
||||
self.hostname = hs.hostname
|
||||
self.services_cache = load_appservices(
|
||||
self.services_cache = ApplicationServiceStore.load_appservices(
|
||||
hs.hostname,
|
||||
hs.config.app_service_config_files
|
||||
)
|
||||
@@ -141,6 +144,102 @@ class ApplicationServiceStore(SQLBaseStore):
|
||||
|
||||
return rooms_for_user_matching_user_id
|
||||
|
||||
@classmethod
|
||||
def _load_appservice(cls, hostname, as_info, config_filename):
|
||||
required_string_fields = [
|
||||
"id", "url", "as_token", "hs_token", "sender_localpart"
|
||||
]
|
||||
for field in required_string_fields:
|
||||
if not isinstance(as_info.get(field), basestring):
|
||||
raise KeyError("Required string field: '%s' (%s)" % (
|
||||
field, config_filename,
|
||||
))
|
||||
|
||||
localpart = as_info["sender_localpart"]
|
||||
if urllib.quote(localpart) != localpart:
|
||||
raise ValueError(
|
||||
"sender_localpart needs characters which are not URL encoded."
|
||||
)
|
||||
user = UserID(localpart, hostname)
|
||||
user_id = user.to_string()
|
||||
|
||||
# namespace checks
|
||||
if not isinstance(as_info.get("namespaces"), dict):
|
||||
raise KeyError("Requires 'namespaces' object.")
|
||||
for ns in ApplicationService.NS_LIST:
|
||||
# specific namespaces are optional
|
||||
if ns in as_info["namespaces"]:
|
||||
# expect a list of dicts with exclusive and regex keys
|
||||
for regex_obj in as_info["namespaces"][ns]:
|
||||
if not isinstance(regex_obj, dict):
|
||||
raise ValueError(
|
||||
"Expected namespace entry in %s to be an object,"
|
||||
" but got %s", ns, regex_obj
|
||||
)
|
||||
if not isinstance(regex_obj.get("regex"), basestring):
|
||||
raise ValueError(
|
||||
"Missing/bad type 'regex' key in %s", regex_obj
|
||||
)
|
||||
if not isinstance(regex_obj.get("exclusive"), bool):
|
||||
raise ValueError(
|
||||
"Missing/bad type 'exclusive' key in %s", regex_obj
|
||||
)
|
||||
return ApplicationService(
|
||||
token=as_info["as_token"],
|
||||
url=as_info["url"],
|
||||
namespaces=as_info["namespaces"],
|
||||
hs_token=as_info["hs_token"],
|
||||
sender=user_id,
|
||||
id=as_info["id"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_appservices(cls, hostname, config_files):
|
||||
"""Returns a list of Application Services from the config files."""
|
||||
if not isinstance(config_files, list):
|
||||
logger.warning(
|
||||
"Expected %s to be a list of AS config files.", config_files
|
||||
)
|
||||
return []
|
||||
|
||||
# Dicts of value -> filename
|
||||
seen_as_tokens = {}
|
||||
seen_ids = {}
|
||||
|
||||
appservices = []
|
||||
|
||||
for config_file in config_files:
|
||||
try:
|
||||
with open(config_file, 'r') as f:
|
||||
appservice = ApplicationServiceStore._load_appservice(
|
||||
hostname, yaml.load(f), config_file
|
||||
)
|
||||
if appservice.id in seen_ids:
|
||||
raise ConfigError(
|
||||
"Cannot reuse ID across application services: "
|
||||
"%s (files: %s, %s)" % (
|
||||
appservice.id, config_file, seen_ids[appservice.id],
|
||||
)
|
||||
)
|
||||
seen_ids[appservice.id] = config_file
|
||||
if appservice.token in seen_as_tokens:
|
||||
raise ConfigError(
|
||||
"Cannot reuse as_token across application services: "
|
||||
"%s (files: %s, %s)" % (
|
||||
appservice.token,
|
||||
config_file,
|
||||
seen_as_tokens[appservice.token],
|
||||
)
|
||||
)
|
||||
seen_as_tokens[appservice.token] = config_file
|
||||
logger.info("Loaded application service: %s", appservice)
|
||||
appservices.append(appservice)
|
||||
except Exception as e:
|
||||
logger.error("Failed to load appservice from '%s'", config_file)
|
||||
logger.exception(e)
|
||||
raise
|
||||
return appservices
|
||||
|
||||
|
||||
class ApplicationServiceTransactionStore(SQLBaseStore):
|
||||
|
||||
|
||||
@@ -173,12 +173,11 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||
|
||||
logger.info(
|
||||
"Updating %r. Updated %r items in %rms."
|
||||
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
|
||||
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r)",
|
||||
update_name, items_updated, duration_ms,
|
||||
performance.total_items_per_ms(),
|
||||
performance.average_items_per_ms(),
|
||||
performance.total_item_count,
|
||||
batch_size,
|
||||
)
|
||||
|
||||
performance.update(items_updated, duration_ms)
|
||||
|
||||
@@ -24,10 +24,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventPushActionsStore(SQLBaseStore):
|
||||
def __init__(self, hs):
|
||||
self.stream_ordering_month_ago = None
|
||||
super(EventPushActionsStore, self).__init__(hs)
|
||||
|
||||
def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
|
||||
"""
|
||||
Args:
|
||||
@@ -119,23 +115,19 @@ class EventPushActionsStore(SQLBaseStore):
|
||||
@defer.inlineCallbacks
|
||||
def get_unread_push_actions_for_user_in_range(self, user_id,
|
||||
min_stream_ordering,
|
||||
max_stream_ordering=None,
|
||||
limit=20):
|
||||
max_stream_ordering=None):
|
||||
def get_after_receipt(txn):
|
||||
sql = (
|
||||
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, "
|
||||
"e.received_ts "
|
||||
"FROM ("
|
||||
" SELECT room_id, user_id, "
|
||||
" max(topological_ordering) as topological_ordering, "
|
||||
" max(stream_ordering) as stream_ordering "
|
||||
"SELECT ep.event_id, ep.stream_ordering, ep.actions "
|
||||
"FROM event_push_actions AS ep, ("
|
||||
" SELECT room_id, user_id,"
|
||||
" max(topological_ordering) as topological_ordering,"
|
||||
" max(stream_ordering) as stream_ordering"
|
||||
" FROM events"
|
||||
" NATURAL JOIN receipts_linearized WHERE receipt_type = 'm.read'"
|
||||
" GROUP BY room_id, user_id"
|
||||
") AS rl,"
|
||||
" event_push_actions AS ep"
|
||||
" INNER JOIN events AS e USING (room_id, event_id)"
|
||||
" WHERE"
|
||||
") AS rl "
|
||||
"WHERE"
|
||||
" ep.room_id = rl.room_id"
|
||||
" AND ("
|
||||
" ep.topological_ordering > rl.topological_ordering"
|
||||
@@ -152,8 +144,7 @@ class EventPushActionsStore(SQLBaseStore):
|
||||
if max_stream_ordering is not None:
|
||||
sql += " AND ep.stream_ordering <= ?"
|
||||
args.append(max_stream_ordering)
|
||||
sql += " ORDER BY ep.stream_ordering ASC LIMIT ?"
|
||||
args.append(limit)
|
||||
sql += " ORDER BY ep.stream_ordering ASC"
|
||||
txn.execute(sql, args)
|
||||
return txn.fetchall()
|
||||
after_read_receipt = yield self.runInteraction(
|
||||
@@ -162,13 +153,11 @@ class EventPushActionsStore(SQLBaseStore):
|
||||
|
||||
def get_no_receipt(txn):
|
||||
sql = (
|
||||
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
|
||||
" e.received_ts"
|
||||
" FROM event_push_actions AS ep"
|
||||
" JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
|
||||
" WHERE ep.room_id not in ("
|
||||
"SELECT ep.event_id, ep.stream_ordering, ep.actions "
|
||||
"FROM event_push_actions AS ep "
|
||||
"WHERE ep.room_id not in ("
|
||||
" SELECT room_id FROM events NATURAL JOIN receipts_linearized"
|
||||
" WHERE receipt_type = 'm.read' AND user_id = ?"
|
||||
" WHERE receipt_type = 'm.read' AND user_id = ? "
|
||||
" GROUP BY room_id"
|
||||
") AND ep.user_id = ? AND ep.stream_ordering > ?"
|
||||
)
|
||||
@@ -186,29 +175,11 @@ class EventPushActionsStore(SQLBaseStore):
|
||||
defer.returnValue([
|
||||
{
|
||||
"event_id": row[0],
|
||||
"room_id": row[1],
|
||||
"stream_ordering": row[2],
|
||||
"actions": json.loads(row[3]),
|
||||
"received_ts": row[4],
|
||||
"stream_ordering": row[1],
|
||||
"actions": json.loads(row[2]),
|
||||
} for row in after_read_receipt + no_read_receipt
|
||||
])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_time_of_last_push_action_before(self, stream_ordering):
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT e.received_ts"
|
||||
" FROM event_push_actions AS ep"
|
||||
" JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
|
||||
" WHERE ep.stream_ordering > ?"
|
||||
" ORDER BY ep.stream_ordering ASC"
|
||||
" LIMIT 1"
|
||||
)
|
||||
txn.execute(sql, (stream_ordering,))
|
||||
return txn.fetchone()
|
||||
result = yield self.runInteraction("get_time_of_last_push_action_before", f)
|
||||
defer.returnValue(result[0] if result else None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_latest_push_action_stream_ordering(self):
|
||||
def f(txn):
|
||||
@@ -230,93 +201,6 @@ class EventPushActionsStore(SQLBaseStore):
|
||||
(room_id, event_id)
|
||||
)
|
||||
|
||||
def _remove_old_push_actions_before_txn(self, txn, room_id, user_id,
|
||||
topological_ordering):
|
||||
"""
|
||||
Purges old, stale push actions for a user and room before a given
|
||||
topological_ordering
|
||||
Args:
|
||||
txn: The transcation
|
||||
room_id: Room ID to delete from
|
||||
user_id: user ID to delete for
|
||||
topological_ordering: The lowest topological ordering which will
|
||||
not be deleted.
|
||||
"""
|
||||
txn.call_after(
|
||||
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
|
||||
(room_id, user_id, )
|
||||
)
|
||||
|
||||
# We need to join on the events table to get the received_ts for
|
||||
# event_push_actions and sqlite won't let us use a join in a delete so
|
||||
# we can't just delete where received_ts < x. Furthermore we can
|
||||
# only identify event_push_actions by a tuple of room_id, event_id
|
||||
# we we can't use a subquery.
|
||||
# Instead, we look up the stream ordering for the last event in that
|
||||
# room received before the threshold time and delete event_push_actions
|
||||
# in the room with a stream_odering before that.
|
||||
txn.execute(
|
||||
"DELETE FROM event_push_actions "
|
||||
" WHERE user_id = ? AND room_id = ? AND "
|
||||
" topological_ordering < ? AND stream_ordering < ?",
|
||||
(user_id, room_id, topological_ordering, self.stream_ordering_month_ago)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _find_stream_orderings_for_times(self):
|
||||
yield self.runInteraction(
|
||||
"_find_stream_orderings_for_times",
|
||||
self._find_stream_orderings_for_times_txn
|
||||
)
|
||||
|
||||
def _find_stream_orderings_for_times_txn(self, txn):
|
||||
logger.info("Searching for stream ordering 1 month ago")
|
||||
self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
|
||||
txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
|
||||
)
|
||||
logger.info(
|
||||
"Found stream ordering 1 month ago: it's %d",
|
||||
self.stream_ordering_month_ago
|
||||
)
|
||||
|
||||
def _find_first_stream_ordering_after_ts_txn(self, txn, ts):
|
||||
"""
|
||||
Find the stream_ordering of the first event that was received after
|
||||
a given timestamp. This is relatively slow as there is no index on
|
||||
received_ts but we can then use this to delete push actions before
|
||||
this.
|
||||
|
||||
received_ts must necessarily be in the same order as stream_ordering
|
||||
and stream_ordering is indexed, so we manually binary search using
|
||||
stream_ordering
|
||||
"""
|
||||
txn.execute("SELECT MAX(stream_ordering) FROM events")
|
||||
max_stream_ordering = txn.fetchone()[0]
|
||||
|
||||
if max_stream_ordering is None:
|
||||
return 0
|
||||
|
||||
range_start = 0
|
||||
range_end = max_stream_ordering
|
||||
|
||||
sql = (
|
||||
"SELECT received_ts FROM events"
|
||||
" WHERE stream_ordering > ?"
|
||||
" ORDER BY stream_ordering"
|
||||
" LIMIT 1"
|
||||
)
|
||||
|
||||
while range_end - range_start > 1:
|
||||
middle = int((range_end + range_start) / 2)
|
||||
txn.execute(sql, (middle,))
|
||||
middle_ts = txn.fetchone()[0]
|
||||
if ts > middle_ts:
|
||||
range_start = middle
|
||||
else:
|
||||
range_end = middle
|
||||
|
||||
return range_end
|
||||
|
||||
|
||||
def _action_has_highlight(actions):
|
||||
for action in actions:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user