Compare commits
99 Commits
matthew/de
...
hhs-2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
897d976c1e | ||
|
|
45fc0ead10 | ||
|
|
cdd24449ee | ||
|
|
14d49c51db | ||
|
|
84b4e76fed | ||
|
|
c780d84d66 | ||
|
|
1d67b13674 | ||
|
|
92d50e3c2a | ||
|
|
e94cdbaecf | ||
|
|
9b1bc593c5 | ||
|
|
7f147d623b | ||
|
|
15e8dd2ccc | ||
|
|
05fe8a6c1c | ||
|
|
f584d6108f | ||
|
|
28d7b546cb | ||
|
|
f2cbbda956 | ||
|
|
cd77270a66 | ||
|
|
242a0483eb | ||
|
|
60cf1c6e16 | ||
|
|
7e6e588e60 | ||
|
|
1ca2744621 | ||
|
|
dddb5aa7bb | ||
|
|
4eb8408ed2 | ||
|
|
c5842dff1a | ||
|
|
7a0da69eee | ||
|
|
a5806aba27 | ||
|
|
fd2dbf1836 | ||
|
|
9643a6f7f2 | ||
|
|
c7181dcc6c | ||
|
|
db10f553ba | ||
|
|
764030cf63 | ||
|
|
8432e2ebd7 | ||
|
|
a81f140880 | ||
|
|
47b25ba5f3 | ||
|
|
3bf8bab8f9 | ||
|
|
a4cf660a32 | ||
|
|
0d568ff403 | ||
|
|
365f588d98 | ||
|
|
eb7be75a10 | ||
|
|
dd0ac1614c | ||
|
|
bb81e78ec6 | ||
|
|
a52f276990 | ||
|
|
46c832eaac | ||
|
|
2b1a4b2596 | ||
|
|
cd6937fb26 | ||
|
|
c2c153dd3b | ||
|
|
79d3b4689e | ||
|
|
808d8e06aa | ||
|
|
3f6762f0bb | ||
|
|
8bd585b09b | ||
|
|
f89f6b7c09 | ||
|
|
e2c0aa2c26 | ||
|
|
b01a755498 | ||
|
|
1058d14127 | ||
|
|
4d664278af | ||
|
|
8dee601054 | ||
|
|
e21c368b8b | ||
|
|
e07970165f | ||
|
|
c5171bf171 | ||
|
|
ba1fbf7d5b | ||
|
|
ab822a2d1f | ||
|
|
91cdb6de08 | ||
|
|
d49b77404b | ||
|
|
3ee57bdcbb | ||
|
|
782689bd40 | ||
|
|
ca87ad1def | ||
|
|
b5f638f1f4 | ||
|
|
38f708a2bb | ||
|
|
51b17ec566 | ||
|
|
3c1080b6e4 | ||
|
|
eff3ae3b9a | ||
|
|
bfb6c58624 | ||
|
|
a675f9c556 | ||
|
|
25d2b5d55f | ||
|
|
df1e4f259f | ||
|
|
c055c91655 | ||
|
|
8cfad2e686 | ||
|
|
fc5d937550 | ||
|
|
eabc5f8271 | ||
|
|
c24fc9797b | ||
|
|
e2c9fe0a6a | ||
|
|
9b75c78b4d | ||
|
|
63417c31e9 | ||
|
|
c79c4c0a7d | ||
|
|
6c6aba76e1 | ||
|
|
01021c812f | ||
|
|
5075e444f4 | ||
|
|
3e19beb941 | ||
|
|
bb99b1f550 | ||
|
|
ce6db0e547 | ||
|
|
152c0aa58e | ||
|
|
119451dcd1 | ||
|
|
a6c813761a | ||
|
|
54a9bea88c | ||
|
|
484a0ebdfc | ||
|
|
f81f421086 | ||
|
|
cd9765805e | ||
|
|
495cb100d1 | ||
|
|
37be52ac34 |
1
changelog.d/3659.feature
Normal file
1
changelog.d/3659.feature
Normal file
@@ -0,0 +1 @@
|
||||
Support profile API endpoints on workers
|
||||
1
changelog.d/3673.misc
Normal file
1
changelog.d/3673.misc
Normal file
@@ -0,0 +1 @@
|
||||
Refactor state module to support multiple room versions
|
||||
1
changelog.d/3680.feature
Normal file
1
changelog.d/3680.feature
Normal file
@@ -0,0 +1 @@
|
||||
Server notices for resource limit blocking
|
||||
1
changelog.d/3722.bugfix
Normal file
1
changelog.d/3722.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix error collecting prometheus metrics when run on dedicated thread due to threading concurrency issues
|
||||
1
changelog.d/3724.feature
Normal file
1
changelog.d/3724.feature
Normal file
@@ -0,0 +1 @@
|
||||
Allow guests to use /rooms/:roomId/event/:eventId
|
||||
1
changelog.d/3726.misc
Normal file
1
changelog.d/3726.misc
Normal file
@@ -0,0 +1 @@
|
||||
Split the state_group_cache into member and non-member state events (and so speed up LL /sync)
|
||||
1
changelog.d/3727.misc
Normal file
1
changelog.d/3727.misc
Normal file
@@ -0,0 +1 @@
|
||||
Log failure to authenticate remote servers as warnings (without stack traces)
|
||||
1
changelog.d/3734.misc
Normal file
1
changelog.d/3734.misc
Normal file
@@ -0,0 +1 @@
|
||||
Reference the need for an HTTP replication port when using the federation_reader worker
|
||||
1
changelog.d/3735.misc
Normal file
1
changelog.d/3735.misc
Normal file
@@ -0,0 +1 @@
|
||||
Fix minor spelling error in federation client documentation.
|
||||
1
changelog.d/3746.misc
Normal file
1
changelog.d/3746.misc
Normal file
@@ -0,0 +1 @@
|
||||
Fix MAU cache invalidation due to missing yield
|
||||
1
changelog.d/3747.bugfix
Normal file
1
changelog.d/3747.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix bug where we resent "limit exceeded" server notices repeatedly
|
||||
1
changelog.d/3749.feature
Normal file
1
changelog.d/3749.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add mau_trial_days config param, so that users only get counted as MAU after N days.
|
||||
1
changelog.d/3751.feature
Normal file
1
changelog.d/3751.feature
Normal file
@@ -0,0 +1 @@
|
||||
Require twisted 17.1 or later (fixes [#3741](https://github.com/matrix-org/synapse/issues/3741)).
|
||||
1
changelog.d/3753.bugfix
Normal file
1
changelog.d/3753.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix bug where we broke sync when using limit_usage_by_mau but hadn't configured server notices
|
||||
1
changelog.d/3754.bugfix
Normal file
1
changelog.d/3754.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix 'federation_domain_whitelist' such that an empty list correctly blocks all outbound federation traffic
|
||||
1
changelog.d/3755.bugfix
Normal file
1
changelog.d/3755.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix tagging of server notice rooms
|
||||
@@ -74,7 +74,7 @@ replication endpoints that it's talking to on the main synapse process.
|
||||
``worker_replication_port`` should point to the TCP replication listener port and
|
||||
``worker_replication_http_port`` should point to the HTTP replication port.
|
||||
|
||||
Currently, only the ``event_creator`` worker requires specifying
|
||||
Currently, the ``event_creator`` and ``federation_reader`` workers require specifying
|
||||
``worker_replication_http_port``.
|
||||
|
||||
For instance::
|
||||
@@ -265,6 +265,7 @@ Handles some event creation. It can handle REST endpoints matching::
|
||||
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send
|
||||
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$
|
||||
^/_matrix/client/(api/v1|r0|unstable)/join/
|
||||
^/_matrix/client/(api/v1|r0|unstable)/profile/
|
||||
|
||||
It will create events locally and then send them on to the main synapse
|
||||
instance to be persisted and handled.
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
.header {
|
||||
border-bottom: 4px solid #e4f7ed ! important;
|
||||
}
|
||||
|
||||
.notif_link a, .footer a {
|
||||
color: #76CFA6 ! important;
|
||||
}
|
||||
@@ -1,156 +0,0 @@
|
||||
body {
|
||||
margin: 0px;
|
||||
}
|
||||
|
||||
pre, code {
|
||||
word-break: break-word;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
#page {
|
||||
font-family: 'Open Sans', Helvetica, Arial, Sans-Serif;
|
||||
font-color: #454545;
|
||||
font-size: 12pt;
|
||||
width: 100%;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
#inner {
|
||||
width: 640px;
|
||||
}
|
||||
|
||||
.header {
|
||||
width: 100%;
|
||||
height: 87px;
|
||||
color: #454545;
|
||||
border-bottom: 4px solid #e5e5e5;
|
||||
}
|
||||
|
||||
.logo {
|
||||
text-align: right;
|
||||
margin-left: 20px;
|
||||
}
|
||||
|
||||
.salutation {
|
||||
padding-top: 10px;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.summarytext {
|
||||
}
|
||||
|
||||
.room {
|
||||
width: 100%;
|
||||
color: #454545;
|
||||
border-bottom: 1px solid #e5e5e5;
|
||||
}
|
||||
|
||||
.room_header td {
|
||||
padding-top: 38px;
|
||||
padding-bottom: 10px;
|
||||
border-bottom: 1px solid #e5e5e5;
|
||||
}
|
||||
|
||||
.room_name {
|
||||
vertical-align: middle;
|
||||
font-size: 18px;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.room_header h2 {
|
||||
margin-top: 0px;
|
||||
margin-left: 75px;
|
||||
font-size: 20px;
|
||||
}
|
||||
|
||||
.room_avatar {
|
||||
width: 56px;
|
||||
line-height: 0px;
|
||||
text-align: center;
|
||||
vertical-align: middle;
|
||||
}
|
||||
|
||||
.room_avatar img {
|
||||
width: 48px;
|
||||
height: 48px;
|
||||
object-fit: cover;
|
||||
border-radius: 24px;
|
||||
}
|
||||
|
||||
.notif {
|
||||
border-bottom: 1px solid #e5e5e5;
|
||||
margin-top: 16px;
|
||||
padding-bottom: 16px;
|
||||
}
|
||||
|
||||
.historical_message .sender_avatar {
|
||||
opacity: 0.3;
|
||||
}
|
||||
|
||||
/* spell out opacity and historical_message class names for Outlook aka Word */
|
||||
.historical_message .sender_name {
|
||||
color: #e3e3e3;
|
||||
}
|
||||
|
||||
.historical_message .message_time {
|
||||
color: #e3e3e3;
|
||||
}
|
||||
|
||||
.historical_message .message_body {
|
||||
color: #c7c7c7;
|
||||
}
|
||||
|
||||
.historical_message td,
|
||||
.message td {
|
||||
padding-top: 10px;
|
||||
}
|
||||
|
||||
.sender_avatar {
|
||||
width: 56px;
|
||||
text-align: center;
|
||||
vertical-align: top;
|
||||
}
|
||||
|
||||
.sender_avatar img {
|
||||
margin-top: -2px;
|
||||
width: 32px;
|
||||
height: 32px;
|
||||
border-radius: 16px;
|
||||
}
|
||||
|
||||
.sender_name {
|
||||
display: inline;
|
||||
font-size: 13px;
|
||||
color: #a2a2a2;
|
||||
}
|
||||
|
||||
.message_time {
|
||||
text-align: right;
|
||||
width: 100px;
|
||||
font-size: 11px;
|
||||
color: #a2a2a2;
|
||||
}
|
||||
|
||||
.message_body {
|
||||
}
|
||||
|
||||
.notif_link td {
|
||||
padding-top: 10px;
|
||||
padding-bottom: 10px;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.notif_link a, .footer a {
|
||||
color: #454545;
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
.debug {
|
||||
font-size: 10px;
|
||||
color: #888;
|
||||
}
|
||||
|
||||
.footer {
|
||||
margin-top: 20px;
|
||||
text-align: center;
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
{% for message in notif.messages %}
|
||||
<tr class="{{ "historical_message" if message.is_historical else "message" }}">
|
||||
<td class="sender_avatar">
|
||||
{% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
|
||||
{% if message.sender_avatar_url %}
|
||||
<img alt="" class="sender_avatar" src="{{ message.sender_avatar_url|mxc_to_http(32,32) }}" />
|
||||
{% else %}
|
||||
{% if message.sender_hash % 3 == 0 %}
|
||||
<img class="sender_avatar" src="https://vector.im/beta/img/76cfa6.png" />
|
||||
{% elif message.sender_hash % 3 == 1 %}
|
||||
<img class="sender_avatar" src="https://vector.im/beta/img/50e2c2.png" />
|
||||
{% else %}
|
||||
<img class="sender_avatar" src="https://vector.im/beta/img/f4c371.png" />
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
</td>
|
||||
<td class="message_contents">
|
||||
{% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
|
||||
<div class="sender_name">{% if message.msgtype == "m.emote" %}*{% endif %} {{ message.sender_name }}</div>
|
||||
{% endif %}
|
||||
<div class="message_body">
|
||||
{% if message.msgtype == "m.text" %}
|
||||
{{ message.body_text_html }}
|
||||
{% elif message.msgtype == "m.emote" %}
|
||||
{{ message.body_text_html }}
|
||||
{% elif message.msgtype == "m.notice" %}
|
||||
{{ message.body_text_html }}
|
||||
{% elif message.msgtype == "m.image" %}
|
||||
<img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
|
||||
{% elif message.msgtype == "m.file" %}
|
||||
<span class="filename">{{ message.body_text_plain }}</span>
|
||||
{% endif %}
|
||||
</div>
|
||||
</td>
|
||||
<td class="message_time">{{ message.ts|format_ts("%H:%M") }}</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
<tr class="notif_link">
|
||||
<td></td>
|
||||
<td>
|
||||
<a href="{{ notif.link }}">Voir {{ room.title }}</a>
|
||||
</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
@@ -1,16 +0,0 @@
|
||||
{% for message in notif.messages %}
|
||||
{% if message.msgtype == "m.emote" %}* {% endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }})
|
||||
{% if message.msgtype == "m.text" %}
|
||||
{{ message.body_text_plain }}
|
||||
{% elif message.msgtype == "m.emote" %}
|
||||
{{ message.body_text_plain }}
|
||||
{% elif message.msgtype == "m.notice" %}
|
||||
{{ message.body_text_plain }}
|
||||
{% elif message.msgtype == "m.image" %}
|
||||
{{ message.body_text_plain }}
|
||||
{% elif message.msgtype == "m.file" %}
|
||||
{{ message.body_text_plain }}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
Voir {{ room.title }} à {{ notif.link }}
|
||||
@@ -1,55 +0,0 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<style type="text/css">
|
||||
{% include 'mail.css' without context %}
|
||||
{% include "mail-%s.css" % app_name ignore missing without context %}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<table id="page">
|
||||
<tr>
|
||||
<td> </td>
|
||||
<td id="inner">
|
||||
<table class="header">
|
||||
<tr>
|
||||
<td>
|
||||
<div class="salutation">Bonjour {{ user_display_name }},</div>
|
||||
<div class="summarytext">{{ summary_text }}</div>
|
||||
</td>
|
||||
<td class="logo">
|
||||
{% if app_name == "Riot" %}
|
||||
<img src="http://matrix.org/img/riot-logo-email.png" width="83" height="83" alt="[Riot]"/>
|
||||
{% elif app_name == "Vector" %}
|
||||
<img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
|
||||
{% else %}
|
||||
<img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>
|
||||
{% endif %}
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
{% for room in rooms %}
|
||||
{% include 'room.html' with context %}
|
||||
{% endfor %}
|
||||
<div class="footer">
|
||||
<a href="{{ unsubscribe_link }}">Se désinscrire</a>
|
||||
<br/>
|
||||
<br/>
|
||||
<div class="debug">
|
||||
Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
|
||||
an event was received at {{ reason.received_at|format_ts("%c") }}
|
||||
which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago,
|
||||
{% if reason.last_sent_ts %}
|
||||
and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
|
||||
which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
|
||||
{% else %}
|
||||
and we don't have a last time we sent a mail for this room.
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
</td>
|
||||
<td> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</body>
|
||||
</html>
|
||||
@@ -1,10 +0,0 @@
|
||||
Bonjour {{ user_display_name }},
|
||||
|
||||
{{ summary_text }}
|
||||
|
||||
{% for room in rooms %}
|
||||
{% include 'room.txt' with context %}
|
||||
{% endfor %}
|
||||
|
||||
Vous pouvez désactiver ces notifications en cliquant ici {{ unsubscribe_link }}
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
<table class="room">
|
||||
<tr class="room_header">
|
||||
<td class="room_avatar">
|
||||
{% if room.avatar_url %}
|
||||
<img alt="" src="{{ room.avatar_url|mxc_to_http(48,48) }}" />
|
||||
{% else %}
|
||||
{% if room.hash % 3 == 0 %}
|
||||
<img alt="" src="https://vector.im/beta/img/76cfa6.png" />
|
||||
{% elif room.hash % 3 == 1 %}
|
||||
<img alt="" src="https://vector.im/beta/img/50e2c2.png" />
|
||||
{% else %}
|
||||
<img alt="" src="https://vector.im/beta/img/f4c371.png" />
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
</td>
|
||||
<td class="room_name" colspan="2">
|
||||
{{ room.title }}
|
||||
</td>
|
||||
</tr>
|
||||
{% if room.invite %}
|
||||
<tr>
|
||||
<td></td>
|
||||
<td>
|
||||
<a href="{{ room.link }}">Rejoindre la conversation.</a>
|
||||
</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
{% else %}
|
||||
{% for notif in room.notifs %}
|
||||
{% include 'notif.html' with context %}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
</table>
|
||||
@@ -1,9 +0,0 @@
|
||||
{{ room.title }}
|
||||
|
||||
{% if room.invite %}
|
||||
Vous avez été invité, rejoignez la conversation en cliquant sur le lien suivant {{ room.link }}
|
||||
{% else %}
|
||||
{% for notif in room.notifs %}
|
||||
{% include 'notif.txt' with context %}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
@@ -783,19 +783,29 @@ class Auth(object):
|
||||
user_id(str|None): If present, checks for presence against existing
|
||||
MAU cohort
|
||||
"""
|
||||
|
||||
# Never fail an auth check for the server notices users
|
||||
# This can be a problem where event creation is prohibited due to blocking
|
||||
if user_id == self.hs.config.server_notices_mxid:
|
||||
return
|
||||
|
||||
if self.hs.config.hs_disabled:
|
||||
raise ResourceLimitError(
|
||||
403, self.hs.config.hs_disabled_message,
|
||||
errcode=Codes.RESOURCE_LIMIT_EXCEED,
|
||||
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
|
||||
admin_uri=self.hs.config.admin_uri,
|
||||
limit_type=self.hs.config.hs_disabled_limit_type
|
||||
)
|
||||
if self.hs.config.limit_usage_by_mau is True:
|
||||
# If the user is already part of the MAU cohort
|
||||
# If the user is already part of the MAU cohort or a trial user
|
||||
if user_id:
|
||||
timestamp = yield self.store.user_last_seen_monthly_active(user_id)
|
||||
if timestamp:
|
||||
return
|
||||
|
||||
is_trial = yield self.store.is_trial_user(user_id)
|
||||
if is_trial:
|
||||
return
|
||||
# Else if there is no room in the MAU bucket, bail
|
||||
current_mau = yield self.store.get_monthly_active_count()
|
||||
if current_mau >= self.hs.config.max_mau_value:
|
||||
@@ -803,6 +813,6 @@ class Auth(object):
|
||||
403, "Monthly Active User Limit Exceeded",
|
||||
|
||||
admin_uri=self.hs.config.admin_uri,
|
||||
errcode=Codes.RESOURCE_LIMIT_EXCEED,
|
||||
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
|
||||
limit_type="monthly_active_user"
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
# Copyright 2018 New Vector Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -71,7 +71,6 @@ class EventTypes(object):
|
||||
CanonicalAlias = "m.room.canonical_alias"
|
||||
RoomAvatar = "m.room.avatar"
|
||||
GuestAccess = "m.room.guest_access"
|
||||
Encryption = "m.room.encryption"
|
||||
|
||||
# These are used for validation
|
||||
Message = "m.room.message"
|
||||
@@ -79,6 +78,7 @@ class EventTypes(object):
|
||||
Name = "m.room.name"
|
||||
|
||||
ServerACL = "m.room.server_acl"
|
||||
Pinned = "m.room.pinned_events"
|
||||
|
||||
|
||||
class RejectedReason(object):
|
||||
@@ -98,9 +98,17 @@ class ThirdPartyEntityKind(object):
|
||||
LOCATION = "location"
|
||||
|
||||
|
||||
class RoomVersions(object):
|
||||
V1 = "1"
|
||||
VDH_TEST = "vdh-test-version"
|
||||
|
||||
|
||||
# the version we will give rooms which are created on this server
|
||||
DEFAULT_ROOM_VERSION = "1"
|
||||
DEFAULT_ROOM_VERSION = RoomVersions.V1
|
||||
|
||||
# vdh-test-version is a placeholder to get room versioning support working and tested
|
||||
# until we have a working v2.
|
||||
KNOWN_ROOM_VERSIONS = {"1", "vdh-test-version"}
|
||||
KNOWN_ROOM_VERSIONS = {RoomVersions.V1, RoomVersions.VDH_TEST}
|
||||
|
||||
ServerNoticeMsgType = "m.server_notice"
|
||||
ServerNoticeLimitReached = "m.server_notice.usage_limit_reached"
|
||||
|
||||
@@ -56,7 +56,7 @@ class Codes(object):
|
||||
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
|
||||
CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN"
|
||||
CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM"
|
||||
RESOURCE_LIMIT_EXCEED = "M_RESOURCE_LIMIT_EXCEED"
|
||||
RESOURCE_LIMIT_EXCEEDED = "M_RESOURCE_LIMIT_EXCEEDED"
|
||||
UNSUPPORTED_ROOM_VERSION = "M_UNSUPPORTED_ROOM_VERSION"
|
||||
INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION"
|
||||
|
||||
@@ -238,7 +238,7 @@ class ResourceLimitError(SynapseError):
|
||||
"""
|
||||
def __init__(
|
||||
self, code, msg,
|
||||
errcode=Codes.RESOURCE_LIMIT_EXCEED,
|
||||
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
|
||||
admin_uri=None,
|
||||
limit_type=None,
|
||||
):
|
||||
|
||||
@@ -45,6 +45,11 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
|
||||
from synapse.replication.slave.storage.room import RoomStore
|
||||
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
|
||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.rest.client.v1.profile import (
|
||||
ProfileAvatarURLRestServlet,
|
||||
ProfileDisplaynameRestServlet,
|
||||
ProfileRestServlet,
|
||||
)
|
||||
from synapse.rest.client.v1.room import (
|
||||
JoinRoomAliasServlet,
|
||||
RoomMembershipRestServlet,
|
||||
@@ -53,6 +58,7 @@ from synapse.rest.client.v1.room import (
|
||||
)
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.storage.user_directory import UserDirectoryStore
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.util.manhole import manhole
|
||||
@@ -62,6 +68,9 @@ logger = logging.getLogger("synapse.app.event_creator")
|
||||
|
||||
|
||||
class EventCreatorSlavedStore(
|
||||
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
|
||||
# rather than going via the correct worker.
|
||||
UserDirectoryStore,
|
||||
DirectoryStore,
|
||||
SlavedTransactionStore,
|
||||
SlavedProfileStore,
|
||||
@@ -101,6 +110,9 @@ class EventCreatorServer(HomeServer):
|
||||
RoomMembershipRestServlet(self).register(resource)
|
||||
RoomStateEventRestServlet(self).register(resource)
|
||||
JoinRoomAliasServlet(self).register(resource)
|
||||
ProfileAvatarURLRestServlet(self).register(resource)
|
||||
ProfileDisplaynameRestServlet(self).register(resource)
|
||||
ProfileRestServlet(self).register(resource)
|
||||
resources.update({
|
||||
"/_matrix/client/r0": resource,
|
||||
"/_matrix/client/unstable": resource,
|
||||
|
||||
@@ -33,15 +33,7 @@ class RegistrationConfig(Config):
|
||||
|
||||
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
|
||||
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
|
||||
self.check_is_for_allowed_local_3pids = config.get(
|
||||
"check_is_for_allowed_local_3pids", None
|
||||
)
|
||||
self.allow_invited_3pids = config.get("allow_invited_3pids", False)
|
||||
|
||||
self.disable_3pid_changes = config.get("disable_3pid_changes", False)
|
||||
|
||||
self.registration_shared_secret = config.get("registration_shared_secret")
|
||||
self.register_mxid_from_3pid = config.get("register_mxid_from_3pid")
|
||||
|
||||
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
||||
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
|
||||
@@ -53,15 +45,6 @@ class RegistrationConfig(Config):
|
||||
|
||||
self.auto_join_rooms = config.get("auto_join_rooms", [])
|
||||
|
||||
self.disable_set_displayname = config.get("disable_set_displayname", False)
|
||||
self.disable_set_avatar_url = config.get("disable_set_avatar_url", False)
|
||||
|
||||
self.replicate_user_profiles_to = config.get("replicate_user_profiles_to", [])
|
||||
if not isinstance(self.replicate_user_profiles_to, list):
|
||||
self.replicate_user_profiles_to = [self.replicate_user_profiles_to, ]
|
||||
|
||||
self.chain_register = config.get("chain_register", None)
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
registration_shared_secret = random_string_with_symbols(50)
|
||||
|
||||
@@ -77,26 +60,9 @@ class RegistrationConfig(Config):
|
||||
# - email
|
||||
# - msisdn
|
||||
|
||||
# Derive the user's matrix ID from a type of 3PID used when registering.
|
||||
# This overrides any matrix ID the user proposes when calling /register
|
||||
# The 3PID type should be present in registrations_require_3pid to avoid
|
||||
# users failing to register if they don't specify the right kind of 3pid.
|
||||
#
|
||||
# register_mxid_from_3pid: email
|
||||
|
||||
# Mandate that users are only allowed to associate certain formats of
|
||||
# 3PIDs with accounts on this server.
|
||||
#
|
||||
# Use an Identity Server to establish which 3PIDs are allowed to register?
|
||||
# Overrides allowed_local_3pids below.
|
||||
# check_is_for_allowed_local_3pids: matrix.org
|
||||
#
|
||||
# If you are using an IS you can also check whether that IS registers
|
||||
# pending invites for the given 3PID (and then allow it to sign up on
|
||||
# the platform):
|
||||
#
|
||||
# allow_invited_3pids: False
|
||||
#
|
||||
# allowed_local_3pids:
|
||||
# - medium: email
|
||||
# pattern: ".*@matrix\\.org"
|
||||
@@ -105,11 +71,6 @@ class RegistrationConfig(Config):
|
||||
# - medium: msisdn
|
||||
# pattern: "\\+44"
|
||||
|
||||
# If true, stop users from trying to change the 3PIDs associated with
|
||||
# their accounts.
|
||||
#
|
||||
# disable_3pid_changes: False
|
||||
|
||||
# If set, allows registration by anyone who also has the shared
|
||||
# secret, even if registration is otherwise disabled.
|
||||
registration_shared_secret: "%(registration_shared_secret)s"
|
||||
@@ -133,32 +94,10 @@ class RegistrationConfig(Config):
|
||||
- vector.im
|
||||
- riot.im
|
||||
|
||||
# If enabled, user IDs, display names and avatar URLs will be replicated
|
||||
# to this server whenever they change.
|
||||
# This is an experimental API currently implemented by sydent to support
|
||||
# cross-homeserver user directories.
|
||||
# replicate_user_profiles_to: example.com
|
||||
|
||||
# If specified, attempt to replay registrations on the given target
|
||||
# homeserver and identity server. The HS is authed via a given shared secret
|
||||
# chain_register:
|
||||
# hs: https://shadow.example.com
|
||||
# hs_shared_secret: 12u394refgbdhivsia
|
||||
# is: https://shadow-is.example.com
|
||||
|
||||
# If enabled, don't let users set their own display names/avatars
|
||||
# other than for the very first time (unless they are a server admin).
|
||||
# Useful when provisioning users based on the contents of a 3rd party
|
||||
# directory and to avoid ambiguities.
|
||||
#
|
||||
# disable_set_displayname: False
|
||||
# disable_set_avatar_url: False
|
||||
|
||||
# Users who register on this homeserver will automatically be joined
|
||||
# to these rooms
|
||||
#auto_join_rooms:
|
||||
# - "#example:example.com"
|
||||
|
||||
""" % locals()
|
||||
|
||||
def add_arguments(self, parser):
|
||||
|
||||
@@ -77,10 +77,15 @@ class ServerConfig(Config):
|
||||
self.max_mau_value = config.get(
|
||||
"max_mau_value", 0,
|
||||
)
|
||||
|
||||
self.mau_limits_reserved_threepids = config.get(
|
||||
"mau_limit_reserved_threepids", []
|
||||
)
|
||||
|
||||
self.mau_trial_days = config.get(
|
||||
"mau_trial_days", 0,
|
||||
)
|
||||
|
||||
# Options to disable HS
|
||||
self.hs_disabled = config.get("hs_disabled", False)
|
||||
self.hs_disabled_message = config.get("hs_disabled_message", "")
|
||||
@@ -365,6 +370,7 @@ class ServerConfig(Config):
|
||||
# Enables monthly active user checking
|
||||
# limit_usage_by_mau: False
|
||||
# max_mau_value: 50
|
||||
# mau_trial_days: 2
|
||||
#
|
||||
# Sometimes the server admin will want to ensure certain accounts are
|
||||
# never blocked by mau checking. These accounts are specified here.
|
||||
|
||||
@@ -23,15 +23,11 @@ class UserDirectoryConfig(Config):
|
||||
|
||||
def read_config(self, config):
|
||||
self.user_directory_search_all_users = False
|
||||
self.user_directory_defer_to_id_server = None
|
||||
user_directory_config = config.get("user_directory", None)
|
||||
if user_directory_config:
|
||||
self.user_directory_search_all_users = (
|
||||
user_directory_config.get("search_all_users", False)
|
||||
)
|
||||
self.user_directory_defer_to_id_server = (
|
||||
user_directory_config.get("defer_to_id_server", None)
|
||||
)
|
||||
|
||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||
return """
|
||||
@@ -45,9 +41,4 @@ class UserDirectoryConfig(Config):
|
||||
#
|
||||
#user_directory:
|
||||
# search_all_users: false
|
||||
#
|
||||
# If this is set, user search will be delegated to this ID server instead
|
||||
# of synapse performing the search itself.
|
||||
# This is an experimental API.
|
||||
# defer_to_id_server: id.example.com
|
||||
"""
|
||||
|
||||
@@ -18,7 +18,9 @@ import logging
|
||||
from canonicaljson import json
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.error import ConnectError
|
||||
from twisted.internet.protocol import Factory
|
||||
from twisted.names.error import DomainError
|
||||
from twisted.web.http import HTTPClient
|
||||
|
||||
from synapse.http.endpoint import matrix_federation_endpoint
|
||||
@@ -47,12 +49,14 @@ def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1):
|
||||
server_response, server_certificate = yield protocol.remote_key
|
||||
defer.returnValue((server_response, server_certificate))
|
||||
except SynapseKeyClientError as e:
|
||||
logger.exception("Error getting key for %r" % (server_name,))
|
||||
logger.warn("Error getting key for %r: %s", server_name, e)
|
||||
if e.status.startswith("4"):
|
||||
# Don't retry for 4xx responses.
|
||||
raise IOError("Cannot get key for %r" % server_name)
|
||||
except (ConnectError, DomainError) as e:
|
||||
logger.warn("Error getting key for %r: %s", server_name, e)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
logger.exception("Error getting key for %r", server_name)
|
||||
raise IOError("Cannot get key for %r" % server_name)
|
||||
|
||||
|
||||
|
||||
@@ -106,7 +106,7 @@ class TransportLayerClient(object):
|
||||
dest (str)
|
||||
room_id (str)
|
||||
event_tuples (list)
|
||||
limt (int)
|
||||
limit (int)
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a dict received from the remote homeserver.
|
||||
|
||||
@@ -261,10 +261,10 @@ class BaseFederationServlet(object):
|
||||
except NoAuthenticationError:
|
||||
origin = None
|
||||
if self.REQUIRE_AUTH:
|
||||
logger.exception("authenticate_request failed")
|
||||
logger.warn("authenticate_request failed: missing authentication")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("authenticate_request failed")
|
||||
except Exception as e:
|
||||
logger.warn("authenticate_request failed: %s", e)
|
||||
raise
|
||||
|
||||
if origin:
|
||||
|
||||
@@ -33,7 +33,6 @@ class DeactivateAccountHandler(BaseHandler):
|
||||
self._device_handler = hs.get_device_handler()
|
||||
self._room_member_handler = hs.get_room_member_handler()
|
||||
self._identity_handler = hs.get_handlers().identity_handler
|
||||
self._profile_handler = hs.get_profile_handler()
|
||||
self.user_directory_handler = hs.get_user_directory_handler()
|
||||
|
||||
# Flag that indicates whether the process to part users from rooms is running
|
||||
@@ -95,9 +94,6 @@ class DeactivateAccountHandler(BaseHandler):
|
||||
|
||||
yield self.store.user_set_password_hash(user_id, None)
|
||||
|
||||
user = UserID.from_string(user_id)
|
||||
yield self._profile_handler.set_active(user, False)
|
||||
|
||||
# Add the user to a table of users pending deactivation (ie.
|
||||
# removal from all the rooms they're a member of)
|
||||
yield self.store.add_user_pending_deactivation(user_id)
|
||||
|
||||
@@ -291,8 +291,9 @@ class FederationHandler(BaseHandler):
|
||||
ev_ids, get_prev_content=False, check_redacted=False
|
||||
)
|
||||
|
||||
room_version = yield self.store.get_room_version(pdu.room_id)
|
||||
state_map = yield resolve_events_with_factory(
|
||||
state_groups, {pdu.event_id: pdu}, fetch
|
||||
room_version, state_groups, {pdu.event_id: pdu}, fetch
|
||||
)
|
||||
|
||||
state = (yield self.store.get_events(state_map.values())).values()
|
||||
@@ -1828,7 +1829,10 @@ class FederationHandler(BaseHandler):
|
||||
(d.type, d.state_key): d for d in different_events if d
|
||||
})
|
||||
|
||||
room_version = yield self.store.get_room_version(event.room_id)
|
||||
|
||||
new_state = self.state_handler.resolve_events(
|
||||
room_version,
|
||||
[list(local_view.values()), list(remote_view.values())],
|
||||
event
|
||||
)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -16,9 +15,7 @@
|
||||
|
||||
import logging
|
||||
|
||||
from signedjson.sign import sign_json
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
@@ -29,21 +26,22 @@ from synapse.api.errors import (
|
||||
)
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.types import UserID, get_domain_from_id
|
||||
from synapse.util.logcontext import run_in_background
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProfileHandler(BaseHandler):
|
||||
PROFILE_UPDATE_MS = 60 * 1000
|
||||
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
|
||||
class BaseProfileHandler(BaseHandler):
|
||||
"""Handles fetching and updating user profile information.
|
||||
|
||||
PROFILE_REPLICATE_INTERVAL = 2 * 60 * 1000
|
||||
BaseProfileHandler can be instantiated directly on workers and will
|
||||
delegate to master when necessary. The master process should use the
|
||||
subclass MasterProfileHandler
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ProfileHandler, self).__init__(hs)
|
||||
super(BaseProfileHandler, self).__init__(hs)
|
||||
|
||||
self.federation = hs.get_federation_client()
|
||||
hs.get_federation_registry().register_query_handler(
|
||||
@@ -52,84 +50,6 @@ class ProfileHandler(BaseHandler):
|
||||
|
||||
self.user_directory_handler = hs.get_user_directory_handler()
|
||||
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
|
||||
if hs.config.worker_app is None:
|
||||
self.clock.looping_call(
|
||||
self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS,
|
||||
)
|
||||
|
||||
if len(self.hs.config.replicate_user_profiles_to) > 0:
|
||||
reactor.callWhenRunning(self._assign_profile_replication_batches)
|
||||
reactor.callWhenRunning(self._replicate_profiles)
|
||||
# Add a looping call to replicate_profiles: this handles retries
|
||||
# if the replication is unsuccessful when the user updated their
|
||||
# profile.
|
||||
self.clock.looping_call(
|
||||
self._replicate_profiles, self.PROFILE_REPLICATE_INTERVAL
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _assign_profile_replication_batches(self):
|
||||
"""If no profile replication has been done yet, allocate replication batch
|
||||
numbers to each profile to start the replication process.
|
||||
"""
|
||||
logger.info("Assigning profile batch numbers...")
|
||||
total = 0
|
||||
while True:
|
||||
assigned = yield self.store.assign_profile_batch()
|
||||
total += assigned
|
||||
if assigned == 0:
|
||||
break
|
||||
logger.info("Assigned %d profile batch numbers", total)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _replicate_profiles(self):
|
||||
"""If any profile data has been updated and not pushed to the replication targets,
|
||||
replicate it.
|
||||
"""
|
||||
host_batches = yield self.store.get_replication_hosts()
|
||||
latest_batch = yield self.store.get_latest_profile_replication_batch_number()
|
||||
if latest_batch is None:
|
||||
latest_batch = -1
|
||||
for repl_host in self.hs.config.replicate_user_profiles_to:
|
||||
if repl_host not in host_batches:
|
||||
host_batches[repl_host] = -1
|
||||
try:
|
||||
for i in xrange(host_batches[repl_host] + 1, latest_batch + 1):
|
||||
yield self._replicate_host_profile_batch(repl_host, i)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Exception while replicating to %s: aborting for now", repl_host,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _replicate_host_profile_batch(self, host, batchnum):
|
||||
logger.info("Replicating profile batch %d to %s", batchnum, host)
|
||||
batch_rows = yield self.store.get_profile_batch(batchnum)
|
||||
batch = {
|
||||
UserID(r["user_id"], self.hs.hostname).to_string(): ({
|
||||
"display_name": r["displayname"],
|
||||
"avatar_url": r["avatar_url"],
|
||||
} if r["active"] else None) for r in batch_rows
|
||||
}
|
||||
|
||||
url = "https://%s/_matrix/identity/api/v1/replicate_profiles" % (host,)
|
||||
body = {
|
||||
"batchnum": batchnum,
|
||||
"batch": batch,
|
||||
"origin_server": self.hs.hostname,
|
||||
}
|
||||
signed_body = sign_json(body, self.hs.hostname, self.hs.config.signing_key[0])
|
||||
try:
|
||||
yield self.http_client.post_json_get_json(url, signed_body)
|
||||
yield self.store.update_replication_batch_for_host(host, batchnum)
|
||||
logger.info("Sucessfully replicated profile batch %d to %s", batchnum, host)
|
||||
except Exception:
|
||||
# This will get retried when the looping call next comes around
|
||||
logger.exception("Failed to replicate profile batch %d to %s", batchnum, host)
|
||||
raise
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_profile(self, user_id):
|
||||
target_user = UserID.from_string(user_id)
|
||||
@@ -229,30 +149,19 @@ class ProfileHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_displayname(self, target_user, requester, new_displayname, by_admin=False):
|
||||
"""target_user is the UserID whose displayname is to be changed;
|
||||
requester is the authenticated user attempting to make this change."""
|
||||
"""target_user is the user whose displayname is to be changed;
|
||||
auth_user is the user attempting to make this change."""
|
||||
if not self.hs.is_mine(target_user):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
if not by_admin and requester and target_user != requester.user:
|
||||
if not by_admin and target_user != requester.user:
|
||||
raise AuthError(400, "Cannot set another user's displayname")
|
||||
|
||||
if not by_admin and self.hs.config.disable_set_displayname:
|
||||
profile = yield self.store.get_profileinfo(target_user.localpart)
|
||||
if profile.display_name:
|
||||
raise SynapseError(400, "Changing displayname is disabled on this server")
|
||||
|
||||
if new_displayname == '':
|
||||
new_displayname = None
|
||||
|
||||
if len(self.hs.config.replicate_user_profiles_to) > 0:
|
||||
cur_batchnum = yield self.store.get_latest_profile_replication_batch_number()
|
||||
new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
|
||||
else:
|
||||
new_batchnum = None
|
||||
|
||||
yield self.store.set_profile_displayname(
|
||||
target_user.localpart, new_displayname, new_batchnum
|
||||
target_user.localpart, new_displayname
|
||||
)
|
||||
|
||||
if self.hs.config.user_directory_search_all_users:
|
||||
@@ -261,32 +170,7 @@ class ProfileHandler(BaseHandler):
|
||||
target_user.to_string(), profile
|
||||
)
|
||||
|
||||
if requester:
|
||||
yield self._update_join_states(requester, target_user)
|
||||
|
||||
# start a profile replication push
|
||||
run_in_background(self._replicate_profiles)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_active(self, target_user, active):
|
||||
"""
|
||||
Sets the 'active' flag on a user profile. If set to false, the user account is
|
||||
considered deactivated.
|
||||
Note that unlike set_displayname and set_avatar_url, this does *not* perform
|
||||
authorization checks! This is because the only place it's used currently is
|
||||
in account deactivation where we've already done these checks anyway.
|
||||
"""
|
||||
if len(self.hs.config.replicate_user_profiles_to) > 0:
|
||||
cur_batchnum = yield self.store.get_latest_profile_replication_batch_number()
|
||||
new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
|
||||
else:
|
||||
new_batchnum = None
|
||||
yield self.store.set_profile_active(
|
||||
target_user.localpart, active, new_batchnum
|
||||
)
|
||||
|
||||
# start a profile replication push
|
||||
run_in_background(self._replicate_profiles)
|
||||
yield self._update_join_states(requester, target_user)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_avatar_url(self, target_user):
|
||||
@@ -330,19 +214,8 @@ class ProfileHandler(BaseHandler):
|
||||
if not by_admin and target_user != requester.user:
|
||||
raise AuthError(400, "Cannot set another user's avatar_url")
|
||||
|
||||
if not by_admin and self.hs.config.disable_set_avatar_url:
|
||||
profile = yield self.store.get_profileinfo(target_user.localpart)
|
||||
if profile.avatar_url:
|
||||
raise SynapseError(400, "Changing avatar url is disabled on this server")
|
||||
|
||||
if len(self.hs.config.replicate_user_profiles_to) > 0:
|
||||
cur_batchnum = yield self.store.get_latest_profile_replication_batch_number()
|
||||
new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
|
||||
else:
|
||||
new_batchnum = None
|
||||
|
||||
yield self.store.set_profile_avatar_url(
|
||||
target_user.localpart, new_avatar_url, new_batchnum,
|
||||
target_user.localpart, new_avatar_url
|
||||
)
|
||||
|
||||
if self.hs.config.user_directory_search_all_users:
|
||||
@@ -353,9 +226,6 @@ class ProfileHandler(BaseHandler):
|
||||
|
||||
yield self._update_join_states(requester, target_user)
|
||||
|
||||
# start a profile replication push
|
||||
run_in_background(self._replicate_profiles)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_profile_query(self, args):
|
||||
user = UserID.from_string(args["user_id"])
|
||||
@@ -411,6 +281,20 @@ class ProfileHandler(BaseHandler):
|
||||
room_id, str(e.message)
|
||||
)
|
||||
|
||||
|
||||
class MasterProfileHandler(BaseProfileHandler):
|
||||
PROFILE_UPDATE_MS = 60 * 1000
|
||||
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
|
||||
|
||||
def __init__(self, hs):
|
||||
super(MasterProfileHandler, self).__init__(hs)
|
||||
|
||||
assert hs.config.worker_app is None
|
||||
|
||||
self.clock.looping_call(
|
||||
self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS,
|
||||
)
|
||||
|
||||
def _start_update_remote_profile_cache(self):
|
||||
return run_as_background_process(
|
||||
"Update remote profile", self._update_remote_profile_cache,
|
||||
|
||||
@@ -51,7 +51,6 @@ class RegistrationHandler(BaseHandler):
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
self.user_directory_handler = hs.get_user_directory_handler()
|
||||
self.captcha_client = CaptchaServerHttpClient(hs)
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
|
||||
self._next_generated_user_id = None
|
||||
|
||||
@@ -125,7 +124,6 @@ class RegistrationHandler(BaseHandler):
|
||||
generate_token=True,
|
||||
guest_access_token=None,
|
||||
make_guest=False,
|
||||
display_name=None,
|
||||
admin=False,
|
||||
):
|
||||
"""Registers a new client on the server.
|
||||
@@ -141,7 +139,6 @@ class RegistrationHandler(BaseHandler):
|
||||
since it offers no means of associating a device_id with the
|
||||
access_token. Instead you should call auth_handler.issue_access_token
|
||||
after registration.
|
||||
display_name (str): The displayname to set for this user, if any
|
||||
Returns:
|
||||
A tuple of (user_id, access_token).
|
||||
Raises:
|
||||
@@ -180,19 +177,12 @@ class RegistrationHandler(BaseHandler):
|
||||
password_hash=password_hash,
|
||||
was_guest=was_guest,
|
||||
make_guest=make_guest,
|
||||
admin=admin,
|
||||
)
|
||||
|
||||
if display_name is None:
|
||||
display_name = (
|
||||
create_profile_with_localpart=(
|
||||
# If the user was a guest then they already have a profile
|
||||
None if was_guest else user.localpart
|
||||
)
|
||||
|
||||
if display_name:
|
||||
yield self.profile_handler.set_displayname(
|
||||
user, None, display_name, by_admin=True,
|
||||
)
|
||||
),
|
||||
admin=admin,
|
||||
)
|
||||
|
||||
if self.hs.config.user_directory_search_all_users:
|
||||
profile = yield self.store.get_profileinfo(localpart)
|
||||
@@ -218,12 +208,8 @@ class RegistrationHandler(BaseHandler):
|
||||
token=token,
|
||||
password_hash=password_hash,
|
||||
make_guest=make_guest,
|
||||
create_profile_with_localpart=user.localpart,
|
||||
)
|
||||
|
||||
yield self.profile_handler.set_displayname(
|
||||
user, None, user.localpart, by_admin=True,
|
||||
)
|
||||
|
||||
except SynapseError:
|
||||
# if user id is taken, just generate another
|
||||
user = None
|
||||
@@ -267,12 +253,8 @@ class RegistrationHandler(BaseHandler):
|
||||
user_id=user_id,
|
||||
password_hash="",
|
||||
appservice_id=service_id,
|
||||
create_profile_with_localpart=user.localpart,
|
||||
)
|
||||
|
||||
yield self.profile_handler.set_displayname(
|
||||
user, None, user.localpart, by_admin=True,
|
||||
)
|
||||
|
||||
defer.returnValue(user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -319,10 +301,7 @@ class RegistrationHandler(BaseHandler):
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
password_hash=None,
|
||||
)
|
||||
|
||||
yield self.profile_handler.set_displayname(
|
||||
user, None, user.localpart, by_admin=True,
|
||||
create_profile_with_localpart=user.localpart,
|
||||
)
|
||||
except Exception as e:
|
||||
yield self.store.add_access_token_to_user(user_id, token)
|
||||
@@ -353,9 +332,7 @@ class RegistrationHandler(BaseHandler):
|
||||
logger.info("got threepid with medium '%s' and address '%s'",
|
||||
threepid['medium'], threepid['address'])
|
||||
|
||||
if not (
|
||||
yield check_3pid_allowed(self.hs, threepid['medium'], threepid['address'])
|
||||
):
|
||||
if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']):
|
||||
raise RegistrationError(
|
||||
403, "Third party identifier is not allowed"
|
||||
)
|
||||
@@ -397,43 +374,6 @@ class RegistrationHandler(BaseHandler):
|
||||
errcode=Codes.EXCLUSIVE
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def chain_register(self, localpart, auth_result, params):
|
||||
"""Invokes the current registration on another server, using
|
||||
shared secret registration, passing in any auth_results from
|
||||
other registration UI auth flows (e.g. validated 3pids)
|
||||
Useful for setting up shadow/backup accounts on a parallel deployment.
|
||||
"""
|
||||
|
||||
# TODO: retries
|
||||
|
||||
chained_hs = self.hs.config.chain_register.get("hs")
|
||||
|
||||
user = localpart.encode("utf-8")
|
||||
mac = hmac.new(
|
||||
key=self.hs.config.chain_register.get("hs_shared_secret").encode(),
|
||||
msg=user,
|
||||
digestmod=sha1,
|
||||
).hexdigest()
|
||||
|
||||
data = yield self.http_client.post_urlencoded_get_json(
|
||||
"https://%s%s" % (
|
||||
chained_hs, "/_matrix/client/r0/register"
|
||||
),
|
||||
{
|
||||
# XXX: auth_result is an unspecified extension for chained registration
|
||||
'auth_result': auth_result,
|
||||
'username': localpart,
|
||||
'password': params.get("password"),
|
||||
'bind_email': params.get("bind_email"),
|
||||
'bind_msisdn': params.get("bind_msisdn"),
|
||||
'device_id': params.get("device_id"),
|
||||
'initial_device_display_name': params.get("initial_device_display_name"),
|
||||
'inhibit_login': True,
|
||||
'mac': mac,
|
||||
}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _generate_user_id(self, reseed=False):
|
||||
if reseed or self._next_generated_user_id is None:
|
||||
@@ -520,15 +460,18 @@ class RegistrationHandler(BaseHandler):
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
password_hash=password_hash,
|
||||
create_profile_with_localpart=user.localpart,
|
||||
)
|
||||
if displayname is not None:
|
||||
yield self.profile_handler.set_displayname(
|
||||
user, None, displayname, by_admin=True,
|
||||
)
|
||||
else:
|
||||
yield self._auth_handler.delete_access_tokens_for_user(user_id)
|
||||
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
|
||||
|
||||
if displayname is not None:
|
||||
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
||||
yield self.profile_handler.set_displayname(
|
||||
user, requester, displayname, by_admin=True,
|
||||
)
|
||||
|
||||
defer.returnValue((user_id, token))
|
||||
|
||||
def auth_handler(self):
|
||||
|
||||
@@ -52,14 +52,12 @@ class RoomCreationHandler(BaseHandler):
|
||||
"history_visibility": "shared",
|
||||
"original_invitees_have_ops": False,
|
||||
"guest_can_join": True,
|
||||
"encryption_alg": "m.megolm.v1.aes-sha2",
|
||||
},
|
||||
RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
|
||||
"join_rules": JoinRules.INVITE,
|
||||
"history_visibility": "shared",
|
||||
"original_invitees_have_ops": True,
|
||||
"guest_can_join": True,
|
||||
"encryption_alg": "m.megolm.v1.aes-sha2",
|
||||
},
|
||||
RoomCreationPreset.PUBLIC_CHAT: {
|
||||
"join_rules": JoinRules.PUBLIC,
|
||||
@@ -427,15 +425,6 @@ class RoomCreationHandler(BaseHandler):
|
||||
content=content,
|
||||
)
|
||||
|
||||
if "encryption_alg" in config:
|
||||
send(
|
||||
etype=EventTypes.Encryption,
|
||||
state_key="",
|
||||
content={
|
||||
'algorithm': config["encryption_alg"],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class RoomContextHandler(object):
|
||||
def __init__(self, hs):
|
||||
|
||||
@@ -344,6 +344,7 @@ class RoomMemberHandler(object):
|
||||
latest_event_ids = (
|
||||
event_id for (event_id, _, _) in prev_events_and_hashes
|
||||
)
|
||||
|
||||
current_state_ids = yield self.state_handler.get_current_state_ids(
|
||||
room_id, latest_event_ids=latest_event_ids,
|
||||
)
|
||||
|
||||
@@ -854,7 +854,7 @@ class SyncHandler(object):
|
||||
res = yield self._generate_sync_entry_for_rooms(
|
||||
sync_result_builder, account_data_by_room
|
||||
)
|
||||
newly_joined_rooms, newly_joined_or_invited_users, _, _ = res
|
||||
newly_joined_rooms, newly_joined_users, _, _ = res
|
||||
_, _, newly_left_rooms, newly_left_users = res
|
||||
|
||||
block_all_presence_data = (
|
||||
@@ -863,7 +863,7 @@ class SyncHandler(object):
|
||||
)
|
||||
if self.hs_config.use_presence and not block_all_presence_data:
|
||||
yield self._generate_sync_entry_for_presence(
|
||||
sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
|
||||
sync_result_builder, newly_joined_rooms, newly_joined_users
|
||||
)
|
||||
|
||||
yield self._generate_sync_entry_for_to_device(sync_result_builder)
|
||||
@@ -871,7 +871,7 @@ class SyncHandler(object):
|
||||
device_lists = yield self._generate_sync_entry_for_device_list(
|
||||
sync_result_builder,
|
||||
newly_joined_rooms=newly_joined_rooms,
|
||||
newly_joined_or_invited_users=newly_joined_or_invited_users,
|
||||
newly_joined_users=newly_joined_users,
|
||||
newly_left_rooms=newly_left_rooms,
|
||||
newly_left_users=newly_left_users,
|
||||
)
|
||||
@@ -947,8 +947,7 @@ class SyncHandler(object):
|
||||
@measure_func("_generate_sync_entry_for_device_list")
|
||||
@defer.inlineCallbacks
|
||||
def _generate_sync_entry_for_device_list(self, sync_result_builder,
|
||||
newly_joined_rooms,
|
||||
newly_joined_or_invited_users,
|
||||
newly_joined_rooms, newly_joined_users,
|
||||
newly_left_rooms, newly_left_users):
|
||||
user_id = sync_result_builder.sync_config.user.to_string()
|
||||
since_token = sync_result_builder.since_token
|
||||
@@ -962,7 +961,7 @@ class SyncHandler(object):
|
||||
# share a room with?
|
||||
for room_id in newly_joined_rooms:
|
||||
joined_users = yield self.state.get_current_user_in_room(room_id)
|
||||
newly_joined_or_invited_users.update(joined_users)
|
||||
newly_joined_users.update(joined_users)
|
||||
|
||||
for room_id in newly_left_rooms:
|
||||
left_users = yield self.state.get_current_user_in_room(room_id)
|
||||
@@ -970,7 +969,7 @@ class SyncHandler(object):
|
||||
|
||||
# TODO: Check that these users are actually new, i.e. either they
|
||||
# weren't in the previous sync *or* they left and rejoined.
|
||||
changed.update(newly_joined_or_invited_users)
|
||||
changed.update(newly_joined_users)
|
||||
|
||||
if not changed and not newly_left_users:
|
||||
defer.returnValue(DeviceLists(
|
||||
@@ -1088,7 +1087,7 @@ class SyncHandler(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _generate_sync_entry_for_presence(self, sync_result_builder, newly_joined_rooms,
|
||||
newly_joined_or_invited_users):
|
||||
newly_joined_users):
|
||||
"""Generates the presence portion of the sync response. Populates the
|
||||
`sync_result_builder` with the result.
|
||||
|
||||
@@ -1096,9 +1095,8 @@ class SyncHandler(object):
|
||||
sync_result_builder(SyncResultBuilder)
|
||||
newly_joined_rooms(list): List of rooms that the user has joined
|
||||
since the last sync (or empty if an initial sync)
|
||||
newly_joined_or_invited_users(list): List of users that have joined
|
||||
or been invited to rooms since the last sync (or empty if an initial
|
||||
sync)
|
||||
newly_joined_users(list): List of users that have joined rooms
|
||||
since the last sync (or empty if an initial sync)
|
||||
"""
|
||||
now_token = sync_result_builder.now_token
|
||||
sync_config = sync_result_builder.sync_config
|
||||
@@ -1124,7 +1122,7 @@ class SyncHandler(object):
|
||||
"presence_key", presence_key
|
||||
)
|
||||
|
||||
extra_users_ids = set(newly_joined_or_invited_users)
|
||||
extra_users_ids = set(newly_joined_users)
|
||||
for room_id in newly_joined_rooms:
|
||||
users = yield self.state.get_current_user_in_room(room_id)
|
||||
extra_users_ids.update(users)
|
||||
@@ -1156,8 +1154,7 @@ class SyncHandler(object):
|
||||
|
||||
Returns:
|
||||
Deferred(tuple): Returns a 4-tuple of
|
||||
`(newly_joined_rooms, newly_joined_or_invited_users,
|
||||
newly_left_rooms, newly_left_users)`
|
||||
`(newly_joined_rooms, newly_joined_users, newly_left_rooms, newly_left_users)`
|
||||
"""
|
||||
user_id = sync_result_builder.sync_config.user.to_string()
|
||||
block_all_room_ephemeral = (
|
||||
@@ -1228,8 +1225,8 @@ class SyncHandler(object):
|
||||
|
||||
sync_result_builder.invited.extend(invited)
|
||||
|
||||
# Now we want to get any newly joined or invited users
|
||||
newly_joined_or_invited_users = set()
|
||||
# Now we want to get any newly joined users
|
||||
newly_joined_users = set()
|
||||
newly_left_users = set()
|
||||
if since_token:
|
||||
for joined_sync in sync_result_builder.joined:
|
||||
@@ -1238,22 +1235,19 @@ class SyncHandler(object):
|
||||
)
|
||||
for event in it:
|
||||
if event.type == EventTypes.Member:
|
||||
if (
|
||||
event.membership == Membership.JOIN or
|
||||
event.membership == Membership.INVITE
|
||||
):
|
||||
newly_joined_or_invited_users.add(event.state_key)
|
||||
if event.membership == Membership.JOIN:
|
||||
newly_joined_users.add(event.state_key)
|
||||
else:
|
||||
prev_content = event.unsigned.get("prev_content", {})
|
||||
prev_membership = prev_content.get("membership", None)
|
||||
if prev_membership == Membership.JOIN:
|
||||
newly_left_users.add(event.state_key)
|
||||
|
||||
newly_left_users -= newly_joined_or_invited_users
|
||||
newly_left_users -= newly_joined_users
|
||||
|
||||
defer.returnValue((
|
||||
newly_joined_rooms,
|
||||
newly_joined_or_invited_users,
|
||||
newly_joined_users,
|
||||
newly_left_rooms,
|
||||
newly_left_users,
|
||||
))
|
||||
@@ -1298,7 +1292,7 @@ class SyncHandler(object):
|
||||
where:
|
||||
room_entries is a list [RoomSyncResultBuilder]
|
||||
invited_rooms is a list [InvitedSyncResult]
|
||||
newly_joined_rooms is a list[str] of room ids
|
||||
newly_joined rooms is a list[str] of room ids
|
||||
newly_left_rooms is a list[str] of room ids
|
||||
"""
|
||||
user_id = sync_result_builder.sync_config.user.to_string()
|
||||
@@ -1333,7 +1327,7 @@ class SyncHandler(object):
|
||||
if room_id in sync_result_builder.joined_room_ids and non_joins:
|
||||
# Always include if the user (re)joined the room, especially
|
||||
# important so that device list changes are calculated correctly.
|
||||
# If there are non-join member events, but we are still in the room,
|
||||
# If there are non join member events, but we are still in the room,
|
||||
# then the user must have left and joined
|
||||
newly_joined_rooms.append(room_id)
|
||||
|
||||
|
||||
@@ -119,6 +119,8 @@ class UserDirectoryHandler(object):
|
||||
"""Called to update index of our local user profiles when they change
|
||||
irrespective of any rooms the user may be in.
|
||||
"""
|
||||
# FIXME(#3714): We should probably do this in the same worker as all
|
||||
# the other changes.
|
||||
yield self.store.update_profile_in_user_dir(
|
||||
user_id, profile.display_name, profile.avatar_url, None,
|
||||
)
|
||||
@@ -127,6 +129,8 @@ class UserDirectoryHandler(object):
|
||||
def handle_user_deactivated(self, user_id):
|
||||
"""Called when a user ID is deactivated
|
||||
"""
|
||||
# FIXME(#3714): We should probably do this in the same worker as all
|
||||
# the other changes.
|
||||
yield self.store.remove_from_user_dir(user_id)
|
||||
yield self.store.remove_from_user_in_public_room(user_id)
|
||||
|
||||
|
||||
@@ -133,7 +133,7 @@ class MatrixFederationHttpClient(object):
|
||||
failures, connection failures, SSL failures.)
|
||||
"""
|
||||
if (
|
||||
self.hs.config.federation_domain_whitelist and
|
||||
self.hs.config.federation_domain_whitelist is not None and
|
||||
destination not in self.hs.config.federation_domain_whitelist
|
||||
):
|
||||
raise FederationDeniedError(destination)
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import threading
|
||||
|
||||
from prometheus_client.core import Counter, Histogram
|
||||
|
||||
@@ -111,6 +112,9 @@ in_flight_requests_db_sched_duration = Counter(
|
||||
# The set of all in flight requests, set[RequestMetrics]
|
||||
_in_flight_requests = set()
|
||||
|
||||
# Protects the _in_flight_requests set from concurrent accesss
|
||||
_in_flight_requests_lock = threading.Lock()
|
||||
|
||||
|
||||
def _get_in_flight_counts():
|
||||
"""Returns a count of all in flight requests by (method, server_name)
|
||||
@@ -120,7 +124,8 @@ def _get_in_flight_counts():
|
||||
"""
|
||||
# Cast to a list to prevent it changing while the Prometheus
|
||||
# thread is collecting metrics
|
||||
reqs = list(_in_flight_requests)
|
||||
with _in_flight_requests_lock:
|
||||
reqs = list(_in_flight_requests)
|
||||
|
||||
for rm in reqs:
|
||||
rm.update_metrics()
|
||||
@@ -154,10 +159,12 @@ class RequestMetrics(object):
|
||||
# to the "in flight" metrics.
|
||||
self._request_stats = self.start_context.get_resource_usage()
|
||||
|
||||
_in_flight_requests.add(self)
|
||||
with _in_flight_requests_lock:
|
||||
_in_flight_requests.add(self)
|
||||
|
||||
def stop(self, time_sec, request):
|
||||
_in_flight_requests.discard(self)
|
||||
with _in_flight_requests_lock:
|
||||
_in_flight_requests.discard(self)
|
||||
|
||||
context = LoggingContext.current_context()
|
||||
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import threading
|
||||
|
||||
import six
|
||||
|
||||
from prometheus_client.core import REGISTRY, Counter, GaugeMetricFamily
|
||||
@@ -78,6 +80,9 @@ _background_process_counts = dict() # type: dict[str, int]
|
||||
# of process descriptions that no longer have any active processes.
|
||||
_background_processes = dict() # type: dict[str, set[_BackgroundProcess]]
|
||||
|
||||
# A lock that covers the above dicts
|
||||
_bg_metrics_lock = threading.Lock()
|
||||
|
||||
|
||||
class _Collector(object):
|
||||
"""A custom metrics collector for the background process metrics.
|
||||
@@ -92,7 +97,11 @@ class _Collector(object):
|
||||
labels=["name"],
|
||||
)
|
||||
|
||||
for desc, processes in six.iteritems(_background_processes):
|
||||
# We copy the dict so that it doesn't change from underneath us
|
||||
with _bg_metrics_lock:
|
||||
_background_processes_copy = dict(_background_processes)
|
||||
|
||||
for desc, processes in six.iteritems(_background_processes_copy):
|
||||
background_process_in_flight_count.add_metric(
|
||||
(desc,), len(processes),
|
||||
)
|
||||
@@ -167,19 +176,26 @@ def run_as_background_process(desc, func, *args, **kwargs):
|
||||
"""
|
||||
@defer.inlineCallbacks
|
||||
def run():
|
||||
count = _background_process_counts.get(desc, 0)
|
||||
_background_process_counts[desc] = count + 1
|
||||
with _bg_metrics_lock:
|
||||
count = _background_process_counts.get(desc, 0)
|
||||
_background_process_counts[desc] = count + 1
|
||||
|
||||
_background_process_start_count.labels(desc).inc()
|
||||
|
||||
with LoggingContext(desc) as context:
|
||||
context.request = "%s-%i" % (desc, count)
|
||||
proc = _BackgroundProcess(desc, context)
|
||||
_background_processes.setdefault(desc, set()).add(proc)
|
||||
|
||||
with _bg_metrics_lock:
|
||||
_background_processes.setdefault(desc, set()).add(proc)
|
||||
|
||||
try:
|
||||
yield func(*args, **kwargs)
|
||||
finally:
|
||||
proc.update_metrics()
|
||||
_background_processes[desc].remove(proc)
|
||||
|
||||
with _bg_metrics_lock:
|
||||
_background_processes[desc].remove(proc)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
return run()
|
||||
|
||||
@@ -39,7 +39,7 @@ REQUIREMENTS = {
|
||||
"signedjson>=1.0.0": ["signedjson>=1.0.0"],
|
||||
"pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"],
|
||||
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
|
||||
"Twisted>=16.0.0": ["twisted>=16.0.0"],
|
||||
"Twisted>=17.1.0": ["twisted>=17.1.0"],
|
||||
|
||||
# We use crypto.get_elliptic_curve which is only supported in >=0.15
|
||||
"pyopenssl>=0.15": ["OpenSSL>=0.15"],
|
||||
|
||||
@@ -531,7 +531,7 @@ class RoomEventServlet(ClientV1RestServlet):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id, event_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
event = yield self.event_handler.get_event(requester.user, room_id, event_id)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
@@ -51,7 +51,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
|
||||
'id_server', 'client_secret', 'email', 'send_attempt'
|
||||
])
|
||||
|
||||
if not (yield check_3pid_allowed(self.hs, "email", body['email'])):
|
||||
if not check_3pid_allowed(self.hs, "email", body['email']):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||
)
|
||||
@@ -87,7 +87,7 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
|
||||
|
||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||
|
||||
if not (yield check_3pid_allowed(self.hs, "msisdn", msisdn)):
|
||||
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||
)
|
||||
@@ -239,7 +239,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
|
||||
['id_server', 'client_secret', 'email', 'send_attempt'],
|
||||
)
|
||||
|
||||
if not (yield check_3pid_allowed(self.hs, "email", body['email'])):
|
||||
if not check_3pid_allowed(self.hs, "email", body['email']):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||
)
|
||||
@@ -274,7 +274,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
|
||||
|
||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||
|
||||
if not (yield check_3pid_allowed(self.hs, "msisdn", msisdn)):
|
||||
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||
)
|
||||
@@ -313,9 +313,6 @@ class ThreepidRestServlet(RestServlet):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
if self.hs.config.disable_3pid_changes:
|
||||
raise SynapseError(400, "3PID changes disabled on this server")
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
threePidCreds = body.get('threePidCreds')
|
||||
@@ -362,15 +359,11 @@ class ThreepidDeleteRestServlet(RestServlet):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ThreepidDeleteRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
if self.hs.config.disable_3pid_changes:
|
||||
raise SynapseError(400, "3PID changes disabled on this server")
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(body, ['medium', 'address'])
|
||||
|
||||
|
||||
@@ -16,9 +16,7 @@
|
||||
|
||||
import hmac
|
||||
import logging
|
||||
import re
|
||||
from hashlib import sha1
|
||||
from string import capwords
|
||||
|
||||
from six import string_types
|
||||
|
||||
@@ -74,7 +72,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
|
||||
'id_server', 'client_secret', 'email', 'send_attempt'
|
||||
])
|
||||
|
||||
if not (yield check_3pid_allowed(self.hs, "email", body['email'])):
|
||||
if not check_3pid_allowed(self.hs, "email", body['email']):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||
)
|
||||
@@ -114,7 +112,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
|
||||
|
||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||
|
||||
if not (yield check_3pid_allowed(self.hs, "msisdn", msisdn)):
|
||||
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||
)
|
||||
@@ -224,8 +222,6 @@ class RegisterRestServlet(RestServlet):
|
||||
raise SynapseError(400, "Invalid username")
|
||||
desired_username = body['username']
|
||||
|
||||
desired_display_name = None
|
||||
|
||||
appservice = None
|
||||
if self.auth.has_access_token(request):
|
||||
appservice = yield self.auth.get_appservice_by_req(request)
|
||||
@@ -301,6 +297,13 @@ class RegisterRestServlet(RestServlet):
|
||||
session_id, "registered_user_id", None
|
||||
)
|
||||
|
||||
if desired_username is not None:
|
||||
yield self.registration_handler.check_username(
|
||||
desired_username,
|
||||
guest_access_token=guest_access_token,
|
||||
assigned_user_id=registered_user_id,
|
||||
)
|
||||
|
||||
# Only give msisdn flows if the x_show_msisdn flag is given:
|
||||
# this is a hack to work around the fact that clients were shipped
|
||||
# that use fallback registration if they see any flows that they don't
|
||||
@@ -367,87 +370,12 @@ class RegisterRestServlet(RestServlet):
|
||||
medium = auth_result[login_type]['medium']
|
||||
address = auth_result[login_type]['address']
|
||||
|
||||
if not (yield check_3pid_allowed(self.hs, medium, address)):
|
||||
if not check_3pid_allowed(self.hs, medium, address):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed",
|
||||
Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
if self.hs.config.register_mxid_from_3pid:
|
||||
# override the desired_username based on the 3PID if any.
|
||||
# reset it first to avoid folks picking their own username.
|
||||
desired_username = None
|
||||
|
||||
# we should have an auth_result at this point if we're going to progress
|
||||
# to register the user (i.e. we haven't picked up a registered_user_id
|
||||
# from our session store), in which case get ready and gen the
|
||||
# desired_username
|
||||
if auth_result:
|
||||
if (
|
||||
self.hs.config.register_mxid_from_3pid == 'email' and
|
||||
LoginType.EMAIL_IDENTITY in auth_result
|
||||
):
|
||||
address = auth_result[LoginType.EMAIL_IDENTITY]['address']
|
||||
desired_username = synapse.types.strip_invalid_mxid_characters(
|
||||
address.replace('@', '-').lower()
|
||||
)
|
||||
|
||||
# find a unique mxid for the account, suffixing numbers
|
||||
# if needed
|
||||
while True:
|
||||
try:
|
||||
yield self.registration_handler.check_username(
|
||||
desired_username,
|
||||
guest_access_token=guest_access_token,
|
||||
assigned_user_id=registered_user_id,
|
||||
)
|
||||
# if we got this far we passed the check.
|
||||
break
|
||||
except SynapseError as e:
|
||||
if e.errcode == Codes.USER_IN_USE:
|
||||
m = re.match(r'^(.*?)(\d+)$', desired_username)
|
||||
if m:
|
||||
desired_username = m.group(1) + str(
|
||||
int(m.group(2)) + 1
|
||||
)
|
||||
else:
|
||||
desired_username += "1"
|
||||
else:
|
||||
# something else went wrong.
|
||||
break
|
||||
|
||||
# XXX: a nasty heuristic to turn an email address into
|
||||
# a displayname, as part of register_mxid_from_3pid
|
||||
parts = address.replace('.', ' ').split('@')
|
||||
org_parts = parts[1].split(' ')
|
||||
|
||||
if org_parts[-2] == "matrix" and org_parts[-1] == "org":
|
||||
org = "Tchap Admin"
|
||||
elif org_parts[-2] == "gouv" and org_parts[-1] == "fr":
|
||||
org = org_parts[-3] if len(org_parts) > 2 else org_parts[-2]
|
||||
else:
|
||||
org = org_parts[-2]
|
||||
|
||||
desired_display_name = (
|
||||
capwords(parts[0]) + " [" + capwords(org) + "]"
|
||||
)
|
||||
elif (
|
||||
self.hs.config.register_mxid_from_3pid == 'msisdn' and
|
||||
LoginType.MSISDN in auth_result
|
||||
):
|
||||
desired_username = auth_result[LoginType.MSISDN]['address']
|
||||
else:
|
||||
raise SynapseError(
|
||||
400, "Cannot derive mxid from 3pid; no recognised 3pid"
|
||||
)
|
||||
|
||||
if desired_username is not None:
|
||||
yield self.registration_handler.check_username(
|
||||
desired_username,
|
||||
guest_access_token=guest_access_token,
|
||||
assigned_user_id=registered_user_id,
|
||||
)
|
||||
|
||||
if registered_user_id is not None:
|
||||
logger.info(
|
||||
"Already registered user ID %r for this session",
|
||||
@@ -460,35 +388,20 @@ class RegisterRestServlet(RestServlet):
|
||||
# NB: This may be from the auth handler and NOT from the POST
|
||||
assert_params_in_dict(params, ["password"])
|
||||
|
||||
if not self.hs.config.register_mxid_from_3pid:
|
||||
desired_username = params.get("username", None)
|
||||
else:
|
||||
# we keep the original desired_username derived from the 3pid above
|
||||
pass
|
||||
|
||||
desired_username = params.get("username", None)
|
||||
guest_access_token = params.get("guest_access_token", None)
|
||||
|
||||
# XXX: don't we need to validate these for length etc like we did on
|
||||
# the ones from the JSON body earlier on in the method?
|
||||
new_password = params.get("password", None)
|
||||
|
||||
if desired_username is not None:
|
||||
desired_username = desired_username.lower()
|
||||
|
||||
(registered_user_id, _) = yield self.registration_handler.register(
|
||||
localpart=desired_username,
|
||||
password=params.get("password", None),
|
||||
password=new_password,
|
||||
guest_access_token=guest_access_token,
|
||||
generate_token=False,
|
||||
display_name=desired_display_name,
|
||||
)
|
||||
|
||||
if self.hs.config.chain_register:
|
||||
yield self.registration_handler.chain_register(
|
||||
localpart=desired_username,
|
||||
auth_result=auth_result,
|
||||
params=params,
|
||||
)
|
||||
|
||||
# remember that we've now registered that user account, and with
|
||||
# what user ID (since the user may not have specified)
|
||||
self.auth_handler.set_session_data(
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
|
||||
import logging
|
||||
|
||||
from signedjson.sign import sign_json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
@@ -39,7 +37,6 @@ class UserDirectorySearchRestServlet(RestServlet):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.user_directory_handler = hs.get_user_directory_handler()
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
@@ -64,14 +61,6 @@ class UserDirectorySearchRestServlet(RestServlet):
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
if self.hs.config.user_directory_defer_to_id_server:
|
||||
signed_body = sign_json(body, self.hs.hostname, self.hs.config.signing_key[0])
|
||||
url = "http://%s/_matrix/identity/api/v1/user_directory/search" % (
|
||||
self.hs.config.user_directory_defer_to_id_server,
|
||||
)
|
||||
resp = yield self.http_client.post_json_get_json(url, signed_body)
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
limit = body.get("limit", 10)
|
||||
limit = min(limit, 50)
|
||||
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from synapse.config._base import ConfigError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DomainRuleChecker(object):
|
||||
"""
|
||||
A re-implementation of the SpamChecker that prevents users in one domain from
|
||||
inviting users in other domains to rooms, based on a configuration.
|
||||
|
||||
Takes a config in the format:
|
||||
|
||||
spam_checker:
|
||||
module: "rulecheck.DomainRuleChecker"
|
||||
config:
|
||||
domain_mapping:
|
||||
"inviter_domain": [ "invitee_domain_permitted", "other_domain_permitted" ]
|
||||
"other_inviter_domain": [ "invitee_domain_permitted" ]
|
||||
default: False
|
||||
}
|
||||
|
||||
Don't forget to consider if you can invite users from your own domain.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.domain_mapping = config["domain_mapping"] or {}
|
||||
self.default = config["default"]
|
||||
|
||||
def check_event_for_spam(self, event):
|
||||
"""Implements synapse.events.SpamChecker.check_event_for_spam
|
||||
"""
|
||||
return False
|
||||
|
||||
def user_may_invite(self, inviter_userid, invitee_userid, room_id):
|
||||
"""Implements synapse.events.SpamChecker.user_may_invite
|
||||
"""
|
||||
inviter_domain = self._get_domain_from_id(inviter_userid)
|
||||
invitee_domain = self._get_domain_from_id(invitee_userid)
|
||||
|
||||
if inviter_domain not in self.domain_mapping:
|
||||
return self.default
|
||||
|
||||
return invitee_domain in self.domain_mapping[inviter_domain]
|
||||
|
||||
def user_may_create_room(self, userid):
|
||||
"""Implements synapse.events.SpamChecker.user_may_create_room
|
||||
"""
|
||||
return True
|
||||
|
||||
def user_may_create_room_alias(self, userid, room_alias):
|
||||
"""Implements synapse.events.SpamChecker.user_may_create_room_alias
|
||||
"""
|
||||
return True
|
||||
|
||||
def user_may_publish_room(self, userid, room_id):
|
||||
"""Implements synapse.events.SpamChecker.user_may_publish_room
|
||||
"""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def parse_config(config):
|
||||
"""Implements synapse.events.SpamChecker.parse_config
|
||||
"""
|
||||
if "default" in config:
|
||||
return config
|
||||
else:
|
||||
raise ConfigError("No default set for spam_config DomainRuleChecker")
|
||||
|
||||
@staticmethod
|
||||
def _get_domain_from_id(mxid):
|
||||
"""Parses a string and returns the domain part of the mxid.
|
||||
|
||||
Args:
|
||||
mxid (str): a valid mxid
|
||||
|
||||
Returns:
|
||||
str: the domain part of the mxid
|
||||
|
||||
"""
|
||||
idx = mxid.find(":")
|
||||
if idx == -1:
|
||||
raise Exception("Invalid ID: %r" % (mxid,))
|
||||
return mxid[idx + 1:]
|
||||
@@ -56,7 +56,7 @@ from synapse.handlers.initial_sync import InitialSyncHandler
|
||||
from synapse.handlers.message import EventCreationHandler, MessageHandler
|
||||
from synapse.handlers.pagination import PaginationHandler
|
||||
from synapse.handlers.presence import PresenceHandler
|
||||
from synapse.handlers.profile import ProfileHandler
|
||||
from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler
|
||||
from synapse.handlers.read_marker import ReadMarkerHandler
|
||||
from synapse.handlers.receipts import ReceiptsHandler
|
||||
from synapse.handlers.room import RoomContextHandler, RoomCreationHandler
|
||||
@@ -308,7 +308,10 @@ class HomeServer(object):
|
||||
return InitialSyncHandler(self)
|
||||
|
||||
def build_profile_handler(self):
|
||||
return ProfileHandler(self)
|
||||
if self.config.worker_app:
|
||||
return BaseProfileHandler(self)
|
||||
else:
|
||||
return MasterProfileHandler(self)
|
||||
|
||||
def build_event_creation_handler(self):
|
||||
return EventCreationHandler(self)
|
||||
|
||||
204
synapse/server_notices/resource_limits_server_notices.py
Normal file
204
synapse/server_notices/resource_limits_server_notices.py
Normal file
@@ -0,0 +1,204 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from six import iteritems
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import (
|
||||
EventTypes,
|
||||
ServerNoticeLimitReached,
|
||||
ServerNoticeMsgType,
|
||||
)
|
||||
from synapse.api.errors import AuthError, ResourceLimitError, SynapseError
|
||||
from synapse.server_notices.server_notices_manager import SERVER_NOTICE_ROOM_TAG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResourceLimitsServerNotices(object):
|
||||
""" Keeps track of whether the server has reached it's resource limit and
|
||||
ensures that the client is kept up to date.
|
||||
"""
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer):
|
||||
"""
|
||||
self._server_notices_manager = hs.get_server_notices_manager()
|
||||
self._store = hs.get_datastore()
|
||||
self._auth = hs.get_auth()
|
||||
self._config = hs.config
|
||||
self._resouce_limited = False
|
||||
self._message_handler = hs.get_message_handler()
|
||||
self._state = hs.get_state_handler()
|
||||
|
||||
self._notifier = hs.get_notifier()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def maybe_send_server_notice_to_user(self, user_id):
|
||||
"""Check if we need to send a notice to this user, this will be true in
|
||||
two cases.
|
||||
1. The server has reached its limit does not reflect this
|
||||
2. The room state indicates that the server has reached its limit when
|
||||
actually the server is fine
|
||||
|
||||
Args:
|
||||
user_id (str): user to check
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
if self._config.hs_disabled is True:
|
||||
return
|
||||
|
||||
if self._config.limit_usage_by_mau is False:
|
||||
return
|
||||
|
||||
if not self._server_notices_manager.is_enabled():
|
||||
# Don't try and send server notices unles they've been enabled
|
||||
return
|
||||
|
||||
timestamp = yield self._store.user_last_seen_monthly_active(user_id)
|
||||
if timestamp is None:
|
||||
# This user will be blocked from receiving the notice anyway.
|
||||
# In practice, not sure we can ever get here
|
||||
return
|
||||
|
||||
# Determine current state of room
|
||||
|
||||
room_id = yield self._server_notices_manager.get_notice_room_for_user(user_id)
|
||||
|
||||
if not room_id:
|
||||
logger.warn("Failed to get server notices room")
|
||||
return
|
||||
|
||||
yield self._check_and_set_tags(user_id, room_id)
|
||||
currently_blocked, ref_events = yield self._is_room_currently_blocked(room_id)
|
||||
|
||||
try:
|
||||
# Normally should always pass in user_id if you have it, but in
|
||||
# this case are checking what would happen to other users if they
|
||||
# were to arrive.
|
||||
try:
|
||||
yield self._auth.check_auth_blocking()
|
||||
is_auth_blocking = False
|
||||
except ResourceLimitError as e:
|
||||
is_auth_blocking = True
|
||||
event_content = e.msg
|
||||
event_limit_type = e.limit_type
|
||||
|
||||
if currently_blocked and not is_auth_blocking:
|
||||
# Room is notifying of a block, when it ought not to be.
|
||||
# Remove block notification
|
||||
content = {
|
||||
"pinned": ref_events
|
||||
}
|
||||
yield self._server_notices_manager.send_notice(
|
||||
user_id, content, EventTypes.Pinned, '',
|
||||
)
|
||||
|
||||
elif not currently_blocked and is_auth_blocking:
|
||||
# Room is not notifying of a block, when it ought to be.
|
||||
# Add block notification
|
||||
content = {
|
||||
'body': event_content,
|
||||
'msgtype': ServerNoticeMsgType,
|
||||
'server_notice_type': ServerNoticeLimitReached,
|
||||
'admin_uri': self._config.admin_uri,
|
||||
'limit_type': event_limit_type
|
||||
}
|
||||
event = yield self._server_notices_manager.send_notice(
|
||||
user_id, content, EventTypes.Message,
|
||||
)
|
||||
|
||||
content = {
|
||||
"pinned": [
|
||||
event.event_id,
|
||||
]
|
||||
}
|
||||
yield self._server_notices_manager.send_notice(
|
||||
user_id, content, EventTypes.Pinned, '',
|
||||
)
|
||||
|
||||
except SynapseError as e:
|
||||
logger.error("Error sending resource limits server notice: %s", e)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_and_set_tags(self, user_id, room_id):
|
||||
"""
|
||||
Since server notices rooms were originally not with tags,
|
||||
important to check that tags have been set correctly
|
||||
Args:
|
||||
user_id(str): the user in question
|
||||
room_id(str): the server notices room for that user
|
||||
"""
|
||||
tags = yield self._store.get_tags_for_user(user_id)
|
||||
server_notices_tags = tags.get(room_id)
|
||||
need_to_set_tag = True
|
||||
if server_notices_tags:
|
||||
if server_notices_tags.get(SERVER_NOTICE_ROOM_TAG):
|
||||
# tag already present, nothing to do here
|
||||
need_to_set_tag = False
|
||||
if need_to_set_tag:
|
||||
max_id = yield self._store.add_tag_to_room(
|
||||
user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
|
||||
)
|
||||
self._notifier.on_new_event(
|
||||
"account_data_key", max_id, users=[user_id]
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _is_room_currently_blocked(self, room_id):
|
||||
"""
|
||||
Determines if the room is currently blocked
|
||||
|
||||
Args:
|
||||
room_id(str): The room id of the server notices room
|
||||
|
||||
Returns:
|
||||
|
||||
bool: Is the room currently blocked
|
||||
list: The list of pinned events that are unrelated to limit blocking
|
||||
This list can be used as a convenience in the case where the block
|
||||
is to be lifted and the remaining pinned event references need to be
|
||||
preserved
|
||||
"""
|
||||
currently_blocked = False
|
||||
pinned_state_event = None
|
||||
try:
|
||||
pinned_state_event = yield self._state.get_current_state(
|
||||
room_id, event_type=EventTypes.Pinned
|
||||
)
|
||||
except AuthError:
|
||||
# The user has yet to join the server notices room
|
||||
pass
|
||||
|
||||
referenced_events = []
|
||||
if pinned_state_event is not None:
|
||||
referenced_events = list(pinned_state_event.content.get('pinned', []))
|
||||
|
||||
events = yield self._store.get_events(referenced_events)
|
||||
for event_id, event in iteritems(events):
|
||||
if event.type != EventTypes.Message:
|
||||
continue
|
||||
if event.content.get("msgtype") == ServerNoticeMsgType:
|
||||
currently_blocked = True
|
||||
# remove event in case we need to disable blocking later on.
|
||||
if event_id in referenced_events:
|
||||
referenced_events.remove(event.event_id)
|
||||
|
||||
defer.returnValue((currently_blocked, referenced_events))
|
||||
@@ -22,6 +22,8 @@ from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SERVER_NOTICE_ROOM_TAG = "m.server_notice"
|
||||
|
||||
|
||||
class ServerNoticesManager(object):
|
||||
def __init__(self, hs):
|
||||
@@ -37,6 +39,8 @@ class ServerNoticesManager(object):
|
||||
self._event_creation_handler = hs.get_event_creation_handler()
|
||||
self._is_mine_id = hs.is_mine_id
|
||||
|
||||
self._notifier = hs.get_notifier()
|
||||
|
||||
def is_enabled(self):
|
||||
"""Checks if server notices are enabled on this server.
|
||||
|
||||
@@ -46,7 +50,10 @@ class ServerNoticesManager(object):
|
||||
return self._config.server_notices_mxid is not None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_notice(self, user_id, event_content):
|
||||
def send_notice(
|
||||
self, user_id, event_content,
|
||||
type=EventTypes.Message, state_key=None
|
||||
):
|
||||
"""Send a notice to the given user
|
||||
|
||||
Creates the server notices room, if none exists.
|
||||
@@ -54,9 +61,11 @@ class ServerNoticesManager(object):
|
||||
Args:
|
||||
user_id (str): mxid of user to send event to.
|
||||
event_content (dict): content of event to send
|
||||
type(EventTypes): type of event
|
||||
is_state_event(bool): Is the event a state event
|
||||
|
||||
Returns:
|
||||
Deferred[None]
|
||||
Deferred[FrozenEvent]
|
||||
"""
|
||||
room_id = yield self.get_notice_room_for_user(user_id)
|
||||
|
||||
@@ -65,15 +74,20 @@ class ServerNoticesManager(object):
|
||||
|
||||
logger.info("Sending server notice to %s", user_id)
|
||||
|
||||
yield self._event_creation_handler.create_and_send_nonmember_event(
|
||||
requester, {
|
||||
"type": EventTypes.Message,
|
||||
"room_id": room_id,
|
||||
"sender": system_mxid,
|
||||
"content": event_content,
|
||||
},
|
||||
ratelimit=False,
|
||||
event_dict = {
|
||||
"type": type,
|
||||
"room_id": room_id,
|
||||
"sender": system_mxid,
|
||||
"content": event_content,
|
||||
}
|
||||
|
||||
if state_key is not None:
|
||||
event_dict['state_key'] = state_key
|
||||
|
||||
res = yield self._event_creation_handler.create_and_send_nonmember_event(
|
||||
requester, event_dict, ratelimit=False,
|
||||
)
|
||||
defer.returnValue(res)
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def get_notice_room_for_user(self, user_id):
|
||||
@@ -142,5 +156,12 @@ class ServerNoticesManager(object):
|
||||
)
|
||||
room_id = info['room_id']
|
||||
|
||||
max_id = yield self._store.add_tag_to_room(
|
||||
user_id, room_id, SERVER_NOTICE_ROOM_TAG, {},
|
||||
)
|
||||
self._notifier.on_new_event(
|
||||
"account_data_key", max_id, users=[user_id]
|
||||
)
|
||||
|
||||
logger.info("Created server notices room %s for %s", room_id, user_id)
|
||||
defer.returnValue(room_id)
|
||||
|
||||
@@ -12,7 +12,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.server_notices.consent_server_notices import ConsentServerNotices
|
||||
from synapse.server_notices.resource_limits_server_notices import (
|
||||
ResourceLimitsServerNotices,
|
||||
)
|
||||
|
||||
|
||||
class ServerNoticesSender(object):
|
||||
@@ -25,34 +30,34 @@ class ServerNoticesSender(object):
|
||||
Args:
|
||||
hs (synapse.server.HomeServer):
|
||||
"""
|
||||
# todo: it would be nice to make this more dynamic
|
||||
self._consent_server_notices = ConsentServerNotices(hs)
|
||||
self._server_notices = (
|
||||
ConsentServerNotices(hs),
|
||||
ResourceLimitsServerNotices(hs)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_user_syncing(self, user_id):
|
||||
"""Called when the user performs a sync operation.
|
||||
|
||||
Args:
|
||||
user_id (str): mxid of user who synced
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
return self._consent_server_notices.maybe_send_server_notice_to_user(
|
||||
user_id,
|
||||
)
|
||||
for sn in self._server_notices:
|
||||
yield sn.maybe_send_server_notice_to_user(
|
||||
user_id,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_user_ip(self, user_id):
|
||||
"""Called on the master when a worker process saw a client request.
|
||||
|
||||
Args:
|
||||
user_id (str): mxid
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
# The synchrotrons use a stubbed version of ServerNoticesSender, so
|
||||
# we check for notices to send to the user in on_user_ip as well as
|
||||
# in on_user_syncing
|
||||
return self._consent_server_notices.maybe_send_server_notice_to_user(
|
||||
user_id,
|
||||
)
|
||||
for sn in self._server_notices:
|
||||
yield sn.maybe_send_server_notice_to_user(
|
||||
user_id,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -13,21 +14,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
from six import iteritems, iterkeys, itervalues
|
||||
from six import iteritems, itervalues
|
||||
|
||||
from frozendict import frozendict
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse import event_auth
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.api.constants import EventTypes, RoomVersions
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.state import v1
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches import get_cache_factor_for
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
@@ -264,6 +262,7 @@ class StateHandler(object):
|
||||
defer.returnValue(context)
|
||||
|
||||
logger.debug("calling resolve_state_groups from compute_event_context")
|
||||
|
||||
entry = yield self.resolve_state_groups_for_events(
|
||||
event.room_id, [e for e, _ in event.prev_events],
|
||||
)
|
||||
@@ -338,8 +337,11 @@ class StateHandler(object):
|
||||
event, resolves conflicts between them and returns them.
|
||||
|
||||
Args:
|
||||
room_id (str):
|
||||
event_ids (list[str]):
|
||||
room_id (str)
|
||||
event_ids (list[str])
|
||||
explicit_room_version (str|None): If set uses the the given room
|
||||
version to choose the resolution algorithm. If None, then
|
||||
checks the database for room version.
|
||||
|
||||
Returns:
|
||||
Deferred[_StateCacheEntry]: resolved state
|
||||
@@ -353,7 +355,12 @@ class StateHandler(object):
|
||||
room_id, event_ids
|
||||
)
|
||||
|
||||
if len(state_groups_ids) == 1:
|
||||
if len(state_groups_ids) == 0:
|
||||
defer.returnValue(_StateCacheEntry(
|
||||
state={},
|
||||
state_group=None,
|
||||
))
|
||||
elif len(state_groups_ids) == 1:
|
||||
name, state_list = list(state_groups_ids.items()).pop()
|
||||
|
||||
prev_group, delta_ids = yield self.store.get_state_group_delta(name)
|
||||
@@ -365,8 +372,11 @@ class StateHandler(object):
|
||||
delta_ids=delta_ids,
|
||||
))
|
||||
|
||||
room_version = yield self.store.get_room_version(room_id)
|
||||
|
||||
result = yield self._state_resolution_handler.resolve_state_groups(
|
||||
room_id, state_groups_ids, None, self._state_map_factory,
|
||||
room_id, room_version, state_groups_ids, None,
|
||||
self._state_map_factory,
|
||||
)
|
||||
defer.returnValue(result)
|
||||
|
||||
@@ -375,7 +385,7 @@ class StateHandler(object):
|
||||
ev_ids, get_prev_content=False, check_redacted=False,
|
||||
)
|
||||
|
||||
def resolve_events(self, state_sets, event):
|
||||
def resolve_events(self, room_version, state_sets, event):
|
||||
logger.info(
|
||||
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
|
||||
)
|
||||
@@ -391,7 +401,9 @@ class StateHandler(object):
|
||||
}
|
||||
|
||||
with Measure(self.clock, "state._resolve_events"):
|
||||
new_state = resolve_events_with_state_map(state_set_ids, state_map)
|
||||
new_state = resolve_events_with_state_map(
|
||||
room_version, state_set_ids, state_map,
|
||||
)
|
||||
|
||||
new_state = {
|
||||
key: state_map[ev_id] for key, ev_id in iteritems(new_state)
|
||||
@@ -430,7 +442,7 @@ class StateResolutionHandler(object):
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def resolve_state_groups(
|
||||
self, room_id, state_groups_ids, event_map, state_map_factory,
|
||||
self, room_id, room_version, state_groups_ids, event_map, state_map_factory,
|
||||
):
|
||||
"""Resolves conflicts between a set of state groups
|
||||
|
||||
@@ -439,6 +451,7 @@ class StateResolutionHandler(object):
|
||||
|
||||
Args:
|
||||
room_id (str): room we are resolving for (used for logging)
|
||||
room_version (str): version of the room
|
||||
state_groups_ids (dict[int, dict[(str, str), str]]):
|
||||
map from state group id to the state in that state group
|
||||
(where 'state' is a map from state key to event id)
|
||||
@@ -492,6 +505,7 @@ class StateResolutionHandler(object):
|
||||
logger.info("Resolving conflicted state for %r", room_id)
|
||||
with Measure(self.clock, "state._resolve_events"):
|
||||
new_state = yield resolve_events_with_factory(
|
||||
room_version,
|
||||
list(itervalues(state_groups_ids)),
|
||||
event_map=event_map,
|
||||
state_map_factory=state_map_factory,
|
||||
@@ -575,16 +589,10 @@ def _make_state_cache_entry(
|
||||
)
|
||||
|
||||
|
||||
def _ordered_events(events):
|
||||
def key_func(e):
|
||||
return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest()
|
||||
|
||||
return sorted(events, key=key_func)
|
||||
|
||||
|
||||
def resolve_events_with_state_map(state_sets, state_map):
|
||||
def resolve_events_with_state_map(room_version, state_sets, state_map):
|
||||
"""
|
||||
Args:
|
||||
room_version(str): Version of the room
|
||||
state_sets(list): List of dicts of (type, state_key) -> event_id,
|
||||
which are the different state groups to resolve.
|
||||
state_map(dict): a dict from event_id to event, for all events in
|
||||
@@ -594,75 +602,23 @@ def resolve_events_with_state_map(state_sets, state_map):
|
||||
dict[(str, str), str]:
|
||||
a map from (type, state_key) to event_id.
|
||||
"""
|
||||
if len(state_sets) == 1:
|
||||
return state_sets[0]
|
||||
|
||||
unconflicted_state, conflicted_state = _seperate(
|
||||
state_sets,
|
||||
)
|
||||
|
||||
auth_events = _create_auth_events_from_maps(
|
||||
unconflicted_state, conflicted_state, state_map
|
||||
)
|
||||
|
||||
return _resolve_with_state(
|
||||
unconflicted_state, conflicted_state, auth_events, state_map
|
||||
)
|
||||
if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,):
|
||||
return v1.resolve_events_with_state_map(
|
||||
state_sets, state_map,
|
||||
)
|
||||
else:
|
||||
# This should only happen if we added a version but forgot to add it to
|
||||
# the list above.
|
||||
raise Exception(
|
||||
"No state resolution algorithm defined for version %r" % (room_version,)
|
||||
)
|
||||
|
||||
|
||||
def _seperate(state_sets):
|
||||
"""Takes the state_sets and figures out which keys are conflicted and
|
||||
which aren't. i.e., which have multiple different event_ids associated
|
||||
with them in different state sets.
|
||||
|
||||
Args:
|
||||
state_sets(iterable[dict[(str, str), str]]):
|
||||
List of dicts of (type, state_key) -> event_id, which are the
|
||||
different state groups to resolve.
|
||||
|
||||
Returns:
|
||||
(dict[(str, str), str], dict[(str, str), set[str]]):
|
||||
A tuple of (unconflicted_state, conflicted_state), where:
|
||||
|
||||
unconflicted_state is a dict mapping (type, state_key)->event_id
|
||||
for unconflicted state keys.
|
||||
|
||||
conflicted_state is a dict mapping (type, state_key) to a set of
|
||||
event ids for conflicted state keys.
|
||||
"""
|
||||
state_set_iterator = iter(state_sets)
|
||||
unconflicted_state = dict(next(state_set_iterator))
|
||||
conflicted_state = {}
|
||||
|
||||
for state_set in state_set_iterator:
|
||||
for key, value in iteritems(state_set):
|
||||
# Check if there is an unconflicted entry for the state key.
|
||||
unconflicted_value = unconflicted_state.get(key)
|
||||
if unconflicted_value is None:
|
||||
# There isn't an unconflicted entry so check if there is a
|
||||
# conflicted entry.
|
||||
ls = conflicted_state.get(key)
|
||||
if ls is None:
|
||||
# There wasn't a conflicted entry so haven't seen this key before.
|
||||
# Therefore it isn't conflicted yet.
|
||||
unconflicted_state[key] = value
|
||||
else:
|
||||
# This key is already conflicted, add our value to the conflict set.
|
||||
ls.add(value)
|
||||
elif unconflicted_value != value:
|
||||
# If the unconflicted value is not the same as our value then we
|
||||
# have a new conflict. So move the key from the unconflicted_state
|
||||
# to the conflicted state.
|
||||
conflicted_state[key] = {value, unconflicted_value}
|
||||
unconflicted_state.pop(key, None)
|
||||
|
||||
return unconflicted_state, conflicted_state
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def resolve_events_with_factory(state_sets, event_map, state_map_factory):
|
||||
def resolve_events_with_factory(room_version, state_sets, event_map, state_map_factory):
|
||||
"""
|
||||
Args:
|
||||
room_version(str): Version of the room
|
||||
|
||||
state_sets(list): List of dicts of (type, state_key) -> event_id,
|
||||
which are the different state groups to resolve.
|
||||
|
||||
@@ -682,185 +638,13 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
|
||||
Deferred[dict[(str, str), str]]:
|
||||
a map from (type, state_key) to event_id.
|
||||
"""
|
||||
if len(state_sets) == 1:
|
||||
defer.returnValue(state_sets[0])
|
||||
|
||||
unconflicted_state, conflicted_state = _seperate(
|
||||
state_sets,
|
||||
)
|
||||
|
||||
needed_events = set(
|
||||
event_id
|
||||
for event_ids in itervalues(conflicted_state)
|
||||
for event_id in event_ids
|
||||
)
|
||||
if event_map is not None:
|
||||
needed_events -= set(iterkeys(event_map))
|
||||
|
||||
logger.info("Asking for %d conflicted events", len(needed_events))
|
||||
|
||||
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
|
||||
# the state events which are in conflict (and those in event_map)
|
||||
state_map = yield state_map_factory(needed_events)
|
||||
if event_map is not None:
|
||||
state_map.update(event_map)
|
||||
|
||||
# get the ids of the auth events which allow us to authenticate the
|
||||
# conflicted state, picking only from the unconflicting state.
|
||||
#
|
||||
# dict[(str, str), str]: a map from state key to event id
|
||||
auth_events = _create_auth_events_from_maps(
|
||||
unconflicted_state, conflicted_state, state_map
|
||||
)
|
||||
|
||||
new_needed_events = set(itervalues(auth_events))
|
||||
new_needed_events -= needed_events
|
||||
if event_map is not None:
|
||||
new_needed_events -= set(iterkeys(event_map))
|
||||
|
||||
logger.info("Asking for %d auth events", len(new_needed_events))
|
||||
|
||||
state_map_new = yield state_map_factory(new_needed_events)
|
||||
state_map.update(state_map_new)
|
||||
|
||||
defer.returnValue(_resolve_with_state(
|
||||
unconflicted_state, conflicted_state, auth_events, state_map
|
||||
))
|
||||
|
||||
|
||||
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
|
||||
auth_events = {}
|
||||
for event_ids in itervalues(conflicted_state):
|
||||
for event_id in event_ids:
|
||||
if event_id in state_map:
|
||||
keys = event_auth.auth_types_for_event(state_map[event_id])
|
||||
for key in keys:
|
||||
if key not in auth_events:
|
||||
event_id = unconflicted_state.get(key, None)
|
||||
if event_id:
|
||||
auth_events[key] = event_id
|
||||
return auth_events
|
||||
|
||||
|
||||
def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event_ids,
|
||||
state_map):
|
||||
conflicted_state = {}
|
||||
for key, event_ids in iteritems(conflicted_state_ids):
|
||||
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
|
||||
if len(events) > 1:
|
||||
conflicted_state[key] = events
|
||||
elif len(events) == 1:
|
||||
unconflicted_state_ids[key] = events[0].event_id
|
||||
|
||||
auth_events = {
|
||||
key: state_map[ev_id]
|
||||
for key, ev_id in iteritems(auth_event_ids)
|
||||
if ev_id in state_map
|
||||
}
|
||||
|
||||
try:
|
||||
resolved_state = _resolve_state_events(
|
||||
conflicted_state, auth_events
|
||||
if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,):
|
||||
return v1.resolve_events_with_factory(
|
||||
state_sets, event_map, state_map_factory,
|
||||
)
|
||||
else:
|
||||
# This should only happen if we added a version but forgot to add it to
|
||||
# the list above.
|
||||
raise Exception(
|
||||
"No state resolution algorithm defined for version %r" % (room_version,)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to resolve state")
|
||||
raise
|
||||
|
||||
new_state = unconflicted_state_ids
|
||||
for key, event in iteritems(resolved_state):
|
||||
new_state[key] = event.event_id
|
||||
|
||||
return new_state
|
||||
|
||||
|
||||
def _resolve_state_events(conflicted_state, auth_events):
|
||||
""" This is where we actually decide which of the conflicted state to
|
||||
use.
|
||||
|
||||
We resolve conflicts in the following order:
|
||||
1. power levels
|
||||
2. join rules
|
||||
3. memberships
|
||||
4. other events.
|
||||
"""
|
||||
resolved_state = {}
|
||||
if POWER_KEY in conflicted_state:
|
||||
events = conflicted_state[POWER_KEY]
|
||||
logger.debug("Resolving conflicted power levels %r", events)
|
||||
resolved_state[POWER_KEY] = _resolve_auth_events(
|
||||
events, auth_events)
|
||||
|
||||
auth_events.update(resolved_state)
|
||||
|
||||
for key, events in iteritems(conflicted_state):
|
||||
if key[0] == EventTypes.JoinRules:
|
||||
logger.debug("Resolving conflicted join rules %r", events)
|
||||
resolved_state[key] = _resolve_auth_events(
|
||||
events,
|
||||
auth_events
|
||||
)
|
||||
|
||||
auth_events.update(resolved_state)
|
||||
|
||||
for key, events in iteritems(conflicted_state):
|
||||
if key[0] == EventTypes.Member:
|
||||
logger.debug("Resolving conflicted member lists %r", events)
|
||||
resolved_state[key] = _resolve_auth_events(
|
||||
events,
|
||||
auth_events
|
||||
)
|
||||
|
||||
auth_events.update(resolved_state)
|
||||
|
||||
for key, events in iteritems(conflicted_state):
|
||||
if key not in resolved_state:
|
||||
logger.debug("Resolving conflicted state %r:%r", key, events)
|
||||
resolved_state[key] = _resolve_normal_events(
|
||||
events, auth_events
|
||||
)
|
||||
|
||||
return resolved_state
|
||||
|
||||
|
||||
def _resolve_auth_events(events, auth_events):
|
||||
reverse = [i for i in reversed(_ordered_events(events))]
|
||||
|
||||
auth_keys = set(
|
||||
key
|
||||
for event in events
|
||||
for key in event_auth.auth_types_for_event(event)
|
||||
)
|
||||
|
||||
new_auth_events = {}
|
||||
for key in auth_keys:
|
||||
auth_event = auth_events.get(key, None)
|
||||
if auth_event:
|
||||
new_auth_events[key] = auth_event
|
||||
|
||||
auth_events = new_auth_events
|
||||
|
||||
prev_event = reverse[0]
|
||||
for event in reverse[1:]:
|
||||
auth_events[(prev_event.type, prev_event.state_key)] = prev_event
|
||||
try:
|
||||
# The signatures have already been checked at this point
|
||||
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
|
||||
prev_event = event
|
||||
except AuthError:
|
||||
return prev_event
|
||||
|
||||
return event
|
||||
|
||||
|
||||
def _resolve_normal_events(events, auth_events):
|
||||
for event in _ordered_events(events):
|
||||
try:
|
||||
# The signatures have already been checked at this point
|
||||
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
|
||||
return event
|
||||
except AuthError:
|
||||
pass
|
||||
|
||||
# Use the last event (the one with the least depth) if they all fail
|
||||
# the auth check.
|
||||
return event
|
||||
321
synapse/state/v1.py
Normal file
321
synapse/state/v1.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from six import iteritems, iterkeys, itervalues
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse import event_auth
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import AuthError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
POWER_KEY = (EventTypes.PowerLevels, "")
|
||||
|
||||
|
||||
def resolve_events_with_state_map(state_sets, state_map):
|
||||
"""
|
||||
Args:
|
||||
state_sets(list): List of dicts of (type, state_key) -> event_id,
|
||||
which are the different state groups to resolve.
|
||||
state_map(dict): a dict from event_id to event, for all events in
|
||||
state_sets.
|
||||
|
||||
Returns
|
||||
dict[(str, str), str]:
|
||||
a map from (type, state_key) to event_id.
|
||||
"""
|
||||
if len(state_sets) == 1:
|
||||
return state_sets[0]
|
||||
|
||||
unconflicted_state, conflicted_state = _seperate(
|
||||
state_sets,
|
||||
)
|
||||
|
||||
auth_events = _create_auth_events_from_maps(
|
||||
unconflicted_state, conflicted_state, state_map
|
||||
)
|
||||
|
||||
return _resolve_with_state(
|
||||
unconflicted_state, conflicted_state, auth_events, state_map
|
||||
)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def resolve_events_with_factory(state_sets, event_map, state_map_factory):
|
||||
"""
|
||||
Args:
|
||||
state_sets(list): List of dicts of (type, state_key) -> event_id,
|
||||
which are the different state groups to resolve.
|
||||
|
||||
event_map(dict[str,FrozenEvent]|None):
|
||||
a dict from event_id to event, for any events that we happen to
|
||||
have in flight (eg, those currently being persisted). This will be
|
||||
used as a starting point fof finding the state we need; any missing
|
||||
events will be requested via state_map_factory.
|
||||
|
||||
If None, all events will be fetched via state_map_factory.
|
||||
|
||||
state_map_factory(func): will be called
|
||||
with a list of event_ids that are needed, and should return with
|
||||
a Deferred of dict of event_id to event.
|
||||
|
||||
Returns
|
||||
Deferred[dict[(str, str), str]]:
|
||||
a map from (type, state_key) to event_id.
|
||||
"""
|
||||
if len(state_sets) == 1:
|
||||
defer.returnValue(state_sets[0])
|
||||
|
||||
unconflicted_state, conflicted_state = _seperate(
|
||||
state_sets,
|
||||
)
|
||||
|
||||
needed_events = set(
|
||||
event_id
|
||||
for event_ids in itervalues(conflicted_state)
|
||||
for event_id in event_ids
|
||||
)
|
||||
if event_map is not None:
|
||||
needed_events -= set(iterkeys(event_map))
|
||||
|
||||
logger.info("Asking for %d conflicted events", len(needed_events))
|
||||
|
||||
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
|
||||
# the state events which are in conflict (and those in event_map)
|
||||
state_map = yield state_map_factory(needed_events)
|
||||
if event_map is not None:
|
||||
state_map.update(event_map)
|
||||
|
||||
# get the ids of the auth events which allow us to authenticate the
|
||||
# conflicted state, picking only from the unconflicting state.
|
||||
#
|
||||
# dict[(str, str), str]: a map from state key to event id
|
||||
auth_events = _create_auth_events_from_maps(
|
||||
unconflicted_state, conflicted_state, state_map
|
||||
)
|
||||
|
||||
new_needed_events = set(itervalues(auth_events))
|
||||
new_needed_events -= needed_events
|
||||
if event_map is not None:
|
||||
new_needed_events -= set(iterkeys(event_map))
|
||||
|
||||
logger.info("Asking for %d auth events", len(new_needed_events))
|
||||
|
||||
state_map_new = yield state_map_factory(new_needed_events)
|
||||
state_map.update(state_map_new)
|
||||
|
||||
defer.returnValue(_resolve_with_state(
|
||||
unconflicted_state, conflicted_state, auth_events, state_map
|
||||
))
|
||||
|
||||
|
||||
def _seperate(state_sets):
|
||||
"""Takes the state_sets and figures out which keys are conflicted and
|
||||
which aren't. i.e., which have multiple different event_ids associated
|
||||
with them in different state sets.
|
||||
|
||||
Args:
|
||||
state_sets(iterable[dict[(str, str), str]]):
|
||||
List of dicts of (type, state_key) -> event_id, which are the
|
||||
different state groups to resolve.
|
||||
|
||||
Returns:
|
||||
(dict[(str, str), str], dict[(str, str), set[str]]):
|
||||
A tuple of (unconflicted_state, conflicted_state), where:
|
||||
|
||||
unconflicted_state is a dict mapping (type, state_key)->event_id
|
||||
for unconflicted state keys.
|
||||
|
||||
conflicted_state is a dict mapping (type, state_key) to a set of
|
||||
event ids for conflicted state keys.
|
||||
"""
|
||||
state_set_iterator = iter(state_sets)
|
||||
unconflicted_state = dict(next(state_set_iterator))
|
||||
conflicted_state = {}
|
||||
|
||||
for state_set in state_set_iterator:
|
||||
for key, value in iteritems(state_set):
|
||||
# Check if there is an unconflicted entry for the state key.
|
||||
unconflicted_value = unconflicted_state.get(key)
|
||||
if unconflicted_value is None:
|
||||
# There isn't an unconflicted entry so check if there is a
|
||||
# conflicted entry.
|
||||
ls = conflicted_state.get(key)
|
||||
if ls is None:
|
||||
# There wasn't a conflicted entry so haven't seen this key before.
|
||||
# Therefore it isn't conflicted yet.
|
||||
unconflicted_state[key] = value
|
||||
else:
|
||||
# This key is already conflicted, add our value to the conflict set.
|
||||
ls.add(value)
|
||||
elif unconflicted_value != value:
|
||||
# If the unconflicted value is not the same as our value then we
|
||||
# have a new conflict. So move the key from the unconflicted_state
|
||||
# to the conflicted state.
|
||||
conflicted_state[key] = {value, unconflicted_value}
|
||||
unconflicted_state.pop(key, None)
|
||||
|
||||
return unconflicted_state, conflicted_state
|
||||
|
||||
|
||||
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
|
||||
auth_events = {}
|
||||
for event_ids in itervalues(conflicted_state):
|
||||
for event_id in event_ids:
|
||||
if event_id in state_map:
|
||||
keys = event_auth.auth_types_for_event(state_map[event_id])
|
||||
for key in keys:
|
||||
if key not in auth_events:
|
||||
event_id = unconflicted_state.get(key, None)
|
||||
if event_id:
|
||||
auth_events[key] = event_id
|
||||
return auth_events
|
||||
|
||||
|
||||
def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event_ids,
|
||||
state_map):
|
||||
conflicted_state = {}
|
||||
for key, event_ids in iteritems(conflicted_state_ids):
|
||||
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
|
||||
if len(events) > 1:
|
||||
conflicted_state[key] = events
|
||||
elif len(events) == 1:
|
||||
unconflicted_state_ids[key] = events[0].event_id
|
||||
|
||||
auth_events = {
|
||||
key: state_map[ev_id]
|
||||
for key, ev_id in iteritems(auth_event_ids)
|
||||
if ev_id in state_map
|
||||
}
|
||||
|
||||
try:
|
||||
resolved_state = _resolve_state_events(
|
||||
conflicted_state, auth_events
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to resolve state")
|
||||
raise
|
||||
|
||||
new_state = unconflicted_state_ids
|
||||
for key, event in iteritems(resolved_state):
|
||||
new_state[key] = event.event_id
|
||||
|
||||
return new_state
|
||||
|
||||
|
||||
def _resolve_state_events(conflicted_state, auth_events):
|
||||
""" This is where we actually decide which of the conflicted state to
|
||||
use.
|
||||
|
||||
We resolve conflicts in the following order:
|
||||
1. power levels
|
||||
2. join rules
|
||||
3. memberships
|
||||
4. other events.
|
||||
"""
|
||||
resolved_state = {}
|
||||
if POWER_KEY in conflicted_state:
|
||||
events = conflicted_state[POWER_KEY]
|
||||
logger.debug("Resolving conflicted power levels %r", events)
|
||||
resolved_state[POWER_KEY] = _resolve_auth_events(
|
||||
events, auth_events)
|
||||
|
||||
auth_events.update(resolved_state)
|
||||
|
||||
for key, events in iteritems(conflicted_state):
|
||||
if key[0] == EventTypes.JoinRules:
|
||||
logger.debug("Resolving conflicted join rules %r", events)
|
||||
resolved_state[key] = _resolve_auth_events(
|
||||
events,
|
||||
auth_events
|
||||
)
|
||||
|
||||
auth_events.update(resolved_state)
|
||||
|
||||
for key, events in iteritems(conflicted_state):
|
||||
if key[0] == EventTypes.Member:
|
||||
logger.debug("Resolving conflicted member lists %r", events)
|
||||
resolved_state[key] = _resolve_auth_events(
|
||||
events,
|
||||
auth_events
|
||||
)
|
||||
|
||||
auth_events.update(resolved_state)
|
||||
|
||||
for key, events in iteritems(conflicted_state):
|
||||
if key not in resolved_state:
|
||||
logger.debug("Resolving conflicted state %r:%r", key, events)
|
||||
resolved_state[key] = _resolve_normal_events(
|
||||
events, auth_events
|
||||
)
|
||||
|
||||
return resolved_state
|
||||
|
||||
|
||||
def _resolve_auth_events(events, auth_events):
|
||||
reverse = [i for i in reversed(_ordered_events(events))]
|
||||
|
||||
auth_keys = set(
|
||||
key
|
||||
for event in events
|
||||
for key in event_auth.auth_types_for_event(event)
|
||||
)
|
||||
|
||||
new_auth_events = {}
|
||||
for key in auth_keys:
|
||||
auth_event = auth_events.get(key, None)
|
||||
if auth_event:
|
||||
new_auth_events[key] = auth_event
|
||||
|
||||
auth_events = new_auth_events
|
||||
|
||||
prev_event = reverse[0]
|
||||
for event in reverse[1:]:
|
||||
auth_events[(prev_event.type, prev_event.state_key)] = prev_event
|
||||
try:
|
||||
# The signatures have already been checked at this point
|
||||
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
|
||||
prev_event = event
|
||||
except AuthError:
|
||||
return prev_event
|
||||
|
||||
return event
|
||||
|
||||
|
||||
def _resolve_normal_events(events, auth_events):
|
||||
for event in _ordered_events(events):
|
||||
try:
|
||||
# The signatures have already been checked at this point
|
||||
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
|
||||
return event
|
||||
except AuthError:
|
||||
pass
|
||||
|
||||
# Use the last event (the one with the least depth) if they all fail
|
||||
# the auth check.
|
||||
return event
|
||||
|
||||
|
||||
def _ordered_events(events):
|
||||
def key_func(e):
|
||||
return -int(e.depth), hashlib.sha1(e.event_id.encode('ascii')).hexdigest()
|
||||
|
||||
return sorted(events, key=key_func)
|
||||
@@ -705,9 +705,11 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
|
||||
}
|
||||
|
||||
events_map = {ev.event_id: ev for ev, _ in events_context}
|
||||
room_version = yield self.get_room_version(room_id)
|
||||
|
||||
logger.debug("calling resolve_state_groups from preserve_events")
|
||||
res = yield self._state_resolution_handler.resolve_state_groups(
|
||||
room_id, state_groups, events_map, get_events
|
||||
room_id, room_version, state_groups, events_map, get_events
|
||||
)
|
||||
|
||||
defer.returnValue((res.state, None))
|
||||
|
||||
@@ -147,6 +147,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
|
||||
return count
|
||||
return self.runInteraction("count_users", _count_users)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def upsert_monthly_active_user(self, user_id):
|
||||
"""
|
||||
Updates or inserts monthly active user member
|
||||
@@ -155,7 +156,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
|
||||
Deferred[bool]: True if a new entry was created, False if an
|
||||
existing one was updated.
|
||||
"""
|
||||
is_insert = self._simple_upsert(
|
||||
is_insert = yield self._simple_upsert(
|
||||
desc="upsert_monthly_active_user",
|
||||
table="monthly_active_users",
|
||||
keyvalues={
|
||||
@@ -200,6 +201,11 @@ class MonthlyActiveUsersStore(SQLBaseStore):
|
||||
user_id(str): the user_id to query
|
||||
"""
|
||||
if self.hs.config.limit_usage_by_mau:
|
||||
is_trial = yield self.is_trial_user(user_id)
|
||||
if is_trial:
|
||||
# we don't track trial users in the MAU table.
|
||||
return
|
||||
|
||||
last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id)
|
||||
now = self.hs.get_clock().time_msec()
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -21,8 +20,6 @@ from synapse.storage.roommember import ProfileInfo
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
BATCH_SIZE = 100
|
||||
|
||||
|
||||
class ProfileWorkerStore(SQLBaseStore):
|
||||
@defer.inlineCallbacks
|
||||
@@ -65,55 +62,6 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||
desc="get_profile_avatar_url",
|
||||
)
|
||||
|
||||
def get_latest_profile_replication_batch_number(self):
|
||||
def f(txn):
|
||||
txn.execute("SELECT MAX(batch) as maxbatch FROM profiles")
|
||||
rows = self.cursor_to_dict(txn)
|
||||
return rows[0]['maxbatch']
|
||||
return self.runInteraction(
|
||||
"get_latest_profile_replication_batch_number", f,
|
||||
)
|
||||
|
||||
def get_profile_batch(self, batchnum):
|
||||
return self._simple_select_list(
|
||||
table="profiles",
|
||||
keyvalues={
|
||||
"batch": batchnum,
|
||||
},
|
||||
retcols=("user_id", "displayname", "avatar_url", "active"),
|
||||
desc="get_profile_batch",
|
||||
)
|
||||
|
||||
def assign_profile_batch(self):
|
||||
def f(txn):
|
||||
sql = (
|
||||
"UPDATE profiles SET batch = "
|
||||
"(SELECT COALESCE(MAX(batch), -1) + 1 FROM profiles) "
|
||||
"WHERE user_id in ("
|
||||
" SELECT user_id FROM profiles WHERE batch is NULL limit ?"
|
||||
")"
|
||||
)
|
||||
txn.execute(sql, (BATCH_SIZE,))
|
||||
return txn.rowcount
|
||||
return self.runInteraction("assign_profile_batch", f)
|
||||
|
||||
def get_replication_hosts(self):
|
||||
def f(txn):
|
||||
txn.execute("SELECT host, last_synced_batch FROM profile_replication_status")
|
||||
rows = self.cursor_to_dict(txn)
|
||||
return {r['host']: r['last_synced_batch'] for r in rows}
|
||||
return self.runInteraction("get_replication_hosts", f)
|
||||
|
||||
def update_replication_batch_for_host(self, host, last_synced_batch):
|
||||
return self._simple_upsert(
|
||||
table="profile_replication_status",
|
||||
keyvalues={"host": host},
|
||||
values={
|
||||
"last_synced_batch": last_synced_batch,
|
||||
},
|
||||
desc="update_replication_batch_for_host",
|
||||
)
|
||||
|
||||
def get_from_remote_profile_cache(self, user_id):
|
||||
return self._simple_select_one(
|
||||
table="remote_profile_cache",
|
||||
@@ -123,48 +71,31 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||
desc="get_from_remote_profile_cache",
|
||||
)
|
||||
|
||||
def create_profile(self, user_localpart):
|
||||
return self._simple_insert(
|
||||
table="profiles",
|
||||
values={"user_id": user_localpart},
|
||||
desc="create_profile",
|
||||
)
|
||||
|
||||
def set_profile_displayname(self, user_localpart, new_displayname):
|
||||
return self._simple_update_one(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
updatevalues={"displayname": new_displayname},
|
||||
desc="set_profile_displayname",
|
||||
)
|
||||
|
||||
def set_profile_avatar_url(self, user_localpart, new_avatar_url):
|
||||
return self._simple_update_one(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
updatevalues={"avatar_url": new_avatar_url},
|
||||
desc="set_profile_avatar_url",
|
||||
)
|
||||
|
||||
|
||||
class ProfileStore(ProfileWorkerStore):
|
||||
def set_profile_displayname(self, user_localpart, new_displayname, batchnum):
|
||||
return self._simple_upsert(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
values={
|
||||
"displayname": new_displayname,
|
||||
"batch": batchnum,
|
||||
},
|
||||
desc="set_profile_displayname",
|
||||
lock=False # we can do this because user_id has a unique index
|
||||
)
|
||||
|
||||
def set_profile_avatar_url(self, user_localpart, new_avatar_url, batchnum):
|
||||
return self._simple_upsert(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
values={
|
||||
"avatar_url": new_avatar_url,
|
||||
"batch": batchnum,
|
||||
},
|
||||
desc="set_profile_avatar_url",
|
||||
lock=False # we can do this because user_id has a unique index
|
||||
)
|
||||
|
||||
def set_profile_active(self, user_localpart, active, batchnum):
|
||||
values = {
|
||||
"active": int(active),
|
||||
"batch": batchnum,
|
||||
}
|
||||
if not active:
|
||||
values["avatar_url"] = None
|
||||
values["displayname"] = None
|
||||
return self._simple_upsert(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
values=values,
|
||||
desc="set_profile_active",
|
||||
lock=False # we can do this because user_id has a unique index
|
||||
)
|
||||
|
||||
def add_remote_profile_cache(self, user_id, displayname, avatar_url):
|
||||
"""Ensure we are caching the remote user's profiles.
|
||||
|
||||
|
||||
@@ -26,6 +26,11 @@ from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||
|
||||
|
||||
class RegistrationWorkerStore(SQLBaseStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(RegistrationWorkerStore, self).__init__(db_conn, hs)
|
||||
|
||||
self.config = hs.config
|
||||
|
||||
@cached()
|
||||
def get_user_by_id(self, user_id):
|
||||
return self._simple_select_one(
|
||||
@@ -36,12 +41,33 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||
retcols=[
|
||||
"name", "password_hash", "is_guest",
|
||||
"consent_version", "consent_server_notice_sent",
|
||||
"appservice_id",
|
||||
"appservice_id", "creation_ts",
|
||||
],
|
||||
allow_none=True,
|
||||
desc="get_user_by_id",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def is_trial_user(self, user_id):
|
||||
"""Checks if user is in the "trial" period, i.e. within the first
|
||||
N days of registration defined by `mau_trial_days` config
|
||||
|
||||
Args:
|
||||
user_id (str)
|
||||
|
||||
Returns:
|
||||
Deferred[bool]
|
||||
"""
|
||||
|
||||
info = yield self.get_user_by_id(user_id)
|
||||
if not info:
|
||||
defer.returnValue(False)
|
||||
|
||||
now = self.clock.time_msec()
|
||||
trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
|
||||
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
|
||||
defer.returnValue(is_trial)
|
||||
|
||||
@cached()
|
||||
def get_user_by_access_token(self, token):
|
||||
"""Get a user from the given access token.
|
||||
@@ -141,7 +167,7 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||
|
||||
def register(self, user_id, token=None, password_hash=None,
|
||||
was_guest=False, make_guest=False, appservice_id=None,
|
||||
admin=False):
|
||||
create_profile_with_localpart=None, admin=False):
|
||||
"""Attempts to register an account.
|
||||
|
||||
Args:
|
||||
@@ -155,6 +181,8 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||
make_guest (boolean): True if the the new user should be guest,
|
||||
false to add a regular user account.
|
||||
appservice_id (str): The ID of the appservice registering the user.
|
||||
create_profile_with_localpart (str): Optionally create a profile for
|
||||
the given localpart.
|
||||
Raises:
|
||||
StoreError if the user_id could not be registered.
|
||||
"""
|
||||
@@ -167,6 +195,7 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||
was_guest,
|
||||
make_guest,
|
||||
appservice_id,
|
||||
create_profile_with_localpart,
|
||||
admin
|
||||
)
|
||||
|
||||
@@ -179,6 +208,7 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||
was_guest,
|
||||
make_guest,
|
||||
appservice_id,
|
||||
create_profile_with_localpart,
|
||||
admin,
|
||||
):
|
||||
now = int(self.clock.time())
|
||||
@@ -243,6 +273,14 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||
(next_id, user_id, token,)
|
||||
)
|
||||
|
||||
if create_profile_with_localpart:
|
||||
# set a default displayname serverside to avoid ugly race
|
||||
# between auto-joins and clients trying to set displaynames
|
||||
txn.execute(
|
||||
"INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
|
||||
(create_profile_with_localpart, create_profile_with_localpart)
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_user_by_id, (user_id,)
|
||||
)
|
||||
|
||||
@@ -186,6 +186,35 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
desc="is_room_blocked",
|
||||
)
|
||||
|
||||
@cachedInlineCallbacks(max_entries=10000)
|
||||
def get_ratelimit_for_user(self, user_id):
|
||||
"""Check if there are any overrides for ratelimiting for the given
|
||||
user
|
||||
|
||||
Args:
|
||||
user_id (str)
|
||||
|
||||
Returns:
|
||||
RatelimitOverride if there is an override, else None. If the contents
|
||||
of RatelimitOverride are None or 0 then ratelimitng has been
|
||||
disabled for that user entirely.
|
||||
"""
|
||||
row = yield self._simple_select_one(
|
||||
table="ratelimit_override",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("messages_per_second", "burst_count"),
|
||||
allow_none=True,
|
||||
desc="get_ratelimit_for_user",
|
||||
)
|
||||
|
||||
if row:
|
||||
defer.returnValue(RatelimitOverride(
|
||||
messages_per_second=row["messages_per_second"],
|
||||
burst_count=row["burst_count"],
|
||||
))
|
||||
else:
|
||||
defer.returnValue(None)
|
||||
|
||||
|
||||
class RoomStore(RoomWorkerStore, SearchStore):
|
||||
|
||||
@@ -469,35 +498,6 @@ class RoomStore(RoomWorkerStore, SearchStore):
|
||||
"get_all_new_public_rooms", get_all_new_public_rooms
|
||||
)
|
||||
|
||||
@cachedInlineCallbacks(max_entries=10000)
|
||||
def get_ratelimit_for_user(self, user_id):
|
||||
"""Check if there are any overrides for ratelimiting for the given
|
||||
user
|
||||
|
||||
Args:
|
||||
user_id (str)
|
||||
|
||||
Returns:
|
||||
RatelimitOverride if there is an override, else None. If the contents
|
||||
of RatelimitOverride are None or 0 then ratelimitng has been
|
||||
disabled for that user entirely.
|
||||
"""
|
||||
row = yield self._simple_select_one(
|
||||
table="ratelimit_override",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("messages_per_second", "burst_count"),
|
||||
allow_none=True,
|
||||
desc="get_ratelimit_for_user",
|
||||
)
|
||||
|
||||
if row:
|
||||
defer.returnValue(RatelimitOverride(
|
||||
messages_per_second=row["messages_per_second"],
|
||||
burst_count=row["burst_count"],
|
||||
))
|
||||
else:
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def block_room(self, room_id, user_id):
|
||||
yield self._simple_insert(
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
/* Copyright 2018 New Vector Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*
|
||||
* Add a batch number to track changes to profiles and the
|
||||
* order they're made in so we can replicate user profiles
|
||||
* to other hosts as they change
|
||||
*/
|
||||
ALTER TABLE profiles ADD COLUMN batch BIGINT DEFAULT NULL;
|
||||
|
||||
/*
|
||||
* Index on the batch number so we can get profiles
|
||||
* by their batch
|
||||
*/
|
||||
CREATE INDEX profiles_batch_idx ON profiles(batch);
|
||||
|
||||
/*
|
||||
* A table to track what batch of user profiles has been
|
||||
* synced to what profile replication target.
|
||||
*/
|
||||
CREATE TABLE profile_replication_status (
|
||||
host TEXT NOT NULL,
|
||||
last_synced_batch BIGINT NOT NULL
|
||||
);
|
||||
@@ -1,23 +0,0 @@
|
||||
/* Copyright 2018 New Vector Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*
|
||||
* A flag saying whether the user owning the profile has been deactivated
|
||||
* This really belongs on the users table, not here, but the users table
|
||||
* stores users by their full user_id and profiles stores them by localpart,
|
||||
* so we can't easily join between the two tables. Plus, the batch number
|
||||
* realy ought to represent data in this table that has changed.
|
||||
*/
|
||||
ALTER TABLE profiles ADD COLUMN active SMALLINT DEFAULT 1 NOT NULL;
|
||||
@@ -60,8 +60,43 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(StateGroupWorkerStore, self).__init__(db_conn, hs)
|
||||
|
||||
# Originally the state store used a single DictionaryCache to cache the
|
||||
# event IDs for the state types in a given state group to avoid hammering
|
||||
# on the state_group* tables.
|
||||
#
|
||||
# The point of using a DictionaryCache is that it can cache a subset
|
||||
# of the state events for a given state group (i.e. a subset of the keys for a
|
||||
# given dict which is an entry in the cache for a given state group ID).
|
||||
#
|
||||
# However, this poses problems when performing complicated queries
|
||||
# on the store - for instance: "give me all the state for this group, but
|
||||
# limit members to this subset of users", as DictionaryCache's API isn't
|
||||
# rich enough to say "please cache any of these fields, apart from this subset".
|
||||
# This is problematic when lazy loading members, which requires this behaviour,
|
||||
# as without it the cache has no choice but to speculatively load all
|
||||
# state events for the group, which negates the efficiency being sought.
|
||||
#
|
||||
# Rather than overcomplicating DictionaryCache's API, we instead split the
|
||||
# state_group_cache into two halves - one for tracking non-member events,
|
||||
# and the other for tracking member_events. This means that lazy loading
|
||||
# queries can be made in a cache-friendly manner by querying both caches
|
||||
# separately and then merging the result. So for the example above, you
|
||||
# would query the members cache for a specific subset of state keys
|
||||
# (which DictionaryCache will handle efficiently and fine) and the non-members
|
||||
# cache for all state (which DictionaryCache will similarly handle fine)
|
||||
# and then just merge the results together.
|
||||
#
|
||||
# We size the non-members cache to be smaller than the members cache as the
|
||||
# vast majority of state in Matrix (today) is member events.
|
||||
|
||||
self._state_group_cache = DictionaryCache(
|
||||
"*stateGroupCache*", 500000 * get_cache_factor_for("stateGroupCache")
|
||||
"*stateGroupCache*",
|
||||
# TODO: this hasn't been tuned yet
|
||||
50000 * get_cache_factor_for("stateGroupCache")
|
||||
)
|
||||
self._state_group_members_cache = DictionaryCache(
|
||||
"*stateGroupMembersCache*",
|
||||
500000 * get_cache_factor_for("stateGroupMembersCache")
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -275,7 +310,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_state_groups_from_groups(self, groups, types):
|
||||
def _get_state_groups_from_groups(self, groups, types, members=None):
|
||||
"""Returns the state groups for a given set of groups, filtering on
|
||||
types of state events.
|
||||
|
||||
@@ -284,6 +319,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
types (Iterable[str, str|None]|None): list of 2-tuples of the form
|
||||
(`type`, `state_key`), where a `state_key` of `None` matches all
|
||||
state_keys for the `type`. If None, all types are returned.
|
||||
members (bool|None): If not None, then, in addition to any filtering
|
||||
implied by types, the results are also filtered to only include
|
||||
member events (if True), or to exclude member events (if False)
|
||||
|
||||
Returns:
|
||||
dictionary state_group -> (dict of (type, state_key) -> event id)
|
||||
@@ -294,14 +332,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
for chunk in chunks:
|
||||
res = yield self.runInteraction(
|
||||
"_get_state_groups_from_groups",
|
||||
self._get_state_groups_from_groups_txn, chunk, types,
|
||||
self._get_state_groups_from_groups_txn, chunk, types, members,
|
||||
)
|
||||
results.update(res)
|
||||
|
||||
defer.returnValue(results)
|
||||
|
||||
def _get_state_groups_from_groups_txn(
|
||||
self, txn, groups, types=None,
|
||||
self, txn, groups, types=None, members=None,
|
||||
):
|
||||
results = {group: {} for group in groups}
|
||||
|
||||
@@ -339,6 +377,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
%s
|
||||
""")
|
||||
|
||||
if members is True:
|
||||
sql += " AND type = '%s'" % (EventTypes.Member,)
|
||||
elif members is False:
|
||||
sql += " AND type <> '%s'" % (EventTypes.Member,)
|
||||
|
||||
# Turns out that postgres doesn't like doing a list of OR's and
|
||||
# is about 1000x slower, so we just issue a query for each specific
|
||||
# type seperately.
|
||||
@@ -386,6 +429,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
else:
|
||||
where_clause = ""
|
||||
|
||||
if members is True:
|
||||
where_clause += " AND type = '%s'" % EventTypes.Member
|
||||
elif members is False:
|
||||
where_clause += " AND type <> '%s'" % EventTypes.Member
|
||||
|
||||
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
|
||||
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
|
||||
for group in groups:
|
||||
@@ -580,10 +628,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
|
||||
defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
|
||||
|
||||
def _get_some_state_from_cache(self, group, types, filtered_types=None):
|
||||
def _get_some_state_from_cache(self, cache, group, types, filtered_types=None):
|
||||
"""Checks if group is in cache. See `_get_state_for_groups`
|
||||
|
||||
Args:
|
||||
cache(DictionaryCache): the state group cache to use
|
||||
group(int): The state group to lookup
|
||||
types(list[str, str|None]): List of 2-tuples of the form
|
||||
(`type`, `state_key`), where a `state_key` of `None` matches all
|
||||
@@ -597,11 +646,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
requests state from the cache, if False we need to query the DB for the
|
||||
missing state.
|
||||
"""
|
||||
is_all, known_absent, state_dict_ids = self._state_group_cache.get(group)
|
||||
is_all, known_absent, state_dict_ids = cache.get(group)
|
||||
|
||||
type_to_key = {}
|
||||
|
||||
# tracks whether any of ourrequested types are missing from the cache
|
||||
# tracks whether any of our requested types are missing from the cache
|
||||
missing_types = False
|
||||
|
||||
for typ, state_key in types:
|
||||
@@ -648,7 +697,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
if include(k[0], k[1])
|
||||
}, got_all
|
||||
|
||||
def _get_all_state_from_cache(self, group):
|
||||
def _get_all_state_from_cache(self, cache, group):
|
||||
"""Checks if group is in cache. See `_get_state_for_groups`
|
||||
|
||||
Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool
|
||||
@@ -656,9 +705,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
cache, if False we need to query the DB for the missing state.
|
||||
|
||||
Args:
|
||||
cache(DictionaryCache): the state group cache to use
|
||||
group: The state group to lookup
|
||||
"""
|
||||
is_all, _, state_dict_ids = self._state_group_cache.get(group)
|
||||
is_all, _, state_dict_ids = cache.get(group)
|
||||
|
||||
return state_dict_ids, is_all
|
||||
|
||||
@@ -681,6 +731,62 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
list of event types. Other types of events are returned unfiltered.
|
||||
If None, `types` filtering is applied to all events.
|
||||
|
||||
Returns:
|
||||
Deferred[dict[int, dict[(type, state_key), EventBase]]]
|
||||
a dictionary mapping from state group to state dictionary.
|
||||
"""
|
||||
if types is not None:
|
||||
non_member_types = [t for t in types if t[0] != EventTypes.Member]
|
||||
|
||||
if filtered_types is not None and EventTypes.Member not in filtered_types:
|
||||
# we want all of the membership events
|
||||
member_types = None
|
||||
else:
|
||||
member_types = [t for t in types if t[0] == EventTypes.Member]
|
||||
|
||||
else:
|
||||
non_member_types = None
|
||||
member_types = None
|
||||
|
||||
non_member_state = yield self._get_state_for_groups_using_cache(
|
||||
groups, self._state_group_cache, non_member_types, filtered_types,
|
||||
)
|
||||
# XXX: we could skip this entirely if member_types is []
|
||||
member_state = yield self._get_state_for_groups_using_cache(
|
||||
# we set filtered_types=None as member_state only ever contain members.
|
||||
groups, self._state_group_members_cache, member_types, None,
|
||||
)
|
||||
|
||||
state = non_member_state
|
||||
for group in groups:
|
||||
state[group].update(member_state[group])
|
||||
|
||||
defer.returnValue(state)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_state_for_groups_using_cache(
|
||||
self, groups, cache, types=None, filtered_types=None
|
||||
):
|
||||
"""Gets the state at each of a list of state groups, optionally
|
||||
filtering by type/state_key, querying from a specific cache.
|
||||
|
||||
Args:
|
||||
groups (iterable[int]): list of state groups for which we want
|
||||
to get the state.
|
||||
cache (DictionaryCache): the cache of group ids to state dicts which
|
||||
we will pass through - either the normal state cache or the specific
|
||||
members state cache.
|
||||
types (None|iterable[(str, None|str)]):
|
||||
indicates the state type/keys required. If None, the whole
|
||||
state is fetched and returned.
|
||||
|
||||
Otherwise, each entry should be a `(type, state_key)` tuple to
|
||||
include in the response. A `state_key` of None is a wildcard
|
||||
meaning that we require all state with that type.
|
||||
filtered_types(list[str]|None): Only apply filtering via `types` to this
|
||||
list of event types. Other types of events are returned unfiltered.
|
||||
If None, `types` filtering is applied to all events.
|
||||
|
||||
Returns:
|
||||
Deferred[dict[int, dict[(type, state_key), EventBase]]]
|
||||
a dictionary mapping from state group to state dictionary.
|
||||
@@ -692,7 +798,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
if types is not None:
|
||||
for group in set(groups):
|
||||
state_dict_ids, got_all = self._get_some_state_from_cache(
|
||||
group, types, filtered_types
|
||||
cache, group, types, filtered_types
|
||||
)
|
||||
results[group] = state_dict_ids
|
||||
|
||||
@@ -701,7 +807,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
else:
|
||||
for group in set(groups):
|
||||
state_dict_ids, got_all = self._get_all_state_from_cache(
|
||||
group
|
||||
cache, group
|
||||
)
|
||||
|
||||
results[group] = state_dict_ids
|
||||
@@ -710,8 +816,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
missing_groups.append(group)
|
||||
|
||||
if missing_groups:
|
||||
# Okay, so we have some missing_types, lets fetch them.
|
||||
cache_seq_num = self._state_group_cache.sequence
|
||||
# Okay, so we have some missing_types, let's fetch them.
|
||||
cache_seq_num = cache.sequence
|
||||
|
||||
# the DictionaryCache knows if it has *all* the state, but
|
||||
# does not know if it has all of the keys of a particular type,
|
||||
@@ -725,7 +831,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
types_to_fetch = types
|
||||
|
||||
group_to_state_dict = yield self._get_state_groups_from_groups(
|
||||
missing_groups, types_to_fetch
|
||||
missing_groups, types_to_fetch, cache == self._state_group_members_cache,
|
||||
)
|
||||
|
||||
for group, group_state_dict in iteritems(group_to_state_dict):
|
||||
@@ -745,7 +851,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
|
||||
# update the cache with all the things we fetched from the
|
||||
# database.
|
||||
self._state_group_cache.update(
|
||||
cache.update(
|
||||
cache_seq_num,
|
||||
key=group,
|
||||
value=group_state_dict,
|
||||
@@ -847,15 +953,33 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
],
|
||||
)
|
||||
|
||||
# Prefill the state group cache with this group.
|
||||
# Prefill the state group caches with this group.
|
||||
# It's fine to use the sequence like this as the state group map
|
||||
# is immutable. (If the map wasn't immutable then this prefill could
|
||||
# race with another update)
|
||||
|
||||
current_member_state_ids = {
|
||||
s: ev
|
||||
for (s, ev) in iteritems(current_state_ids)
|
||||
if s[0] == EventTypes.Member
|
||||
}
|
||||
txn.call_after(
|
||||
self._state_group_members_cache.update,
|
||||
self._state_group_members_cache.sequence,
|
||||
key=state_group,
|
||||
value=dict(current_member_state_ids),
|
||||
)
|
||||
|
||||
current_non_member_state_ids = {
|
||||
s: ev
|
||||
for (s, ev) in iteritems(current_state_ids)
|
||||
if s[0] != EventTypes.Member
|
||||
}
|
||||
txn.call_after(
|
||||
self._state_group_cache.update,
|
||||
self._state_group_cache.sequence,
|
||||
key=state_group,
|
||||
value=dict(current_state_ids),
|
||||
value=dict(current_non_member_state_ids),
|
||||
)
|
||||
|
||||
return state_group
|
||||
|
||||
@@ -228,18 +228,6 @@ def contains_invalid_mxid_characters(localpart):
|
||||
return any(c not in mxid_localpart_allowed_characters for c in localpart)
|
||||
|
||||
|
||||
def strip_invalid_mxid_characters(localpart):
|
||||
"""Removes any invalid characters from an mxid
|
||||
|
||||
Args:
|
||||
localpart (basestring): the localpart to be stripped
|
||||
|
||||
Returns:
|
||||
localpart (basestring): the localpart having been stripped
|
||||
"""
|
||||
return filter(lambda c: c in mxid_localpart_allowed_characters, localpart)
|
||||
|
||||
|
||||
class StreamToken(
|
||||
namedtuple("Token", (
|
||||
"room_key",
|
||||
|
||||
@@ -16,12 +16,9 @@
|
||||
import logging
|
||||
import re
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_3pid_allowed(hs, medium, address):
|
||||
"""Checks whether a given format of 3PID is allowed to be used on this HS
|
||||
|
||||
@@ -31,22 +28,9 @@ def check_3pid_allowed(hs, medium, address):
|
||||
address (str): address within that medium (e.g. "wotan@matrix.org")
|
||||
msisdns need to first have been canonicalised
|
||||
Returns:
|
||||
defered bool: whether the 3PID medium/address is allowed to be added to this HS
|
||||
bool: whether the 3PID medium/address is allowed to be added to this HS
|
||||
"""
|
||||
|
||||
if hs.config.check_is_for_allowed_local_3pids:
|
||||
data = yield hs.get_simple_http_client().get_json(
|
||||
"https://%s%s" % (
|
||||
hs.config.check_is_for_allowed_local_3pids,
|
||||
"/_matrix/identity/api/v1/info"
|
||||
),
|
||||
{'medium': medium, 'address': address}
|
||||
)
|
||||
if hs.config.allow_invited_3pids and data.get('invited'):
|
||||
defer.returnValue(True)
|
||||
else:
|
||||
defer.returnValue(data['hs'] == hs.config.server_name)
|
||||
|
||||
if hs.config.allowed_local_3pids:
|
||||
for constraint in hs.config.allowed_local_3pids:
|
||||
logger.debug(
|
||||
@@ -57,8 +41,8 @@ def check_3pid_allowed(hs, medium, address):
|
||||
medium == constraint['medium'] and
|
||||
re.match(constraint['pattern'], address)
|
||||
):
|
||||
defer.returnValue(True)
|
||||
return True
|
||||
else:
|
||||
defer.returnValue(True)
|
||||
return True
|
||||
|
||||
defer.returnValue(False)
|
||||
return False
|
||||
|
||||
@@ -458,7 +458,7 @@ class AuthTestCase(unittest.TestCase):
|
||||
with self.assertRaises(ResourceLimitError) as e:
|
||||
yield self.auth.check_auth_blocking()
|
||||
self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri)
|
||||
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
|
||||
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||
self.assertEquals(e.exception.code, 403)
|
||||
|
||||
# Ensure does not throw an error
|
||||
@@ -474,5 +474,13 @@ class AuthTestCase(unittest.TestCase):
|
||||
with self.assertRaises(ResourceLimitError) as e:
|
||||
yield self.auth.check_auth_blocking()
|
||||
self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri)
|
||||
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
|
||||
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||
self.assertEquals(e.exception.code, 403)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_server_notices_mxid_special_cased(self):
|
||||
self.hs.config.hs_disabled = True
|
||||
user = "@user:server"
|
||||
self.hs.config.server_notices_mxid = user
|
||||
self.hs.config.hs_disabled_message = "Reason for being disabled"
|
||||
yield self.auth.check_auth_blocking(user)
|
||||
|
||||
@@ -20,7 +20,7 @@ from twisted.internet import defer
|
||||
|
||||
import synapse.types
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.handlers.profile import ProfileHandler
|
||||
from synapse.handlers.profile import MasterProfileHandler
|
||||
from synapse.types import UserID
|
||||
|
||||
from tests import unittest
|
||||
@@ -29,7 +29,7 @@ from tests.utils import setup_test_homeserver
|
||||
|
||||
class ProfileHandlers(object):
|
||||
def __init__(self, hs):
|
||||
self.profile_handler = ProfileHandler(hs)
|
||||
self.profile_handler = MasterProfileHandler(hs)
|
||||
|
||||
|
||||
class ProfileTestCase(unittest.TestCase):
|
||||
@@ -67,13 +67,13 @@ class ProfileTestCase(unittest.TestCase):
|
||||
self.bob = UserID.from_string("@4567:test")
|
||||
self.alice = UserID.from_string("@alice:remote")
|
||||
|
||||
yield self.store.create_profile(self.frank.localpart)
|
||||
|
||||
self.handler = hs.get_profile_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_my_name(self):
|
||||
yield self.store.set_profile_displayname(
|
||||
self.frank.localpart, "Frank", 1,
|
||||
)
|
||||
yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
|
||||
|
||||
displayname = yield self.handler.get_displayname(self.frank)
|
||||
|
||||
@@ -116,7 +116,8 @@ class ProfileTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_incoming_fed_query(self):
|
||||
yield self.store.set_profile_displayname("caroline", "Caroline", 1)
|
||||
yield self.store.create_profile("caroline")
|
||||
yield self.store.set_profile_displayname("caroline", "Caroline")
|
||||
|
||||
response = yield self.query_handlers["profile"](
|
||||
{"user_id": "@caroline:test", "field": "displayname"}
|
||||
@@ -127,7 +128,7 @@ class ProfileTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def test_get_my_avatar(self):
|
||||
yield self.store.set_profile_avatar_url(
|
||||
self.frank.localpart, "http://my.server/me.png", 1,
|
||||
self.frank.localpart, "http://my.server/me.png"
|
||||
)
|
||||
|
||||
avatar_url = yield self.handler.get_avatar_url(self.frank)
|
||||
|
||||
@@ -51,7 +51,7 @@ class SyncTestCase(tests.unittest.TestCase):
|
||||
self.hs.config.hs_disabled = True
|
||||
with self.assertRaises(ResourceLimitError) as e:
|
||||
yield self.sync_handler.wait_for_sync_for_user(sync_config)
|
||||
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
|
||||
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||
|
||||
self.hs.config.hs_disabled = False
|
||||
|
||||
@@ -59,7 +59,7 @@ class SyncTestCase(tests.unittest.TestCase):
|
||||
|
||||
with self.assertRaises(ResourceLimitError) as e:
|
||||
yield self.sync_handler.wait_for_sync_for_user(sync_config)
|
||||
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEED)
|
||||
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||
|
||||
def _generate_sync_config(self, user_id):
|
||||
return SyncConfig(
|
||||
|
||||
@@ -112,6 +112,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_invites(self):
|
||||
yield self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
yield self.check("get_invited_rooms_for_user", [USER_ID_2], [])
|
||||
event = yield self.persist(
|
||||
type="m.room.member", key=USER_ID_2, membership="invite"
|
||||
@@ -133,7 +134,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_push_actions_for_user(self):
|
||||
yield self.persist(type="m.room.create", creator=USER_ID)
|
||||
yield self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
yield self.persist(type="m.room.join", key=USER_ID, membership="join")
|
||||
yield self.persist(
|
||||
type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join"
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -1,101 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from tests import unittest
|
||||
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.rulecheck.domain_rule_checker import DomainRuleChecker
|
||||
|
||||
|
||||
class DomainRuleCheckerTestCase(unittest.TestCase):
|
||||
|
||||
def test_allowed(self):
|
||||
config = {
|
||||
"default": False,
|
||||
"domain_mapping": {
|
||||
"source_one": ["target_one", "target_two"],
|
||||
"source_two": ["target_two"]
|
||||
}
|
||||
}
|
||||
check = DomainRuleChecker(config)
|
||||
self.assertTrue(check.user_may_invite("test:source_one",
|
||||
"test:target_one", "room"))
|
||||
self.assertTrue(check.user_may_invite("test:source_one",
|
||||
"test:target_two", "room"))
|
||||
self.assertTrue(check.user_may_invite("test:source_two",
|
||||
"test:target_two", "room"))
|
||||
|
||||
def test_disallowed(self):
|
||||
config = {
|
||||
"default": True,
|
||||
"domain_mapping": {
|
||||
"source_one": ["target_one", "target_two"],
|
||||
"source_two": ["target_two"],
|
||||
"source_four": []
|
||||
}
|
||||
}
|
||||
check = DomainRuleChecker(config)
|
||||
self.assertFalse(check.user_may_invite("test:source_one",
|
||||
"test:target_three", "room"))
|
||||
self.assertFalse(check.user_may_invite("test:source_two",
|
||||
"test:target_three", "room"))
|
||||
self.assertFalse(check.user_may_invite("test:source_two",
|
||||
"test:target_one", "room"))
|
||||
self.assertFalse(check.user_may_invite("test:source_four",
|
||||
"test:target_one", "room"))
|
||||
|
||||
def test_default_allow(self):
|
||||
config = {
|
||||
"default": True,
|
||||
"domain_mapping": {
|
||||
"source_one": ["target_one", "target_two"],
|
||||
"source_two": ["target_two"]
|
||||
}
|
||||
}
|
||||
check = DomainRuleChecker(config)
|
||||
self.assertTrue(check.user_may_invite("test:source_three",
|
||||
"test:target_one", "room"))
|
||||
|
||||
def test_default_deny(self):
|
||||
config = {
|
||||
"default": False,
|
||||
"domain_mapping": {
|
||||
"source_one": ["target_one", "target_two"],
|
||||
"source_two": ["target_two"]
|
||||
}
|
||||
}
|
||||
check = DomainRuleChecker(config)
|
||||
self.assertFalse(check.user_may_invite("test:source_three",
|
||||
"test:target_one", "room"))
|
||||
|
||||
def test_config_parse(self):
|
||||
config = {
|
||||
"default": False,
|
||||
"domain_mapping": {
|
||||
"source_one": ["target_one", "target_two"],
|
||||
"source_two": ["target_two"]
|
||||
}
|
||||
}
|
||||
self.assertEquals(config, DomainRuleChecker.parse_config(config))
|
||||
|
||||
def test_config_parse_failure(self):
|
||||
config = {
|
||||
"domain_mapping": {
|
||||
"source_one": ["target_one", "target_two"],
|
||||
"source_two": ["target_two"]
|
||||
}
|
||||
}
|
||||
self.assertRaises(ConfigError, DomainRuleChecker.parse_config, config)
|
||||
@@ -5,7 +5,7 @@ from six import text_type
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.internet import threads
|
||||
from twisted.internet import address, threads
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.test.proto_helpers import MemoryReactorClock
|
||||
@@ -63,7 +63,9 @@ class FakeChannel(object):
|
||||
self.result["done"] = True
|
||||
|
||||
def getPeer(self):
|
||||
return None
|
||||
# We give an address so that getClientIP returns a non null entry,
|
||||
# causing us to record the MAU
|
||||
return address.IPv4Address(b"TCP", "127.0.0.1", 3423)
|
||||
|
||||
def getHost(self):
|
||||
return None
|
||||
@@ -91,7 +93,7 @@ class FakeSite:
|
||||
return FakeLogger()
|
||||
|
||||
|
||||
def make_request(method, path, content=b""):
|
||||
def make_request(method, path, content=b"", access_token=None):
|
||||
"""
|
||||
Make a web request using the given method and path, feed it the
|
||||
content, and return the Request and the Channel underneath.
|
||||
@@ -116,6 +118,11 @@ def make_request(method, path, content=b""):
|
||||
req = SynapseRequest(site, channel)
|
||||
req.process = lambda: b""
|
||||
req.content = BytesIO(content)
|
||||
|
||||
if access_token:
|
||||
req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token)
|
||||
|
||||
req.requestHeaders.addRawHeader(b"X-Forwarded-For", b"127.0.0.1")
|
||||
req.requestReceived(method, path, b"1.1")
|
||||
|
||||
return req, channel
|
||||
|
||||
212
tests/server_notices/test_resource_limits_server_notices.py
Normal file
212
tests/server_notices/test_resource_limits_server_notices.py
Normal file
@@ -0,0 +1,212 @@
|
||||
from mock import Mock
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, ServerNoticeMsgType
|
||||
from synapse.api.errors import ResourceLimitError
|
||||
from synapse.handlers.auth import AuthHandler
|
||||
from synapse.server_notices.resource_limits_server_notices import (
|
||||
ResourceLimitsServerNotices,
|
||||
)
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import setup_test_homeserver
|
||||
|
||||
|
||||
class AuthHandlers(object):
|
||||
def __init__(self, hs):
|
||||
self.auth_handler = AuthHandler(hs)
|
||||
|
||||
|
||||
class TestResourceLimitsServerNotices(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None)
|
||||
self.hs.handlers = AuthHandlers(self.hs)
|
||||
self.auth_handler = self.hs.handlers.auth_handler
|
||||
self.server_notices_sender = self.hs.get_server_notices_sender()
|
||||
|
||||
# relying on [1] is far from ideal, but the only case where
|
||||
# ResourceLimitsServerNotices class needs to be isolated is this test,
|
||||
# general code should never have a reason to do so ...
|
||||
self._rlsn = self.server_notices_sender._server_notices[1]
|
||||
if not isinstance(self._rlsn, ResourceLimitsServerNotices):
|
||||
raise Exception("Failed to find reference to ResourceLimitsServerNotices")
|
||||
|
||||
self._rlsn._store.user_last_seen_monthly_active = Mock(
|
||||
return_value=defer.succeed(1000)
|
||||
)
|
||||
self._send_notice = self._rlsn._server_notices_manager.send_notice
|
||||
self._rlsn._server_notices_manager.send_notice = Mock()
|
||||
self._rlsn._state.get_current_state = Mock(return_value=defer.succeed(None))
|
||||
self._rlsn._store.get_events = Mock(return_value=defer.succeed({}))
|
||||
|
||||
self._send_notice = self._rlsn._server_notices_manager.send_notice
|
||||
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
self.user_id = "@user_id:test"
|
||||
|
||||
# self.server_notices_mxid = "@server:test"
|
||||
# self.server_notices_mxid_display_name = None
|
||||
# self.server_notices_mxid_avatar_url = None
|
||||
# self.server_notices_room_name = "Server Notices"
|
||||
|
||||
self._rlsn._server_notices_manager.get_notice_room_for_user = Mock(
|
||||
returnValue=""
|
||||
)
|
||||
self._rlsn._store.add_tag_to_room = Mock()
|
||||
self.hs.config.admin_uri = "mailto:user@test.com"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_maybe_send_server_notice_to_user_flag_off(self):
|
||||
"""Tests cases where the flags indicate nothing to do"""
|
||||
# test hs disabled case
|
||||
self.hs.config.hs_disabled = True
|
||||
|
||||
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
|
||||
|
||||
self._send_notice.assert_not_called()
|
||||
# Test when mau limiting disabled
|
||||
self.hs.config.hs_disabled = False
|
||||
self.hs.limit_usage_by_mau = False
|
||||
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
|
||||
|
||||
self._send_notice.assert_not_called()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
|
||||
"""Test when user has blocked notice, but should have it removed"""
|
||||
|
||||
self._rlsn._auth.check_auth_blocking = Mock()
|
||||
mock_event = Mock(
|
||||
type=EventTypes.Message,
|
||||
content={"msgtype": ServerNoticeMsgType},
|
||||
)
|
||||
self._rlsn._store.get_events = Mock(return_value=defer.succeed(
|
||||
{"123": mock_event}
|
||||
))
|
||||
|
||||
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
|
||||
# Would be better to check the content, but once == remove blocking event
|
||||
self._send_notice.assert_called_once()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self):
|
||||
"""Test when user has blocked notice, but notice ought to be there (NOOP)"""
|
||||
self._rlsn._auth.check_auth_blocking = Mock(
|
||||
side_effect=ResourceLimitError(403, 'foo')
|
||||
)
|
||||
|
||||
mock_event = Mock(
|
||||
type=EventTypes.Message,
|
||||
content={"msgtype": ServerNoticeMsgType},
|
||||
)
|
||||
self._rlsn._store.get_events = Mock(return_value=defer.succeed(
|
||||
{"123": mock_event}
|
||||
))
|
||||
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
|
||||
|
||||
self._send_notice.assert_not_called()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_maybe_send_server_notice_to_user_add_blocked_notice(self):
|
||||
"""Test when user does not have blocked notice, but should have one"""
|
||||
|
||||
self._rlsn._auth.check_auth_blocking = Mock(
|
||||
side_effect=ResourceLimitError(403, 'foo')
|
||||
)
|
||||
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
|
||||
|
||||
# Would be better to check contents, but 2 calls == set blocking event
|
||||
self.assertTrue(self._send_notice.call_count == 2)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self):
|
||||
"""Test when user does not have blocked notice, nor should they (NOOP)"""
|
||||
|
||||
self._rlsn._auth.check_auth_blocking = Mock()
|
||||
|
||||
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
|
||||
|
||||
self._send_notice.assert_not_called()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self):
|
||||
|
||||
"""Test when user is not part of the MAU cohort - this should not ever
|
||||
happen - but ...
|
||||
"""
|
||||
|
||||
self._rlsn._auth.check_auth_blocking = Mock()
|
||||
self._rlsn._store.user_last_seen_monthly_active = Mock(
|
||||
return_value=defer.succeed(None)
|
||||
)
|
||||
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
|
||||
|
||||
self._send_notice.assert_not_called()
|
||||
|
||||
|
||||
class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.hs = yield setup_test_homeserver(self.addCleanup)
|
||||
self.store = self.hs.get_datastore()
|
||||
self.server_notices_sender = self.hs.get_server_notices_sender()
|
||||
self.server_notices_manager = self.hs.get_server_notices_manager()
|
||||
self.event_source = self.hs.get_event_sources()
|
||||
|
||||
# relying on [1] is far from ideal, but the only case where
|
||||
# ResourceLimitsServerNotices class needs to be isolated is this test,
|
||||
# general code should never have a reason to do so ...
|
||||
self._rlsn = self.server_notices_sender._server_notices[1]
|
||||
if not isinstance(self._rlsn, ResourceLimitsServerNotices):
|
||||
raise Exception("Failed to find reference to ResourceLimitsServerNotices")
|
||||
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
self.hs.config.hs_disabled = False
|
||||
self.hs.config.max_mau_value = 5
|
||||
self.hs.config.server_notices_mxid = "@server:test"
|
||||
self.hs.config.server_notices_mxid_display_name = None
|
||||
self.hs.config.server_notices_mxid_avatar_url = None
|
||||
self.hs.config.server_notices_room_name = "Test Server Notice Room"
|
||||
|
||||
self.user_id = "@user_id:test"
|
||||
|
||||
self.hs.config.admin_uri = "mailto:user@test.com"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_server_notice_only_sent_once(self):
|
||||
self.store.get_monthly_active_count = Mock(
|
||||
return_value=1000,
|
||||
)
|
||||
|
||||
self.store.user_last_seen_monthly_active = Mock(
|
||||
return_value=1000,
|
||||
)
|
||||
|
||||
# Call the function multiple times to ensure we only send the notice once
|
||||
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
|
||||
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
|
||||
yield self._rlsn.maybe_send_server_notice_to_user(self.user_id)
|
||||
|
||||
# Now lets get the last load of messages in the service notice room and
|
||||
# check that there is only one server notice
|
||||
room_id = yield self.server_notices_manager.get_notice_room_for_user(
|
||||
self.user_id,
|
||||
)
|
||||
|
||||
token = yield self.event_source.get_current_token()
|
||||
events, _ = yield self.store.get_recent_events_for_room(
|
||||
room_id, limit=100, end_token=token.room_key,
|
||||
)
|
||||
|
||||
count = 0
|
||||
for event in events:
|
||||
if event.type != EventTypes.Message:
|
||||
continue
|
||||
if event.content.get("msgtype") != ServerNoticeMsgType:
|
||||
continue
|
||||
|
||||
count += 1
|
||||
|
||||
self.assertEqual(count, 1)
|
||||
@@ -34,19 +34,20 @@ class ProfileStoreTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_displayname(self):
|
||||
yield self.store.set_profile_displayname(
|
||||
self.u_frank.localpart, "Frank", 1,
|
||||
)
|
||||
yield self.store.create_profile(self.u_frank.localpart)
|
||||
|
||||
yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
|
||||
|
||||
self.assertEquals(
|
||||
"Frank",
|
||||
(yield self.store.get_profile_displayname(self.u_frank.localpart))
|
||||
"Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_avatar_url(self):
|
||||
yield self.store.create_profile(self.u_frank.localpart)
|
||||
|
||||
yield self.store.set_profile_avatar_url(
|
||||
self.u_frank.localpart, "http://my.site/here", 1,
|
||||
self.u_frank.localpart, "http://my.site/here"
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
|
||||
@@ -22,7 +22,7 @@ from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.types import RoomID, UserID
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import setup_test_homeserver
|
||||
from tests.utils import create_room, setup_test_homeserver
|
||||
|
||||
|
||||
class RedactionTestCase(unittest.TestCase):
|
||||
@@ -41,6 +41,8 @@ class RedactionTestCase(unittest.TestCase):
|
||||
|
||||
self.room1 = RoomID.from_string("!abc123:test")
|
||||
|
||||
yield create_room(hs, self.room1.to_string(), self.u_alice.to_string())
|
||||
|
||||
self.depth = 1
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
||||
@@ -46,6 +46,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
||||
"consent_version": None,
|
||||
"consent_server_notice_sent": None,
|
||||
"appservice_id": None,
|
||||
"creation_ts": 1000,
|
||||
},
|
||||
(yield self.store.get_user_by_id(self.user_id)),
|
||||
)
|
||||
|
||||
@@ -22,7 +22,7 @@ from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.types import RoomID, UserID
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import setup_test_homeserver
|
||||
from tests.utils import create_room, setup_test_homeserver
|
||||
|
||||
|
||||
class RoomMemberStoreTestCase(unittest.TestCase):
|
||||
@@ -45,6 +45,8 @@ class RoomMemberStoreTestCase(unittest.TestCase):
|
||||
|
||||
self.room = RoomID.from_string("!abc123:test")
|
||||
|
||||
yield create_room(hs, self.room.to_string(), self.u_alice.to_string())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def inject_room_member(self, room, user, membership, replaces_state=None):
|
||||
builder = self.event_builder_factory.new(
|
||||
|
||||
@@ -185,6 +185,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||
|
||||
# test _get_some_state_from_cache correctly filters out members with types=[]
|
||||
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
|
||||
self.store._state_group_cache,
|
||||
group, [], filtered_types=[EventTypes.Member]
|
||||
)
|
||||
|
||||
@@ -197,8 +198,20 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||
state_dict,
|
||||
)
|
||||
|
||||
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
|
||||
self.store._state_group_members_cache,
|
||||
group, [], filtered_types=[EventTypes.Member]
|
||||
)
|
||||
|
||||
self.assertEqual(is_all, True)
|
||||
self.assertDictEqual(
|
||||
{},
|
||||
state_dict,
|
||||
)
|
||||
|
||||
# test _get_some_state_from_cache correctly filters in members with wildcard types
|
||||
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
|
||||
self.store._state_group_cache,
|
||||
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
|
||||
)
|
||||
|
||||
@@ -207,6 +220,18 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||
{
|
||||
(e1.type, e1.state_key): e1.event_id,
|
||||
(e2.type, e2.state_key): e2.event_id,
|
||||
},
|
||||
state_dict,
|
||||
)
|
||||
|
||||
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
|
||||
self.store._state_group_members_cache,
|
||||
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
|
||||
)
|
||||
|
||||
self.assertEqual(is_all, True)
|
||||
self.assertDictEqual(
|
||||
{
|
||||
(e3.type, e3.state_key): e3.event_id,
|
||||
# e4 is overwritten by e5
|
||||
(e5.type, e5.state_key): e5.event_id,
|
||||
@@ -216,6 +241,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||
|
||||
# test _get_some_state_from_cache correctly filters in members with specific types
|
||||
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
|
||||
self.store._state_group_cache,
|
||||
group,
|
||||
[(EventTypes.Member, e5.state_key)],
|
||||
filtered_types=[EventTypes.Member],
|
||||
@@ -226,6 +252,20 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||
{
|
||||
(e1.type, e1.state_key): e1.event_id,
|
||||
(e2.type, e2.state_key): e2.event_id,
|
||||
},
|
||||
state_dict,
|
||||
)
|
||||
|
||||
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
|
||||
self.store._state_group_members_cache,
|
||||
group,
|
||||
[(EventTypes.Member, e5.state_key)],
|
||||
filtered_types=[EventTypes.Member],
|
||||
)
|
||||
|
||||
self.assertEqual(is_all, True)
|
||||
self.assertDictEqual(
|
||||
{
|
||||
(e5.type, e5.state_key): e5.event_id,
|
||||
},
|
||||
state_dict,
|
||||
@@ -234,6 +274,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||
# test _get_some_state_from_cache correctly filters in members with specific types
|
||||
# and no filtered_types
|
||||
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
|
||||
self.store._state_group_members_cache,
|
||||
group, [(EventTypes.Member, e5.state_key)], filtered_types=None
|
||||
)
|
||||
|
||||
@@ -254,9 +295,6 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||
{
|
||||
(e1.type, e1.state_key): e1.event_id,
|
||||
(e2.type, e2.state_key): e2.event_id,
|
||||
(e3.type, e3.state_key): e3.event_id,
|
||||
# e4 is overwritten by e5
|
||||
(e5.type, e5.state_key): e5.event_id,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -269,8 +307,6 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||
# list fetched keys so it knows it's partial
|
||||
fetched_keys=(
|
||||
(e1.type, e1.state_key),
|
||||
(e3.type, e3.state_key),
|
||||
(e5.type, e5.state_key),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -284,8 +320,6 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||
set(
|
||||
[
|
||||
(e1.type, e1.state_key),
|
||||
(e3.type, e3.state_key),
|
||||
(e5.type, e5.state_key),
|
||||
]
|
||||
),
|
||||
)
|
||||
@@ -293,8 +327,6 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||
state_dict_ids,
|
||||
{
|
||||
(e1.type, e1.state_key): e1.event_id,
|
||||
(e3.type, e3.state_key): e3.event_id,
|
||||
(e5.type, e5.state_key): e5.event_id,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -304,14 +336,25 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||
# test _get_some_state_from_cache correctly filters out members with types=[]
|
||||
room_id = self.room.to_string()
|
||||
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
|
||||
self.store._state_group_cache,
|
||||
group, [], filtered_types=[EventTypes.Member]
|
||||
)
|
||||
|
||||
self.assertEqual(is_all, False)
|
||||
self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
|
||||
|
||||
room_id = self.room.to_string()
|
||||
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
|
||||
self.store._state_group_members_cache,
|
||||
group, [], filtered_types=[EventTypes.Member]
|
||||
)
|
||||
|
||||
self.assertEqual(is_all, True)
|
||||
self.assertDictEqual({}, state_dict)
|
||||
|
||||
# test _get_some_state_from_cache correctly filters in members wildcard types
|
||||
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
|
||||
self.store._state_group_cache,
|
||||
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
|
||||
)
|
||||
|
||||
@@ -319,8 +362,19 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||
self.assertDictEqual(
|
||||
{
|
||||
(e1.type, e1.state_key): e1.event_id,
|
||||
},
|
||||
state_dict,
|
||||
)
|
||||
|
||||
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
|
||||
self.store._state_group_members_cache,
|
||||
group, [(EventTypes.Member, None)], filtered_types=[EventTypes.Member]
|
||||
)
|
||||
|
||||
self.assertEqual(is_all, True)
|
||||
self.assertDictEqual(
|
||||
{
|
||||
(e3.type, e3.state_key): e3.event_id,
|
||||
# e4 is overwritten by e5
|
||||
(e5.type, e5.state_key): e5.event_id,
|
||||
},
|
||||
state_dict,
|
||||
@@ -328,6 +382,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||
|
||||
# test _get_some_state_from_cache correctly filters in members with specific types
|
||||
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
|
||||
self.store._state_group_cache,
|
||||
group,
|
||||
[(EventTypes.Member, e5.state_key)],
|
||||
filtered_types=[EventTypes.Member],
|
||||
@@ -337,6 +392,20 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||
self.assertDictEqual(
|
||||
{
|
||||
(e1.type, e1.state_key): e1.event_id,
|
||||
},
|
||||
state_dict,
|
||||
)
|
||||
|
||||
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
|
||||
self.store._state_group_members_cache,
|
||||
group,
|
||||
[(EventTypes.Member, e5.state_key)],
|
||||
filtered_types=[EventTypes.Member],
|
||||
)
|
||||
|
||||
self.assertEqual(is_all, True)
|
||||
self.assertDictEqual(
|
||||
{
|
||||
(e5.type, e5.state_key): e5.event_id,
|
||||
},
|
||||
state_dict,
|
||||
@@ -345,8 +414,22 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||
# test _get_some_state_from_cache correctly filters in members with specific types
|
||||
# and no filtered_types
|
||||
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
|
||||
self.store._state_group_cache,
|
||||
group, [(EventTypes.Member, e5.state_key)], filtered_types=None
|
||||
)
|
||||
|
||||
self.assertEqual(is_all, False)
|
||||
self.assertDictEqual({}, state_dict)
|
||||
|
||||
(state_dict, is_all) = yield self.store._get_some_state_from_cache(
|
||||
self.store._state_group_members_cache,
|
||||
group, [(EventTypes.Member, e5.state_key)], filtered_types=None
|
||||
)
|
||||
|
||||
self.assertEqual(is_all, True)
|
||||
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
|
||||
self.assertDictEqual(
|
||||
{
|
||||
(e5.type, e5.state_key): e5.event_id,
|
||||
},
|
||||
state_dict,
|
||||
)
|
||||
|
||||
217
tests/test_mau.py
Normal file
217
tests/test_mau.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests REST events for /rooms paths."""
|
||||
|
||||
import json
|
||||
|
||||
from mock import Mock, NonCallableMock
|
||||
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.rest.client.v2_alpha import register, sync
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import (
|
||||
ThreadedMemoryReactorClock,
|
||||
make_request,
|
||||
render,
|
||||
setup_test_homeserver,
|
||||
)
|
||||
|
||||
|
||||
class TestMauLimit(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.reactor = ThreadedMemoryReactorClock()
|
||||
self.clock = Clock(self.reactor)
|
||||
|
||||
self.hs = setup_test_homeserver(
|
||||
self.addCleanup,
|
||||
"red",
|
||||
http_client=None,
|
||||
clock=self.clock,
|
||||
reactor=self.reactor,
|
||||
federation_client=Mock(),
|
||||
ratelimiter=NonCallableMock(spec_set=["send_message"]),
|
||||
)
|
||||
|
||||
self.store = self.hs.get_datastore()
|
||||
|
||||
self.hs.config.registrations_require_3pid = []
|
||||
self.hs.config.enable_registration_captcha = False
|
||||
self.hs.config.recaptcha_public_key = []
|
||||
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
self.hs.config.hs_disabled = False
|
||||
self.hs.config.max_mau_value = 2
|
||||
self.hs.config.mau_trial_days = 0
|
||||
self.hs.config.server_notices_mxid = "@server:red"
|
||||
self.hs.config.server_notices_mxid_display_name = None
|
||||
self.hs.config.server_notices_mxid_avatar_url = None
|
||||
self.hs.config.server_notices_room_name = "Test Server Notice Room"
|
||||
|
||||
self.resource = JsonResource(self.hs)
|
||||
register.register_servlets(self.hs, self.resource)
|
||||
sync.register_servlets(self.hs, self.resource)
|
||||
|
||||
def test_simple_deny_mau(self):
|
||||
# Create and sync so that the MAU counts get updated
|
||||
token1 = self.create_user("kermit1")
|
||||
self.do_sync_for_user(token1)
|
||||
token2 = self.create_user("kermit2")
|
||||
self.do_sync_for_user(token2)
|
||||
|
||||
# We've created and activated two users, we shouldn't be able to
|
||||
# register new users
|
||||
with self.assertRaises(SynapseError) as cm:
|
||||
self.create_user("kermit3")
|
||||
|
||||
e = cm.exception
|
||||
self.assertEqual(e.code, 403)
|
||||
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||
|
||||
def test_allowed_after_a_month_mau(self):
|
||||
# Create and sync so that the MAU counts get updated
|
||||
token1 = self.create_user("kermit1")
|
||||
self.do_sync_for_user(token1)
|
||||
token2 = self.create_user("kermit2")
|
||||
self.do_sync_for_user(token2)
|
||||
|
||||
# Advance time by 31 days
|
||||
self.reactor.advance(31 * 24 * 60 * 60)
|
||||
|
||||
self.store.reap_monthly_active_users()
|
||||
|
||||
self.reactor.advance(0)
|
||||
|
||||
# We should be able to register more users
|
||||
token3 = self.create_user("kermit3")
|
||||
self.do_sync_for_user(token3)
|
||||
|
||||
def test_trial_delay(self):
|
||||
self.hs.config.mau_trial_days = 1
|
||||
|
||||
# We should be able to register more than the limit initially
|
||||
token1 = self.create_user("kermit1")
|
||||
self.do_sync_for_user(token1)
|
||||
token2 = self.create_user("kermit2")
|
||||
self.do_sync_for_user(token2)
|
||||
token3 = self.create_user("kermit3")
|
||||
self.do_sync_for_user(token3)
|
||||
|
||||
# Advance time by 2 days
|
||||
self.reactor.advance(2 * 24 * 60 * 60)
|
||||
|
||||
# Two users should be able to sync
|
||||
self.do_sync_for_user(token1)
|
||||
self.do_sync_for_user(token2)
|
||||
|
||||
# But the third should fail
|
||||
with self.assertRaises(SynapseError) as cm:
|
||||
self.do_sync_for_user(token3)
|
||||
|
||||
e = cm.exception
|
||||
self.assertEqual(e.code, 403)
|
||||
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||
|
||||
# And new registrations are now denied too
|
||||
with self.assertRaises(SynapseError) as cm:
|
||||
self.create_user("kermit4")
|
||||
|
||||
e = cm.exception
|
||||
self.assertEqual(e.code, 403)
|
||||
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||
|
||||
def test_trial_users_cant_come_back(self):
|
||||
self.hs.config.mau_trial_days = 1
|
||||
|
||||
# We should be able to register more than the limit initially
|
||||
token1 = self.create_user("kermit1")
|
||||
self.do_sync_for_user(token1)
|
||||
token2 = self.create_user("kermit2")
|
||||
self.do_sync_for_user(token2)
|
||||
token3 = self.create_user("kermit3")
|
||||
self.do_sync_for_user(token3)
|
||||
|
||||
# Advance time by 2 days
|
||||
self.reactor.advance(2 * 24 * 60 * 60)
|
||||
|
||||
# Two users should be able to sync
|
||||
self.do_sync_for_user(token1)
|
||||
self.do_sync_for_user(token2)
|
||||
|
||||
# Advance by 2 months so everyone falls out of MAU
|
||||
self.reactor.advance(60 * 24 * 60 * 60)
|
||||
self.store.reap_monthly_active_users()
|
||||
self.reactor.advance(0)
|
||||
|
||||
# We can create as many new users as we want
|
||||
token4 = self.create_user("kermit4")
|
||||
self.do_sync_for_user(token4)
|
||||
token5 = self.create_user("kermit5")
|
||||
self.do_sync_for_user(token5)
|
||||
token6 = self.create_user("kermit6")
|
||||
self.do_sync_for_user(token6)
|
||||
|
||||
# users 2 and 3 can come back to bring us back up to MAU limit
|
||||
self.do_sync_for_user(token2)
|
||||
self.do_sync_for_user(token3)
|
||||
|
||||
# New trial users can still sync
|
||||
self.do_sync_for_user(token4)
|
||||
self.do_sync_for_user(token5)
|
||||
self.do_sync_for_user(token6)
|
||||
|
||||
# But old user cant
|
||||
with self.assertRaises(SynapseError) as cm:
|
||||
self.do_sync_for_user(token1)
|
||||
|
||||
e = cm.exception
|
||||
self.assertEqual(e.code, 403)
|
||||
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||
|
||||
def create_user(self, localpart):
|
||||
request_data = json.dumps({
|
||||
"username": localpart,
|
||||
"password": "monkey",
|
||||
"auth": {"type": LoginType.DUMMY},
|
||||
})
|
||||
|
||||
request, channel = make_request(b"POST", b"/register", request_data)
|
||||
render(request, self.resource, self.reactor)
|
||||
|
||||
if channel.result["code"] != b"200":
|
||||
raise HttpResponseException(
|
||||
int(channel.result["code"]),
|
||||
channel.result["reason"],
|
||||
channel.result["body"],
|
||||
).to_synapse_error()
|
||||
|
||||
access_token = channel.json_body["access_token"]
|
||||
|
||||
return access_token
|
||||
|
||||
def do_sync_for_user(self, token):
|
||||
request, channel = make_request(b"GET", b"/sync", access_token=token)
|
||||
render(request, self.resource, self.reactor)
|
||||
|
||||
if channel.result["code"] != b"200":
|
||||
raise HttpResponseException(
|
||||
int(channel.result["code"]),
|
||||
channel.result["reason"],
|
||||
channel.result["body"],
|
||||
).to_synapse_error()
|
||||
@@ -18,7 +18,7 @@ from mock import Mock
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.auth import Auth
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.constants import EventTypes, Membership, RoomVersions
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.state import StateHandler, StateResolutionHandler
|
||||
|
||||
@@ -117,6 +117,9 @@ class StateGroupStore(object):
|
||||
def register_event_id_state_group(self, event_id, state_group):
|
||||
self._event_to_state_group[event_id] = state_group
|
||||
|
||||
def get_room_version(self, room_id):
|
||||
return RoomVersions.V1
|
||||
|
||||
|
||||
class DictObj(dict):
|
||||
def __init__(self, **kwargs):
|
||||
@@ -176,7 +179,9 @@ class StateTestCase(unittest.TestCase):
|
||||
def test_branch_no_conflict(self):
|
||||
graph = Graph(
|
||||
nodes={
|
||||
"START": DictObj(type=EventTypes.Create, state_key="", depth=1),
|
||||
"START": DictObj(
|
||||
type=EventTypes.Create, state_key="", content={}, depth=1,
|
||||
),
|
||||
"A": DictObj(type=EventTypes.Message, depth=2),
|
||||
"B": DictObj(type=EventTypes.Message, depth=3),
|
||||
"C": DictObj(type=EventTypes.Name, state_key="", depth=3),
|
||||
|
||||
@@ -21,7 +21,7 @@ from synapse.events import FrozenEvent
|
||||
from synapse.visibility import filter_events_for_server
|
||||
|
||||
import tests.unittest
|
||||
from tests.utils import setup_test_homeserver
|
||||
from tests.utils import create_room, setup_test_homeserver
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -36,6 +36,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
|
||||
self.event_builder_factory = self.hs.get_event_builder_factory()
|
||||
self.store = self.hs.get_datastore()
|
||||
|
||||
yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_filtering(self):
|
||||
#
|
||||
|
||||
@@ -24,6 +24,7 @@ from six.moves.urllib import parse as urlparse
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import CodeMessageException, cs_error
|
||||
from synapse.federation.transport import server
|
||||
from synapse.http.server import HttpServer
|
||||
@@ -131,7 +132,6 @@ def setup_test_homeserver(
|
||||
config.federation_rc_concurrent = 10
|
||||
config.filter_timeline_limit = 5000
|
||||
config.user_directory_search_all_users = False
|
||||
config.replicate_user_profiles_to = []
|
||||
config.user_consent_server_notice_content = None
|
||||
config.block_events_without_consent_error = None
|
||||
config.media_storage_providers = []
|
||||
@@ -540,3 +540,32 @@ class DeferredMockCallable(object):
|
||||
"Expected not to received any calls, got:\n"
|
||||
+ "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls])
|
||||
)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_room(hs, room_id, creator_id):
|
||||
"""Creates and persist a creation event for the given room
|
||||
|
||||
Args:
|
||||
hs
|
||||
room_id (str)
|
||||
creator_id (str)
|
||||
"""
|
||||
|
||||
store = hs.get_datastore()
|
||||
event_builder_factory = hs.get_event_builder_factory()
|
||||
event_creation_handler = hs.get_event_creation_handler()
|
||||
|
||||
builder = event_builder_factory.new({
|
||||
"type": EventTypes.Create,
|
||||
"state_key": "",
|
||||
"sender": creator_id,
|
||||
"room_id": room_id,
|
||||
"content": {},
|
||||
})
|
||||
|
||||
event, context = yield event_creation_handler.create_new_client_event(
|
||||
builder
|
||||
)
|
||||
|
||||
yield store.persist_event(event, context)
|
||||
|
||||
Reference in New Issue
Block a user