1
0

Compare commits

..

1 Commits

Author SHA1 Message Date
Erik Johnston ce7051df61 When daemonizing, restart synapse process if it dies 2015-09-29 11:30:19 +01:00
130 changed files with 1946 additions and 6532 deletions
+1 -4
View File
@@ -44,7 +44,4 @@ Eric Myhre <hash at exultant.us>
repository API.
Muthu Subramanian <muthu.subramanian.karunanidhi at ericsson.com>
* Add SAML2 support for registration and login.
Steven Hammerton <steven.hammerton at openmarket.com>
* Add CAS support for registration and login.
* Add SAML2 support for registration and logins.
-46
View File
@@ -1,49 +1,3 @@
Changes in synapse v0.11.0 (2015-11-17)
=======================================
* Change CAS login API (PR #349)
Changes in synapse v0.11.0-rc2 (2015-11-13)
===========================================
* Various changes to /sync API response format (PR #373)
* Fix regression when setting display name in newly joined room over
federation (PR #368)
* Fix problem where /search was slow when using SQLite (PR #366)
Changes in synapse v0.11.0-rc1 (2015-11-11)
===========================================
* Add Search API (PR #307, #324, #327, #336, #350, #359)
* Add 'archived' state to v2 /sync API (PR #316)
* Add ability to reject invites (PR #317)
* Add config option to disable password login (PR #322)
* Add the login fallback API (PR #330)
* Add room context API (PR #334)
* Add room tagging support (PR #335)
* Update v2 /sync API to match spec (PR #305, #316, #321, #332, #337, #341)
* Change retry schedule for application services (PR #320)
* Change retry schedule for remote servers (PR #340)
* Fix bug where we hosted static content in the incorrect place (PR #329)
* Fix bug where we didn't increment retry interval for remote servers (PR #343)
Changes in synapse v0.10.1-rc1 (2015-10-15)
===========================================
* Add support for CAS, thanks to Steven Hammerton (PR #295, #296)
* Add support for using macaroons for ``access_token`` (PR #256, #229)
* Add support for ``m.room.canonical_alias`` (PR #287)
* Add support for viewing the history of rooms that they have left. (PR #276,
#294)
* Add support for refresh tokens (PR #240)
* Add flag on creation which disables federation of the room (PR #279)
* Add some room state to invites. (PR #275)
* Atomically persist events when joining a room over federation (PR #283)
* Change default history visibility for private rooms (PR #271)
* Allow users to redact their own sent events (PR #262)
* Use tox for tests (PR #247)
* Split up syutil into separate libraries (PR #243)
Changes in synapse v0.10.0-r2 (2015-09-16)
==========================================
+3 -6
View File
@@ -15,11 +15,8 @@ recursive-include scripts *
recursive-include scripts-dev *
recursive-include tests *.py
recursive-include synapse/static *.css
recursive-include synapse/static *.gif
recursive-include synapse/static *.html
recursive-include synapse/static *.js
exclude jenkins.sh
recursive-include static *.css
recursive-include static *.html
recursive-include static *.js
prune demo/etc
+7 -7
View File
@@ -20,8 +20,8 @@ The overall architecture is::
https://somewhere.org/_matrix https://elsewhere.net/_matrix
``#matrix:matrix.org`` is the official support room for Matrix, and can be
accessed by any client from https://matrix.org/blog/try-matrix-now or via IRC
bridge at irc://irc.freenode.net/matrix.
accessed by the web client at http://matrix.org/beta or via an IRC bridge at
irc://irc.freenode.net/matrix.
Synapse is currently in rapid development, but as of version 0.5 we believe it
is sufficiently stable to be run as an internet-facing service for real usage!
@@ -77,14 +77,14 @@ Meanwhile, iOS and Android SDKs and clients are available from:
- https://github.com/matrix-org/matrix-android-sdk
We'd like to invite you to join #matrix:matrix.org (via
https://matrix.org/blog/try-matrix-now), run a homeserver, take a look at the
Matrix spec at https://matrix.org/docs/spec and API docs at
https://matrix.org/docs/api, experiment with the APIs and the demo clients, and
report any bugs via https://matrix.org/jira.
https://matrix.org/beta), run a homeserver, take a look at the Matrix spec at
https://matrix.org/docs/spec and API docs at https://matrix.org/docs/api,
experiment with the APIs and the demo clients, and report any bugs via
https://matrix.org/jira.
Thanks for using Matrix!
[1] End-to-end encryption is currently in development - see https://matrix.org/git/olm
[1] End-to-end encryption is currently in development
Synapse Installation
====================
-7
View File
@@ -38,13 +38,6 @@ for port in 8080 8081 8082; do
perl -p -i -e 's/^enable_registration:.*/enable_registration: true/g' $DIR/etc/$port.config
if ! grep -F "full_twisted_stacktraces" -q $DIR/etc/$port.config; then
echo "full_twisted_stacktraces: true" >> $DIR/etc/$port.config
fi
if ! grep -F "report_stats" -q $DIR/etc/$port.config ; then
echo "report_stats: false" >> $DIR/etc/$port.config
fi
python -m synapse.app.homeserver \
--config-path "$DIR/etc/$port.config" \
-D \
-39
View File
@@ -1,39 +0,0 @@
#!/bin/bash -eu
export PYTHONDONTWRITEBYTECODE=yep
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Output coverage to coverage.xml
export DUMP_COVERAGE_COMMAND="coverage xml -o coverage.xml"
# Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
tox
: ${GIT_BRANCH:="$(git rev-parse --abbrev-ref HEAD)"}
set +u
. .tox/py27/bin/activate
set -u
rm -rf sytest
git clone https://github.com/matrix-org/sytest.git sytest
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PERL5LIB:=$WORKSPACE/perl5/lib/perl5}
: ${PERL_MB_OPT:=--install_base=$WORKSPACE/perl5}
: ${PERL_MM_OPT:=INSTALL_BASE=$WORKSPACE/perl5}
export PERL5LIB PERL_MB_OPT PERL_MM_OPT
./install-deps.pl
./run-tests.pl -O tap --synapse-directory .. --all > results.tap
+1 -1
View File
@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
__version__ = "0.11.0"
__version__ = "0.10.0-r2"
+52 -193
View File
@@ -14,17 +14,13 @@
# limitations under the License.
"""This module contains classes for authenticating the user."""
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import RoomID, UserID, EventID
from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.util.logutils import log_function
from unpaddedbase64 import decode_base64
from synapse.types import UserID, EventID
import logging
import pymacaroons
@@ -35,7 +31,6 @@ logger = logging.getLogger(__name__)
AuthEventTypes = (
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
EventTypes.JoinRules, EventTypes.RoomHistoryVisibility,
EventTypes.ThirdPartyInvite,
)
@@ -48,7 +43,6 @@ class Auth(object):
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
self._KNOWN_CAVEAT_PREFIXES = set([
"gen = ",
"guest = ",
"type = ",
"time < ",
"user_id = ",
@@ -65,8 +59,6 @@ class Auth(object):
Returns:
True if the auth checks pass.
"""
self.check_size_limits(event)
try:
if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event)
@@ -88,15 +80,6 @@ class Auth(object):
"Room %r does not exist" % (event.room_id,)
)
creating_domain = RoomID.from_string(event.room_id).domain
originating_domain = UserID.from_string(event.sender).domain
if creating_domain != originating_domain:
if not self.can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
# FIXME: Temp hack
if event.type == EventTypes.Aliases:
return True
@@ -134,23 +117,6 @@ class Auth(object):
logger.info("Denying! %s", event)
raise
def check_size_limits(self, event):
def too_big(field):
raise EventSizeError("%s too large" % (field,))
if len(event.user_id) > 255:
too_big("user_id")
if len(event.room_id) > 255:
too_big("room_id")
if event.is_state() and len(event.state_key) > 255:
too_big("state_key")
if len(event.type) > 255:
too_big("type")
if len(event.event_id) > 255:
too_big("event_id")
if len(encode_canonical_json(event.get_pdu_json())) > 65536:
too_big("event")
@defer.inlineCallbacks
def check_joined_room(self, room_id, user_id, current_state=None):
"""Check if the user is currently joined in the room
@@ -183,11 +149,15 @@ class Auth(object):
defer.returnValue(member)
@defer.inlineCallbacks
def check_user_was_in_room(self, room_id, user_id):
def check_user_was_in_room(self, room_id, user_id, current_state=None):
"""Check if the user was in the room at some point.
Args:
room_id(str): The room to check.
user_id(str): The user to check.
current_state(dict): Optional map of the current state of the room.
If provided then that map is used to check whether they are a
member of the room. Otherwise the current membership is
loaded from the database.
Raises:
AuthError if the user was never in the room.
Returns:
@@ -195,11 +165,17 @@ class Auth(object):
room. This will be the join event if they are currently joined to
the room. This will be the leave event if they have left the room.
"""
member = yield self.state.get_current_state(
room_id=room_id,
event_type=EventTypes.Member,
state_key=user_id
)
if current_state:
member = current_state.get(
(EventTypes.Member, user_id),
None
)
else:
member = yield self.state.get_current_state(
room_id=room_id,
event_type=EventTypes.Member,
state_key=user_id
)
membership = member.membership if member else None
if membership not in (Membership.JOIN, Membership.LEAVE):
@@ -243,11 +219,6 @@ class Auth(object):
user_id, room_id, repr(member)
))
def can_federate(self, event, auth_events):
creation_event = auth_events.get((EventTypes.Create, ""))
return creation_event.content.get("m.federate", True) is True
@log_function
def is_membership_change_allowed(self, event, auth_events):
membership = event.content["membership"]
@@ -263,15 +234,6 @@ class Auth(object):
target_user_id = event.state_key
creating_domain = RoomID.from_string(event.room_id).domain
target_domain = UserID.from_string(target_user_id).domain
if creating_domain != target_domain:
if not self.can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
# get info about the caller
key = (EventTypes.Member, event.user_id, )
caller = auth_events.get(key)
@@ -317,17 +279,8 @@ class Auth(object):
}
)
if Membership.INVITE == membership and "third_party_invite" in event.content:
if not self._verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.")
return True
if Membership.JOIN != membership:
if (caller_invited
and Membership.LEAVE == membership
and target_user_id == event.user_id):
return True
# JOIN is the only action you can perform if you're not in the room
if not caller_in_room: # caller isn't joined
raise AuthError(
403,
@@ -391,66 +344,6 @@ class Auth(object):
return True
def _verify_third_party_invite(self, event, auth_events):
"""
Validates that the invite event is authorized by a previous third-party invite.
Checks that the public key, and keyserver, match those in the third party invite,
and that the invite event has a signature issued using that public key.
Args:
event: The m.room.member join event being validated.
auth_events: All relevant previous context events which may be used
for authorization decisions.
Return:
True if the event fulfills the expectations of a previous third party
invite event.
"""
if "third_party_invite" not in event.content:
return False
if "signed" not in event.content["third_party_invite"]:
return False
signed = event.content["third_party_invite"]["signed"]
for key in {"mxid", "token"}:
if key not in signed:
return False
token = signed["token"]
invite_event = auth_events.get(
(EventTypes.ThirdPartyInvite, token,)
)
if not invite_event:
return False
if event.user_id != invite_event.user_id:
return False
try:
public_key = invite_event.content["public_key"]
if signed["mxid"] != event.state_key:
return False
if signed["token"] != token:
return False
for server, signature_block in signed["signatures"].items():
for key_name, encoded_signature in signature_block.items():
if not key_name.startswith("ed25519:"):
return False
verify_key = decode_verify_key_bytes(
key_name,
decode_base64(public_key)
)
verify_signed_json(signed, server, verify_key)
# We got the public key from the invite, so we know that the
# correct server signed the signed bundle.
# The caller is responsible for checking that the signing
# server has not revoked that public key.
return True
return False
except (KeyError, SignatureVerifyException,):
return False
def _get_power_level_event(self, auth_events):
key = (EventTypes.PowerLevels, "", )
return auth_events.get(key)
@@ -489,7 +382,7 @@ class Auth(object):
return default
@defer.inlineCallbacks
def get_user_by_req(self, request, allow_guest=False):
def get_user_by_req(self, request):
""" Get a registered user's ID.
Args:
@@ -527,7 +420,7 @@ class Auth(object):
request.authenticated_entity = user_id
defer.returnValue((UserID.from_string(user_id), "", False))
defer.returnValue((UserID.from_string(user_id), ""))
return
except KeyError:
pass # normal users won't have the user_id query parameter set.
@@ -535,7 +428,6 @@ class Auth(object):
user_info = yield self._get_user_by_access_token(access_token)
user = user_info["user"]
token_id = user_info["token_id"]
is_guest = user_info["is_guest"]
ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders(
@@ -550,14 +442,9 @@ class Auth(object):
user_agent=user_agent
)
if is_guest and not allow_guest:
raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
)
request.authenticated_entity = user.to_string()
defer.returnValue((user, token_id, is_guest,))
defer.returnValue((user, token_id,))
except KeyError:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
@@ -587,76 +474,59 @@ class Auth(object):
def _get_user_from_macaroon(self, macaroon_str):
try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
self.validate_macaroon(
macaroon, "access",
[lambda c: c.startswith("time < ")]
)
self._validate_macaroon(macaroon)
user_prefix = "user_id = "
user = None
guest = False
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
user = UserID.from_string(caveat.caveat_id[len(user_prefix):])
elif caveat.caveat_id == "guest = true":
guest = True
if user is None:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
if guest:
ret = {
"user": user,
"is_guest": True,
"token_id": None,
}
else:
# This codepath exists so that we can actually return a
# token ID, because we use token IDs in place of device
# identifiers throughout the codebase.
# TODO(daniel): Remove this fallback when device IDs are
# properly implemented.
ret = yield self._look_up_user_by_access_token(macaroon_str)
if ret["user"] != user:
logger.error(
"Macaroon user (%s) != DB user (%s)",
user,
ret["user"]
)
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"User mismatch in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
defer.returnValue(ret)
# This codepath exists so that we can actually return a
# token ID, because we use token IDs in place of device
# identifiers throughout the codebase.
# TODO(daniel): Remove this fallback when device IDs are
# properly implemented.
ret = yield self._look_up_user_by_access_token(macaroon_str)
if ret["user"] != user:
logger.error(
"Macaroon user (%s) != DB user (%s)",
user,
ret["user"]
)
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"User mismatch in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
defer.returnValue(ret)
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
errcode=Codes.UNKNOWN_TOKEN
)
def validate_macaroon(self, macaroon, type_string, additional_validation_functions):
def _validate_macaroon(self, macaroon):
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = " + type_string)
v.satisfy_exact("type = access")
v.satisfy_general(lambda c: c.startswith("user_id = "))
v.satisfy_exact("guest = true")
for validation_function in additional_validation_functions:
v.satisfy_general(validation_function)
v.satisfy_general(self._verify_expiry)
v.verify(macaroon, self.hs.config.macaroon_secret_key)
v = pymacaroons.Verifier()
v.satisfy_general(self._verify_recognizes_caveats)
v.verify(macaroon, self.hs.config.macaroon_secret_key)
def verify_expiry(self, caveat):
def _verify_expiry(self, caveat):
prefix = "time < "
if not caveat.startswith(prefix):
return False
# TODO(daniel): Enable expiry check when clients actually know how to
# refresh tokens. (And remember to enable the tests)
return True
expiry = int(caveat[len(prefix):])
now = self.hs.get_clock().time_msec()
return now < expiry
@@ -681,7 +551,6 @@ class Auth(object):
user_info = {
"user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
"is_guest": False,
}
defer.returnValue(user_info)
@@ -757,16 +626,6 @@ class Auth(object):
else:
if member_event:
auth_ids.append(member_event.event_id)
if e_type == Membership.INVITE:
if "third_party_invite" in event.content:
key = (
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["token"]
)
third_party_invite = current_state.get(key)
if third_party_invite:
auth_ids.append(third_party_invite.event_id)
elif member_event:
if member_event.content["membership"] == Membership.JOIN:
auth_ids.append(member_event.event_id)
-3
View File
@@ -63,12 +63,10 @@ class EventTypes(object):
PowerLevels = "m.room.power_levels"
Aliases = "m.room.aliases"
Redaction = "m.room.redaction"
ThirdPartyInvite = "m.room.third_party_invite"
RoomHistoryVisibility = "m.room.history_visibility"
CanonicalAlias = "m.room.canonical_alias"
RoomAvatar = "m.room.avatar"
GuestAccess = "m.room.guest_access"
# These are used for validation
Message = "m.room.message"
@@ -85,4 +83,3 @@ class RejectedReason(object):
class RoomCreationPreset(object):
PRIVATE_CHAT = "private_chat"
PUBLIC_CHAT = "public_chat"
TRUSTED_PRIVATE_CHAT = "trusted_private_chat"
+1 -10
View File
@@ -33,7 +33,6 @@ class Codes(object):
NOT_FOUND = "M_NOT_FOUND"
MISSING_TOKEN = "M_MISSING_TOKEN"
UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN"
GUEST_ACCESS_FORBIDDEN = "M_GUEST_ACCESS_FORBIDDEN"
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
@@ -48,6 +47,7 @@ class CodeMessageException(RuntimeError):
"""An exception with integer code and message string attributes."""
def __init__(self, code, msg):
logger.info("%s: %s, %s", type(self).__name__, code, msg)
super(CodeMessageException, self).__init__("%d: %s" % (code, msg))
self.code = code
self.msg = msg
@@ -120,15 +120,6 @@ class AuthError(SynapseError):
super(AuthError, self).__init__(*args, **kwargs)
class EventSizeError(SynapseError):
"""An error raised when an event is too big."""
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.TOO_LARGE
super(EventSizeError, self).__init__(413, *args, **kwargs)
class EventStreamError(SynapseError):
"""An error raised when there a problem with the event stream."""
def __init__(self, *args, **kwargs):
+98 -116
View File
@@ -24,7 +24,7 @@ class Filtering(object):
def get_user_filter(self, user_localpart, filter_id):
result = self.store.get_user_filter(user_localpart, filter_id)
result.addCallback(FilterCollection)
result.addCallback(Filter)
return result
def add_user_filter(self, user_localpart, user_filter):
@@ -50,11 +50,11 @@ class Filtering(object):
# many definitions.
top_level_definitions = [
"presence"
"public_user_data", "private_user_data", "server_data"
]
room_level_definitions = [
"state", "timeline", "ephemeral", "private_user_data"
"state", "events", "ephemeral"
]
for key in top_level_definitions:
@@ -114,134 +114,116 @@ class Filtering(object):
if not isinstance(event_type, basestring):
raise SynapseError(400, "Event type should be a string")
if "format" in definition:
event_format = definition["format"]
if event_format not in ["federation", "events"]:
raise SynapseError(400, "Invalid format: %s" % (event_format,))
class FilterCollection(object):
def __init__(self, filter_json):
self.filter_json = filter_json
if "select" in definition:
event_select_list = definition["select"]
for select_key in event_select_list:
if select_key not in ["event_id", "origin_server_ts",
"thread_id", "content", "content.body"]:
raise SynapseError(400, "Bad select: %s" % (select_key,))
self.room_timeline_filter = Filter(
self.filter_json.get("room", {}).get("timeline", {})
)
self.room_state_filter = Filter(
self.filter_json.get("room", {}).get("state", {})
)
self.room_ephemeral_filter = Filter(
self.filter_json.get("room", {}).get("ephemeral", {})
)
self.room_private_user_data = Filter(
self.filter_json.get("room", {}).get("private_user_data", {})
)
self.presence_filter = Filter(
self.filter_json.get("presence", {})
)
def timeline_limit(self):
return self.room_timeline_filter.limit()
def presence_limit(self):
return self.presence_filter.limit()
def ephemeral_limit(self):
return self.room_ephemeral_filter.limit()
def filter_presence(self, events):
return self.presence_filter.filter(events)
def filter_room_state(self, events):
return self.room_state_filter.filter(events)
def filter_room_timeline(self, events):
return self.room_timeline_filter.filter(events)
def filter_room_ephemeral(self, events):
return self.room_ephemeral_filter.filter(events)
def filter_room_private_user_data(self, events):
return self.room_private_user_data.filter(events)
if ("bundle_updates" in definition and
type(definition["bundle_updates"]) != bool):
raise SynapseError(400, "Bad bundle_updates: expected bool.")
class Filter(object):
def __init__(self, filter_json):
self.filter_json = filter_json
def check(self, event):
"""Checks whether the filter matches the given event.
def filter_public_user_data(self, events):
return self._filter_on_key(events, ["public_user_data"])
def filter_private_user_data(self, events):
return self._filter_on_key(events, ["private_user_data"])
def filter_room_state(self, events):
return self._filter_on_key(events, ["room", "state"])
def filter_room_events(self, events):
return self._filter_on_key(events, ["room", "events"])
def filter_room_ephemeral(self, events):
return self._filter_on_key(events, ["room", "ephemeral"])
def _filter_on_key(self, events, keys):
filter_json = self.filter_json
if not filter_json:
return events
try:
# extract the right definition from the filter
definition = filter_json
for key in keys:
definition = definition[key]
return self._filter_with_definition(events, definition)
except KeyError:
# return all events if definition isn't specified.
return events
def _filter_with_definition(self, events, definition):
return [e for e in events if self._passes_definition(definition, e)]
def _passes_definition(self, definition, event):
"""Check if the event passes through the given definition.
Args:
definition(dict): The definition to check against.
event(Event): The event to check.
Returns:
bool: True if the event matches
True if the event passes through the filter.
"""
if isinstance(event, dict):
return self.check_fields(
event.get("room_id", None),
event.get("sender", None),
event.get("type", None),
)
else:
return self.check_fields(
getattr(event, "room_id", None),
getattr(event, "sender", None),
event.type,
)
# Algorithm notes:
# For each key in the definition, check the event meets the criteria:
# * For types: Literal match or prefix match (if ends with wildcard)
# * For senders/rooms: Literal match only
# * "not_" checks take presedence (e.g. if "m.*" is in both 'types'
# and 'not_types' then it is treated as only being in 'not_types')
def check_fields(self, room_id, sender, event_type):
"""Checks whether the filter matches the given event fields.
Returns:
bool: True if the event fields match
"""
literal_keys = {
"rooms": lambda v: room_id == v,
"senders": lambda v: sender == v,
"types": lambda v: _matches_wildcard(event_type, v)
}
for name, match_func in literal_keys.items():
not_name = "not_%s" % (name,)
disallowed_values = self.filter_json.get(not_name, [])
if any(map(match_func, disallowed_values)):
# room checks
if hasattr(event, "room_id"):
room_id = event.room_id
allow_rooms = definition.get("rooms", None)
reject_rooms = definition.get("not_rooms", None)
if reject_rooms and room_id in reject_rooms:
return False
if allow_rooms and room_id not in allow_rooms:
return False
allowed_values = self.filter_json.get(name, None)
if allowed_values is not None:
if not any(map(match_func, allowed_values)):
# sender checks
if hasattr(event, "sender"):
# Should we be including event.state_key for some event types?
sender = event.sender
allow_senders = definition.get("senders", None)
reject_senders = definition.get("not_senders", None)
if reject_senders and sender in reject_senders:
return False
if allow_senders and sender not in allow_senders:
return False
# type checks
if "not_types" in definition:
for def_type in definition["not_types"]:
if self._event_matches_type(event, def_type):
return False
if "types" in definition:
included = False
for def_type in definition["types"]:
if self._event_matches_type(event, def_type):
included = True
break
if not included:
return False
return True
def filter_rooms(self, room_ids):
"""Apply the 'rooms' filter to a given list of rooms.
Args:
room_ids (list): A list of room_ids.
Returns:
list: A list of room_ids that match the filter
"""
room_ids = set(room_ids)
disallowed_rooms = set(self.filter_json.get("not_rooms", []))
room_ids -= disallowed_rooms
allowed_rooms = self.filter_json.get("rooms", None)
if allowed_rooms is not None:
room_ids &= set(allowed_rooms)
return room_ids
def filter(self, events):
return filter(self.check, events)
def limit(self):
return self.filter_json.get("limit", 10)
def _matches_wildcard(actual_value, filter_value):
if filter_value.endswith("*"):
type_prefix = filter_value[:-1]
return actual_value.startswith(type_prefix)
else:
return actual_value == filter_value
def _event_matches_type(self, event, def_type):
if def_type.endswith("*"):
type_prefix = def_type[:-1]
return event.type.startswith(type_prefix)
else:
return event.type == def_type
+98 -55
View File
@@ -33,9 +33,11 @@ if __name__ == '__main__':
sys.stderr.writelines(message)
sys.exit(1)
from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
from synapse.storage import are_all_users_on_domain
from synapse.storage.prepare_database import UpgradeDatabaseException
from synapse.storage import (
are_all_users_on_domain, UpgradeDatabaseException,
)
from synapse.server import HomeServer
@@ -69,6 +71,8 @@ from synapse import events
from daemonize import Daemonize
import twisted.manhole.telnet
from multiprocessing import Process
import synapse
import contextlib
@@ -76,6 +80,7 @@ import logging
import os
import re
import resource
import signal
import subprocess
import time
@@ -132,9 +137,7 @@ class SynapseHomeServer(HomeServer):
def build_resource_for_static_content(self):
# This is old and should go away: not going to bother adding gzip
return File(
os.path.join(os.path.dirname(synapse.__file__), "static")
)
return File("static")
def build_resource_for_content_repo(self):
return ContentRepoResource(
@@ -368,15 +371,16 @@ def change_resource_limit(soft_file_no):
logger.warn("Failed to set file limit: %s", e)
def setup(config_options):
def load_config(config_options):
"""
Args:
config_options_options: The options passed to Synapse. Usually
`sys.argv[1:]`.
Returns:
HomeServer
HomeServerConfig
"""
config = HomeServerConfig.load_config(
"Synapse Homeserver",
config_options,
@@ -385,9 +389,17 @@ def setup(config_options):
config.setup_logging()
# check any extra requirements we have now we have a config
check_requirements(config)
return config
def setup(config):
"""
Args:
config (Homeserver)
Returns:
HomeServer
"""
version_string = get_version_string()
logger.info("Server hostname: %s", config.server_name)
@@ -439,9 +451,44 @@ def setup(config_options):
hs.get_pusherpool().start()
hs.get_state_handler().start_caching()
hs.get_datastore().start_profiling()
hs.get_datastore().start_doing_background_updates()
hs.get_replication_layer().start_get_pdu_cache()
start_time = time.time()
@defer.inlineCallbacks
def phone_stats_home():
now = int(time.time())
uptime = int(now - start_time)
if uptime < 0:
uptime = 0
stats = {}
stats["homeserver"] = config.server_name
stats["timestamp"] = now
stats["uptime_seconds"] = uptime
stats["total_users"] = yield hs.get_datastore().count_all_users()
all_rooms = yield hs.get_datastore().get_rooms(False)
stats["total_room_count"] = len(all_rooms)
stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
daily_messages = yield hs.get_datastore().count_daily_messages()
if daily_messages is not None:
stats["daily_messages"] = daily_messages
logger.info("Reporting stats to matrix.org: %s" % (stats,))
try:
yield hs.get_simple_http_client().put_json(
"https://matrix.org/report-usage-stats/push",
stats
)
except Exception as e:
logger.warn("Error reporting stats: %s", e)
if hs.config.report_stats:
phone_home_task = task.LoopingCall(phone_stats_home)
phone_home_task.start(60 * 60 * 24, now=False)
return hs
@@ -650,7 +697,7 @@ def _resource_id(resource, path_seg):
return "%s-%s" % (resource, path_seg)
def run(hs):
def run(config):
PROFILE_SYNAPSE = False
if PROFILE_SYNAPSE:
def profile(func):
@@ -664,7 +711,7 @@ def run(hs):
profile.disable()
ident = current_thread().ident
profile.dump_stats("/tmp/%s.%s.%i.pstat" % (
hs.hostname, func.__name__, ident
config.server_name, func.__name__, ident
))
return profiled
@@ -673,56 +720,52 @@ def run(hs):
ThreadPool._worker = profile(ThreadPool._worker)
reactor.run = profile(reactor.run)
start_time = hs.get_clock().time()
@defer.inlineCallbacks
def phone_stats_home():
now = int(hs.get_clock().time())
uptime = int(now - start_time)
if uptime < 0:
uptime = 0
stats = {}
stats["homeserver"] = hs.config.server_name
stats["timestamp"] = now
stats["uptime_seconds"] = uptime
stats["total_users"] = yield hs.get_datastore().count_all_users()
all_rooms = yield hs.get_datastore().get_rooms(False)
stats["total_room_count"] = len(all_rooms)
stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
daily_messages = yield hs.get_datastore().count_daily_messages()
if daily_messages is not None:
stats["daily_messages"] = daily_messages
logger.info("Reporting stats to matrix.org: %s" % (stats,))
try:
yield hs.get_simple_http_client().put_json(
"https://matrix.org/report-usage-stats/push",
stats
)
except Exception as e:
logger.warn("Error reporting stats: %s", e)
if hs.config.report_stats:
phone_home_task = task.LoopingCall(phone_stats_home)
phone_home_task.start(60 * 60 * 24, now=False)
def in_thread():
hs = setup(config)
with LoggingContext("run"):
change_resource_limit(hs.config.soft_file_limit)
reactor.run()
if hs.config.daemonize:
def start_in_process_checker():
p = None
should_restart = [True]
if hs.config.print_pidfile:
print hs.config.pid_file
def proxy_signal(signum, stack):
logger.info("Got signal: %r", signum)
if p is not None:
os.kill(p.pid, signum)
if signum == signal.SIGTERM:
should_restart[0] = False
if getattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, proxy_signal)
signal.signal(signal.SIGTERM, proxy_signal)
last_start = 0
next_delay = 1
while should_restart[0]:
last_start = time.time()
p = Process(target=in_thread, args=())
p.start()
p.join()
if time.time() - last_start < 120:
next_delay = min(next_delay * 5, 5 * 60)
else:
next_delay = 1
time.sleep(next_delay)
if config.daemonize:
if config.print_pidfile:
print config.pid_file
daemon = Daemonize(
app="synapse-homeserver",
pid=hs.config.pid_file,
action=lambda: in_thread(),
pid=config.pid_file,
action=lambda: start_in_process_checker(),
auto_close_fds=False,
verbose=True,
logger=logger,
@@ -737,8 +780,8 @@ def main():
with LoggingContext("main"):
# check base requirements
check_requirements()
hs = setup(sys.argv[1:])
run(hs)
config = load_config(sys.argv[1:])
run(config)
if __name__ == '__main__':
+2 -2
View File
@@ -32,9 +32,9 @@ def start(configfile):
print "Starting ...",
args = SYNAPSE
args.extend(["--daemonize", "-c", configfile])
cwd = os.path.dirname(os.path.abspath(__file__))
try:
subprocess.check_call(args)
subprocess.check_call(args, cwd=cwd)
print GREEN + "started" + NORMAL
except subprocess.CalledProcessError as e:
print (
+2 -2
View File
@@ -224,8 +224,8 @@ class _Recoverer(object):
self.clock.call_later((2 ** self.backoff_counter), self.retry)
def _backoff(self):
# cap the backoff to be around 8.5min => (2^9) = 512 secs
if self.backoff_counter < 9:
# cap the backoff to be around 18h => (2^16) = 65536 secs
if self.backoff_counter < 16:
self.backoff_counter += 1
self.recover()
+1 -5
View File
@@ -14,7 +14,6 @@
# limitations under the License.
import argparse
import errno
import os
import yaml
import sys
@@ -92,11 +91,8 @@ class Config(object):
@classmethod
def ensure_directory(cls, dir_path):
dir_path = cls.abspath(dir_path)
try:
if not os.path.exists(dir_path):
os.makedirs(dir_path)
except OSError, e:
if e.errno != errno.EEXIST:
raise
if not os.path.isdir(dir_path):
raise ConfigError(
"%s is not a directory" % (dir_path,)
-47
View File
@@ -1,47 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class CasConfig(Config):
"""Cas Configuration
cas_server_url: URL of CAS server
"""
def read_config(self, config):
cas_config = config.get("cas_config", None)
if cas_config:
self.cas_enabled = cas_config.get("enabled", True)
self.cas_server_url = cas_config["server_url"]
self.cas_service_url = cas_config["service_url"]
self.cas_required_attributes = cas_config.get("required_attributes", {})
else:
self.cas_enabled = False
self.cas_server_url = None
self.cas_service_url = None
self.cas_required_attributes = {}
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Enable CAS for registration and login.
#cas_config:
# enabled: true
# server_url: "https://cas-server.com"
# service_url: "https://homesever.domain.com:8448"
# #required_attributes:
# # name: value
"""
+1 -4
View File
@@ -26,15 +26,12 @@ from .metrics import MetricsConfig
from .appservice import AppServiceConfig
from .key import KeyConfig
from .saml2 import SAML2Config
from .cas import CasConfig
from .password import PasswordConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
PasswordConfig,):
AppServiceConfig, KeyConfig, SAML2Config, ):
pass
-8
View File
@@ -22,7 +22,6 @@ import yaml
from string import Template
import os
import signal
from synapse.util.debug import debug_deferreds
DEFAULT_LOG_CONFIG = Template("""
@@ -70,8 +69,6 @@ class LoggingConfig(Config):
self.verbosity = config.get("verbose", 0)
self.log_config = self.abspath(config.get("log_config"))
self.log_file = self.abspath(config.get("log_file"))
if config.get("full_twisted_stacktraces"):
debug_deferreds()
def default_config(self, config_dir_path, server_name, **kwargs):
log_file = self.abspath("homeserver.log")
@@ -87,11 +84,6 @@ class LoggingConfig(Config):
# A yaml python logging config file
log_config: "%(log_config)s"
# Stop twisted from discarding the stack traces of exceptions in
# deferreds by waiting a reactor tick before running a deferred's
# callbacks.
# full_twisted_stacktraces: true
""" % locals()
def read_arguments(self, args):
-32
View File
@@ -1,32 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class PasswordConfig(Config):
"""Password login configuration
"""
def read_config(self, config):
password_config = config.get("password_config", {})
self.password_enabled = password_config.get("enabled", True)
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Enable password for login.
password_config:
enabled: true
"""
-12
View File
@@ -33,8 +33,6 @@ class RegistrationConfig(Config):
self.registration_shared_secret = config.get("registration_shared_secret")
self.macaroon_secret_key = config.get("macaroon_secret_key")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.allow_guest_access = config.get("allow_guest_access", False)
def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50)
@@ -50,16 +48,6 @@ class RegistrationConfig(Config):
registration_shared_secret: "%(registration_shared_secret)s"
macaroon_secret_key: "%(macaroon_secret_key)s"
# Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash.
# The default number of rounds is 12.
bcrypt_rounds: 12
# Allows users to register as guests without a password/email/etc, and
# participate in rooms hosted on this server which have been made
# accessible to anonymous users.
allow_guest_access: False
""" % locals()
def add_arguments(self, parser):
+1 -2
View File
@@ -33,7 +33,7 @@ class SAML2Config(Config):
def read_config(self, config):
saml2_config = config.get("saml2_config", None)
if saml2_config:
self.saml2_enabled = saml2_config.get("enabled", True)
self.saml2_enabled = True
self.saml2_config_path = saml2_config["config_path"]
self.saml2_idp_redirect_url = saml2_config["idp_redirect_url"]
else:
@@ -49,7 +49,6 @@ class SAML2Config(Config):
# the user back to /login/saml2 with proper info.
# See pysaml2 docs for format of config.
#saml2_config:
# enabled: true
# config_path: "%s/sp_conf.py"
# idp_redirect_url: "http://%s/idp"
""" % (config_dir_path, server_name)
-1
View File
@@ -26,7 +26,6 @@ class ServerConfig(Config):
self.soft_file_limit = config["soft_file_limit"]
self.daemonize = config.get("daemonize")
self.print_pidfile = config.get("print_pidfile")
self.user_agent_suffix = config.get("user_agent_suffix")
self.use_frozen_dicts = config.get("use_frozen_dicts", True)
self.listeners = config.get("listeners", [])
+3 -6
View File
@@ -66,6 +66,7 @@ def prune_event(event):
"users_default",
"events",
"events_default",
"events_default",
"state_default",
"ban",
"kick",
@@ -102,10 +103,7 @@ def format_event_raw(d):
def format_event_for_client_v1(d):
d["user_id"] = d.pop("sender", None)
move_keys = (
"age", "redacted_because", "replaces_state", "prev_content",
"invite_room_state",
)
move_keys = ("age", "redacted_because", "replaces_state", "prev_content")
for key in move_keys:
if key in d["unsigned"]:
d[key] = d["unsigned"][key]
@@ -154,8 +152,7 @@ def serialize_event(e, time_now_ms, as_client_event=True,
if "redacted_because" in e.unsigned:
d["unsigned"]["redacted_because"] = serialize_event(
e.unsigned["redacted_because"], time_now_ms,
event_format=event_format
e.unsigned["redacted_because"], time_now_ms
)
if token_id is not None:
+6 -93
View File
@@ -17,7 +17,6 @@
from twisted.internet import defer
from .federation_base import FederationBase
from synapse.api.constants import Membership
from .units import Edu
from synapse.api.errors import (
@@ -357,55 +356,19 @@ class FederationClient(FederationBase):
defer.returnValue(signed_auth)
@defer.inlineCallbacks
def make_membership_event(self, destinations, room_id, user_id, membership,
content={},):
"""
Creates an m.room.member event, with context, without participating in the room.
Does so by asking one of the already participating servers to create an
event with proper context.
Note that this does not append any events to any graphs.
Args:
destinations (str): Candidate homeservers which are probably
participating in the room.
room_id (str): The room in which the event will happen.
user_id (str): The user whose membership is being evented.
membership (str): The "membership" property of the event. Must be
one of "join" or "leave".
content (object): Any additional data to put into the content field
of the event.
Return:
A tuple of (origin (str), event (object)) where origin is the remote
homeserver which generated the event.
"""
valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships:
raise RuntimeError(
"make_membership_event called with membership='%s', must be one of %s" %
(membership, ",".join(valid_memberships))
)
def make_join(self, destinations, room_id, user_id):
for destination in destinations:
if destination == self.server_name:
continue
try:
ret = yield self.transport_layer.make_membership_event(
destination, room_id, user_id, membership
ret = yield self.transport_layer.make_join(
destination, room_id, user_id
)
pdu_dict = ret["event"]
logger.debug("Got response to make_%s: %s", membership, pdu_dict)
pdu_dict["content"].update(content)
# The protoevent received over the JSON wire may not have all
# the required fields. Lets just gloss over that because
# there's some we never care about
if "prev_state" not in pdu_dict:
pdu_dict["prev_state"] = []
logger.debug("Got response to make_join: %s", pdu_dict)
defer.returnValue(
(destination, self.event_from_pdu_json(pdu_dict))
@@ -415,8 +378,8 @@ class FederationClient(FederationBase):
raise
except Exception as e:
logger.warn(
"Failed to make_%s via %s: %s",
membership, destination, e.message
"Failed to make_join via %s: %s",
destination, e.message
)
raise RuntimeError("Failed to send to any server.")
@@ -522,33 +485,6 @@ class FederationClient(FederationBase):
defer.returnValue(pdu)
@defer.inlineCallbacks
def send_leave(self, destinations, pdu):
for destination in destinations:
if destination == self.server_name:
continue
try:
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_leave(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
logger.debug("Got content: %s", content)
defer.returnValue(None)
except CodeMessageException:
raise
except Exception as e:
logger.exception(
"Failed to send_leave via %s: %s",
destination, e.message
)
raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks
def query_auth(self, destination, room_id, event_id, local_auth):
"""
@@ -707,26 +643,3 @@ class FederationClient(FederationBase):
event.internal_metadata.outlier = outlier
return event
@defer.inlineCallbacks
def forward_third_party_invite(self, destinations, room_id, event_dict):
for destination in destinations:
if destination == self.server_name:
continue
try:
yield self.transport_layer.exchange_third_party_invite(
destination=destination,
room_id=room_id,
event_dict=event_dict,
)
defer.returnValue(None)
except CodeMessageException:
raise
except Exception as e:
logger.exception(
"Failed to send_third_party_invite via %s: %s",
destination, e.message
)
raise RuntimeError("Failed to send to any server.")
-26
View File
@@ -254,20 +254,6 @@ class FederationServer(FederationBase):
],
}))
@defer.inlineCallbacks
def on_make_leave_request(self, room_id, user_id):
pdu = yield self.handler.on_make_leave_request(room_id, user_id)
time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@defer.inlineCallbacks
def on_send_leave_request(self, origin, content):
logger.debug("on_send_leave_request: content: %s", content)
pdu = self.event_from_pdu_json(content)
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
yield self.handler.on_send_leave_request(origin, pdu)
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id):
time_now = self._clock.time_msec()
@@ -543,15 +529,3 @@ class FederationServer(FederationBase):
event.internal_metadata.outlier = outlier
return event
@defer.inlineCallbacks
def exchange_third_party_invite(self, invite):
ret = yield self.handler.exchange_third_party_invite(invite)
defer.returnValue(ret)
@defer.inlineCallbacks
def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
ret = yield self.handler.on_exchange_third_party_invite_request(
origin, room_id, event_dict
)
defer.returnValue(ret)
+14 -14
View File
@@ -202,7 +202,6 @@ class TransactionQueue(object):
@defer.inlineCallbacks
@log_function
def _attempt_new_transaction(self, destination):
# list of (pending_pdu, deferred, order)
if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending
# request at which point pending_pdus_by_dest just keeps growing.
@@ -214,6 +213,9 @@ class TransactionQueue(object):
)
return
logger.debug("TX [%s] _attempt_new_transaction", destination)
# list of (pending_pdu, deferred, order)
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_failures = self.pending_failures_by_dest.pop(destination, [])
@@ -226,22 +228,20 @@ class TransactionQueue(object):
logger.debug("TX [%s] Nothing to send", destination)
return
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[2])
pdus = [x[0] for x in pending_pdus]
edus = [x[0] for x in pending_edus]
failures = [x[0].get_dict() for x in pending_failures]
deferreds = [
x[1]
for x in pending_pdus + pending_edus + pending_failures
]
try:
self.pending_transactions[destination] = 1
logger.debug("TX [%s] _attempt_new_transaction", destination)
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[2])
pdus = [x[0] for x in pending_pdus]
edus = [x[0] for x in pending_edus]
failures = [x[0].get_dict() for x in pending_failures]
deferreds = [
x[1]
for x in pending_pdus + pending_edus + pending_failures
]
txn_id = str(self._next_txn_id)
limiter = yield get_retry_limiter(
+3 -36
View File
@@ -14,7 +14,6 @@
# limitations under the License.
from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.util.logutils import log_function
@@ -161,19 +160,13 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def make_membership_event(self, destination, room_id, user_id, membership):
valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships:
raise RuntimeError(
"make_membership_event called with membership='%s', must be one of %s" %
(membership, ",".join(valid_memberships))
)
path = PREFIX + "/make_%s/%s/%s" % (membership, room_id, user_id)
def make_join(self, destination, room_id, user_id, retry_on_dns_fail=True):
path = PREFIX + "/make_join/%s/%s" % (room_id, user_id)
content = yield self.client.get_json(
destination=destination,
path=path,
retry_on_dns_fail=True,
retry_on_dns_fail=retry_on_dns_fail,
)
defer.returnValue(content)
@@ -191,19 +184,6 @@ class TransportLayerClient(object):
defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def send_leave(self, destination, room_id, event_id, content):
path = PREFIX + "/send_leave/%s/%s" % (room_id, event_id)
response = yield self.client.put_json(
destination=destination,
path=path,
data=content,
)
defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def send_invite(self, destination, room_id, event_id, content):
@@ -217,19 +197,6 @@ class TransportLayerClient(object):
defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(self, destination, room_id, event_dict):
path = PREFIX + "/exchange_third_party_invite/%s" % (room_id,)
response = yield self.client.put_json(
destination=destination,
path=path,
data=event_dict,
)
defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
-57
View File
@@ -296,24 +296,6 @@ class FederationMakeJoinServlet(BaseFederationServlet):
defer.returnValue((200, content))
class FederationMakeLeaveServlet(BaseFederationServlet):
PATH = "/make_leave/([^/]*)/([^/]*)"
@defer.inlineCallbacks
def on_GET(self, origin, content, query, context, user_id):
content = yield self.handler.on_make_leave_request(context, user_id)
defer.returnValue((200, content))
class FederationSendLeaveServlet(BaseFederationServlet):
PATH = "/send_leave/([^/]*)/([^/]*)"
@defer.inlineCallbacks
def on_PUT(self, origin, content, query, room_id, txid):
content = yield self.handler.on_send_leave_request(origin, content)
defer.returnValue((200, content))
class FederationEventAuthServlet(BaseFederationServlet):
PATH = "/event_auth/([^/]*)/([^/]*)"
@@ -343,17 +325,6 @@ class FederationInviteServlet(BaseFederationServlet):
defer.returnValue((200, content))
class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
PATH = "/exchange_third_party_invite/([^/]*)"
@defer.inlineCallbacks
def on_PUT(self, origin, content, query, room_id):
content = yield self.handler.on_exchange_third_party_invite_request(
origin, room_id, content
)
defer.returnValue((200, content))
class FederationClientKeysQueryServlet(BaseFederationServlet):
PATH = "/user/keys/query"
@@ -407,30 +378,6 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
defer.returnValue((200, content))
class On3pidBindServlet(BaseFederationServlet):
PATH = "/3pid/onbind"
@defer.inlineCallbacks
def on_POST(self, request):
content_bytes = request.content.read()
content = json.loads(content_bytes)
if "invites" in content:
last_exception = None
for invite in content["invites"]:
try:
yield self.handler.exchange_third_party_invite(invite)
except Exception as e:
last_exception = e
if last_exception:
raise last_exception
defer.returnValue((200, {}))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
SERVLET_CLASSES = (
FederationPullServlet,
FederationEventServlet,
@@ -438,16 +385,12 @@ SERVLET_CLASSES = (
FederationBackfillServlet,
FederationQueryServlet,
FederationMakeJoinServlet,
FederationMakeLeaveServlet,
FederationEventServlet,
FederationSendJoinServlet,
FederationSendLeaveServlet,
FederationInviteServlet,
FederationQueryAuthServlet,
FederationGetMissingEventsServlet,
FederationEventAuthServlet,
FederationClientKeysQueryServlet,
FederationClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,
)
+1 -4
View File
@@ -17,7 +17,7 @@ from synapse.appservice.scheduler import AppServiceScheduler
from synapse.appservice.api import ApplicationServiceApi
from .register import RegistrationHandler
from .room import (
RoomCreationHandler, RoomMemberHandler, RoomListHandler, RoomContextHandler,
RoomCreationHandler, RoomMemberHandler, RoomListHandler
)
from .message import MessageHandler
from .events import EventStreamHandler, EventHandler
@@ -32,7 +32,6 @@ from .sync import SyncHandler
from .auth import AuthHandler
from .identity import IdentityHandler
from .receipts import ReceiptsHandler
from .search import SearchHandler
class Handlers(object):
@@ -69,5 +68,3 @@ class Handlers(object):
self.sync_handler = SyncHandler(hs)
self.auth_handler = AuthHandler(hs)
self.identity_handler = IdentityHandler(hs)
self.search_handler = SearchHandler(hs)
self.room_context_handler = RoomContextHandler(hs)
+4 -157
View File
@@ -29,12 +29,6 @@ logger = logging.getLogger(__name__)
class BaseHandler(object):
"""
Common base class for the event handlers.
:type store: synapse.storage.events.StateStore
:type state_handler: synapse.state.StateHandler
"""
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -51,74 +45,6 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks
def _filter_events_for_client(self, user_id, events, is_guest=False,
require_all_visible_for_guests=True):
# Assumes that user has at some point joined the room if not is_guest.
def allowed(event, membership, visibility):
if visibility == "world_readable":
return True
if is_guest:
return False
if membership == Membership.JOIN:
return True
if event.type == EventTypes.RoomHistoryVisibility:
return not is_guest
if visibility == "shared":
return True
elif visibility == "joined":
return membership == Membership.JOIN
elif visibility == "invited":
return membership == Membership.INVITE
return True
event_id_to_state = yield self.store.get_state_for_events(
frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id),
)
)
events_to_return = []
for event in events:
state = event_id_to_state[event.event_id]
membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event:
membership = membership_event.membership
else:
membership = None
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if visibility_event:
visibility = visibility_event.content.get("history_visibility", "shared")
else:
visibility = "shared"
should_include = allowed(event, membership, visibility)
if should_include:
events_to_return.append(event)
if (require_all_visible_for_guests
and is_guest
and len(events_to_return) < len(events)):
# This indicates that some events in the requested range were not
# visible to guest users. To be safe, we reject the entire request,
# so that we don't have to worry about interpreting visibility
# boundaries.
raise AuthError(403, "User %s does not have permission" % (
user_id
))
defer.returnValue(events_to_return)
def ratelimit(self, user_id):
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.send_message(
@@ -181,8 +107,6 @@ class BaseHandler(object):
if not suppress_auth:
self.auth.check(event, auth_events=context.current_state)
yield self.maybe_kick_guest_users(event, context.current_state.values())
if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least)
room_alias_str = event.content.get("alias", None)
@@ -199,39 +123,24 @@ class BaseHandler(object):
)
)
(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
)
federation_handler = self.hs.get_handlers().federation_handler
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
event.unsigned["invite_room_state"] = [
{
"type": e.type,
"state_key": e.state_key,
"content": e.content,
"sender": e.sender,
}
for k, e in context.current_state.items()
if e.type in (
EventTypes.JoinRules,
EventTypes.CanonicalAlias,
EventTypes.RoomAvatar,
EventTypes.Name,
)
]
invitee = UserID.from_string(event.state_key)
if not self.hs.is_mine(invitee):
# TODO: Can we add signature from remote server in a nicer
# way? If we have been invited by a remote server, we need
# to get them to sign the event.
returned_invite = yield federation_handler.send_invite(
invitee.domain,
event,
)
event.unsigned.pop("room_state", None)
# TODO: Make sure the signatures actually are correct.
event.signatures.update(
returned_invite.signatures
@@ -252,10 +161,6 @@ class BaseHandler(object):
"You don't have permission to redact events"
)
(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
)
destinations = set(extra_destinations)
for k, s in context.current_state.items():
try:
@@ -284,64 +189,6 @@ class BaseHandler(object):
notify_d.addErrback(log_failure)
# If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None)
federation_handler.handle_new_event(
event, destinations=destinations,
)
@defer.inlineCallbacks
def maybe_kick_guest_users(self, event, current_state):
# Technically this function invalidates current_state by changing it.
# Hopefully this isn't that important to the caller.
if event.type == EventTypes.GuestAccess:
guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join":
yield self.kick_guest_users(current_state)
@defer.inlineCallbacks
def kick_guest_users(self, current_state):
for member_event in current_state:
try:
if member_event.type != EventTypes.Member:
continue
if not self.hs.is_mine(UserID.from_string(member_event.state_key)):
continue
if member_event.content["membership"] not in {
Membership.JOIN,
Membership.INVITE
}:
continue
if (
"kind" not in member_event.content
or member_event.content["kind"] != "guest"
):
continue
# We make the user choose to leave, rather than have the
# event-sender kick them. This is partially because we don't
# need to worry about power levels, and partially because guest
# users are a concept which doesn't hugely work over federation,
# and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having
# homeserver.
message_handler = self.hs.get_handlers().message_handler
yield message_handler.create_and_send_event(
{
"type": EventTypes.Member,
"state_key": member_event.state_key,
"content": {
"membership": Membership.LEAVE,
"kind": "guest"
},
"room_id": member_event.room_id,
"sender": member_event.state_key
},
ratelimit=False,
)
except Exception as e:
logger.warn("Error kicking guest user: %s" % (e,))
+3 -68
View File
@@ -18,7 +18,7 @@ from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.constants import LoginType
from synapse.types import UserID
from synapse.api.errors import AuthError, LoginError, Codes
from synapse.api.errors import LoginError, Codes
from synapse.util.async import run_on_reactor
from twisted.web.client import PartialDownloadError
@@ -44,9 +44,7 @@ class AuthHandler(BaseHandler):
LoginType.EMAIL_IDENTITY: self._check_email_identity,
LoginType.DUMMY: self._check_dummy_auth,
}
self.bcrypt_rounds = hs.config.bcrypt_rounds
self.sessions = {}
self.INVALID_TOKEN_HTTP_STATUS = 401
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip):
@@ -297,39 +295,6 @@ class AuthHandler(BaseHandler):
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks
def get_login_tuple_for_user_id(self, user_id):
"""
Gets login tuple for the user with the given user ID.
The user is assumed to have been authenticated by some other
machanism (e.g. CAS)
Args:
user_id (str): User ID
Returns:
A tuple of:
The user's ID.
The access token for the user's session.
The refresh token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id)
logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks
def does_user_exist(self, user_id):
try:
yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(True)
except LoginError:
defer.returnValue(False)
@defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id):
"""Checks to see if a user with the given id exists. Will check case
@@ -374,15 +339,12 @@ class AuthHandler(BaseHandler):
yield self.store.add_refresh_token_to_user(user_id, refresh_token)
defer.returnValue(refresh_token)
def generate_access_token(self, user_id, extra_caveats=None):
extra_caveats = extra_caveats or []
def generate_access_token(self, user_id):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access")
now = self.hs.get_clock().time_msec()
expiry = now + (60 * 60 * 1000)
macaroon.add_first_party_caveat("time < %d" % (expiry,))
for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat)
return macaroon.serialize()
def generate_refresh_token(self, user_id):
@@ -395,23 +357,6 @@ class AuthHandler(BaseHandler):
))
return m.serialize()
def generate_short_term_login_token(self, user_id):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
expiry = now + (2 * 60 * 1000)
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def validate_short_term_login_token_and_get_user_id(self, login_token):
try:
macaroon = pymacaroons.Macaroon.deserialize(login_token)
auth_api = self.hs.get_auth()
auth_api.validate_macaroon(macaroon, "login", [auth_api.verify_expiry])
return self._get_user_from_macaroon(macaroon)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)
def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
@@ -421,16 +366,6 @@ class AuthHandler(BaseHandler):
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
def _get_user_from_macaroon(self, macaroon):
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
return caveat.caveat_id[len(user_prefix):]
raise AuthError(
self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token",
errcode=Codes.UNKNOWN_TOKEN
)
@defer.inlineCallbacks
def set_password(self, user_id, newpassword):
password_hash = self.hash(newpassword)
@@ -465,7 +400,7 @@ class AuthHandler(BaseHandler):
Returns:
Hashed password (str).
"""
return bcrypt.hashpw(password, bcrypt.gensalt(self.bcrypt_rounds))
return bcrypt.hashpw(password, bcrypt.gensalt())
def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash.
+49 -61
View File
@@ -46,61 +46,11 @@ class EventStreamHandler(BaseHandler):
self.notifier = hs.get_notifier()
@defer.inlineCallbacks
def started_stream(self, user):
"""Tells the presence handler that we have started an eventstream for
the user:
Args:
user (User): The user who started a stream.
Returns:
A deferred that completes once their presence has been updated.
"""
if user not in self._streams_per_user:
self._streams_per_user[user] = 0
if user in self._stop_timer_per_user:
try:
self.clock.cancel_call_later(
self._stop_timer_per_user.pop(user)
)
except:
logger.exception("Failed to cancel event timer")
else:
yield self.distributor.fire("started_user_eventstream", user)
self._streams_per_user[user] += 1
def stopped_stream(self, user):
"""If there are no streams for a user this starts a timer that will
notify the presence handler that we haven't got an event stream for
the user unless the user starts a new stream in 30 seconds.
Args:
user (User): The user who stopped a stream.
"""
self._streams_per_user[user] -= 1
if not self._streams_per_user[user]:
del self._streams_per_user[user]
# 30 seconds of grace to allow the client to reconnect again
# before we think they're gone
def _later():
logger.debug("_later stopped_user_eventstream %s", user)
self._stop_timer_per_user.pop(user, None)
return self.distributor.fire("stopped_user_eventstream", user)
logger.debug("Scheduling _later: for %s", user)
self._stop_timer_per_user[user] = (
self.clock.call_later(30, _later)
)
@defer.inlineCallbacks
@log_function
def get_stream(self, auth_user_id, pagin_config, timeout=0,
as_client_event=True, affect_presence=True,
only_room_events=False, room_id=None, is_guest=False):
only_room_events=False):
"""Fetches the events stream for a given user.
If `only_room_events` is `True` only room events will be returned.
@@ -109,7 +59,31 @@ class EventStreamHandler(BaseHandler):
try:
if affect_presence:
yield self.started_stream(auth_user)
if auth_user not in self._streams_per_user:
self._streams_per_user[auth_user] = 0
if auth_user in self._stop_timer_per_user:
try:
self.clock.cancel_call_later(
self._stop_timer_per_user.pop(auth_user)
)
except:
logger.exception("Failed to cancel event timer")
else:
yield self.distributor.fire(
"started_user_eventstream", auth_user
)
self._streams_per_user[auth_user] += 1
rm_handler = self.hs.get_handlers().room_member_handler
app_service = yield self.store.get_app_service_by_user_id(
auth_user.to_string()
)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
room_ids = set(r.room_id for r in rooms)
else:
room_ids = yield rm_handler.get_joined_rooms_for_user(auth_user)
if timeout:
# If they've set a timeout set a minimum limit.
@@ -119,15 +93,9 @@ class EventStreamHandler(BaseHandler):
# thundering herds on restart.
timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
if is_guest:
yield self.distributor.fire(
"user_joined_room", user=auth_user, room_id=room_id
)
events, tokens = yield self.notifier.get_events_for(
auth_user, pagin_config, timeout,
only_room_events=only_room_events,
is_guest=is_guest, guest_room_id=room_id
auth_user, room_ids, pagin_config, timeout,
only_room_events=only_room_events
)
time_now = self.clock.time_msec()
@@ -146,7 +114,27 @@ class EventStreamHandler(BaseHandler):
finally:
if affect_presence:
self.stopped_stream(auth_user)
self._streams_per_user[auth_user] -= 1
if not self._streams_per_user[auth_user]:
del self._streams_per_user[auth_user]
# 10 seconds of grace to allow the client to reconnect again
# before we think they're gone
def _later():
logger.debug(
"_later stopped_user_eventstream %s", auth_user
)
self._stop_timer_per_user.pop(auth_user, None)
return self.distributor.fire(
"stopped_user_eventstream", auth_user
)
logger.debug("Scheduling _later: for %s", auth_user)
self._stop_timer_per_user[auth_user] = (
self.clock.call_later(30, _later)
)
class EventHandler(BaseHandler):
+132 -389
View File
@@ -21,7 +21,6 @@ from synapse.api.errors import (
AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
)
from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function
@@ -41,6 +40,7 @@ from twisted.internet import defer
import itertools
import logging
logger = logging.getLogger(__name__)
@@ -58,8 +58,6 @@ class FederationHandler(BaseHandler):
def __init__(self, hs):
super(FederationHandler, self).__init__(hs)
self.hs = hs
self.distributor.observe(
"user_joined_room",
self._on_user_joined
@@ -70,9 +68,12 @@ class FederationHandler(BaseHandler):
self.store = hs.get_datastore()
self.replication_layer = hs.get_replication_layer()
self.state_handler = hs.get_state_handler()
# self.auth_handler = gs.get_auth_handler()
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
self.lock_manager = hs.get_room_lock_manager()
self.replication_layer.set_handler(self)
# When joining a room we need to queue any events for that room up
@@ -124,72 +125,60 @@ class FederationHandler(BaseHandler):
)
if not is_in_room and not event.internal_metadata.is_outlier():
logger.debug("Got event for room we're not in.")
current_state = state
try:
event_stream_id, max_stream_id = yield self._persist_auth_tree(
auth_chain, state, event
)
except AuthError as e:
raise FederationError(
"ERROR",
e.code,
e.msg,
affected=event.event_id,
)
event_ids = set()
if state:
event_ids |= {e.event_id for e in state}
if auth_chain:
event_ids |= {e.event_id for e in auth_chain}
else:
event_ids = set()
if state:
event_ids |= {e.event_id for e in state}
if auth_chain:
event_ids |= {e.event_id for e in auth_chain}
seen_ids = set(
(yield self.store.have_events(event_ids)).keys()
)
seen_ids = set(
(yield self.store.have_events(event_ids)).keys()
if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication
# layer, then we should handle them (if we haven't before.)
event_infos = []
for e in itertools.chain(auth_chain, state):
if e.event_id in seen_ids:
continue
e.internal_metadata.outlier = True
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
event_infos.append({
"event": e,
"auth_events": auth,
})
seen_ids.add(e.event_id)
yield self._handle_new_events(
origin,
event_infos,
outliers=True
)
if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication
# layer, then we should handle them (if we haven't before.)
event_infos = []
for e in itertools.chain(auth_chain, state):
if e.event_id in seen_ids:
continue
e.internal_metadata.outlier = True
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create
}
event_infos.append({
"event": e,
"auth_events": auth,
})
seen_ids.add(e.event_id)
yield self._handle_new_events(
origin,
event_infos,
outliers=True
)
try:
_, event_stream_id, max_stream_id = yield self._handle_new_event(
origin,
event,
state=state,
backfilled=backfilled,
current_state=current_state,
)
except AuthError as e:
raise FederationError(
"ERROR",
e.code,
e.msg,
affected=event.event_id,
)
try:
_, event_stream_id, max_stream_id = yield self._handle_new_event(
origin,
event,
state=state,
backfilled=backfilled,
current_state=current_state,
)
except AuthError as e:
raise FederationError(
"ERROR",
e.code,
e.msg,
affected=event.event_id,
)
# if we're receiving valid events from an origin,
# it's probably a good idea to mark it as not in retry-state
@@ -241,7 +230,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events):
event_to_state = yield self.store.get_state_for_events(
frozenset(e.event_id for e in events),
room_id, frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, None),
@@ -564,7 +553,7 @@ class FederationHandler(BaseHandler):
@log_function
@defer.inlineCallbacks
def do_invite_join(self, target_hosts, room_id, joinee, content):
def do_invite_join(self, target_hosts, room_id, joinee, content, snapshot):
""" Attempts to join the `joinee` to the room `room_id` via the
server `target_host`.
@@ -580,19 +569,49 @@ class FederationHandler(BaseHandler):
yield self.store.clean_room_for_join(room_id)
origin, event = yield self._make_and_verify_event(
origin, pdu = yield self.replication_layer.make_join(
target_hosts,
room_id,
joinee,
"join",
content,
joinee
)
logger.debug("Got response to make_join: %s", pdu)
event = pdu
# We should assert some things.
# FIXME: Do this in a nicer way
assert(event.type == EventTypes.Member)
assert(event.user_id == joinee)
assert(event.state_key == joinee)
assert(event.room_id == room_id)
event.internal_metadata.outlier = False
self.room_queues[room_id] = []
builder = self.event_builder_factory.new(
unfreeze(event.get_pdu_json())
)
handled_events = set()
try:
new_event = self._sign_event(event)
builder.event_id = self.event_builder_factory.create_event_id()
builder.origin = self.hs.hostname
builder.content = content
if not hasattr(event, "signatures"):
builder.signatures = {}
add_hashes_and_signatures(
builder,
self.hs.hostname,
self.hs.config.signing_key[0],
)
new_event = builder.build()
# Try the host we successfully got a response to /make_join/
# request first.
try:
@@ -600,7 +619,11 @@ class FederationHandler(BaseHandler):
target_hosts.insert(0, origin)
except ValueError:
pass
ret = yield self.replication_layer.send_join(target_hosts, new_event)
ret = yield self.replication_layer.send_join(
target_hosts,
new_event
)
origin = ret["origin"]
state = ret["state"]
@@ -626,8 +649,35 @@ class FederationHandler(BaseHandler):
# FIXME
pass
event_stream_id, max_stream_id = yield self._persist_auth_tree(
auth_chain, state, event
ev_infos = []
for e in itertools.chain(state, auth_chain):
if e.event_id == event.event_id:
continue
e.internal_metadata.outlier = True
auth_ids = [e_id for e_id, _ in e.auth_events]
ev_infos.append({
"event": e,
"auth_events": {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
})
yield self._handle_new_events(origin, ev_infos, outliers=True)
auth_ids = [e_id for e_id, _ in event.auth_events]
auth_events = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
_, event_stream_id, max_stream_id = yield self._handle_new_event(
origin,
new_event,
state=state,
current_state=state,
auth_events=auth_events,
)
with PreserveLoggingContext():
@@ -664,14 +714,12 @@ class FederationHandler(BaseHandler):
@log_function
def on_make_join_request(self, room_id, user_id):
""" We've received a /make_join/ request, so we create a partial
join event for the room and return that. We do *not* persist or
join event for the room and return that. We don *not* persist or
process it until the other server has signed it and sent it back.
"""
event_content = {"membership": Membership.JOIN}
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"content": event_content,
"content": {"membership": Membership.JOIN},
"room_id": room_id,
"sender": user_id,
"state_key": user_id,
@@ -816,168 +864,6 @@ class FederationHandler(BaseHandler):
defer.returnValue(event)
@defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
origin, event = yield self._make_and_verify_event(
target_hosts,
room_id,
user_id,
"leave"
)
signed_event = self._sign_event(event)
# Try the host we successfully got a response to /make_join/
# request first.
try:
target_hosts.remove(origin)
target_hosts.insert(0, origin)
except ValueError:
pass
yield self.replication_layer.send_leave(
target_hosts,
signed_event
)
defer.returnValue(None)
@defer.inlineCallbacks
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
content={},):
origin, pdu = yield self.replication_layer.make_membership_event(
target_hosts,
room_id,
user_id,
membership,
content,
)
logger.debug("Got response to make_%s: %s", membership, pdu)
event = pdu
# We should assert some things.
# FIXME: Do this in a nicer way
assert(event.type == EventTypes.Member)
assert(event.user_id == user_id)
assert(event.state_key == user_id)
assert(event.room_id == room_id)
defer.returnValue((origin, event))
def _sign_event(self, event):
event.internal_metadata.outlier = False
builder = self.event_builder_factory.new(
unfreeze(event.get_pdu_json())
)
builder.event_id = self.event_builder_factory.create_event_id()
builder.origin = self.hs.hostname
if not hasattr(event, "signatures"):
builder.signatures = {}
add_hashes_and_signatures(
builder,
self.hs.hostname,
self.hs.config.signing_key[0],
)
return builder.build()
@defer.inlineCallbacks
@log_function
def on_make_leave_request(self, room_id, user_id):
""" We've received a /make_leave/ request, so we create a partial
join event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back.
"""
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"content": {"membership": Membership.LEAVE},
"room_id": room_id,
"sender": user_id,
"state_key": user_id,
})
event, context = yield self._create_new_client_event(
builder=builder,
)
self.auth.check(event, auth_events=context.current_state)
defer.returnValue(event)
@defer.inlineCallbacks
@log_function
def on_send_leave_request(self, origin, pdu):
""" We have received a leave event for a room. Fully process it."""
event = pdu
logger.debug(
"on_send_leave_request: Got event: %s, signatures: %s",
event.event_id,
event.signatures,
)
event.internal_metadata.outlier = False
context, event_stream_id, max_stream_id = yield self._handle_new_event(
origin, event
)
logger.debug(
"on_send_leave_request: After _handle_new_event: %s, sigs: %s",
event.event_id,
event.signatures,
)
extra_users = []
if event.type == EventTypes.Member:
target_user_id = event.state_key
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
with PreserveLoggingContext():
d = self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users
)
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
new_pdu = event
destinations = set()
for k, s in context.current_state.items():
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.LEAVE:
destinations.add(
UserID.from_string(s.state_key).domain
)
except:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
destinations.discard(origin)
logger.debug(
"on_send_leave_request: Sending event: %s, signatures: %s",
event.event_id,
event.signatures,
)
self.replication_layer.send_pdu(new_pdu, destinations)
defer.returnValue(None)
@defer.inlineCallbacks
def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
yield run_on_reactor()
@@ -1100,6 +986,8 @@ class FederationHandler(BaseHandler):
context = yield self._prep_event(
origin, event,
state=state,
backfilled=backfilled,
current_state=current_state,
auth_events=auth_events,
)
@@ -1122,6 +1010,7 @@ class FederationHandler(BaseHandler):
origin,
ev_info["event"],
state=ev_info.get("state"),
backfilled=backfilled,
auth_events=ev_info.get("auth_events"),
)
for ev_info in event_infos
@@ -1138,77 +1027,8 @@ class FederationHandler(BaseHandler):
)
@defer.inlineCallbacks
def _persist_auth_tree(self, auth_events, state, event):
"""Checks the auth chain is valid (and passes auth checks) for the
state and event. Then persists the auth chain and state atomically.
Persists the event seperately.
Returns:
2-tuple of (event_stream_id, max_stream_id) from the persist_event
call for `event`
"""
events_to_context = {}
for e in itertools.chain(auth_events, state):
ctx = yield self.state_handler.compute_event_context(
e, outlier=True,
)
events_to_context[e.event_id] = ctx
e.internal_metadata.outlier = True
event_map = {
e.event_id: e
for e in auth_events
}
create_event = None
for e in auth_events:
if (e.type, e.state_key) == (EventTypes.Create, ""):
create_event = e
break
for e in itertools.chain(auth_events, state, [event]):
auth_for_e = {
(event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
for e_id, _ in e.auth_events
}
if create_event:
auth_for_e[(EventTypes.Create, "")] = create_event
try:
self.auth.check(e, auth_events=auth_for_e)
except AuthError as err:
logger.warn(
"Rejecting %s because %s",
e.event_id, err.msg
)
if e == event:
raise
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
yield self.store.persist_events(
[
(e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state)
],
is_new_state=False,
)
new_event_context = yield self.state_handler.compute_event_context(
event, old_state=state, outlier=False,
)
event_stream_id, max_stream_id = yield self.store.persist_event(
event, new_event_context,
backfilled=False,
is_new_state=True,
current_state=state,
)
defer.returnValue((event_stream_id, max_stream_id))
@defer.inlineCallbacks
def _prep_event(self, origin, event, state=None, auth_events=None):
def _prep_event(self, origin, event, state=None, backfilled=False,
current_state=None, auth_events=None):
outlier = event.internal_metadata.is_outlier()
context = yield self.state_handler.compute_event_context(
@@ -1241,10 +1061,6 @@ class FederationHandler(BaseHandler):
context.rejected = RejectedReason.AUTH_ERROR
if event.type == EventTypes.GuestAccess:
full_context = yield self.store.get_current_state(room_id=event.room_id)
yield self.maybe_kick_guest_users(event, full_context)
defer.returnValue(context)
@defer.inlineCallbacks
@@ -1350,7 +1166,7 @@ class FederationHandler(BaseHandler):
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in remote_auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create
if e.event_id in auth_ids
}
e.internal_metadata.outlier = True
@@ -1468,7 +1284,6 @@ class FederationHandler(BaseHandler):
(e.type, e.state_key): e
for e in result["auth_chain"]
if e.event_id in auth_ids
or event.type == EventTypes.Create
}
ev.internal_metadata.outlier = True
@@ -1641,75 +1456,3 @@ class FederationHandler(BaseHandler):
},
"missing": [e.event_id for e in missing_locals],
})
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(self, invite):
sender = invite["sender"]
room_id = invite["room_id"]
event_dict = {
"type": EventTypes.Member,
"content": {
"membership": Membership.INVITE,
"third_party_invite": invite,
},
"room_id": room_id,
"sender": sender,
"state_key": invite["mxid"],
}
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder)
event, context = yield self._create_new_client_event(builder=builder)
self.auth.check(event, context.current_state)
yield self._validate_keyserver(event, auth_events=context.current_state)
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.change_membership(event, context)
else:
destinations = set([x.split(":", 1)[-1] for x in (sender, room_id)])
yield self.replication_layer.forward_third_party_invite(
destinations,
room_id,
event_dict,
)
@defer.inlineCallbacks
@log_function
def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
builder = self.event_builder_factory.new(event_dict)
event, context = yield self._create_new_client_event(
builder=builder,
)
self.auth.check(event, auth_events=context.current_state)
yield self._validate_keyserver(event, auth_events=context.current_state)
returned_invite = yield self.send_invite(origin, event)
# TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures)
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.change_membership(event, context)
@defer.inlineCallbacks
def _validate_keyserver(self, event, auth_events):
token = event.content["third_party_invite"]["signed"]["token"]
invite_event = auth_events.get(
(EventTypes.ThirdPartyInvite, token,)
)
try:
response = yield self.hs.get_simple_http_client().get_json(
invite_event.content["key_validity_url"],
{"public_key": invite_event.content["public_key"]}
)
except Exception:
raise SynapseError(
502,
"Third party certificate could not be checked"
)
if "valid" not in response or not response["valid"]:
raise AuthError(403, "Third party certificate was invalid")
+113 -139
View File
@@ -16,7 +16,7 @@
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError, AuthError, Codes
from synapse.api.errors import SynapseError
from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
@@ -71,20 +71,20 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks
def get_messages(self, user_id=None, room_id=None, pagin_config=None,
as_client_event=True, is_guest=False):
as_client_event=True):
"""Get messages in a room.
Args:
user_id (str): The user requesting messages.
room_id (str): The room they want messages from.
pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any.
config rules to apply, if any.
as_client_event (bool): True to get events in client-server format.
is_guest (bool): Whether the requesting user is a guest (as opposed
to a fully registered user).
Returns:
dict: Pagination API results
"""
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
data_source = self.hs.get_event_sources().sources["room"]
if pagin_config.from_token:
@@ -107,27 +107,23 @@ class MessageHandler(BaseHandler):
source_config = pagin_config.get_source_config("room")
if not is_guest:
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
if member_event.membership == Membership.LEAVE:
# If they have left the room then clamp the token to be before
# they left the room.
# If they're a guest, we'll just 403 them if they're asking for
# events they can't see.
leave_token = yield self.store.get_topological_token_for_event(
member_event.event_id
)
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < room_token.topological:
source_config.from_key = str(leave_token)
if member_event.membership == Membership.LEAVE:
# If they have left the room then clamp the token to be before
# they left the room
leave_token = yield self.store.get_topological_token_for_event(
member_event.event_id
)
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < room_token.topological:
source_config.from_key = str(leave_token)
if source_config.direction == "f":
if source_config.to_key is None:
if source_config.direction == "f":
if source_config.to_key is None:
source_config.to_key = str(leave_token)
else:
to_token = RoomStreamToken.parse(source_config.to_key)
if leave_token.topological < to_token.topological:
source_config.to_key = str(leave_token)
else:
to_token = RoomStreamToken.parse(source_config.to_key)
if leave_token.topological < to_token.topological:
source_config.to_key = str(leave_token)
yield self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, room_token.topological
@@ -150,7 +146,7 @@ class MessageHandler(BaseHandler):
"end": next_token.to_string(),
})
events = yield self._filter_events_for_client(user_id, events, is_guest=is_guest)
events = yield self._filter_events_for_client(user_id, room_id, events)
time_now = self.clock.time_msec()
@@ -165,9 +161,55 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk)
@defer.inlineCallbacks
def _filter_events_for_client(self, user_id, room_id, events):
event_id_to_state = yield self.store.get_state_for_events(
room_id, frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id),
)
)
def allowed(event, state):
if event.type == EventTypes.RoomHistoryVisibility:
return True
membership_ev = state.get((EventTypes.Member, user_id), None)
if membership_ev:
membership = membership_ev.membership
else:
membership = Membership.LEAVE
if membership == Membership.JOIN:
return True
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
if history:
visibility = history.content.get("history_visibility", "shared")
else:
visibility = "shared"
if visibility == "public":
return True
elif visibility == "shared":
return True
elif visibility == "joined":
return membership == Membership.JOIN
elif visibility == "invited":
return membership == Membership.INVITE
return True
defer.returnValue([
event
for event in events
if allowed(event, event_id_to_state[event.event_id])
])
@defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True,
token_id=None, txn_id=None, is_guest=False):
token_id=None, txn_id=None):
""" Given a dict from a client, create and handle a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events,
@@ -213,7 +255,7 @@ class MessageHandler(BaseHandler):
if event.type == EventTypes.Member:
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.change_membership(event, context, is_guest=is_guest)
yield member_handler.change_membership(event, context)
else:
yield self.handle_new_client_event(
event=event,
@@ -229,7 +271,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks
def get_room_data(self, user_id=None, room_id=None,
event_type=None, state_key="", is_guest=False):
event_type=None, state_key=""):
""" Get data from a room.
Args:
@@ -239,52 +281,23 @@ class MessageHandler(BaseHandler):
Raises:
SynapseError if something went wrong.
"""
membership, membership_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id, is_guest
)
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
if membership == Membership.JOIN:
if member_event.membership == Membership.JOIN:
data = yield self.state_handler.get_current_state(
room_id, event_type, state_key
)
elif membership == Membership.LEAVE:
elif member_event.membership == Membership.LEAVE:
key = (event_type, state_key)
room_state = yield self.store.get_state_for_events(
[membership_event_id], [key]
room_id, [member_event.event_id], [key]
)
data = room_state[membership_event_id].get(key)
data = room_state[member_event.event_id].get(key)
defer.returnValue(data)
@defer.inlineCallbacks
def _check_in_room_or_world_readable(self, room_id, user_id, is_guest):
try:
# check_user_was_in_room will return the most recent membership
# event for the user if:
# * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room
# else it will throw.
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
defer.returnValue((member_event.membership, member_event.event_id))
return
except AuthError, auth_error:
visibility = yield self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (
visibility and
visibility.content["history_visibility"] == "world_readable"
):
defer.returnValue((Membership.JOIN, None))
return
if not is_guest:
raise auth_error
raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
)
@defer.inlineCallbacks
def get_state_events(self, user_id, room_id, is_guest=False):
def get_state_events(self, user_id, room_id):
"""Retrieve all state events for a given room. If the user is
joined to the room then return the current state. If the user has
left the room return the state events from when they left.
@@ -295,17 +308,15 @@ class MessageHandler(BaseHandler):
Returns:
A list of dicts representing state events. [{}, {}, {}]
"""
membership, membership_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id, is_guest
)
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
if membership == Membership.JOIN:
if member_event.membership == Membership.JOIN:
room_state = yield self.state_handler.get_current_state(room_id)
elif membership == Membership.LEAVE:
elif member_event.membership == Membership.LEAVE:
room_state = yield self.store.get_state_for_events(
[membership_event_id], None
room_id, [member_event.event_id], None
)
room_state = room_state[membership_event_id]
room_state = room_state[member_event.event_id]
now = self.clock.time_msec()
defer.returnValue(
@@ -313,8 +324,7 @@ class MessageHandler(BaseHandler):
)
@defer.inlineCallbacks
def snapshot_all_rooms(self, user_id=None, pagin_config=None,
as_client_event=True, include_archived=False):
def snapshot_all_rooms(self, user_id=None, pagin_config=None, as_client_event=True):
"""Retrieve a snapshot of all rooms the user is invited or has joined.
This snapshot may include messages for all rooms where the user is
@@ -325,19 +335,17 @@ class MessageHandler(BaseHandler):
pagin_config (synapse.api.streams.PaginationConfig): The pagination
config used to determine how many messages *PER ROOM* to return.
as_client_event (bool): True to get events in client-server format.
include_archived (bool): True to get rooms that the user has left
Returns:
A list of dicts with "room_id" and "membership" keys for all rooms
the user is currently invited or joined in on. Rooms where the user
is joined on, may return a "messages" key with messages, depending
on the specified PaginationConfig.
"""
memberships = [Membership.INVITE, Membership.JOIN]
if include_archived:
memberships.append(Membership.LEAVE)
room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=user_id, membership_list=memberships
user_id=user_id,
membership_list=[
Membership.INVITE, Membership.JOIN, Membership.LEAVE
]
)
user = UserID.from_string(user_id)
@@ -357,8 +365,6 @@ class MessageHandler(BaseHandler):
user, pagination_config.get_source_config("receipt"), None
)
tags_by_room = yield self.store.get_tags_for_user(user_id)
public_room_ids = yield self.store.get_public_room_ids()
limit = pagin_config.limit
@@ -377,12 +383,8 @@ class MessageHandler(BaseHandler):
}
if event.membership == Membership.INVITE:
time_now = self.clock.time_msec()
d["inviter"] = event.sender
invite_event = yield self.store.get_event(event.event_id)
d["invite"] = serialize_event(invite_event, time_now, as_client_event)
rooms_ret.append(d)
if event.membership not in (Membership.JOIN, Membership.LEAVE):
@@ -397,7 +399,7 @@ class MessageHandler(BaseHandler):
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = self.store.get_state_for_events(
[event.event_id], None
event.room_id, [event.event_id], None
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
@@ -415,7 +417,7 @@ class MessageHandler(BaseHandler):
).addErrback(unwrapFirstError)
messages = yield self._filter_events_for_client(
user_id, messages
user_id, event.room_id, messages
)
start_token = now_token.copy_and_replace("room_key", token[0])
@@ -435,15 +437,6 @@ class MessageHandler(BaseHandler):
serialize_event(c, time_now, as_client_event)
for c in current_state.values()
]
private_user_data = []
tags = tags_by_room.get(event.room_id)
if tags:
private_user_data.append({
"type": "m.tag",
"content": {"tags": tags},
})
d["private_user_data"] = private_user_data
except:
logger.exception("Failed to get snapshot")
@@ -466,7 +459,7 @@ class MessageHandler(BaseHandler):
defer.returnValue(ret)
@defer.inlineCallbacks
def room_initial_sync(self, user_id, room_id, pagin_config=None, is_guest=False):
def room_initial_sync(self, user_id, room_id, pagin_config=None):
"""Capture the a snapshot of a room. If user is currently a member of
the room this will be what is currently in the room. If the user left
the room this will be what was in the room when they left.
@@ -483,47 +476,33 @@ class MessageHandler(BaseHandler):
A JSON serialisable dict with the snapshot of the room.
"""
membership, member_event_id = yield self._check_in_room_or_world_readable(
room_id,
user_id,
is_guest
)
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
if membership == Membership.JOIN:
if member_event.membership == Membership.JOIN:
result = yield self._room_initial_sync_joined(
user_id, room_id, pagin_config, membership, is_guest
user_id, room_id, pagin_config, member_event
)
elif membership == Membership.LEAVE:
elif member_event.membership == Membership.LEAVE:
result = yield self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_guest
user_id, room_id, pagin_config, member_event
)
private_user_data = []
tags = yield self.store.get_tags_for_room(user_id, room_id)
if tags:
private_user_data.append({
"type": "m.tag",
"content": {"tags": tags},
})
result["private_user_data"] = private_user_data
defer.returnValue(result)
@defer.inlineCallbacks
def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
membership, member_event_id, is_guest):
member_event):
room_state = yield self.store.get_state_for_events(
[member_event_id], None
member_event.room_id, [member_event.event_id], None
)
room_state = room_state[member_event_id]
room_state = room_state[member_event.event_id]
limit = pagin_config.limit if pagin_config else None
if limit is None:
limit = 10
stream_token = yield self.store.get_stream_token_for_event(
member_event_id
member_event.event_id
)
messages, token = yield self.store.get_recent_events_for_room(
@@ -533,16 +512,16 @@ class MessageHandler(BaseHandler):
)
messages = yield self._filter_events_for_client(
user_id, messages, is_guest=is_guest
user_id, room_id, messages
)
start_token = StreamToken(token[0], 0, 0, 0, 0)
end_token = StreamToken(token[1], 0, 0, 0, 0)
start_token = StreamToken(token[0], 0, 0, 0)
end_token = StreamToken(token[1], 0, 0, 0)
time_now = self.clock.time_msec()
defer.returnValue({
"membership": membership,
"membership": member_event.membership,
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
@@ -556,7 +535,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks
def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
membership, is_guest):
member_event):
current_state = yield self.state.get_current_state(
room_id=room_id,
)
@@ -588,14 +567,12 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks
def get_presence():
states = {}
if not is_guest:
states = yield presence_handler.get_states(
target_users=[UserID.from_string(m.user_id) for m in room_members],
auth_user=auth_user,
as_event=True,
check_auth=False,
)
states = yield presence_handler.get_states(
target_users=[UserID.from_string(m.user_id) for m in room_members],
auth_user=auth_user,
as_event=True,
check_auth=False,
)
defer.returnValue(states.values())
@@ -615,7 +592,7 @@ class MessageHandler(BaseHandler):
).addErrback(unwrapFirstError)
messages = yield self._filter_events_for_client(
user_id, messages, is_guest=is_guest, require_all_visible_for_guests=False
user_id, room_id, messages
)
start_token = now_token.copy_and_replace("room_key", token[0])
@@ -623,7 +600,8 @@ class MessageHandler(BaseHandler):
time_now = self.clock.time_msec()
ret = {
defer.returnValue({
"membership": member_event.membership,
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
@@ -633,8 +611,4 @@ class MessageHandler(BaseHandler):
"state": state,
"presence": presence,
"receipts": receipts,
}
if not is_guest:
ret["membership"] = membership
defer.returnValue(ret)
})
+6 -12
View File
@@ -378,7 +378,7 @@ class PresenceHandler(BaseHandler):
# TODO(paul): perform a presence push as part of start/stop poll so
# we don't have to do this all the time
yield self.changed_presencelike_data(target_user, state)
self.changed_presencelike_data(target_user, state)
def bump_presence_active_time(self, user, now=None):
if now is None:
@@ -422,12 +422,12 @@ class PresenceHandler(BaseHandler):
@log_function
def started_user_eventstream(self, user):
# TODO(paul): Use "last online" state
return self.set_state(user, user, {"presence": PresenceState.ONLINE})
self.set_state(user, user, {"presence": PresenceState.ONLINE})
@log_function
def stopped_user_eventstream(self, user):
# TODO(paul): Save current state as "last online" state
return self.set_state(user, user, {"presence": PresenceState.OFFLINE})
self.set_state(user, user, {"presence": PresenceState.OFFLINE})
@defer.inlineCallbacks
def user_joined_room(self, user, room_id):
@@ -950,8 +950,7 @@ class PresenceHandler(BaseHandler):
)
while len(self._remote_offline_serials) > MAX_OFFLINE_SERIALS:
self._remote_offline_serials.pop() # remove the oldest
if user in self._user_cachemap:
del self._user_cachemap[user]
del self._user_cachemap[user]
else:
# Remove the user from remote_offline_serials now that they're
# no longer offline
@@ -1143,9 +1142,8 @@ class PresenceEventSource(object):
@defer.inlineCallbacks
@log_function
def get_new_events(self, user, from_key, room_ids=None, **kwargs):
def get_new_events_for_user(self, user, from_key, limit):
from_key = int(from_key)
room_ids = room_ids or []
presence = self.hs.get_handlers().presence_handler
cachemap = presence._user_cachemap
@@ -1163,6 +1161,7 @@ class PresenceEventSource(object):
user_ids_to_check |= set(
UserID.from_string(p["observed_user_id"]) for p in presence_list
)
room_ids = yield presence.get_joined_rooms_for_user(user)
for room_id in set(room_ids) & set(presence._room_serials):
if presence._room_serials[room_id] > from_key:
joined = yield presence.get_joined_users_for_room_id(room_id)
@@ -1264,11 +1263,6 @@ class UserPresenceCache(object):
self.state = {"presence": PresenceState.OFFLINE}
self.serial = None
def __repr__(self):
return "UserPresenceCache(state=%r, serial=%r)" % (
self.state, self.serial
)
def update(self, state, serial):
assert("mtime_age" not in state)
-46
View File
@@ -1,46 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
class PrivateUserDataEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
def get_current_key(self, direction='f'):
return self.store.get_max_private_user_data_stream_id()
@defer.inlineCallbacks
def get_new_events(self, user, from_key, **kwargs):
user_id = user.to_string()
last_stream_id = from_key
current_stream_id = yield self.store.get_max_private_user_data_stream_id()
tags = yield self.store.get_updated_tags(user_id, last_stream_id)
results = []
for room_id, room_tags in tags.items():
results.append({
"type": "m.tag",
"content": {"tags": room_tags},
"room_id": room_id,
})
defer.returnValue((results, current_stream_id))
@defer.inlineCallbacks
def get_pagination_rows(self, user, config, key):
defer.returnValue(([], config.to_id))
+11 -3
View File
@@ -156,7 +156,13 @@ class ReceiptsHandler(BaseHandler):
if not result:
defer.returnValue([])
defer.returnValue(result)
event = {
"type": "m.receipt",
"room_id": room_id,
"content": result,
}
defer.returnValue([event])
class ReceiptEventSource(object):
@@ -164,15 +170,17 @@ class ReceiptEventSource(object):
self.store = hs.get_datastore()
@defer.inlineCallbacks
def get_new_events(self, from_key, room_ids, **kwargs):
def get_new_events_for_user(self, user, from_key, limit):
from_key = int(from_key)
to_key = yield self.get_current_key()
if from_key == to_key:
defer.returnValue(([], to_key))
rooms = yield self.store.get_rooms_for_user(user.to_string())
rooms = [room.room_id for room in rooms]
events = yield self.store.get_linearized_receipts_for_rooms(
room_ids,
rooms,
from_key=from_key,
to_key=to_key,
)
+5 -7
View File
@@ -64,7 +64,7 @@ class RegistrationHandler(BaseHandler):
)
@defer.inlineCallbacks
def register(self, localpart=None, password=None, generate_token=True):
def register(self, localpart=None, password=None):
"""Registers a new client on the server.
Args:
@@ -89,9 +89,7 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
token = None
if generate_token:
token = self.auth_handler().generate_access_token(user_id)
token = self.auth_handler().generate_access_token(user_id)
yield self.store.register(
user_id=user_id,
token=token,
@@ -104,14 +102,14 @@ class RegistrationHandler(BaseHandler):
attempts = 0
user_id = None
token = None
while not user_id:
while not user_id and not token:
try:
localpart = self._generate_user_id()
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
if generate_token:
token = self.auth_handler().generate_access_token(user_id)
token = self.auth_handler().generate_access_token(user_id)
yield self.store.register(
user_id=user_id,
token=token,
+40 -330
View File
@@ -22,24 +22,16 @@ from synapse.types import UserID, RoomAlias, RoomID
from synapse.api.constants import (
EventTypes, Membership, JoinRules, RoomCreationPreset,
)
from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.api.errors import StoreError, SynapseError
from synapse.util import stringutils, unwrapFirstError
from synapse.util.async import run_on_reactor
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes
from collections import OrderedDict
from unpaddedbase64 import decode_base64
import logging
import math
import string
logger = logging.getLogger(__name__)
id_server_scheme = "https://"
class RoomCreationHandler(BaseHandler):
@@ -49,11 +41,6 @@ class RoomCreationHandler(BaseHandler):
"history_visibility": "shared",
"original_invitees_have_ops": False,
},
RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE,
"history_visibility": "shared",
"original_invitees_have_ops": True,
},
RoomCreationPreset.PUBLIC_CHAT: {
"join_rules": JoinRules.PUBLIC,
"history_visibility": "shared",
@@ -162,16 +149,12 @@ class RoomCreationHandler(BaseHandler):
for val in raw_initial_state:
initial_state[(val["type"], val.get("state_key", ""))] = val["content"]
creation_content = config.get("creation_content", {})
user = UserID.from_string(user_id)
creation_events = self._create_events_for_new_room(
user, room_id,
preset_config=preset_config,
invite_list=invite_list,
initial_state=initial_state,
creation_content=creation_content,
room_alias=room_alias,
)
msg_handler = self.hs.get_handlers().message_handler
@@ -219,8 +202,7 @@ class RoomCreationHandler(BaseHandler):
defer.returnValue(result)
def _create_events_for_new_room(self, creator, room_id, preset_config,
invite_list, initial_state, creation_content,
room_alias):
invite_list, initial_state):
config = RoomCreationHandler.PRESETS_DICT[preset_config]
creator_id = creator.to_string()
@@ -242,10 +224,9 @@ class RoomCreationHandler(BaseHandler):
return e
creation_content.update({"creator": creator.to_string()})
creation_event = create(
etype=EventTypes.Create,
content=creation_content,
content={"creator": creator.to_string()},
)
join_event = create(
@@ -290,14 +271,6 @@ class RoomCreationHandler(BaseHandler):
returned_events.append(power_levels_event)
if room_alias and (EventTypes.CanonicalAlias, '') not in initial_state:
room_alias_event = create(
etype=EventTypes.CanonicalAlias,
content={"alias": room_alias.to_string()},
)
returned_events.append(room_alias_event)
if (EventTypes.JoinRules, '') not in initial_state:
join_rules_event = create(
etype=EventTypes.JoinRules,
@@ -369,7 +342,7 @@ class RoomMemberHandler(BaseHandler):
remotedomains.add(member.domain)
@defer.inlineCallbacks
def change_membership(self, event, context, do_auth=True, is_guest=False):
def change_membership(self, event, context, do_auth=True):
""" Change the membership status of a user in a room.
Args:
@@ -390,38 +363,9 @@ class RoomMemberHandler(BaseHandler):
# if this HS is not currently in the room, i.e. we have to do the
# invite/join dance.
if event.membership == Membership.JOIN:
if is_guest:
guest_access = context.current_state.get(
(EventTypes.GuestAccess, ""),
None
)
is_guest_access_allowed = (
guest_access
and guest_access.content
and "guest_access" in guest_access.content
and guest_access.content["guest_access"] == "can_join"
)
if not is_guest_access_allowed:
raise AuthError(403, "Guest access not allowed")
yield self._do_join(event, context, do_auth=do_auth)
else:
if event.membership == Membership.LEAVE:
is_host_in_room = yield self.is_host_in_room(room_id, context)
if not is_host_in_room:
# Rejecting an invite, rather than leaving a joined room
handler = self.hs.get_handlers().federation_handler
inviter = yield self.get_inviter(event)
if not inviter:
# return the same error as join_room_alias does
raise SynapseError(404, "No known servers")
yield handler.do_remotely_reject_invite(
[inviter.domain],
room_id,
event.user_id
)
defer.returnValue({"room_id": room_id})
return
# This is not a JOIN, so we can handle it normally.
# FIXME: This isn't idempotency.
if prev_state and prev_state.membership == event.membership:
@@ -445,7 +389,7 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue({"room_id": room_id})
@defer.inlineCallbacks
def join_room_alias(self, joinee, room_alias, content={}):
def join_room_alias(self, joinee, room_alias, do_auth=True, content={}):
directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias)
@@ -479,6 +423,8 @@ class RoomMemberHandler(BaseHandler):
@defer.inlineCallbacks
def _do_join(self, event, context, room_hosts=None, do_auth=True):
joinee = UserID.from_string(event.state_key)
# room_id = RoomID.from_string(event.room_id, self.hs)
room_id = event.room_id
# XXX: We don't do an auth check if we are doing an invite
@@ -486,18 +432,41 @@ class RoomMemberHandler(BaseHandler):
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
is_host_in_room = yield self.is_host_in_room(room_id, context)
is_host_in_room = yield self.auth.check_host_in_room(
event.room_id,
self.hs.hostname
)
if not is_host_in_room:
# is *anyone* in the room?
room_member_keys = [
v for (k, v) in context.current_state.keys() if (
k == "m.room.member"
)
]
if len(room_member_keys) == 0:
# has the room been created so we can join it?
create_event = context.current_state.get(("m.room.create", ""))
if create_event:
is_host_in_room = True
if is_host_in_room:
should_do_dance = False
elif room_hosts: # TODO: Shouldn't this be remote_room_host?
should_do_dance = True
else:
inviter = yield self.get_inviter(event)
if not inviter:
# TODO(markjh): get prev_state from snapshot
prev_state = yield self.store.get_room_member(
joinee.to_string(), room_id
)
if prev_state and prev_state.membership == Membership.INVITE:
inviter = UserID.from_string(prev_state.user_id)
should_do_dance = not self.hs.is_mine(inviter)
room_hosts = [inviter.domain]
else:
# return the same error as join_room_alias does
raise SynapseError(404, "No known servers")
should_do_dance = not self.hs.is_mine(inviter)
room_hosts = [inviter.domain]
if should_do_dance:
handler = self.hs.get_handlers().federation_handler
@@ -505,7 +474,8 @@ class RoomMemberHandler(BaseHandler):
room_hosts,
room_id,
event.user_id,
event.content,
event.content, # FIXME To get a non-frozen dict
context
)
else:
logger.debug("Doing normal join")
@@ -522,44 +492,6 @@ class RoomMemberHandler(BaseHandler):
"user_joined_room", user=user, room_id=room_id
)
@defer.inlineCallbacks
def get_inviter(self, event):
# TODO(markjh): get prev_state from snapshot
prev_state = yield self.store.get_room_member(
event.user_id, event.room_id
)
if prev_state and prev_state.membership == Membership.INVITE:
defer.returnValue(UserID.from_string(prev_state.user_id))
return
elif "third_party_invite" in event.content:
if "sender" in event.content["third_party_invite"]:
inviter = UserID.from_string(
event.content["third_party_invite"]["sender"]
)
defer.returnValue(inviter)
defer.returnValue(None)
@defer.inlineCallbacks
def is_host_in_room(self, room_id, context):
is_host_in_room = yield self.auth.check_host_in_room(
room_id,
self.hs.hostname
)
if not is_host_in_room:
# is *anyone* in the room?
room_member_keys = [
v for (k, v) in context.current_state.keys() if (
k == "m.room.member"
)
]
if len(room_member_keys) == 0:
# has the room been created so we can join it?
create_event = context.current_state.get(("m.room.create", ""))
if create_event:
is_host_in_room = True
defer.returnValue(is_host_in_room)
@defer.inlineCallbacks
def get_joined_rooms_for_user(self, user):
"""Returns a list of roomids that the user has any of the given
@@ -589,160 +521,6 @@ class RoomMemberHandler(BaseHandler):
suppress_auth=(not do_auth),
)
@defer.inlineCallbacks
def do_3pid_invite(
self,
room_id,
inviter,
medium,
address,
id_server,
token_id,
txn_id
):
invitee = yield self._lookup_3pid(
id_server, medium, address
)
if invitee:
# make sure it looks like a user ID; it'll throw if it's invalid.
UserID.from_string(invitee)
yield self.hs.get_handlers().message_handler.create_and_send_event(
{
"type": EventTypes.Member,
"content": {
"membership": unicode("invite")
},
"room_id": room_id,
"sender": inviter.to_string(),
"state_key": invitee,
},
token_id=token_id,
txn_id=txn_id,
)
else:
yield self._make_and_store_3pid_invite(
id_server,
medium,
address,
room_id,
inviter,
token_id,
txn_id=txn_id
)
@defer.inlineCallbacks
def _lookup_3pid(self, id_server, medium, address):
"""Looks up a 3pid in the passed identity server.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
Returns:
(str) the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
data = yield self.hs.get_simple_http_client().get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
{
"medium": medium,
"address": address,
}
)
if "mxid" in data:
if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding")
self.verify_any_signature(data, id_server)
defer.returnValue(data["mxid"])
except IOError as e:
logger.warn("Error from identity server lookup: %s" % (e,))
defer.returnValue(None)
@defer.inlineCallbacks
def verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
key_data = yield self.hs.get_simple_http_client().get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s" %
(id_server_scheme, server_hostname, key_name,),
)
if "public_key" not in key_data:
raise AuthError(401, "No public key named %s from %s" %
(key_name, server_hostname,))
verify_signed_json(
data,
server_hostname,
decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"]))
)
return
@defer.inlineCallbacks
def _make_and_store_3pid_invite(
self,
id_server,
medium,
address,
room_id,
user,
token_id,
txn_id
):
token, public_key, key_validity_url, display_name = (
yield self._ask_id_server_for_third_party_invite(
id_server,
medium,
address,
room_id,
user.to_string()
)
)
msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_event(
{
"type": EventTypes.ThirdPartyInvite,
"content": {
"display_name": display_name,
"key_validity_url": key_validity_url,
"public_key": public_key,
},
"room_id": room_id,
"sender": user.to_string(),
"state_key": token,
},
token_id=token_id,
txn_id=txn_id,
)
@defer.inlineCallbacks
def _ask_id_server_for_third_party_invite(
self, id_server, medium, address, room_id, sender):
is_url = "%s%s/_matrix/identity/api/v1/store-invite" % (
id_server_scheme, id_server,
)
data = yield self.hs.get_simple_http_client().post_urlencoded_get_json(
is_url,
{
"medium": medium,
"address": address,
"room_id": room_id,
"sender": sender,
}
)
# TODO: Check for success
token = data["token"]
public_key = data["public_key"]
display_name = data["display_name"]
key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
id_server_scheme, id_server,
)
defer.returnValue((token, public_key, key_validity_url, display_name))
class RoomListHandler(BaseHandler):
@@ -764,79 +542,12 @@ class RoomListHandler(BaseHandler):
defer.returnValue({"start": "START", "end": "END", "chunk": chunk})
class RoomContextHandler(BaseHandler):
@defer.inlineCallbacks
def get_event_context(self, user, room_id, event_id, limit, is_guest):
"""Retrieves events, pagination tokens and state around a given event
in a room.
Args:
user (UserID)
room_id (str)
event_id (str)
limit (int): The maximum number of events to return in total
(excluding state).
Returns:
dict
"""
before_limit = math.floor(limit/2.)
after_limit = limit - before_limit
now_token = yield self.hs.get_event_sources().get_current_token()
results = yield self.store.get_events_around(
room_id, event_id, before_limit, after_limit
)
results["events_before"] = yield self._filter_events_for_client(
user.to_string(),
results["events_before"],
is_guest=is_guest,
require_all_visible_for_guests=False
)
results["events_after"] = yield self._filter_events_for_client(
user.to_string(),
results["events_after"],
is_guest=is_guest,
require_all_visible_for_guests=False
)
if results["events_after"]:
last_event_id = results["events_after"][-1].event_id
else:
last_event_id = event_id
state = yield self.store.get_state_for_events(
[last_event_id], None
)
results["state"] = state[last_event_id].values()
results["start"] = now_token.copy_and_replace(
"room_key", results["start"]
).to_string()
results["end"] = now_token.copy_and_replace(
"room_key", results["end"]
).to_string()
defer.returnValue(results)
class RoomEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
@defer.inlineCallbacks
def get_new_events(
self,
user,
from_key,
limit,
room_ids,
is_guest,
):
def get_new_events_for_user(self, user, from_key, limit):
# We just ignore the key for now.
to_key = yield self.get_current_key()
@@ -856,9 +567,8 @@ class RoomEventSource(object):
user_id=user.to_string(),
from_key=from_key,
to_key=to_key,
room_id=None,
limit=limit,
room_ids=room_ids,
is_guest=is_guest,
)
defer.returnValue((events, end_key))
-319
View File
@@ -1,319 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.constants import Membership
from synapse.api.filtering import Filter
from synapse.api.errors import SynapseError
from synapse.events.utils import serialize_event
from unpaddedbase64 import decode_base64, encode_base64
import logging
logger = logging.getLogger(__name__)
class SearchHandler(BaseHandler):
def __init__(self, hs):
super(SearchHandler, self).__init__(hs)
@defer.inlineCallbacks
def search(self, user, content, batch=None):
"""Performs a full text search for a user.
Args:
user (UserID)
content (dict): Search parameters
batch (str): The next_batch parameter. Used for pagination.
Returns:
dict to be returned to the client with results of search
"""
batch_group = None
batch_group_key = None
batch_token = None
if batch:
try:
b = decode_base64(batch)
batch_group, batch_group_key, batch_token = b.split("\n")
assert batch_group is not None
assert batch_group_key is not None
assert batch_token is not None
except:
raise SynapseError(400, "Invalid batch")
try:
room_cat = content["search_categories"]["room_events"]
# The actual thing to query in FTS
search_term = room_cat["search_term"]
# Which "keys" to search over in FTS query
keys = room_cat.get("keys", [
"content.body", "content.name", "content.topic",
])
# Filter to apply to results
filter_dict = room_cat.get("filter", {})
# What to order results by (impacts whether pagination can be doen)
order_by = room_cat.get("order_by", "rank")
# Include context around each event?
event_context = room_cat.get(
"event_context", None
)
# Group results together? May allow clients to paginate within a
# group
group_by = room_cat.get("groupings", {}).get("group_by", {})
group_keys = [g["key"] for g in group_by]
if event_context is not None:
before_limit = int(event_context.get(
"before_limit", 5
))
after_limit = int(event_context.get(
"after_limit", 5
))
except KeyError:
raise SynapseError(400, "Invalid search query")
if order_by not in ("rank", "recent"):
raise SynapseError(400, "Invalid order by: %r" % (order_by,))
if set(group_keys) - {"room_id", "sender"}:
raise SynapseError(
400,
"Invalid group by keys: %r" % (set(group_keys) - {"room_id", "sender"},)
)
search_filter = Filter(filter_dict)
# TODO: Search through left rooms too
rooms = yield self.store.get_rooms_for_user_where_membership_is(
user.to_string(),
membership_list=[Membership.JOIN],
# membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
)
room_ids = set(r.room_id for r in rooms)
room_ids = search_filter.filter_rooms(room_ids)
if batch_group == "room_id":
room_ids.intersection_update({batch_group_key})
rank_map = {} # event_id -> rank of event
allowed_events = []
room_groups = {} # Holds result of grouping by room, if applicable
sender_group = {} # Holds result of grouping by sender, if applicable
# Holds the next_batch for the entire result set if one of those exists
global_next_batch = None
if order_by == "rank":
results = yield self.store.search_msgs(
room_ids, search_term, keys
)
results_map = {r["event"].event_id: r for r in results}
rank_map.update({r["event"].event_id: r["rank"] for r in results})
filtered_events = search_filter.filter([r["event"] for r in results])
events = yield self._filter_events_for_client(
user.to_string(), filtered_events
)
events.sort(key=lambda e: -rank_map[e.event_id])
allowed_events = events[:search_filter.limit()]
for e in allowed_events:
rm = room_groups.setdefault(e.room_id, {
"results": [],
"order": rank_map[e.event_id],
})
rm["results"].append(e.event_id)
s = sender_group.setdefault(e.sender, {
"results": [],
"order": rank_map[e.event_id],
})
s["results"].append(e.event_id)
elif order_by == "recent":
# In this case we specifically loop through each room as the given
# limit applies to each room, rather than a global list.
# This is not necessarilly a good idea.
for room_id in room_ids:
room_events = []
if batch_group == "room_id" and batch_group_key == room_id:
pagination_token = batch_token
else:
pagination_token = None
i = 0
# We keep looping and we keep filtering until we reach the limit
# or we run out of things.
# But only go around 5 times since otherwise synapse will be sad.
while len(room_events) < search_filter.limit() and i < 5:
i += 1
results = yield self.store.search_room(
room_id, search_term, keys, search_filter.limit() * 2,
pagination_token=pagination_token,
)
results_map = {r["event"].event_id: r for r in results}
rank_map.update({r["event"].event_id: r["rank"] for r in results})
filtered_events = search_filter.filter([
r["event"] for r in results
])
events = yield self._filter_events_for_client(
user.to_string(), filtered_events
)
room_events.extend(events)
room_events = room_events[:search_filter.limit()]
if len(results) < search_filter.limit() * 2:
pagination_token = None
break
else:
pagination_token = results[-1]["pagination_token"]
if room_events:
res = results_map[room_events[-1].event_id]
pagination_token = res["pagination_token"]
group = room_groups.setdefault(room_id, {})
if pagination_token:
next_batch = encode_base64("%s\n%s\n%s" % (
"room_id", room_id, pagination_token
))
group["next_batch"] = next_batch
if batch_token:
global_next_batch = next_batch
group["results"] = [e.event_id for e in room_events]
group["order"] = max(
e.origin_server_ts/1000 for e in room_events
if hasattr(e, "origin_server_ts")
)
allowed_events.extend(room_events)
# Normalize the group orders
if room_groups:
if len(room_groups) > 1:
mx = max(g["order"] for g in room_groups.values())
mn = min(g["order"] for g in room_groups.values())
for g in room_groups.values():
g["order"] = (g["order"] - mn) * 1.0 / (mx - mn)
else:
room_groups.values()[0]["order"] = 1
else:
# We should never get here due to the guard earlier.
raise NotImplementedError()
# If client has asked for "context" for each event (i.e. some surrounding
# events and state), fetch that
if event_context is not None:
now_token = yield self.hs.get_event_sources().get_current_token()
contexts = {}
for event in allowed_events:
res = yield self.store.get_events_around(
event.room_id, event.event_id, before_limit, after_limit
)
res["events_before"] = yield self._filter_events_for_client(
user.to_string(), res["events_before"]
)
res["events_after"] = yield self._filter_events_for_client(
user.to_string(), res["events_after"]
)
res["start"] = now_token.copy_and_replace(
"room_key", res["start"]
).to_string()
res["end"] = now_token.copy_and_replace(
"room_key", res["end"]
).to_string()
contexts[event.event_id] = res
else:
contexts = {}
# TODO: Add a limit
time_now = self.clock.time_msec()
for context in contexts.values():
context["events_before"] = [
serialize_event(e, time_now)
for e in context["events_before"]
]
context["events_after"] = [
serialize_event(e, time_now)
for e in context["events_after"]
]
results = {
e.event_id: {
"rank": rank_map[e.event_id],
"result": serialize_event(e, time_now),
"context": contexts.get(e.event_id, {}),
}
for e in allowed_events
}
logger.info("Found %d results", len(results))
rooms_cat_res = {
"results": results,
"count": len(results)
}
if room_groups and "room_id" in group_keys:
rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups
if sender_group and "sender" in group_keys:
rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
if global_next_batch:
rooms_cat_res["next_batch"] = global_next_batch
defer.returnValue({
"search_categories": {
"room_events": rooms_cat_res
}
})
+236 -468
View File
@@ -28,30 +28,22 @@ logger = logging.getLogger(__name__)
SyncConfig = collections.namedtuple("SyncConfig", [
"user",
"limit",
"gap",
"sort",
"backfill",
"filter",
])
class TimelineBatch(collections.namedtuple("TimelineBatch", [
"prev_batch",
"events",
class RoomSyncResult(collections.namedtuple("RoomSyncResult", [
"room_id",
"limited",
])):
__slots__ = []
def __nonzero__(self):
"""Make the result appear empty if there are no updates. This is used
to tell if room needs to be part of the sync result.
"""
return bool(self.events)
class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
"room_id", # str
"timeline", # TimelineBatch
"state", # dict[(str, str), FrozenEvent]
"published",
"events",
"state",
"prev_batch",
"ephemeral",
"private_user_data",
])):
__slots__ = []
@@ -59,50 +51,14 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
"""Make the result appear empty if there are no updates. This is used
to tell if room needs to be part of the sync result.
"""
return bool(
self.timeline
or self.state
or self.ephemeral
or self.private_user_data
)
class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [
"room_id", # str
"timeline", # TimelineBatch
"state", # dict[(str, str), FrozenEvent]
"private_user_data",
])):
__slots__ = []
def __nonzero__(self):
"""Make the result appear empty if there are no updates. This is used
to tell if room needs to be part of the sync result.
"""
return bool(
self.timeline
or self.state
or self.private_user_data
)
class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
"room_id", # str
"invite", # FrozenEvent: the invite event
])):
__slots__ = []
def __nonzero__(self):
"""Invited rooms should always be reported to the client"""
return True
return bool(self.events or self.state or self.ephemeral)
class SyncResult(collections.namedtuple("SyncResult", [
"next_batch", # Token for the next sync
"presence", # List of presence events for the user.
"joined", # JoinedSyncResult for each joined room.
"invited", # InvitedSyncResult for each invited room.
"archived", # ArchivedSyncResult for each archived room.
"private_user_data", # List of private events for the user.
"public_user_data", # List of public events for all users.
"rooms", # RoomSyncResult for each room.
])):
__slots__ = []
@@ -112,7 +68,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
events.
"""
return bool(
self.presence or self.joined or self.invited
self.private_user_data or self.public_user_data or self.rooms
)
@@ -124,57 +80,66 @@ class SyncHandler(BaseHandler):
self.clock = hs.get_clock()
@defer.inlineCallbacks
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
full_state=False):
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0):
"""Get the sync for a client if we have new data for it now. Otherwise
wait for new data to arrive on the server. If the timeout expires, then
return an empty sync result.
Returns:
A Deferred SyncResult.
"""
if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling
# notifier.wait_for_events.
result = yield self.current_sync_for_user(sync_config, since_token,
full_state=full_state)
if timeout == 0 or since_token is None:
result = yield self.current_sync_for_user(sync_config, since_token)
defer.returnValue(result)
else:
def current_sync_callback(before_token, after_token):
return self.current_sync_for_user(sync_config, since_token)
rm_handler = self.hs.get_handlers().room_member_handler
app_service = yield self.store.get_app_service_by_user_id(
sync_config.user.to_string()
)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
room_ids = set(r.room_id for r in rooms)
else:
room_ids = yield rm_handler.get_joined_rooms_for_user(
sync_config.user
)
result = yield self.notifier.wait_for_events(
sync_config.user, timeout, current_sync_callback,
from_token=since_token
sync_config.user, room_ids,
sync_config.filter, timeout, current_sync_callback
)
defer.returnValue(result)
def current_sync_for_user(self, sync_config, since_token=None,
full_state=False):
def current_sync_for_user(self, sync_config, since_token=None):
"""Get the sync for client needed to match what the server has now.
Returns:
A Deferred SyncResult.
"""
if since_token is None or full_state:
return self.full_state_sync(sync_config, since_token)
if since_token is None:
return self.initial_sync(sync_config)
else:
return self.incremental_sync_with_gap(sync_config, since_token)
if sync_config.gap:
return self.incremental_sync_with_gap(sync_config, since_token)
else:
# TODO(mjark): Handle gapless sync
raise NotImplementedError()
@defer.inlineCallbacks
def full_state_sync(self, sync_config, timeline_since_token):
"""Get a sync for a client which is starting without any state.
If a 'message_since_token' is given, only timeline events which have
happened since that token will be returned.
def initial_sync(self, sync_config):
"""Get a sync for a client which is starting without any state
Returns:
A Deferred SyncResult.
"""
now_token = yield self.event_sources.get_current_token()
if sync_config.sort == "timeline,desc":
# TODO(mjark): Handle going through events in reverse order?.
# What does "most recent events" mean when applying the limits mean
# in this case?
raise NotImplementedError()
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
sync_config, now_token
)
now_token = yield self.event_sources.get_current_token()
presence_stream = self.event_sources.sources["presence"]
# TODO (mjark): This looks wrong, shouldn't we be getting the presence
@@ -187,179 +152,52 @@ class SyncHandler(BaseHandler):
)
room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=sync_config.user.to_string(),
membership_list=(
Membership.INVITE,
Membership.JOIN,
Membership.LEAVE,
Membership.BAN
)
membership_list=[Membership.INVITE, Membership.JOIN]
)
tags_by_room = yield self.store.get_tags_for_user(
sync_config.user.to_string()
)
# TODO (mjark): Does public mean "published"?
published_rooms = yield self.store.get_rooms(is_public=True)
published_room_ids = set(r["room_id"] for r in published_rooms)
joined = []
invited = []
archived = []
rooms = []
for event in room_list:
if event.membership == Membership.JOIN:
room_sync = yield self.full_state_sync_for_joined_room(
room_id=event.room_id,
sync_config=sync_config,
now_token=now_token,
timeline_since_token=timeline_since_token,
ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room,
)
joined.append(room_sync)
elif event.membership == Membership.INVITE:
invite = yield self.store.get_event(event.event_id)
invited.append(InvitedSyncResult(
room_id=event.room_id,
invite=invite,
))
elif event.membership in (Membership.LEAVE, Membership.BAN):
leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,)
)
room_sync = yield self.full_state_sync_for_archived_room(
sync_config=sync_config,
room_id=event.room_id,
leave_event_id=event.event_id,
leave_token=leave_token,
timeline_since_token=timeline_since_token,
tags_by_room=tags_by_room,
)
archived.append(room_sync)
room_sync = yield self.initial_sync_for_room(
event.room_id, sync_config, now_token, published_room_ids
)
rooms.append(room_sync)
defer.returnValue(SyncResult(
presence=presence,
joined=joined,
invited=invited,
archived=archived,
public_user_data=presence,
private_user_data=[],
rooms=rooms,
next_batch=now_token,
))
@defer.inlineCallbacks
def full_state_sync_for_joined_room(self, room_id, sync_config,
now_token, timeline_since_token,
ephemeral_by_room, tags_by_room):
def initial_sync_for_room(self, room_id, sync_config, now_token,
published_room_ids):
"""Sync a room for a client which is starting without any state
Returns:
A Deferred JoinedSyncResult.
A Deferred RoomSyncResult.
"""
batch = yield self.load_filtered_recents(
room_id, sync_config, now_token, since_token=timeline_since_token
recents, prev_batch_token, limited = yield self.load_filtered_recents(
room_id, sync_config, now_token,
)
current_state = yield self.get_state_at(room_id, now_token)
current_state = yield self.state_handler.get_current_state(
room_id
)
current_state_events = current_state.values()
defer.returnValue(JoinedSyncResult(
defer.returnValue(RoomSyncResult(
room_id=room_id,
timeline=batch,
state=current_state,
ephemeral=ephemeral_by_room.get(room_id, []),
private_user_data=self.private_user_data_for_room(
room_id, tags_by_room
),
))
def private_user_data_for_room(self, room_id, tags_by_room):
private_user_data = []
tags = tags_by_room.get(room_id)
if tags is not None:
private_user_data.append({
"type": "m.tag",
"content": {"tags": tags},
})
return private_user_data
@defer.inlineCallbacks
def ephemeral_by_room(self, sync_config, now_token, since_token=None):
"""Get the ephemeral events for each room the user is in
Args:
sync_config (SyncConfig): The flags, filters and user for the sync.
now_token (StreamToken): Where the server is currently up to.
since_token (StreamToken): Where the server was when the client
last synced.
Returns:
A tuple of the now StreamToken, updated to reflect the which typing
events are included, and a dict mapping from room_id to a list of
typing events for that room.
"""
typing_key = since_token.typing_key if since_token else "0"
rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
room_ids = [room.room_id for room in rooms]
typing_source = self.event_sources.sources["typing"]
typing, typing_key = yield typing_source.get_new_events(
user=sync_config.user,
from_key=typing_key,
limit=sync_config.filter.ephemeral_limit(),
room_ids=room_ids,
is_guest=False,
)
now_token = now_token.copy_and_replace("typing_key", typing_key)
ephemeral_by_room = {}
for event in typing:
# we want to exclude the room_id from the event, but modifying the
# result returned by the event source is poor form (it might cache
# the object)
room_id = event["room_id"]
event_copy = {k: v for (k, v) in event.iteritems()
if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
receipt_key = since_token.receipt_key if since_token else "0"
receipt_source = self.event_sources.sources["receipt"]
receipts, receipt_key = yield receipt_source.get_new_events(
user=sync_config.user,
from_key=receipt_key,
limit=sync_config.filter.ephemeral_limit(),
room_ids=room_ids,
# /sync doesn't support guest access, they can't get to this point in code
is_guest=False,
)
now_token = now_token.copy_and_replace("receipt_key", receipt_key)
for event in receipts:
room_id = event["room_id"]
# exclude room id, as above
event_copy = {k: v for (k, v) in event.iteritems()
if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
defer.returnValue((now_token, ephemeral_by_room))
@defer.inlineCallbacks
def full_state_sync_for_archived_room(self, room_id, sync_config,
leave_event_id, leave_token,
timeline_since_token, tags_by_room):
"""Sync a room for a client which is starting without any state
Returns:
A Deferred JoinedSyncResult.
"""
batch = yield self.load_filtered_recents(
room_id, sync_config, leave_token, since_token=timeline_since_token
)
leave_state = yield self.store.get_state_for_event(leave_event_id)
defer.returnValue(ArchivedSyncResult(
room_id=room_id,
timeline=batch,
state=leave_state,
private_user_data=self.private_user_data_for_room(
room_id, tags_by_room
),
published=room_id in published_room_ids,
events=recents,
prev_batch=prev_batch_token,
state=current_state_events,
limited=limited,
ephemeral=[],
))
@defer.inlineCallbacks
@@ -369,25 +207,34 @@ class SyncHandler(BaseHandler):
Returns:
A Deferred SyncResult.
"""
if sync_config.sort == "timeline,desc":
# TODO(mjark): Handle going through events in reverse order?.
# What does "most recent events" mean when applying the limits mean
# in this case?
raise NotImplementedError()
now_token = yield self.event_sources.get_current_token()
rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
room_ids = [room.room_id for room in rooms]
presence_source = self.event_sources.sources["presence"]
presence, presence_key = yield presence_source.get_new_events(
presence, presence_key = yield presence_source.get_new_events_for_user(
user=sync_config.user,
from_key=since_token.presence_key,
limit=sync_config.filter.presence_limit(),
room_ids=room_ids,
# /sync doesn't support guest access, they can't get to this point in code
is_guest=False,
limit=sync_config.limit,
)
now_token = now_token.copy_and_replace("presence_key", presence_key)
now_token, ephemeral_by_room = yield self.ephemeral_by_room(
sync_config, now_token, since_token
typing_source = self.event_sources.sources["typing"]
typing, typing_key = yield typing_source.get_new_events_for_user(
user=sync_config.user,
from_key=since_token.typing_key,
limit=sync_config.limit,
)
now_token = now_token.copy_and_replace("typing_key", typing_key)
typing_by_room = {event["room_id"]: [event] for event in typing}
for event in typing:
event.pop("room_id")
logger.debug("Typing %r", typing_by_room)
rm_handler = self.hs.get_handlers().room_member_handler
app_service = yield self.store.get_app_service_by_user_id(
@@ -395,55 +242,35 @@ class SyncHandler(BaseHandler):
)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
joined_room_ids = set(r.room_id for r in rooms)
room_ids = set(r.room_id for r in rooms)
else:
joined_room_ids = yield rm_handler.get_joined_rooms_for_user(
room_ids = yield rm_handler.get_joined_rooms_for_user(
sync_config.user
)
timeline_limit = sync_config.filter.timeline_limit()
# TODO (mjark): Does public mean "published"?
published_rooms = yield self.store.get_rooms(is_public=True)
published_room_ids = set(r["room_id"] for r in published_rooms)
room_events, _ = yield self.store.get_room_events_stream(
sync_config.user.to_string(),
from_key=since_token.room_key,
to_key=now_token.room_key,
limit=timeline_limit + 1,
room_id=None,
limit=sync_config.limit + 1,
)
tags_by_room = yield self.store.get_updated_tags(
sync_config.user.to_string(),
since_token.private_user_data_key,
)
joined = []
archived = []
if len(room_events) <= timeline_limit:
rooms = []
if len(room_events) <= sync_config.limit:
# There is no gap in any of the rooms. Therefore we can just
# partition the new events by room and return them.
logger.debug("Got %i events for incremental sync - not limited",
len(room_events))
invite_events = []
leave_events = []
events_by_room_id = {}
for event in room_events:
events_by_room_id.setdefault(event.room_id, []).append(event)
if event.room_id not in joined_room_ids:
if (event.type == EventTypes.Member
and event.state_key == sync_config.user.to_string()):
if event.membership == Membership.INVITE:
invite_events.append(event)
elif event.membership in (Membership.LEAVE, Membership.BAN):
leave_events.append(event)
for room_id in joined_room_ids:
for room_id in room_ids:
recents = events_by_room_id.get(room_id, [])
logger.debug("Events for room %s: %r", room_id, recents)
state = {
(event.type, event.state_key): event
for event in recents if event.is_state()}
limited = False
state = [event for event in recents if event.is_state()]
if recents:
prev_batch = now_token.copy_and_replace(
"room_key", recents[0].internal_metadata.before
@@ -451,87 +278,95 @@ class SyncHandler(BaseHandler):
else:
prev_batch = now_token
just_joined = yield self.check_joined_room(sync_config, state)
if just_joined:
logger.debug("User has just joined %s: needs full state",
room_id)
state = yield self.get_state_at(room_id, now_token)
# the timeline is inherently limited if we've just joined
limited = True
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=TimelineBatch(
events=recents,
prev_batch=prev_batch,
limited=limited,
),
state=state,
ephemeral=ephemeral_by_room.get(room_id, []),
private_user_data=self.private_user_data_for_room(
room_id, tags_by_room
),
state = yield self.check_joined_room(
sync_config, room_id, state
)
logger.debug("Result for room %s: %r", room_id, room_sync)
room_sync = RoomSyncResult(
room_id=room_id,
published=room_id in published_room_ids,
events=recents,
prev_batch=prev_batch,
state=state,
limited=False,
ephemeral=typing_by_room.get(room_id, [])
)
if room_sync:
joined.append(room_sync)
rooms.append(room_sync)
else:
logger.debug("Got %i events for incremental sync - hit limit",
len(room_events))
invite_events = yield self.store.get_invites_for_user(
sync_config.user.to_string()
)
leave_events = yield self.store.get_leave_and_ban_events_for_user(
sync_config.user.to_string()
)
for room_id in joined_room_ids:
for room_id in room_ids:
room_sync = yield self.incremental_sync_with_gap_for_room(
room_id, sync_config, since_token, now_token,
ephemeral_by_room, tags_by_room
published_room_ids, typing_by_room
)
if room_sync:
joined.append(room_sync)
for leave_event in leave_events:
room_sync = yield self.incremental_sync_for_archived_room(
sync_config, leave_event, since_token, tags_by_room
)
archived.append(room_sync)
invited = [
InvitedSyncResult(room_id=event.room_id, invite=event)
for event in invite_events
]
rooms.append(room_sync)
defer.returnValue(SyncResult(
presence=presence,
joined=joined,
invited=invited,
archived=archived,
public_user_data=presence,
private_user_data=[],
rooms=rooms,
next_batch=now_token,
))
@defer.inlineCallbacks
def _filter_events_for_client(self, user_id, room_id, events):
event_id_to_state = yield self.store.get_state_for_events(
room_id, frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id),
)
)
def allowed(event, state):
if event.type == EventTypes.RoomHistoryVisibility:
return True
membership_ev = state.get((EventTypes.Member, user_id), None)
if membership_ev:
membership = membership_ev.membership
else:
membership = Membership.LEAVE
if membership == Membership.JOIN:
return True
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
if history:
visibility = history.content.get("history_visibility", "shared")
else:
visibility = "shared"
if visibility == "public":
return True
elif visibility == "shared":
return True
elif visibility == "joined":
return membership == Membership.JOIN
elif visibility == "invited":
return membership == Membership.INVITE
return True
defer.returnValue([
event
for event in events
if allowed(event, event_id_to_state[event.event_id])
])
@defer.inlineCallbacks
def load_filtered_recents(self, room_id, sync_config, now_token,
since_token=None):
"""
:returns a Deferred TimelineBatch
"""
limited = True
recents = []
filtering_factor = 2
timeline_limit = sync_config.filter.timeline_limit()
load_limit = max(timeline_limit * filtering_factor, 100)
load_limit = max(sync_config.limit * filtering_factor, 100)
max_repeat = 3 # Only try a few times per room, otherwise
room_key = now_token.room_key
end_key = room_key
while limited and len(recents) < timeline_limit and max_repeat:
while limited and len(recents) < sync_config.limit and max_repeat:
events, keys = yield self.store.get_recent_events_for_room(
room_id,
limit=load_limit + 1,
@@ -540,9 +375,9 @@ class SyncHandler(BaseHandler):
)
(room_key, _) = keys
end_key = "s" + room_key.split('-')[-1]
loaded_recents = sync_config.filter.filter_room_timeline(events)
loaded_recents = sync_config.filter.filter_room_events(events)
loaded_recents = yield self._filter_events_for_client(
sync_config.user.to_string(), loaded_recents,
sync_config.user.to_string(), room_id, loaded_recents,
)
loaded_recents.extend(recents)
recents = loaded_recents
@@ -550,112 +385,64 @@ class SyncHandler(BaseHandler):
limited = False
max_repeat -= 1
if len(recents) > timeline_limit:
limited = True
recents = recents[-timeline_limit:]
if len(recents) > sync_config.limit:
recents = recents[-sync_config.limit:]
room_key = recents[0].internal_metadata.before
prev_batch_token = now_token.copy_and_replace(
"room_key", room_key
)
defer.returnValue(TimelineBatch(
events=recents, prev_batch=prev_batch_token, limited=limited
))
defer.returnValue((recents, prev_batch_token, limited))
@defer.inlineCallbacks
def incremental_sync_with_gap_for_room(self, room_id, sync_config,
since_token, now_token,
ephemeral_by_room, tags_by_room):
published_room_ids, typing_by_room):
""" Get the incremental delta needed to bring the client up to date for
the room. Gives the client the most recent events and the changes to
state.
Returns:
A Deferred JoinedSyncResult
A Deferred RoomSyncResult
"""
logger.debug("Doing incremental sync for room %s between %s and %s",
room_id, since_token, now_token)
# TODO(mjark): Check for redactions we might have missed.
batch = yield self.load_filtered_recents(
recents, prev_batch_token, limited = yield self.load_filtered_recents(
room_id, sync_config, now_token, since_token,
)
logging.debug("Recents %r", batch)
logging.debug("Recents %r", recents)
current_state = yield self.get_state_at(room_id, now_token)
state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token
# TODO(mjark): This seems racy since this isn't being passed a
# token to indicate what point in the stream this is
current_state = yield self.state_handler.get_current_state(
room_id
)
current_state_events = current_state.values()
state = yield self.compute_state_delta(
since_token=since_token,
previous_state=state_at_previous_sync,
current_state=current_state,
)
just_joined = yield self.check_joined_room(sync_config, state)
if just_joined:
state = yield self.get_state_at(room_id, now_token)
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=batch,
state=state,
ephemeral=ephemeral_by_room.get(room_id, []),
private_user_data=self.private_user_data_for_room(
room_id, tags_by_room
),
)
logging.debug("Room sync: %r", room_sync)
defer.returnValue(room_sync)
@defer.inlineCallbacks
def incremental_sync_for_archived_room(self, sync_config, leave_event,
since_token, tags_by_room):
""" Get the incremental delta needed to bring the client up to date for
the archived room.
Returns:
A Deferred ArchivedSyncResult
"""
stream_token = yield self.store.get_stream_token_for_event(
leave_event.event_id
)
leave_token = since_token.copy_and_replace("room_key", stream_token)
batch = yield self.load_filtered_recents(
leave_event.room_id, sync_config, leave_token, since_token,
)
logging.debug("Recents %r", batch)
state_events_at_leave = yield self.store.get_state_for_event(
leave_event.event_id
)
state_at_previous_sync = yield self.get_state_at(
leave_event.room_id, stream_position=since_token
state_at_previous_sync = yield self.get_state_at_previous_sync(
room_id, since_token=since_token
)
state_events_delta = yield self.compute_state_delta(
since_token=since_token,
previous_state=state_at_previous_sync,
current_state=state_events_at_leave,
current_state=current_state_events,
)
room_sync = ArchivedSyncResult(
room_id=leave_event.room_id,
timeline=batch,
state_events_delta = yield self.check_joined_room(
sync_config, room_id, state_events_delta
)
room_sync = RoomSyncResult(
room_id=room_id,
published=room_id in published_room_ids,
events=recents,
prev_batch=prev_batch_token,
state=state_events_delta,
private_user_data=self.private_user_data_for_room(
leave_event.room_id, tags_by_room
),
limited=limited,
ephemeral=typing_by_room.get(room_id, [])
)
logging.debug("Room sync: %r", room_sync)
@@ -663,77 +450,58 @@ class SyncHandler(BaseHandler):
defer.returnValue(room_sync)
@defer.inlineCallbacks
def get_state_after_event(self, event):
"""
Get the room state after the given event
:param synapse.events.EventBase event: event of interest
:return: A Deferred map from ((type, state_key)->Event)
"""
state = yield self.store.get_state_for_event(event.event_id)
if event.is_state():
state = state.copy()
state[(event.type, event.state_key)] = event
defer.returnValue(state)
@defer.inlineCallbacks
def get_state_at(self, room_id, stream_position):
""" Get the room state at a particular stream position
:param str room_id: room for which to get state
:param StreamToken stream_position: point at which to get state
:returns: A Deferred map from ((type, state_key)->Event)
def get_state_at_previous_sync(self, room_id, since_token):
""" Get the room state at the previous sync the client made.
Returns:
A Deferred list of Events.
"""
last_events, token = yield self.store.get_recent_events_for_room(
room_id, end_token=stream_position.room_key, limit=1,
room_id, end_token=since_token.room_key, limit=1,
)
if last_events:
last_event = last_events[-1]
state = yield self.get_state_after_event(last_event)
last_event = last_events[0]
last_context = yield self.state_handler.compute_event_context(
last_event
)
if last_event.is_state():
state = [last_event] + last_context.current_state.values()
else:
state = last_context.current_state.values()
else:
# no events in this room - so presumably no state
state = {}
state = ()
defer.returnValue(state)
def compute_state_delta(self, since_token, previous_state, current_state):
""" Works out the differnce in state between the current state and the
state the client got when it last performed a sync.
:param str since_token: the point we are comparing against
:param dict[(str,str), synapse.events.FrozenEvent] previous_state: the
state to compare to
:param dict[(str,str), synapse.events.FrozenEvent] current_state: the
new state
:returns A new event dictionary
Returns:
A list of events.
"""
# TODO(mjark) Check if the state events were received by the server
# after the previous sync, since we need to include those state
# updates even if they occured logically before the previous event.
# TODO(mjark) Check for new redactions in the state events.
state_delta = {}
for key, event in current_state.iteritems():
if (key not in previous_state or
previous_state[key].event_id != event.event_id):
state_delta[key] = event
previous_dict = {event.event_id: event for event in previous_state}
state_delta = []
for event in current_state:
if event.event_id not in previous_dict:
state_delta.append(event)
return state_delta
def check_joined_room(self, sync_config, state_delta):
"""
Check if the user has just joined the given room (so should
be given the full state)
@defer.inlineCallbacks
def check_joined_room(self, sync_config, room_id, state_delta):
joined = False
for event in state_delta:
if (
event.type == EventTypes.Member
and event.state_key == sync_config.user.to_string()
):
if event.content["membership"] == Membership.JOIN:
joined = True
:param sync_config:
:param dict[(str,str), synapse.events.FrozenEvent] state_delta: the
difference in state since the last sync
if joined:
res = yield self.state_handler.get_current_state(room_id)
state_delta = res.values()
:returns A deferred Tuple (state_delta, limited)
"""
join_event = state_delta.get((
EventTypes.Member, sync_config.user.to_string()), None)
if join_event is not None:
if join_event.content["membership"] == Membership.JOIN:
return True
return False
defer.returnValue(state_delta)
+8 -3
View File
@@ -246,12 +246,17 @@ class TypingNotificationEventSource(object):
},
}
def get_new_events(self, from_key, room_ids, **kwargs):
@defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit):
from_key = int(from_key)
handler = self.handler()
joined_room_ids = (
yield self.room_member_handler().get_joined_rooms_for_user(user)
)
events = []
for room_id in room_ids:
for room_id in joined_room_ids:
if room_id not in handler._room_serials:
continue
if handler._room_serials[room_id] <= from_key:
@@ -259,7 +264,7 @@ class TypingNotificationEventSource(object):
events.append(self._make_event_for(room_id))
return events, handler._latest_room_serial
defer.returnValue((events, handler._latest_room_serial))
def get_current_key(self):
return self.handler()._latest_room_serial
+30 -46
View File
@@ -24,6 +24,7 @@ from canonicaljson import encode_canonical_json
from twisted.internet import defer, reactor, ssl
from twisted.web.client import (
Agent, readBody, FileBodyProducer, PartialDownloadError,
HTTPConnectionPool,
)
from twisted.web.http_headers import Headers
@@ -58,14 +59,15 @@ class SimpleHttpClient(object):
# The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser'
pool = HTTPConnectionPool(reactor)
pool.maxPersistentPerHost = 10
self.agent = Agent(
reactor,
pool=pool,
connectTimeout=15,
contextFactory=hs.get_http_client_context_factory()
)
self.user_agent = hs.version_string
if hs.config.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,)
self.version_string = hs.version_string
def request(self, method, uri, *args, **kwargs):
# A small wrapper around self.agent.request() so we can easily attach
@@ -110,7 +112,7 @@ class SimpleHttpClient(object):
uri.encode("ascii"),
headers=Headers({
b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [self.user_agent],
b"User-Agent": [self.version_string],
}),
bodyProducer=FileBodyProducer(StringIO(query_bytes))
)
@@ -129,8 +131,7 @@ class SimpleHttpClient(object):
"POST",
uri.encode("ascii"),
headers=Headers({
b"Content-Type": [b"application/json"],
b"User-Agent": [self.user_agent],
"Content-Type": ["application/json"]
}),
bodyProducer=FileBodyProducer(StringIO(json_str))
)
@@ -156,8 +157,27 @@ class SimpleHttpClient(object):
On a non-2xx HTTP response. The response body will be used as the
error message.
"""
body = yield self.get_raw(uri, args)
defer.returnValue(json.loads(body))
if len(args):
query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
response = yield self.request(
"GET",
uri.encode("ascii"),
headers=Headers({
b"User-Agent": [self.version_string],
})
)
body = yield preserve_context_over_fn(readBody, response)
if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
else:
# NB: This is explicitly not json.loads(body)'d because the contract
# of CodeMessageException is a *string* message. Callers can always
# load it into JSON if they want.
raise CodeMessageException(response.code, body)
@defer.inlineCallbacks
def put_json(self, uri, json_body, args={}):
@@ -186,7 +206,7 @@ class SimpleHttpClient(object):
"PUT",
uri.encode("ascii"),
headers=Headers({
b"User-Agent": [self.user_agent],
b"User-Agent": [self.version_string],
"Content-Type": ["application/json"]
}),
bodyProducer=FileBodyProducer(StringIO(json_str))
@@ -202,42 +222,6 @@ class SimpleHttpClient(object):
# load it into JSON if they want.
raise CodeMessageException(response.code, body)
@defer.inlineCallbacks
def get_raw(self, uri, args={}):
""" Gets raw text from the given URI.
Args:
uri (str): The URI to request, not including query parameters
args (dict): A dictionary used to create query strings, defaults to
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body at text.
Raises:
On a non-2xx HTTP response. The response body will be used as the
error message.
"""
if len(args):
query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
response = yield self.request(
"GET",
uri.encode("ascii"),
headers=Headers({
b"User-Agent": [self.user_agent],
})
)
body = yield preserve_context_over_fn(readBody, response)
if 200 <= response.code < 300:
defer.returnValue(body)
else:
raise CodeMessageException(response.code, body)
class CaptchaServerHttpClient(SimpleHttpClient):
"""
@@ -257,7 +241,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
bodyProducer=FileBodyProducer(StringIO(query_bytes)),
headers=Headers({
b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [self.user_agent],
b"User-Agent": [self.version_string],
})
)
+2 -9
View File
@@ -35,7 +35,6 @@ from signedjson.sign import sign_json
import simplejson as json
import logging
import random
import sys
import urllib
import urlparse
@@ -56,9 +55,6 @@ incoming_responses_counter = metrics.register_counter(
)
MAX_RETRIES = 10
class MatrixFederationEndpointFactory(object):
def __init__(self, hs):
self.tls_server_context_factory = hs.tls_server_context_factory
@@ -123,7 +119,7 @@ class MatrixFederationHttpClient(object):
# XXX: Would be much nicer to retry only at the transaction-layer
# (once we have reliable transactions in place)
retries_left = MAX_RETRIES
retries_left = 5
http_url_bytes = urlparse.urlunparse(
("", "", path_bytes, param_bytes, query_bytes, "")
@@ -184,10 +180,7 @@ class MatrixFederationHttpClient(object):
)
if retries_left and not timeout:
delay = 4 ** (MAX_RETRIES + 1 - retries_left)
delay = max(delay, 60)
delay *= random.uniform(0.8, 1.4)
yield sleep(delay)
yield sleep(2 ** (5 - retries_left))
retries_left -= 1
else:
raise
+10 -50
View File
@@ -14,8 +14,6 @@
# limitations under the License.
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor, ObservableDeferred
@@ -271,8 +269,8 @@ class Notifier(object):
logger.exception("Failed to notify listener")
@defer.inlineCallbacks
def wait_for_events(self, user, timeout, callback, room_ids=None,
from_token=StreamToken("s0", "0", "0", "0", "0")):
def wait_for_events(self, user, rooms, timeout, callback,
from_token=StreamToken("s0", "0", "0", "0")):
"""Wait until the callback returns a non empty response or the
timeout fires.
"""
@@ -281,12 +279,11 @@ class Notifier(object):
if user_stream is None:
appservice = yield self.store.get_app_service_by_user_id(user)
current_token = yield self.event_sources.get_current_token()
if room_ids is None:
rooms = yield self.store.get_rooms_for_user(user)
room_ids = [room.room_id for room in rooms]
rooms = yield self.store.get_rooms_for_user(user)
rooms = [room.room_id for room in rooms]
user_stream = _NotifierUserStream(
user=user,
rooms=room_ids,
rooms=rooms,
appservice=appservice,
current_token=current_token,
time_now_ms=self.clock.time_msec(),
@@ -331,9 +328,8 @@ class Notifier(object):
defer.returnValue(result)
@defer.inlineCallbacks
def get_events_for(self, user, pagination_config, timeout,
only_room_events=False,
is_guest=False, guest_room_id=None):
def get_events_for(self, user, rooms, pagination_config, timeout,
only_room_events=False):
""" For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any
new events to happen before returning.
@@ -346,16 +342,6 @@ class Notifier(object):
limit = pagination_config.limit
room_ids = []
if is_guest:
if guest_room_id:
if not self._is_world_readable(guest_room_id):
raise AuthError(403, "Guest access not allowed")
room_ids = [guest_room_id]
else:
rooms = yield self.store.get_rooms_for_user(user.to_string())
room_ids = [room.room_id for room in rooms]
@defer.inlineCallbacks
def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token):
@@ -363,7 +349,6 @@ class Notifier(object):
events = []
end_token = from_token
for name, source in self.event_sources.sources.items():
keyname = "%s_key" % name
before_id = getattr(before_token, keyname)
@@ -372,23 +357,9 @@ class Notifier(object):
continue
if only_room_events and name != "room":
continue
new_events, new_key = yield source.get_new_events(
user=user,
from_key=getattr(from_token, keyname),
limit=limit,
is_guest=is_guest,
room_ids=room_ids,
new_events, new_key = yield source.get_new_events_for_user(
user, getattr(from_token, keyname), limit,
)
if name == "room":
room_member_handler = self.hs.get_handlers().room_member_handler
new_events = yield room_member_handler._filter_events_for_client(
user.to_string(),
new_events,
is_guest=is_guest,
require_all_visible_for_guests=False
)
events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key)
@@ -398,7 +369,7 @@ class Notifier(object):
defer.returnValue(None)
result = yield self.wait_for_events(
user, timeout, check_for_updates, room_ids=room_ids, from_token=from_token
user, rooms, timeout, check_for_updates, from_token=from_token
)
if result is None:
@@ -406,17 +377,6 @@ class Notifier(object):
defer.returnValue(result)
@defer.inlineCallbacks
def _is_world_readable(self, room_id):
state = yield self.hs.get_state_handler().get_current_state(
room_id,
EventTypes.RoomHistoryVisibility
)
if state and "history_visibility" in state.content:
defer.returnValue(state.content["history_visibility"] == "world_readable")
else:
defer.returnValue(False)
@log_function
def remove_expired_streams(self):
time_now_ms = self.clock.time_msec()
+1 -1
View File
@@ -186,7 +186,7 @@ class Pusher(object):
if not display_name:
return False
return re.search(
r"\b%s\b" % re.escape(display_name), ev['content']['body'],
"\b%s\b" % re.escape(display_name), ev['content']['body'],
flags=re.IGNORECASE
) is not None
+1 -1
View File
@@ -31,7 +31,7 @@ class WhoisRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id)
auth_user, _, _ = yield self.auth.get_user_by_req(request)
auth_user, _ = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(auth_user)
if not is_admin and target_user != auth_user:
+2 -2
View File
@@ -69,7 +69,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
try:
# try to auth as a user
user, _, _ = yield self.auth.get_user_by_req(request)
user, _ = yield self.auth.get_user_by_req(request)
try:
user_id = user.to_string()
yield dir_handler.create_association(
@@ -116,7 +116,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
# fallback to default user behaviour if they aren't an AS
pass
user, _, _ = yield self.auth.get_user_by_req(request)
user, _ = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(user)
if not is_admin:
+3 -12
View File
@@ -34,15 +34,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
auth_user, _, is_guest = yield self.auth.get_user_by_req(
request,
allow_guest=True
)
room_id = None
if is_guest:
if "room_id" not in request.args:
raise SynapseError(400, "Guest users must specify room_id param")
room_id = request.args["room_id"][0]
auth_user, _ = yield self.auth.get_user_by_req(request)
try:
handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request)
@@ -57,8 +49,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
chunk = yield handler.get_stream(
auth_user.to_string(), pagin_config, timeout=timeout,
as_client_event=as_client_event, affect_presence=(not is_guest),
room_id=room_id, is_guest=is_guest
as_client_event=as_client_event
)
except:
logger.exception("Event stream failed")
@@ -80,7 +71,7 @@ class EventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, event_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
auth_user, _ = yield self.auth.get_user_by_req(request)
handler = self.handlers.event_handler
event = yield handler.get_event(auth_user, event_id)
+2 -4
View File
@@ -25,16 +25,14 @@ class InitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
user, _, _ = yield self.auth.get_user_by_req(request)
user, _ = yield self.auth.get_user_by_req(request)
as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request)
handler = self.handlers.message_handler
include_archived = request.args.get("archived", None) == ["true"]
content = yield handler.snapshot_all_rooms(
user_id=user.to_string(),
pagin_config=pagination_config,
as_client_event=as_client_event,
include_archived=include_archived,
as_client_event=as_client_event
)
defer.returnValue((200, content))
+29 -251
View File
@@ -15,22 +15,18 @@
from twisted.internet import defer
from synapse.api.errors import SynapseError, LoginError, Codes
from synapse.http.client import SimpleHttpClient
from synapse.api.errors import SynapseError
from synapse.types import UserID
from base import ClientV1RestServlet, client_path_pattern
import simplejson as json
import urllib
import urlparse
import logging
from saml2 import BINDING_HTTP_POST
from saml2 import config
from saml2.client import Saml2Client
import xml.etree.ElementTree as ET
logger = logging.getLogger(__name__)
@@ -39,28 +35,16 @@ class LoginRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login$")
PASS_TYPE = "m.login.password"
SAML2_TYPE = "m.login.saml2"
CAS_TYPE = "m.login.cas"
TOKEN_TYPE = "m.login.token"
def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs)
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
self.password_enabled = hs.config.password_enabled
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
self.cas_server_url = hs.config.cas_server_url
self.cas_required_attributes = hs.config.cas_required_attributes
self.servername = hs.config.server_name
def on_GET(self, request):
flows = []
flows = [{"type": LoginRestServlet.PASS_TYPE}]
if self.saml2_enabled:
flows.append({"type": LoginRestServlet.SAML2_TYPE})
if self.cas_enabled:
flows.append({"type": LoginRestServlet.CAS_TYPE})
if self.password_enabled:
flows.append({"type": LoginRestServlet.PASS_TYPE})
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
return (200, {"flows": flows})
def on_OPTIONS(self, request):
@@ -71,9 +55,6 @@ class LoginRestServlet(ClientV1RestServlet):
login_submission = _parse_json(request)
try:
if login_submission["type"] == LoginRestServlet.PASS_TYPE:
if not self.password_enabled:
raise SynapseError(400, "Password login has been disabled.")
result = yield self.do_password_login(login_submission)
defer.returnValue(result)
elif self.saml2_enabled and (login_submission["type"] ==
@@ -86,23 +67,6 @@ class LoginRestServlet(ClientV1RestServlet):
"uri": "%s%s" % (self.idp_redirect_url, relay_state)
}
defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead
elif self.cas_enabled and (login_submission["type"] ==
LoginRestServlet.CAS_TYPE):
# TODO: get this from the homeserver rather than creating a new one for
# each request
http_client = SimpleHttpClient(self.hs)
uri = "%s/proxyValidate" % (self.cas_server_url,)
args = {
"ticket": login_submission["ticket"],
"service": login_submission["service"]
}
body = yield http_client.get_raw(uri, args)
result = yield self.do_cas_login(body)
defer.returnValue(result)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
result = yield self.do_token_login(login_submission)
defer.returnValue(result)
else:
raise SynapseError(400, "Bad login type.")
except KeyError:
@@ -114,8 +78,6 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
login_submission['medium'], login_submission['address']
)
if not user_id:
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
else:
user_id = login_submission['user']
@@ -138,94 +100,35 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result))
class LoginFallbackRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/fallback$")
def on_GET(self, request):
# TODO(kegan): This should be returning some HTML which is capable of
# hitting LoginRestServlet
return (200, {})
class PasswordResetRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/reset")
@defer.inlineCallbacks
def do_token_login(self, login_submission):
token = login_submission['token']
auth_handler = self.handlers.auth_handler
user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
user_id, access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead
@defer.inlineCallbacks
def do_cas_login(self, cas_response_body):
user, attributes = self.parse_cas_response(cas_response_body)
for required_attribute, required_value in self.cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id)
def on_POST(self, request):
reset_info = _parse_json(request)
try:
email = reset_info["email"]
user_id = reset_info["user_id"]
handler = self.handlers.login_handler
yield handler.reset_password(user_id, email)
# purposefully give no feedback to avoid people hammering different
# combinations.
defer.returnValue((200, {}))
except KeyError:
raise SynapseError(
400,
"Missing keys. Requires 'email' and 'user_id'."
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
else:
user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead
def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not root[0].tag.endswith("authenticationSuccess"):
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
attributes = {}
for attribute in child:
# ElementTree library expands the namespace in attribute tags
# to the full URL of the namespace.
# See (https://docs.python.org/2/library/xml.etree.elementtree.html)
# We don't care about namespace here and it will always be encased in
# curly braces, so we remove them.
if "}" in attribute.tag:
attributes[attribute.tag.split("}")[1]] = attribute.text
else:
attributes[attribute.tag] = attribute.text
if user is None or attributes is None:
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
return (user, attributes)
class SAML2RestServlet(ClientV1RestServlet):
@@ -271,127 +174,6 @@ class SAML2RestServlet(ClientV1RestServlet):
defer.returnValue((200, {"status": "not_authenticated"}))
# TODO Delete this after all CAS clients switch to token login instead
class CasRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/cas")
def __init__(self, hs):
super(CasRestServlet, self).__init__(hs)
self.cas_server_url = hs.config.cas_server_url
def on_GET(self, request):
return (200, {"serverUrl": self.cas_server_url})
class CasRedirectServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/cas/redirect")
def __init__(self, hs):
super(CasRedirectServlet, self).__init__(hs)
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
def on_GET(self, request):
args = request.args
if "redirectUrl" not in args:
return (400, "Redirect URL not specified for CAS auth")
client_redirect_url_param = urllib.urlencode({
"redirectUrl": args["redirectUrl"][0]
})
hs_redirect_url = self.cas_service_url + "/_matrix/client/api/v1/login/cas/ticket"
service_param = urllib.urlencode({
"service": "%s?%s" % (hs_redirect_url, client_redirect_url_param)
})
request.redirect("%s?%s" % (self.cas_server_url, service_param))
request.finish()
class CasTicketServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/cas/ticket")
def __init__(self, hs):
super(CasTicketServlet, self).__init__(hs)
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes
@defer.inlineCallbacks
def on_GET(self, request):
client_redirect_url = request.args["redirectUrl"][0]
http_client = self.hs.get_simple_http_client()
uri = self.cas_server_url + "/proxyValidate"
args = {
"ticket": request.args["ticket"],
"service": self.cas_service_url
}
body = yield http_client.get_raw(uri, args)
result = yield self.handle_cas_response(request, body, client_redirect_url)
defer.returnValue(result)
@defer.inlineCallbacks
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
user, attributes = self.parse_cas_response(cas_response_body)
for required_attribute, required_value in self.cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if not user_exists:
user_id, _ = (
yield self.handlers.registration_handler.register(localpart=user)
)
login_token = auth_handler.generate_short_term_login_token(user_id)
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
login_token)
request.redirect(redirect_url)
request.finish()
def add_login_token_to_redirect_url(self, url, token):
url_parts = list(urlparse.urlparse(url))
query = dict(urlparse.parse_qsl(url_parts[4]))
query.update({"loginToken": token})
url_parts[4] = urllib.urlencode(query)
return urlparse.urlunparse(url_parts)
def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not root[0].tag.endswith("authenticationSuccess"):
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
attributes = {}
for attribute in child:
# ElementTree library expands the namespace in attribute tags
# to the full URL of the namespace.
# See (https://docs.python.org/2/library/xml.etree.elementtree.html)
# We don't care about namespace here and it will always be encased in
# curly braces, so we remove them.
if "}" in attribute.tag:
attributes[attribute.tag.split("}")[1]] = attribute.text
else:
attributes[attribute.tag] = attribute.text
if user is None or attributes is None:
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
return (user, attributes)
def _parse_json(request):
try:
content = json.loads(request.content.read())
@@ -406,8 +188,4 @@ def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
if hs.config.saml2_enabled:
SAML2RestServlet(hs).register(http_server)
if hs.config.cas_enabled:
CasRedirectServlet(hs).register(http_server)
CasTicketServlet(hs).register(http_server)
CasRestServlet(hs).register(http_server)
# TODO PasswordResetRestServlet(hs).register(http_server)
+4 -4
View File
@@ -32,7 +32,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
auth_user, _ = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
state = yield self.handlers.presence_handler.get_state(
@@ -42,7 +42,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
auth_user, _ = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
state = {}
@@ -77,7 +77,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
auth_user, _ = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
if not self.hs.is_mine(user):
@@ -97,7 +97,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, user_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
auth_user, _ = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
if not self.hs.is_mine(user):
+2 -2
View File
@@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request, allow_guest=True)
auth_user, _ = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
try:
@@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
auth_user, _ = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
try:
+3 -3
View File
@@ -43,7 +43,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
except InvalidRuleException as e:
raise SynapseError(400, e.message)
user, _, _ = yield self.auth.get_user_by_req(request)
user, _ = yield self.auth.get_user_by_req(request)
if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
raise SynapseError(400, "rule_id may not contain slashes")
@@ -92,7 +92,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
def on_DELETE(self, request):
spec = _rule_spec_from_path(request.postpath)
user, _, _ = yield self.auth.get_user_by_req(request)
user, _ = yield self.auth.get_user_by_req(request)
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
@@ -109,7 +109,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
user, _, _ = yield self.auth.get_user_by_req(request)
user, _ = yield self.auth.get_user_by_req(request)
# we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is
+1 -1
View File
@@ -27,7 +27,7 @@ class PusherRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
user, token_id, _ = yield self.auth.get_user_by_req(request)
user, token_id = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
+20 -117
View File
@@ -17,7 +17,7 @@
from twisted.internet import defer
from base import ClientV1RestServlet, client_path_pattern
from synapse.api.errors import SynapseError, Codes, AuthError
from synapse.api.errors import SynapseError, Codes
from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership
from synapse.types import UserID, RoomID, RoomAlias
@@ -27,6 +27,7 @@ import simplejson as json
import logging
import urllib
logger = logging.getLogger(__name__)
@@ -61,7 +62,7 @@ class RoomCreateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
auth_user, _ = yield self.auth.get_user_by_req(request)
room_config = self.get_room_config(request)
info = yield self.make_room(room_config, auth_user, None)
@@ -124,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_type, state_key):
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
user, _ = yield self.auth.get_user_by_req(request)
msg_handler = self.handlers.message_handler
data = yield msg_handler.get_room_data(
@@ -132,7 +133,6 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
room_id=room_id,
event_type=event_type,
state_key=state_key,
is_guest=is_guest,
)
if not data:
@@ -143,7 +143,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
user, token_id, _ = yield self.auth.get_user_by_req(request)
user, token_id = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
@@ -175,7 +175,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_type, txn_id=None):
user, token_id, _ = yield self.auth.get_user_by_req(request, allow_guest=True)
user, token_id = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
msg_handler = self.handlers.message_handler
@@ -220,10 +220,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, room_identifier, txn_id=None):
user, token_id, is_guest = yield self.auth.get_user_by_req(
request,
allow_guest=True
)
user, token_id = yield self.auth.get_user_by_req(request)
# the identifier could be a room alias or a room id. Try one then the
# other if it fails to parse, without swallowing other valid
@@ -245,20 +242,16 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
defer.returnValue((200, ret_dict))
else: # room id
msg_handler = self.handlers.message_handler
content = {"membership": Membership.JOIN}
if is_guest:
content["kind"] = "guest"
yield msg_handler.create_and_send_event(
{
"type": EventTypes.Member,
"content": content,
"content": {"membership": Membership.JOIN},
"room_id": identifier.to_string(),
"sender": user.to_string(),
"state_key": user.to_string(),
},
token_id=token_id,
txn_id=txn_id,
is_guest=is_guest,
)
defer.returnValue((200, {"room_id": identifier.to_string()}))
@@ -296,7 +289,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
user, _, _ = yield self.auth.get_user_by_req(request)
user, _ = yield self.auth.get_user_by_req(request)
handler = self.handlers.message_handler
events = yield handler.get_state_events(
room_id=room_id,
@@ -326,13 +319,13 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
}))
# TODO: Needs better unit testing
# TODO: Needs unit testing
class RoomMessageListRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/messages$")
@defer.inlineCallbacks
def on_GET(self, request, room_id):
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
user, _ = yield self.auth.get_user_by_req(request)
pagination_config = PaginationConfig.from_request(
request, default_limit=10,
)
@@ -341,7 +334,6 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
msgs = yield handler.get_messages(
room_id=room_id,
user_id=user.to_string(),
is_guest=is_guest,
pagin_config=pagination_config,
as_client_event=as_client_event
)
@@ -355,13 +347,12 @@ class RoomStateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id):
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
user, _ = yield self.auth.get_user_by_req(request)
handler = self.handlers.message_handler
# Get all the current state for this room
events = yield handler.get_state_events(
room_id=room_id,
user_id=user.to_string(),
is_guest=is_guest,
)
defer.returnValue((200, events))
@@ -372,13 +363,12 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id):
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
user, _ = yield self.auth.get_user_by_req(request)
pagination_config = PaginationConfig.from_request(request)
content = yield self.handlers.message_handler.room_initial_sync(
room_id=room_id,
user_id=user.to_string(),
pagin_config=pagination_config,
is_guest=is_guest,
)
defer.returnValue((200, content))
@@ -407,41 +397,6 @@ class RoomTriggerBackfill(ClientV1RestServlet):
defer.returnValue((200, res))
class RoomEventContext(ClientV1RestServlet):
PATTERN = client_path_pattern(
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$"
)
def __init__(self, hs):
super(RoomEventContext, self).__init__(hs)
self.clock = hs.get_clock()
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
limit = int(request.args.get("limit", [10])[0])
results = yield self.handlers.room_context_handler.get_event_context(
user, room_id, event_id, limit, is_guest
)
time_now = self.clock.time_msec()
results["events_before"] = [
serialize_event(event, time_now) for event in results["events_before"]
]
results["events_after"] = [
serialize_event(event, time_now) for event in results["events_after"]
]
results["state"] = [
serialize_event(event, time_now) for event in results["state"]
]
logger.info("Responding with %r", results)
defer.returnValue((200, results))
# TODO: Needs unit testing
class RoomMembershipRestServlet(ClientV1RestServlet):
@@ -453,37 +408,16 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, room_id, membership_action, txn_id=None):
user, token_id, is_guest = yield self.auth.get_user_by_req(
request,
allow_guest=True
)
if is_guest and membership_action not in {Membership.JOIN, Membership.LEAVE}:
raise AuthError(403, "Guest access not allowed")
user, token_id = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
# target user is you unless it is an invite
state_key = user.to_string()
if membership_action == "invite" and self._has_3pid_invite_keys(content):
yield self.handlers.room_member_handler.do_3pid_invite(
room_id,
user,
content["medium"],
content["address"],
content["id_server"],
token_id,
txn_id
)
defer.returnValue((200, {}))
return
elif membership_action in ["invite", "ban", "kick"]:
if "user_id" in content:
state_key = content["user_id"]
else:
if membership_action in ["invite", "ban", "kick"]:
if "user_id" not in content:
raise SynapseError(400, "Missing user_id key.")
state_key = content["user_id"]
# make sure it looks like a user ID; it'll throw if it's invalid.
UserID.from_string(state_key)
@@ -491,32 +425,20 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
membership_action = "leave"
msg_handler = self.handlers.message_handler
content = {"membership": unicode(membership_action)}
if is_guest:
content["kind"] = "guest"
yield msg_handler.create_and_send_event(
{
"type": EventTypes.Member,
"content": content,
"content": {"membership": unicode(membership_action)},
"room_id": room_id,
"sender": user.to_string(),
"state_key": state_key,
},
token_id=token_id,
txn_id=txn_id,
is_guest=is_guest,
)
defer.returnValue((200, {}))
def _has_3pid_invite_keys(self, content):
for key in {"id_server", "medium", "address"}:
if key not in content:
return False
return True
@defer.inlineCallbacks
def on_PUT(self, request, room_id, membership_action, txn_id):
try:
@@ -541,7 +463,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_id, txn_id=None):
user, token_id, _ = yield self.auth.get_user_by_req(request)
user, token_id = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
msg_handler = self.handlers.message_handler
@@ -581,7 +503,7 @@ class RoomTypingRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
auth_user, _ = yield self.auth.get_user_by_req(request)
room_id = urllib.unquote(room_id)
target_user = UserID.from_string(urllib.unquote(user_id))
@@ -607,23 +529,6 @@ class RoomTypingRestServlet(ClientV1RestServlet):
defer.returnValue((200, {}))
class SearchRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern(
"/search$"
)
@defer.inlineCallbacks
def on_POST(self, request):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
batch = request.args.get("next_batch", [None])[0]
results = yield self.handlers.search_handler.search(auth_user, content, batch)
defer.returnValue((200, results))
def _parse_json(request):
try:
content = json.loads(request.content.read())
@@ -680,5 +585,3 @@ def register_servlets(hs, http_server):
RoomInitialSyncRestServlet(hs).register(http_server)
RoomRedactEventRestServlet(hs).register(http_server)
RoomTypingRestServlet(hs).register(http_server)
SearchRestServlet(hs).register(http_server)
RoomEventContext(hs).register(http_server)
+1 -1
View File
@@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
auth_user, _ = yield self.auth.get_user_by_req(request)
turnUris = self.hs.config.turn_uris
turnSecret = self.hs.config.turn_shared_secret
-2
View File
@@ -22,7 +22,6 @@ from . import (
receipts,
keys,
tokenrefresh,
tags,
)
from synapse.http.server import JsonResource
@@ -45,4 +44,3 @@ class ClientV2AlphaRestResource(JsonResource):
receipts.register_servlets(hs, client_resource)
keys.register_servlets(hs, client_resource)
tokenrefresh.register_servlets(hs, client_resource)
tags.register_servlets(hs, client_resource)
+3 -3
View File
@@ -55,7 +55,7 @@ class PasswordRestServlet(RestServlet):
if LoginType.PASSWORD in result:
# if using password, they should also be logged in
auth_user, _, _ = yield self.auth.get_user_by_req(request)
auth_user, _ = yield self.auth.get_user_by_req(request)
if auth_user.to_string() != result[LoginType.PASSWORD]:
raise LoginError(400, "", Codes.UNKNOWN)
user_id = auth_user.to_string()
@@ -102,7 +102,7 @@ class ThreepidRestServlet(RestServlet):
def on_GET(self, request):
yield run_on_reactor()
auth_user, _, _ = yield self.auth.get_user_by_req(request)
auth_user, _ = yield self.auth.get_user_by_req(request)
threepids = yield self.hs.get_datastore().user_get_threepids(
auth_user.to_string()
@@ -120,7 +120,7 @@ class ThreepidRestServlet(RestServlet):
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
threePidCreds = body['threePidCreds']
auth_user, _, _ = yield self.auth.get_user_by_req(request)
auth_user, _ = yield self.auth.get_user_by_req(request)
threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)
+2 -2
View File
@@ -40,7 +40,7 @@ class GetFilterRestServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id, filter_id):
target_user = UserID.from_string(user_id)
auth_user, _, _ = yield self.auth.get_user_by_req(request)
auth_user, _ = yield self.auth.get_user_by_req(request)
if target_user != auth_user:
raise AuthError(403, "Cannot get filters for other users")
@@ -76,7 +76,7 @@ class CreateFilterRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, user_id):
target_user = UserID.from_string(user_id)
auth_user, _, _ = yield self.auth.get_user_by_req(request)
auth_user, _ = yield self.auth.get_user_by_req(request)
if target_user != auth_user:
raise AuthError(403, "Cannot create filters for other users")
+3 -3
View File
@@ -64,7 +64,7 @@ class KeyUploadServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, device_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
auth_user, _ = yield self.auth.get_user_by_req(request)
user_id = auth_user.to_string()
# TODO: Check that the device_id matches that in the authentication
# or derive the device_id from the authentication instead.
@@ -109,7 +109,7 @@ class KeyUploadServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, device_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
auth_user, _ = yield self.auth.get_user_by_req(request)
user_id = auth_user.to_string()
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
@@ -181,7 +181,7 @@ class KeyQueryServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id, device_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
auth_user, _ = yield self.auth.get_user_by_req(request)
auth_user_id = auth_user.to_string()
user_id = user_id if user_id else auth_user_id
device_ids = [device_id] if device_id else []
+1 -5
View File
@@ -15,7 +15,6 @@
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern
@@ -40,10 +39,7 @@ class ReceiptRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, room_id, receipt_type, event_id):
user, _, _ = yield self.auth.get_user_by_req(request)
if receipt_type != "m.read":
raise SynapseError(400, "Receipt type must be 'm.read'")
user, _ = yield self.auth.get_user_by_req(request)
yield self.receipts_handler.received_client_receipt(
room_id,
+1 -26
View File
@@ -16,7 +16,7 @@
from twisted.internet import defer
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
from synapse.api.errors import SynapseError, Codes
from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern, parse_json_dict_from_request
@@ -55,19 +55,6 @@ class RegisterRestServlet(RestServlet):
def on_POST(self, request):
yield run_on_reactor()
kind = "user"
if "kind" in request.args:
kind = request.args["kind"][0]
if kind == "guest":
ret = yield self._do_guest_registration()
defer.returnValue(ret)
return
elif kind != "user":
raise UnrecognizedRequestError(
"Do not understand membership kind: %s" % (kind,)
)
if '/register/email/requestToken' in request.path:
ret = yield self.onEmailTokenRequest(request)
defer.returnValue(ret)
@@ -249,18 +236,6 @@ class RegisterRestServlet(RestServlet):
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
@defer.inlineCallbacks
def _do_guest_registration(self):
if not self.hs.config.allow_guest_access:
defer.returnValue((403, "Guest access is disabled"))
user_id, _ = yield self.registration_handler.register(generate_token=False)
access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"])
defer.returnValue((200, {
"user_id": user_id,
"access_token": access_token,
"home_server": self.hs.hostname,
}))
def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server)
+82 -239
View File
@@ -20,14 +20,12 @@ from synapse.http.servlet import (
)
from synapse.handlers.sync import SyncConfig
from synapse.types import StreamToken
from synapse.events import FrozenEvent
from synapse.events.utils import (
serialize_event, format_event_for_client_v2_without_event_id,
)
from synapse.api.filtering import FilterCollection
from synapse.api.filtering import Filter
from ._base import client_v2_pattern
import copy
import logging
logger = logging.getLogger(__name__)
@@ -38,77 +36,99 @@ class SyncRestServlet(RestServlet):
GET parameters::
timeout(int): How long to wait for new events in milliseconds.
limit(int): Maxiumum number of events per room to return.
gap(bool): Create gaps the message history if limit is exceeded to
ensure that the client has the most recent messages. Defaults to
"true".
sort(str,str): tuple of sort key (e.g. "timeline") and direction
(e.g. "asc", "desc"). Defaults to "timeline,asc".
since(batch_token): Batch token when asking for incremental deltas.
set_presence(str): What state the device presence should be set to.
default is "online".
backfill(bool): Should the HS request message history from other
servers. This may take a long time making it unsuitable for clients
expecting a prompt response. Defaults to "true".
filter(filter_id): A filter to apply to the events returned.
filter_*: Filter override parameters.
Response JSON::
{
"next_batch": // batch token for the next /sync
"presence": // presence data for the user.
"rooms": {
"joined": { // Joined rooms being updated.
"${room_id}": { // Id of the room being updated
"next_batch": // batch token for the next /sync
"private_user_data": // private events for this user.
"public_user_data": // public events for all users including the
// public events for this user.
"rooms": [{ // List of rooms with updates.
"room_id": // Id of the room being updated
"limited": // Was the per-room event limit exceeded?
"published": // Is the room published by our HS?
"event_map": // Map of EventID -> event JSON.
"timeline": { // The recent events in the room if gap is "true"
"limited": // Was the per-room event limit exceeded?
// otherwise the next events in the room.
"events": [] // list of EventIDs in the "event_map".
"prev_batch": // back token for getting previous events.
"events": { // The recent events in the room if gap is "true"
// otherwise the next events in the room.
"batch": [] // list of EventIDs in the "event_map".
"prev_batch": // back token for getting previous events.
}
"state": {"events": []} // list of EventIDs updating the
// current state to be what it should
// be at the end of the batch.
"ephemeral": {"events": []} // list of event objects
}
},
"invited": {}, // Invited rooms being updated.
"archived": {} // Archived rooms being updated.
}
"state": [] // list of EventIDs updating the current state to
// be what it should be at the end of the batch.
"ephemeral": []
}]
}
"""
PATTERN = client_v2_pattern("/sync$")
ALLOWED_PRESENCE = set(["online", "offline"])
ALLOWED_SORT = set(["timeline,asc", "timeline,desc"])
ALLOWED_PRESENCE = set(["online", "offline", "idle"])
def __init__(self, hs):
super(SyncRestServlet, self).__init__()
self.auth = hs.get_auth()
self.event_stream_handler = hs.get_handlers().event_stream_handler
self.sync_handler = hs.get_handlers().sync_handler
self.clock = hs.get_clock()
self.filtering = hs.get_filtering()
@defer.inlineCallbacks
def on_GET(self, request):
user, token_id, _ = yield self.auth.get_user_by_req(request)
user, token_id = yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", default=0)
limit = parse_integer(request, "limit", required=True)
gap = parse_boolean(request, "gap", default=True)
sort = parse_string(
request, "sort", default="timeline,asc",
allowed_values=self.ALLOWED_SORT
)
since = parse_string(request, "since")
set_presence = parse_string(
request, "set_presence", default="online",
allowed_values=self.ALLOWED_PRESENCE
)
backfill = parse_boolean(request, "backfill", default=False)
filter_id = parse_string(request, "filter", default=None)
full_state = parse_boolean(request, "full_state", default=False)
logger.info(
"/sync: user=%r, timeout=%r, since=%r,"
" set_presence=%r, filter_id=%r" % (
user, timeout, since, set_presence, filter_id
"/sync: user=%r, timeout=%r, limit=%r, gap=%r, sort=%r, since=%r,"
" set_presence=%r, backfill=%r, filter_id=%r" % (
user, timeout, limit, gap, sort, since, set_presence,
backfill, filter_id
)
)
# TODO(mjark): Load filter and apply overrides.
try:
filter = yield self.filtering.get_user_filter(
user.localpart, filter_id
)
except:
filter = FilterCollection({})
filter = Filter({})
# filter = filter.apply_overrides(http_request)
# if filter.matches(event):
# # stuff
sync_config = SyncConfig(
user=user,
gap=gap,
limit=limit,
sort=sort,
backfill=backfill,
filter=filter,
)
@@ -117,154 +137,43 @@ class SyncRestServlet(RestServlet):
else:
since_token = None
if set_presence == "online":
yield self.event_stream_handler.started_stream(user)
try:
sync_result = yield self.sync_handler.wait_for_sync_for_user(
sync_config, since_token=since_token, timeout=timeout,
full_state=full_state
)
finally:
if set_presence == "online":
self.event_stream_handler.stopped_stream(user)
sync_result = yield self.sync_handler.wait_for_sync_for_user(
sync_config, since_token=since_token, timeout=timeout
)
time_now = self.clock.time_msec()
joined = self.encode_joined(
sync_result.joined, filter, time_now, token_id
)
invited = self.encode_invited(
sync_result.invited, filter, time_now, token_id
)
archived = self.encode_archived(
sync_result.archived, filter, time_now, token_id
)
response_content = {
"presence": self.encode_presence(
sync_result.presence, filter, time_now
"public_user_data": self.encode_user_data(
sync_result.public_user_data, filter, time_now
),
"private_user_data": self.encode_user_data(
sync_result.private_user_data, filter, time_now
),
"rooms": self.encode_rooms(
sync_result.rooms, filter, time_now, token_id
),
"rooms": {
"joined": joined,
"invited": invited,
"archived": archived,
},
"next_batch": sync_result.next_batch.to_string(),
}
defer.returnValue((200, response_content))
def encode_presence(self, events, filter, time_now):
formatted = []
for event in events:
event = copy.deepcopy(event)
event['sender'] = event['content'].pop('user_id')
formatted.append(event)
return {"events": filter.filter_presence(formatted)}
def encode_user_data(self, events, filter, time_now):
return events
def encode_joined(self, rooms, filter, time_now, token_id):
"""
Encode the joined rooms in a sync result
:param list[synapse.handlers.sync.JoinedSyncResult] rooms: list of sync
results for rooms this user is joined to
:param FilterCollection filter: filters to apply to the results
:param int time_now: current time - used as a baseline for age
calculations
:param int token_id: ID of the user's auth token - used for namespacing
of transaction IDs
:return: the joined rooms list, in our response format
:rtype: dict[str, dict[str, object]]
"""
joined = {}
for room in rooms:
joined[room.room_id] = self.encode_room(
room, filter, time_now, token_id
)
return joined
def encode_invited(self, rooms, filter, time_now, token_id):
"""
Encode the invited rooms in a sync result
:param list[synapse.handlers.sync.InvitedSyncResult] rooms: list of
sync results for rooms this user is joined to
:param FilterCollection filter: filters to apply to the results
:param int time_now: current time - used as a baseline for age
calculations
:param int token_id: ID of the user's auth token - used for namespacing
of transaction IDs
:return: the invited rooms list, in our response format
:rtype: dict[str, dict[str, object]]
"""
invited = {}
for room in rooms:
invite = serialize_event(
room.invite, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_event_id,
)
invited_state = invite.get("unsigned", {}).pop("invite_room_state", [])
invited_state.append(invite)
invited[room.room_id] = {
"invite_state": {"events": invited_state}
}
return invited
def encode_archived(self, rooms, filter, time_now, token_id):
"""
Encode the archived rooms in a sync result
:param list[synapse.handlers.sync.ArchivedSyncResult] rooms: list of
sync results for rooms this user is joined to
:param FilterCollection filter: filters to apply to the results
:param int time_now: current time - used as a baseline for age
calculations
:param int token_id: ID of the user's auth token - used for namespacing
of transaction IDs
:return: the invited rooms list, in our response format
:rtype: dict[str, dict[str, object]]
"""
joined = {}
for room in rooms:
joined[room.room_id] = self.encode_room(
room, filter, time_now, token_id, joined=False
)
return joined
def encode_rooms(self, rooms, filter, time_now, token_id):
return [
self.encode_room(room, filter, time_now, token_id)
for room in rooms
]
@staticmethod
def encode_room(room, filter, time_now, token_id, joined=True):
"""
:param JoinedSyncResult|ArchivedSyncResult room: sync result for a
single room
:param FilterCollection filter: filters to apply to the results
:param int time_now: current time - used as a baseline for age
calculations
:param int token_id: ID of the user's auth token - used for namespacing
of transaction IDs
:param joined: True if the user is joined to this room - will mean
we handle ephemeral events
:return: the room, encoded in our response format
:rtype: dict[str, object]
"""
def encode_room(room, filter, time_now, token_id):
event_map = {}
state_dict = room.state
timeline_events = filter.filter_room_timeline(room.timeline.events)
state_dict = SyncRestServlet._rollback_state_for_timeline(
state_dict, timeline_events)
state_events = filter.filter_room_state(state_dict.values())
state_events = filter.filter_room_state(room.state)
recent_events = filter.filter_room_events(room.events)
state_event_ids = []
recent_event_ids = []
for event in state_events:
# TODO(mjark): Respect formatting requirements in the filter.
event_map[event.event_id] = serialize_event(
@@ -273,91 +182,25 @@ class SyncRestServlet(RestServlet):
)
state_event_ids.append(event.event_id)
timeline_event_ids = []
for event in timeline_events:
for event in recent_events:
# TODO(mjark): Respect formatting requirements in the filter.
event_map[event.event_id] = serialize_event(
event, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_event_id,
)
timeline_event_ids.append(event.event_id)
private_user_data = filter.filter_room_private_user_data(
room.private_user_data
)
recent_event_ids.append(event.event_id)
result = {
"room_id": room.room_id,
"event_map": event_map,
"timeline": {
"events": timeline_event_ids,
"prev_batch": room.timeline.prev_batch.to_string(),
"limited": room.timeline.limited,
"events": {
"batch": recent_event_ids,
"prev_batch": room.prev_batch.to_string(),
},
"state": {"events": state_event_ids},
"private_user_data": {"events": private_user_data},
"state": state_event_ids,
"limited": room.limited,
"published": room.published,
"ephemeral": room.ephemeral,
}
if joined:
ephemeral_events = filter.filter_room_ephemeral(room.ephemeral)
result["ephemeral"] = {"events": ephemeral_events}
return result
@staticmethod
def _rollback_state_for_timeline(state, timeline):
"""
Wind the state dictionary backwards, so that it represents the
state at the start of the timeline, rather than at the end.
:param dict[(str, str), synapse.events.EventBase] state: the
state dictionary. Will be updated to the state before the timeline.
:param list[synapse.events.EventBase] timeline: the event timeline
:return: updated state dictionary
"""
logger.debug("Processing state dict %r; timeline %r", state,
[e.get_dict() for e in timeline])
result = state.copy()
for timeline_event in reversed(timeline):
if not timeline_event.is_state():
continue
event_key = (timeline_event.type, timeline_event.state_key)
logger.debug("Considering %s for removal", event_key)
state_event = result.get(event_key)
if (state_event is None or
state_event.event_id != timeline_event.event_id):
# the event in the timeline isn't present in the state
# dictionary.
#
# the most likely cause for this is that there was a fork in
# the event graph, and the state is no longer valid. Really,
# the event shouldn't be in the timeline. We're going to ignore
# it for now, however.
logger.warn("Found state event %r in timeline which doesn't "
"match state dictionary", timeline_event)
continue
prev_event_id = timeline_event.unsigned.get("replaces_state", None)
logger.debug("Replacing %s with %s in state dict",
timeline_event.event_id, prev_event_id)
if prev_event_id is None:
del result[event_key]
else:
result[event_key] = FrozenEvent({
"type": timeline_event.type,
"state_key": timeline_event.state_key,
"content": timeline_event.unsigned['prev_content'],
"sender": timeline_event.unsigned['prev_sender'],
"event_id": prev_event_id,
"room_id": timeline_event.room_id,
})
logger.debug("New value: %r", result.get(event_key))
return result
-106
View File
@@ -1,106 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import client_v2_pattern
from synapse.http.servlet import RestServlet
from synapse.api.errors import AuthError, SynapseError
from twisted.internet import defer
import logging
import simplejson as json
logger = logging.getLogger(__name__)
class TagListServlet(RestServlet):
"""
GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1
"""
PATTERN = client_v2_pattern(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags"
)
def __init__(self, hs):
super(TagListServlet, self).__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@defer.inlineCallbacks
def on_GET(self, request, user_id, room_id):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string():
raise AuthError(403, "Cannot get tags for other users.")
tags = yield self.store.get_tags_for_room(user_id, room_id)
defer.returnValue((200, {"tags": tags}))
class TagServlet(RestServlet):
"""
PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
"""
PATTERN = client_v2_pattern(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)"
)
def __init__(self, hs):
super(TagServlet, self).__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@defer.inlineCallbacks
def on_PUT(self, request, user_id, room_id, tag):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
try:
content_bytes = request.content.read()
body = json.loads(content_bytes)
except:
raise SynapseError(400, "Invalid tag JSON")
max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body)
yield self.notifier.on_new_event(
"private_user_data_key", max_id, users=[user_id]
)
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_DELETE(self, request, user_id, room_id, tag):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
yield self.notifier.on_new_event(
"private_user_data_key", max_id, users=[user_id]
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
TagListServlet(hs).register(http_server)
TagServlet(hs).register(http_server)
+1 -1
View File
@@ -66,7 +66,7 @@ class ContentRepoResource(resource.Resource):
@defer.inlineCallbacks
def map_request_to_name(self, request):
# auth the user
auth_user, _, _ = yield self.auth.get_user_by_req(request)
auth_user, _ = yield self.auth.get_user_by_req(request)
# namespace all file uploads on the user
prefix = base64.urlsafe_b64encode(
+1 -1
View File
@@ -70,7 +70,7 @@ class UploadResource(BaseMediaResource):
@request_handler
@defer.inlineCallbacks
def _async_render_POST(self, request):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
auth_user, _ = yield self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
content_length = request.getHeader("Content-Length")
+5
View File
@@ -29,6 +29,7 @@ from synapse.state import StateHandler
from synapse.storage import DataStore
from synapse.util import Clock
from synapse.util.distributor import Distributor
from synapse.util.lockutils import LockManager
from synapse.streams.events import EventSources
from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.keyring import Keyring
@@ -69,6 +70,7 @@ class BaseHomeServer(object):
'auth',
'rest_servlet_factory',
'state_handler',
'room_lock_manager',
'notifier',
'distributor',
'resource_for_client',
@@ -199,6 +201,9 @@ class HomeServer(BaseHomeServer):
def build_state_handler(self):
return StateHandler(self)
def build_room_lock_manager(self):
return LockManager()
def build_distributor(self):
return Distributor()
+9 -23
View File
@@ -71,7 +71,7 @@ class StateHandler(object):
@defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""):
""" Retrieves the current state for the room. This is done by
""" Returns the current state for the room as a list. This is done by
calling `get_latest_events_in_room` to get the leading edges of the
event graph and then resolving any of the state conflicts.
@@ -80,8 +80,6 @@ class StateHandler(object):
If `event_type` is specified, then the method returns only the one
event (or None) with that `event_type` and `state_key`.
:returns map from (type, state_key) to event
"""
event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
@@ -179,10 +177,9 @@ class StateHandler(object):
""" Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
:returns a Deferred tuple of (`state_group`, `state`, `prev_state`).
`state_group` is the name of a state group if one and only one is
involved. `state` is a map from (type, state_key) to event, and
`prev_state` is a list of event ids.
Return format is a tuple: (`state_group`, `state_events`), where the
first is the name of a state group if one and only one is involved,
otherwise `None`.
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
@@ -258,11 +255,6 @@ class StateHandler(object):
return self._resolve_events(state_sets)
def _resolve_events(self, state_sets, event_type=None, state_key=""):
"""
:returns a tuple (new_state, prev_states). new_state is a map
from (type, state_key) to event. prev_states is a list of event_ids.
:rtype: (dict[(str, str), synapse.events.FrozenEvent], list[str])
"""
state = {}
for st in state_sets:
for e in st:
@@ -315,23 +307,19 @@ class StateHandler(object):
We resolve conflicts in the following order:
1. power levels
2. join rules
3. memberships
4. other events.
2. memberships
3. other events.
"""
resolved_state = {}
power_key = (EventTypes.PowerLevels, "")
if power_key in conflicted_state:
events = conflicted_state[power_key]
logger.debug("Resolving conflicted power levels %r", events)
resolved_state[power_key] = self._resolve_auth_events(
events, auth_events)
if power_key in conflicted_state.items():
power_levels = conflicted_state[power_key]
resolved_state[power_key] = self._resolve_auth_events(power_levels)
auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = self._resolve_auth_events(
events,
auth_events
@@ -341,7 +329,6 @@ class StateHandler(object):
for key, events in conflicted_state.items():
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = self._resolve_auth_events(
events,
auth_events
@@ -351,7 +338,6 @@ class StateHandler(object):
for key, events in conflicted_state.items():
if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = self._resolve_normal_events(
events, auth_events
)
-50
View File
@@ -1,50 +0,0 @@
<html>
<head>
<title> Login </title>
<meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
<link rel="stylesheet" href="style.css">
<script src="js/jquery-2.1.3.min.js"></script>
<script src="js/login.js"></script>
</head>
<body onload="matrixLogin.onLoad()">
<center>
<br/>
<h1>Log in with one of the following methods</h1>
<span id="feedback" style="color: #f00"></span>
<br/>
<br/>
<div id="loading">
<img src="spinner.gif" />
</div>
<div id="cas_flow" class="login_flow" style="display:none"
onclick="gotoCas(); return false;">
CAS Authentication: <button id="cas_button" style="margin: 10px">Log in</button>
</div>
<br/>
<form id="password_form" class="login_flow" style="display:none"
onsubmit="matrixLogin.password_login(); return false;">
<div>
Password Authentication:<br/>
<div style="text-align: center">
<input id="user_id" size="32" type="text" placeholder="Matrix ID (e.g. bob)" autocapitalize="off" autocorrect="off" />
<br/>
<input id="password" size="32" type="password" placeholder="Password"/>
<br/>
<button type="submit" style="margin: 10px">Log in</button>
</div>
</div>
</form>
<div id="no_login_types" type="button" class="login_flow" style="display:none">
Log in currently unavailable.
</div>
</center>
</body>
</html>
-153
View File
@@ -1,153 +0,0 @@
window.matrixLogin = {
endpoint: location.origin + "/_matrix/client/api/v1/login",
serverAcceptsPassword: false,
serverAcceptsCas: false
};
var submitPassword = function(user, pwd) {
console.log("Logging in with password...");
var data = {
type: "m.login.password",
user: user,
password: pwd,
};
$.post(matrixLogin.endpoint, JSON.stringify(data), function(response) {
show_login();
matrixLogin.onLogin(response);
}).error(errorFunc);
};
var submitToken = function(loginToken) {
console.log("Logging in with login token...");
var data = {
type: "m.login.token",
token: loginToken
};
$.post(matrixLogin.endpoint, JSON.stringify(data), function(response) {
show_login();
matrixLogin.onLogin(response);
}).error(errorFunc);
};
var errorFunc = function(err) {
show_login();
if (err.responseJSON && err.responseJSON.error) {
setFeedbackString(err.responseJSON.error + " (" + err.responseJSON.errcode + ")");
}
else {
setFeedbackString("Request failed: " + err.status);
}
};
var gotoCas = function() {
var this_page = window.location.origin + window.location.pathname;
var redirect_url = matrixLogin.endpoint + "/cas/redirect?redirectUrl=" + encodeURIComponent(this_page);
window.location.replace(redirect_url);
}
var setFeedbackString = function(text) {
$("#feedback").text(text);
};
var show_login = function() {
$("#loading").hide();
if (matrixLogin.serverAcceptsPassword) {
$("#password_form").show();
}
if (matrixLogin.serverAcceptsCas) {
$("#cas_flow").show();
}
if (!matrixLogin.serverAcceptsPassword && !matrixLogin.serverAcceptsCas) {
$("#no_login_types").show();
}
};
var show_spinner = function() {
$("#password_form").hide();
$("#cas_flow").hide();
$("#no_login_types").hide();
$("#loading").show();
};
var fetch_info = function(cb) {
$.get(matrixLogin.endpoint, function(response) {
var serverAcceptsPassword = false;
var serverAcceptsCas = false;
for (var i=0; i<response.flows.length; i++) {
var flow = response.flows[i];
if ("m.login.cas" === flow.type) {
matrixLogin.serverAcceptsCas = true;
console.log("Server accepts CAS");
}
if ("m.login.password" === flow.type) {
matrixLogin.serverAcceptsPassword = true;
console.log("Server accepts password");
}
}
cb();
}).error(errorFunc);
}
matrixLogin.onLoad = function() {
fetch_info(function() {
if (!try_token()) {
show_login();
}
});
};
matrixLogin.password_login = function() {
var user = $("#user_id").val();
var pwd = $("#password").val();
setFeedbackString("");
show_spinner();
submitPassword(user, pwd);
};
matrixLogin.onLogin = function(response) {
// clobber this function
console.log("onLogin - This function should be replaced to proceed.");
console.log(response);
};
var parseQsFromUrl = function(query) {
var result = {};
query.split("&").forEach(function(part) {
var item = part.split("=");
var key = item[0];
var val = item[1];
if (val) {
val = decodeURIComponent(val);
}
result[key] = val
});
return result;
};
var try_token = function() {
var pos = window.location.href.indexOf("?");
if (pos == -1) {
return false;
}
var qs = parseQsFromUrl(window.location.href.substr(pos+1));
var loginToken = qs.loginToken;
if (!loginToken) {
return false;
}
submitToken(loginToken);
return true;
};
Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.8 KiB

-57
View File
@@ -1,57 +0,0 @@
html {
height: 100%;
}
body {
height: 100%;
font-family: "Myriad Pro", "Myriad", Helvetica, Arial, sans-serif;
font-size: 12pt;
margin: 0px;
}
h1 {
font-size: 20pt;
}
a:link { color: #666; }
a:visited { color: #666; }
a:hover { color: #000; }
a:active { color: #000; }
input {
width: 90%
}
textarea, input {
font-family: inherit;
font-size: inherit;
margin: 5px;
}
.smallPrint {
color: #888;
font-size: 9pt ! important;
font-style: italic ! important;
}
.g-recaptcha div {
margin: auto;
}
.login_flow {
text-align: left;
padding: 10px;
margin-bottom: 40px;
display: inline-block;
-webkit-border-radius: 10px;
-moz-border-radius: 10px;
border-radius: 10px;
-webkit-box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
-moz-box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
background-color: #f8f8f8;
border: 1px #ccc solid;
}
File diff suppressed because one or more lines are too long
+375 -4
View File
@@ -40,16 +40,24 @@ from .filtering import FilteringStore
from .end_to_end_keys import EndToEndKeyStore
from .receipts import ReceiptsStore
from .search import SearchStore
from .tags import TagsStore
import fnmatch
import imp
import logging
import os
import re
logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 24
dir_path = os.path.abspath(os.path.dirname(__file__))
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
# times give more inserts into the database even for readonly API hits
# 120 seconds == 2 minutes
@@ -71,8 +79,6 @@ class DataStore(RoomMemberStore, RoomStore,
EventsStore,
ReceiptsStore,
EndToEndKeyStore,
SearchStore,
TagsStore,
):
def __init__(self, hs):
@@ -152,6 +158,371 @@ class DataStore(RoomMemberStore, RoomStore,
)
def read_schema(path):
""" Read the named database schema.
Args:
path: Path of the database schema.
Returns:
A string containing the database schema.
"""
with open(path) as schema_file:
return schema_file.read()
class PrepareDatabaseException(Exception):
pass
class UpgradeDatabaseException(PrepareDatabaseException):
pass
def prepare_database(db_conn, database_engine):
"""Prepares a database for usage. Will either create all necessary tables
or upgrade from an older schema version.
"""
try:
cur = db_conn.cursor()
version_info = _get_or_create_schema_state(cur, database_engine)
if version_info:
user_version, delta_files, upgraded = version_info
_upgrade_existing_database(
cur, user_version, delta_files, upgraded, database_engine
)
else:
_setup_new_database(cur, database_engine)
# cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
cur.close()
db_conn.commit()
except:
db_conn.rollback()
raise
def _setup_new_database(cur, database_engine):
"""Sets up the database by finding a base set of "full schemas" and then
applying any necessary deltas.
The "full_schemas" directory has subdirectories named after versions. This
function searches for the highest version less than or equal to
`SCHEMA_VERSION` and executes all .sql files in that directory.
The function will then apply all deltas for all versions after the base
version.
Example directory structure:
schema/
delta/
...
full_schemas/
3/
test.sql
...
11/
foo.sql
bar.sql
...
In the example foo.sql and bar.sql would be run, and then any delta files
for versions strictly greater than 11.
"""
current_dir = os.path.join(dir_path, "schema", "full_schemas")
directory_entries = os.listdir(current_dir)
valid_dirs = []
pattern = re.compile(r"^\d+(\.sql)?$")
for filename in directory_entries:
match = pattern.match(filename)
abs_path = os.path.join(current_dir, filename)
if match and os.path.isdir(abs_path):
ver = int(match.group(0))
if ver <= SCHEMA_VERSION:
valid_dirs.append((ver, abs_path))
else:
logger.warn("Unexpected entry in 'full_schemas': %s", filename)
if not valid_dirs:
raise PrepareDatabaseException(
"Could not find a suitable base set of full schemas"
)
max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0])
logger.debug("Initialising schema v%d", max_current_ver)
directory_entries = os.listdir(sql_dir)
for filename in fnmatch.filter(directory_entries, "*.sql"):
sql_loc = os.path.join(sql_dir, filename)
logger.debug("Applying schema %s", sql_loc)
executescript(cur, sql_loc)
cur.execute(
database_engine.convert_param_style(
"INSERT INTO schema_version (version, upgraded)"
" VALUES (?,?)"
),
(max_current_ver, False,)
)
_upgrade_existing_database(
cur,
current_version=max_current_ver,
applied_delta_files=[],
upgraded=False,
database_engine=database_engine,
)
def _upgrade_existing_database(cur, current_version, applied_delta_files,
upgraded, database_engine):
"""Upgrades an existing database.
Delta files can either be SQL stored in *.sql files, or python modules
in *.py.
There can be multiple delta files per version. Synapse will keep track of
which delta files have been applied, and will apply any that haven't been
even if there has been no version bump. This is useful for development
where orthogonal schema changes may happen on separate branches.
Different delta files for the same version *must* be orthogonal and give
the same result when applied in any order. No guarantees are made on the
order of execution of these scripts.
This is a no-op of current_version == SCHEMA_VERSION.
Example directory structure:
schema/
delta/
11/
foo.sql
...
12/
foo.sql
bar.py
...
full_schemas/
...
In the example, if current_version is 11, then foo.sql will be run if and
only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in
some arbitrary order.
Args:
cur (Cursor)
current_version (int): The current version of the schema.
applied_delta_files (list): A list of deltas that have already been
applied.
upgraded (bool): Whether the current version was generated by having
applied deltas or from full schema file. If `True` the function
will never apply delta files for the given `current_version`, since
the current_version wasn't generated by applying those delta files.
"""
if current_version > SCHEMA_VERSION:
raise ValueError(
"Cannot use this database as it is too " +
"new for the server to understand"
)
start_ver = current_version
if not upgraded:
start_ver += 1
logger.debug("applied_delta_files: %s", applied_delta_files)
for v in range(start_ver, SCHEMA_VERSION + 1):
logger.debug("Upgrading schema to v%d", v)
delta_dir = os.path.join(dir_path, "schema", "delta", str(v))
try:
directory_entries = os.listdir(delta_dir)
except OSError:
logger.exception("Could not open delta dir for version %d", v)
raise UpgradeDatabaseException(
"Could not open delta dir for version %d" % (v,)
)
directory_entries.sort()
for file_name in directory_entries:
relative_path = os.path.join(str(v), file_name)
logger.debug("Found file: %s", relative_path)
if relative_path in applied_delta_files:
continue
absolute_path = os.path.join(
dir_path, "schema", "delta", relative_path,
)
root_name, ext = os.path.splitext(file_name)
if ext == ".py":
# This is a python upgrade module. We need to import into some
# package and then execute its `run_upgrade` function.
module_name = "synapse.storage.v%d_%s" % (
v, root_name
)
with open(absolute_path) as python_file:
module = imp.load_source(
module_name, absolute_path, python_file
)
logger.debug("Running script %s", relative_path)
module.run_upgrade(cur, database_engine)
elif ext == ".pyc":
# Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package
# installers. Silently skip it
pass
elif ext == ".sql":
# A plain old .sql file, just read and execute it
logger.debug("Applying schema %s", relative_path)
executescript(cur, absolute_path)
else:
# Not a valid delta file.
logger.warn(
"Found directory entry that did not end in .py or"
" .sql: %s",
relative_path,
)
continue
# Mark as done.
cur.execute(
database_engine.convert_param_style(
"INSERT INTO applied_schema_deltas (version, file)"
" VALUES (?,?)",
),
(v, relative_path)
)
cur.execute("DELETE FROM schema_version")
cur.execute(
database_engine.convert_param_style(
"INSERT INTO schema_version (version, upgraded)"
" VALUES (?,?)",
),
(v, True)
)
def get_statements(f):
statement_buffer = ""
in_comment = False # If we're in a /* ... */ style comment
for line in f:
line = line.strip()
if in_comment:
# Check if this line contains an end to the comment
comments = line.split("*/", 1)
if len(comments) == 1:
continue
line = comments[1]
in_comment = False
# Remove inline block comments
line = re.sub(r"/\*.*\*/", " ", line)
# Does this line start a comment?
comments = line.split("/*", 1)
if len(comments) > 1:
line = comments[0]
in_comment = True
# Deal with line comments
line = line.split("--", 1)[0]
line = line.split("//", 1)[0]
# Find *all* semicolons. We need to treat first and last entry
# specially.
statements = line.split(";")
# We must prepend statement_buffer to the first statement
first_statement = "%s %s" % (
statement_buffer.strip(),
statements[0].strip()
)
statements[0] = first_statement
# Every entry, except the last, is a full statement
for statement in statements[:-1]:
yield statement.strip()
# The last entry did *not* end in a semicolon, so we store it for the
# next semicolon we find
statement_buffer = statements[-1].strip()
def executescript(txn, schema_path):
with open(schema_path, 'r') as f:
for statement in get_statements(f):
txn.execute(statement)
def _get_or_create_schema_state(txn, database_engine):
# Bluntly try creating the schema_version tables.
schema_path = os.path.join(
dir_path, "schema", "schema_version.sql",
)
executescript(txn, schema_path)
txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone()
current_version = int(row[0]) if row else None
upgraded = bool(row[1]) if row else None
if current_version:
txn.execute(
database_engine.convert_param_style(
"SELECT file FROM applied_schema_deltas WHERE version >= ?"
),
(current_version,)
)
applied_deltas = [d for d, in txn.fetchall()]
return current_version, applied_deltas, upgraded
return None
def prepare_sqlite3_database(db_conn):
"""This function should be called before `prepare_database` on sqlite3
databases.
Since we changed the way we store the current schema version and handle
updates to schemas, we need a way to upgrade from the old method to the
new. This only affects sqlite databases since they were the only ones
supported at the time.
"""
with db_conn:
schema_path = os.path.join(
dir_path, "schema", "schema_version.sql",
)
create_schema = read_schema(schema_path)
db_conn.executescript(create_schema)
c = db_conn.execute("SELECT * FROM schema_version")
rows = c.fetchall()
c.close()
if not rows:
c = db_conn.execute("PRAGMA user_version")
row = c.fetchone()
c.close()
if row and row[0]:
db_conn.execute(
"REPLACE INTO schema_version (version, upgraded)"
" VALUES (?,?)",
(row[0], False)
)
def are_all_users_on_domain(txn, database_engine, domain):
sql = database_engine.convert_param_style(
"SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
+1 -1
View File
@@ -519,7 +519,7 @@ class SQLBaseStore(object):
allow_none=False,
desc="_simple_select_one_onecol"):
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.
return a single row, returning a single column from it."
Args:
table : string giving the table name
-256
View File
@@ -1,256 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
from twisted.internet import defer
import ujson as json
import logging
logger = logging.getLogger(__name__)
class BackgroundUpdatePerformance(object):
"""Tracks the how long a background update is taking to update its items"""
def __init__(self, name):
self.name = name
self.total_item_count = 0
self.total_duration_ms = 0
self.avg_item_count = 0
self.avg_duration_ms = 0
def update(self, item_count, duration_ms):
"""Update the stats after doing an update"""
self.total_item_count += item_count
self.total_duration_ms += duration_ms
# Exponential moving averages for the number of items updated and
# the duration.
self.avg_item_count += 0.1 * (item_count - self.avg_item_count)
self.avg_duration_ms += 0.1 * (duration_ms - self.avg_duration_ms)
def average_items_per_ms(self):
"""An estimate of how long it takes to do a single update.
Returns:
A duration in ms as a float
"""
if self.total_item_count == 0:
return None
else:
# Use the exponential moving average so that we can adapt to
# changes in how long the update process takes.
return float(self.avg_item_count) / float(self.avg_duration_ms)
def total_items_per_ms(self):
"""An estimate of how long it takes to do a single update.
Returns:
A duration in ms as a float
"""
if self.total_item_count == 0:
return None
else:
return float(self.total_item_count) / float(self.total_duration_ms)
class BackgroundUpdateStore(SQLBaseStore):
""" Background updates are updates to the database that run in the
background. Each update processes a batch of data at once. We attempt to
limit the impact of each update by monitoring how long each batch takes to
process and autotuning the batch size.
"""
MINIMUM_BACKGROUND_BATCH_SIZE = 100
DEFAULT_BACKGROUND_BATCH_SIZE = 100
BACKGROUND_UPDATE_INTERVAL_MS = 1000
BACKGROUND_UPDATE_DURATION_MS = 100
def __init__(self, hs):
super(BackgroundUpdateStore, self).__init__(hs)
self._background_update_performance = {}
self._background_update_queue = []
self._background_update_handlers = {}
self._background_update_timer = None
@defer.inlineCallbacks
def start_doing_background_updates(self):
while True:
if self._background_update_timer is not None:
return
sleep = defer.Deferred()
self._background_update_timer = self._clock.call_later(
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None
)
try:
yield sleep
finally:
self._background_update_timer = None
try:
result = yield self.do_background_update(
self.BACKGROUND_UPDATE_DURATION_MS
)
except:
logger.exception("Error doing update")
if result is None:
logger.info(
"No more background updates to do."
" Unscheduling background update task."
)
return
@defer.inlineCallbacks
def do_background_update(self, desired_duration_ms):
"""Does some amount of work on a background update
Args:
desired_duration_ms(float): How long we want to spend
updating.
Returns:
A deferred that completes once some amount of work is done.
The deferred will have a value of None if there is currently
no more work to do.
"""
if not self._background_update_queue:
updates = yield self._simple_select_list(
"background_updates",
keyvalues=None,
retcols=("update_name",),
)
for update in updates:
self._background_update_queue.append(update['update_name'])
if not self._background_update_queue:
defer.returnValue(None)
update_name = self._background_update_queue.pop(0)
self._background_update_queue.append(update_name)
update_handler = self._background_update_handlers[update_name]
performance = self._background_update_performance.get(update_name)
if performance is None:
performance = BackgroundUpdatePerformance(update_name)
self._background_update_performance[update_name] = performance
items_per_ms = performance.average_items_per_ms()
if items_per_ms is not None:
batch_size = int(desired_duration_ms * items_per_ms)
# Clamp the batch size so that we always make progress
batch_size = max(batch_size, self.MINIMUM_BACKGROUND_BATCH_SIZE)
else:
batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
progress_json = yield self._simple_select_one_onecol(
"background_updates",
keyvalues={"update_name": update_name},
retcol="progress_json"
)
progress = json.loads(progress_json)
time_start = self._clock.time_msec()
items_updated = yield update_handler(progress, batch_size)
time_stop = self._clock.time_msec()
duration_ms = time_stop - time_start
logger.info(
"Updating %r. Updated %r items in %rms."
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r)",
update_name, items_updated, duration_ms,
performance.total_items_per_ms(),
performance.average_items_per_ms(),
performance.total_item_count,
)
performance.update(items_updated, duration_ms)
defer.returnValue(len(self._background_update_performance))
def register_background_update_handler(self, update_name, update_handler):
"""Register a handler for doing a background update.
The handler should take two arguments:
* A dict of the current progress
* An integer count of the number of items to update in this batch.
The handler should return a deferred integer count of items updated.
The hander is responsible for updating the progress of the update.
Args:
update_name(str): The name of the update that this code handles.
update_handler(function): The function that does the update.
"""
self._background_update_handlers[update_name] = update_handler
def start_background_update(self, update_name, progress):
"""Starts a background update running.
Args:
update_name: The update to set running.
progress: The initial state of the progress of the update.
Returns:
A deferred that completes once the task has been added to the
queue.
"""
# Clear the background update queue so that we will pick up the new
# task on the next iteration of do_background_update.
self._background_update_queue = []
progress_json = json.dumps(progress)
return self._simple_insert(
"background_updates",
{"update_name": update_name, "progress_json": progress_json}
)
def _end_background_update(self, update_name):
"""Removes a completed background update task from the queue.
Args:
update_name(str): The name of the completed task to remove
Returns:
A deferred that completes once the task is removed.
"""
self._background_update_queue = [
name for name in self._background_update_queue if name != update_name
]
return self._simple_delete_one(
"background_updates", keyvalues={"update_name": update_name}
)
def _background_update_progress_txn(self, txn, update_name, progress):
"""Update the progress of a background update
Args:
txn(cursor): The transaction.
update_name(str): The name of the background update task
progress(dict): The progress of the update.
"""
progress_json = json.dumps(progress)
self._simple_update_one_txn(
txn,
"background_updates",
keyvalues={"update_name": update_name},
updatevalues={"progress_json": progress_json},
)
+1 -1
View File
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.prepare_database import prepare_database
from synapse.storage import prepare_database
from ._base import IncorrectDatabaseSetup
+1 -30
View File
@@ -13,11 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.prepare_database import (
prepare_database, prepare_sqlite3_database
)
import struct
from synapse.storage import prepare_database, prepare_sqlite3_database
class Sqlite3Engine(object):
@@ -34,7 +30,6 @@ class Sqlite3Engine(object):
def on_new_connection(self, db_conn):
self.prepare_database(db_conn)
db_conn.create_function("rank", 1, _rank)
def prepare_database(self, db_conn):
prepare_sqlite3_database(db_conn)
@@ -48,27 +43,3 @@ class Sqlite3Engine(object):
def lock_table(self, txn, table):
return
# Following functions taken from: https://github.com/coleifer/peewee
def _parse_match_info(buf):
bufsize = len(buf)
return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)]
def _rank(raw_match_info):
"""Handle match_info called w/default args 'pcx' - based on the example rank
function http://sqlite.org/fts3.html#appendix_a
"""
match_info = _parse_match_info(raw_match_info)
score = 0.0
p, c = match_info[:2]
for phrase_num in range(p):
phrase_info_idx = 2 + (phrase_num * c * 3)
for col_num in range(c):
col_idx = phrase_info_idx + (col_num * 3)
x1, x2 = match_info[col_idx:col_idx + 2]
if x1 > 0:
score += float(x1) / x2
return score
+2 -10
View File
@@ -307,14 +307,8 @@ class EventsStore(SQLBaseStore):
self._store_room_name_txn(txn, event)
elif event.type == EventTypes.Topic:
self._store_room_topic_txn(txn, event)
elif event.type == EventTypes.Message:
self._store_room_message_txn(txn, event)
elif event.type == EventTypes.Redaction:
self._store_redaction(txn, event)
elif event.type == EventTypes.RoomHistoryVisibility:
self._store_history_visibility_txn(txn, event)
elif event.type == EventTypes.GuestAccess:
self._store_guest_access_txn(txn, event)
self._store_room_members_txn(
txn,
@@ -831,8 +825,7 @@ class EventsStore(SQLBaseStore):
allow_none=True,
)
if prev:
ev.unsigned["prev_content"] = prev.content
ev.unsigned["prev_sender"] = prev.sender
ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill(
(ev.event_id, check_redacted, get_prev_content), ev
@@ -889,8 +882,7 @@ class EventsStore(SQLBaseStore):
get_prev_content=False,
)
if prev:
ev.unsigned["prev_content"] = prev.content
ev.unsigned["prev_sender"] = prev.sender
ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill(
(ev.event_id, check_redacted, get_prev_content), ev
+2 -2
View File
@@ -34,10 +34,10 @@ class FilteringStore(SQLBaseStore):
desc="get_user_filter",
)
defer.returnValue(json.loads(str(def_json).decode("utf-8")))
defer.returnValue(json.loads(def_json))
def add_user_filter(self, user_localpart, user_filter):
def_json = json.dumps(user_filter).encode("utf-8")
def_json = json.dumps(user_filter)
# Need an atomic transaction to SELECT the maximal ID so far then
# INSERT a new one
-395
View File
@@ -1,395 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fnmatch
import imp
import logging
import os
import re
logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 25
dir_path = os.path.abspath(os.path.dirname(__file__))
def read_schema(path):
""" Read the named database schema.
Args:
path: Path of the database schema.
Returns:
A string containing the database schema.
"""
with open(path) as schema_file:
return schema_file.read()
class PrepareDatabaseException(Exception):
pass
class UpgradeDatabaseException(PrepareDatabaseException):
pass
def prepare_database(db_conn, database_engine):
"""Prepares a database for usage. Will either create all necessary tables
or upgrade from an older schema version.
"""
try:
cur = db_conn.cursor()
version_info = _get_or_create_schema_state(cur, database_engine)
if version_info:
user_version, delta_files, upgraded = version_info
_upgrade_existing_database(
cur, user_version, delta_files, upgraded, database_engine
)
else:
_setup_new_database(cur, database_engine)
# cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
cur.close()
db_conn.commit()
except:
db_conn.rollback()
raise
def _setup_new_database(cur, database_engine):
"""Sets up the database by finding a base set of "full schemas" and then
applying any necessary deltas.
The "full_schemas" directory has subdirectories named after versions. This
function searches for the highest version less than or equal to
`SCHEMA_VERSION` and executes all .sql files in that directory.
The function will then apply all deltas for all versions after the base
version.
Example directory structure:
schema/
delta/
...
full_schemas/
3/
test.sql
...
11/
foo.sql
bar.sql
...
In the example foo.sql and bar.sql would be run, and then any delta files
for versions strictly greater than 11.
"""
current_dir = os.path.join(dir_path, "schema", "full_schemas")
directory_entries = os.listdir(current_dir)
valid_dirs = []
pattern = re.compile(r"^\d+(\.sql)?$")
for filename in directory_entries:
match = pattern.match(filename)
abs_path = os.path.join(current_dir, filename)
if match and os.path.isdir(abs_path):
ver = int(match.group(0))
if ver <= SCHEMA_VERSION:
valid_dirs.append((ver, abs_path))
else:
logger.warn("Unexpected entry in 'full_schemas': %s", filename)
if not valid_dirs:
raise PrepareDatabaseException(
"Could not find a suitable base set of full schemas"
)
max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0])
logger.debug("Initialising schema v%d", max_current_ver)
directory_entries = os.listdir(sql_dir)
for filename in fnmatch.filter(directory_entries, "*.sql"):
sql_loc = os.path.join(sql_dir, filename)
logger.debug("Applying schema %s", sql_loc)
executescript(cur, sql_loc)
cur.execute(
database_engine.convert_param_style(
"INSERT INTO schema_version (version, upgraded)"
" VALUES (?,?)"
),
(max_current_ver, False,)
)
_upgrade_existing_database(
cur,
current_version=max_current_ver,
applied_delta_files=[],
upgraded=False,
database_engine=database_engine,
)
def _upgrade_existing_database(cur, current_version, applied_delta_files,
upgraded, database_engine):
"""Upgrades an existing database.
Delta files can either be SQL stored in *.sql files, or python modules
in *.py.
There can be multiple delta files per version. Synapse will keep track of
which delta files have been applied, and will apply any that haven't been
even if there has been no version bump. This is useful for development
where orthogonal schema changes may happen on separate branches.
Different delta files for the same version *must* be orthogonal and give
the same result when applied in any order. No guarantees are made on the
order of execution of these scripts.
This is a no-op of current_version == SCHEMA_VERSION.
Example directory structure:
schema/
delta/
11/
foo.sql
...
12/
foo.sql
bar.py
...
full_schemas/
...
In the example, if current_version is 11, then foo.sql will be run if and
only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in
some arbitrary order.
Args:
cur (Cursor)
current_version (int): The current version of the schema.
applied_delta_files (list): A list of deltas that have already been
applied.
upgraded (bool): Whether the current version was generated by having
applied deltas or from full schema file. If `True` the function
will never apply delta files for the given `current_version`, since
the current_version wasn't generated by applying those delta files.
"""
if current_version > SCHEMA_VERSION:
raise ValueError(
"Cannot use this database as it is too " +
"new for the server to understand"
)
start_ver = current_version
if not upgraded:
start_ver += 1
logger.debug("applied_delta_files: %s", applied_delta_files)
for v in range(start_ver, SCHEMA_VERSION + 1):
logger.debug("Upgrading schema to v%d", v)
delta_dir = os.path.join(dir_path, "schema", "delta", str(v))
try:
directory_entries = os.listdir(delta_dir)
except OSError:
logger.exception("Could not open delta dir for version %d", v)
raise UpgradeDatabaseException(
"Could not open delta dir for version %d" % (v,)
)
directory_entries.sort()
for file_name in directory_entries:
relative_path = os.path.join(str(v), file_name)
logger.debug("Found file: %s", relative_path)
if relative_path in applied_delta_files:
continue
absolute_path = os.path.join(
dir_path, "schema", "delta", relative_path,
)
root_name, ext = os.path.splitext(file_name)
if ext == ".py":
# This is a python upgrade module. We need to import into some
# package and then execute its `run_upgrade` function.
module_name = "synapse.storage.v%d_%s" % (
v, root_name
)
with open(absolute_path) as python_file:
module = imp.load_source(
module_name, absolute_path, python_file
)
logger.debug("Running script %s", relative_path)
module.run_upgrade(cur, database_engine)
elif ext == ".pyc":
# Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package
# installers. Silently skip it
pass
elif ext == ".sql":
# A plain old .sql file, just read and execute it
logger.debug("Applying schema %s", relative_path)
executescript(cur, absolute_path)
else:
# Not a valid delta file.
logger.warn(
"Found directory entry that did not end in .py or"
" .sql: %s",
relative_path,
)
continue
# Mark as done.
cur.execute(
database_engine.convert_param_style(
"INSERT INTO applied_schema_deltas (version, file)"
" VALUES (?,?)",
),
(v, relative_path)
)
cur.execute("DELETE FROM schema_version")
cur.execute(
database_engine.convert_param_style(
"INSERT INTO schema_version (version, upgraded)"
" VALUES (?,?)",
),
(v, True)
)
def get_statements(f):
statement_buffer = ""
in_comment = False # If we're in a /* ... */ style comment
for line in f:
line = line.strip()
if in_comment:
# Check if this line contains an end to the comment
comments = line.split("*/", 1)
if len(comments) == 1:
continue
line = comments[1]
in_comment = False
# Remove inline block comments
line = re.sub(r"/\*.*\*/", " ", line)
# Does this line start a comment?
comments = line.split("/*", 1)
if len(comments) > 1:
line = comments[0]
in_comment = True
# Deal with line comments
line = line.split("--", 1)[0]
line = line.split("//", 1)[0]
# Find *all* semicolons. We need to treat first and last entry
# specially.
statements = line.split(";")
# We must prepend statement_buffer to the first statement
first_statement = "%s %s" % (
statement_buffer.strip(),
statements[0].strip()
)
statements[0] = first_statement
# Every entry, except the last, is a full statement
for statement in statements[:-1]:
yield statement.strip()
# The last entry did *not* end in a semicolon, so we store it for the
# next semicolon we find
statement_buffer = statements[-1].strip()
def executescript(txn, schema_path):
with open(schema_path, 'r') as f:
for statement in get_statements(f):
txn.execute(statement)
def _get_or_create_schema_state(txn, database_engine):
# Bluntly try creating the schema_version tables.
schema_path = os.path.join(
dir_path, "schema", "schema_version.sql",
)
executescript(txn, schema_path)
txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone()
current_version = int(row[0]) if row else None
upgraded = bool(row[1]) if row else None
if current_version:
txn.execute(
database_engine.convert_param_style(
"SELECT file FROM applied_schema_deltas WHERE version >= ?"
),
(current_version,)
)
applied_deltas = [d for d, in txn.fetchall()]
return current_version, applied_deltas, upgraded
return None
def prepare_sqlite3_database(db_conn):
"""This function should be called before `prepare_database` on sqlite3
databases.
Since we changed the way we store the current schema version and handle
updates to schemas, we need a way to upgrade from the old method to the
new. This only affects sqlite databases since they were the only ones
supported at the time.
"""
with db_conn:
schema_path = os.path.join(
dir_path, "schema", "schema_version.sql",
)
create_schema = read_schema(schema_path)
db_conn.executescript(create_schema)
c = db_conn.execute("SELECT * FROM schema_version")
rows = c.fetchall()
c.close()
if not rows:
c = db_conn.execute("PRAGMA user_version")
row = c.fetchone()
c.close()
if row and row[0]:
db_conn.execute(
"REPLACE INTO schema_version (version, upgraded)"
" VALUES (?,?)",
(row[0], False)
)
+7 -8
View File
@@ -102,14 +102,13 @@ class RegistrationStore(SQLBaseStore):
400, "User ID already taken.", errcode=Codes.USER_IN_USE
)
if token:
# it's possible for this to get a conflict, but only for a single user
# since tokens are namespaced based on their user ID
txn.execute(
"INSERT INTO access_tokens(id, user_id, token)"
" VALUES (?,?,?)",
(next_id, user_id, token,)
)
# it's possible for this to get a conflict, but only for a single user
# since tokens are namespaced based on their user ID
txn.execute(
"INSERT INTO access_tokens(id, user_id, token)"
" VALUES (?,?,?)",
(next_id, user_id, token,)
)
def get_user_by_id(self, user_id):
return self._simple_select_one(
+24 -82
View File
@@ -19,7 +19,6 @@ from synapse.api.errors import StoreError
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
from .engines import PostgresEngine, Sqlite3Engine
import collections
import logging
@@ -99,39 +98,34 @@ class RoomStore(SQLBaseStore):
"""
def f(txn):
def subquery(table_name, column_name=None):
column_name = column_name or table_name
return (
"SELECT %(table_name)s.event_id as event_id, "
"%(table_name)s.room_id as room_id, %(column_name)s "
"FROM %(table_name)s "
"INNER JOIN current_state_events as c "
"ON c.event_id = %(table_name)s.event_id " % {
"column_name": column_name,
"table_name": table_name,
}
)
topic_subquery = (
"SELECT topics.event_id as event_id, "
"topics.room_id as room_id, topic "
"FROM topics "
"INNER JOIN current_state_events as c "
"ON c.event_id = topics.event_id "
)
name_subquery = (
"SELECT room_names.event_id as event_id, "
"room_names.room_id as room_id, name "
"FROM room_names "
"INNER JOIN current_state_events as c "
"ON c.event_id = room_names.event_id "
)
# We use non printing ascii character US (\x1F) as a separator
sql = (
"SELECT"
" r.room_id,"
" max(n.name),"
" max(t.topic),"
" max(v.history_visibility),"
" max(g.guest_access)"
"SELECT r.room_id, max(n.name), max(t.topic)"
" FROM rooms AS r"
" LEFT JOIN (%(topic)s) AS t ON t.room_id = r.room_id"
" LEFT JOIN (%(name)s) AS n ON n.room_id = r.room_id"
" LEFT JOIN (%(history_visibility)s) AS v ON v.room_id = r.room_id"
" LEFT JOIN (%(guest_access)s) AS g ON g.room_id = r.room_id"
" WHERE r.is_public = ?"
" GROUP BY r.room_id" % {
"topic": subquery("topics", "topic"),
"name": subquery("room_names", "name"),
"history_visibility": subquery("history_visibility"),
"guest_access": subquery("guest_access"),
}
)
" GROUP BY r.room_id"
) % {
"topic": topic_subquery,
"name": name_subquery,
}
txn.execute(sql, (is_public,))
@@ -161,12 +155,10 @@ class RoomStore(SQLBaseStore):
"room_id": r[0],
"name": r[1],
"topic": r[2],
"world_readable": r[3] == "world_readable",
"guest_can_join": r[4] == "can_join",
"aliases": r[5],
"aliases": r[3],
}
for r in rows
if r[5] # We only return rooms that have at least one alias.
if r[3] # We only return rooms that have at least one alias.
]
defer.returnValue(ret)
@@ -183,10 +175,6 @@ class RoomStore(SQLBaseStore):
},
)
self._store_event_search_txn(
txn, event, "content.topic", event.content["topic"]
)
def _store_room_name_txn(self, txn, event):
if hasattr(event, "content") and "name" in event.content:
self._simple_insert_txn(
@@ -199,52 +187,6 @@ class RoomStore(SQLBaseStore):
}
)
self._store_event_search_txn(
txn, event, "content.name", event.content["name"]
)
def _store_room_message_txn(self, txn, event):
if hasattr(event, "content") and "body" in event.content:
self._store_event_search_txn(
txn, event, "content.body", event.content["body"]
)
def _store_history_visibility_txn(self, txn, event):
self._store_content_index_txn(txn, event, "history_visibility")
def _store_guest_access_txn(self, txn, event):
self._store_content_index_txn(txn, event, "guest_access")
def _store_content_index_txn(self, txn, event, key):
if hasattr(event, "content") and key in event.content:
sql = (
"INSERT INTO %(key)s"
" (event_id, room_id, %(key)s)"
" VALUES (?, ?, ?)" % {"key": key}
)
txn.execute(sql, (
event.event_id,
event.room_id,
event.content[key]
))
def _store_event_search_txn(self, txn, event, key, value):
if isinstance(self.database_engine, PostgresEngine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, vector)"
" VALUES (?,?,?,to_tsvector('english', ?))"
)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
txn.execute(sql, (event.event_id, event.room_id, key, value,))
@cachedInlineCallbacks()
def get_room_name_and_aliases(self, room_id):
def f(txn):
-27
View File
@@ -110,33 +110,6 @@ class RoomMemberStore(SQLBaseStore):
membership=membership,
).addCallback(self._get_events)
def get_invites_for_user(self, user_id):
""" Get all the invite events for a user
Args:
user_id (str): The user ID.
Returns:
A deferred list of event objects.
"""
return self.get_rooms_for_user_where_membership_is(
user_id, [Membership.INVITE]
).addCallback(lambda invites: self._get_events([
invites.event_id for invite in invites
]))
def get_leave_and_ban_events_for_user(self, user_id):
""" Get all the leave events for a user
Args:
user_id (str): The user ID.
Returns:
A deferred list of event objects.
"""
return self.get_rooms_for_user_where_membership_is(
user_id, (Membership.LEAVE, Membership.BAN)
).addCallback(lambda leaves: self._get_events([
leave.event_id for leave in leaves
]))
def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this user where the membership for this user
matches one in the membership list.
@@ -1,21 +0,0 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS background_updates(
update_name TEXT NOT NULL, -- The name of the background update.
progress_json TEXT NOT NULL, -- The current progress of the update as JSON.
CONSTRAINT background_updates_uniqueness UNIQUE (update_name)
);
-78
View File
@@ -1,78 +0,0 @@
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.storage.prepare_database import get_statements
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
import ujson
logger = logging.getLogger(__name__)
POSTGRES_TABLE = """
CREATE TABLE IF NOT EXISTS event_search (
event_id TEXT,
room_id TEXT,
sender TEXT,
key TEXT,
vector tsvector
);
CREATE INDEX event_search_fts_idx ON event_search USING gin(vector);
CREATE INDEX event_search_ev_idx ON event_search(event_id);
CREATE INDEX event_search_ev_ridx ON event_search(room_id);
"""
SQLITE_TABLE = (
"CREATE VIRTUAL TABLE IF NOT EXISTS event_search"
" USING fts4 ( event_id, room_id, sender, key, value )"
)
def run_upgrade(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
for statement in get_statements(POSTGRES_TABLE.splitlines()):
cur.execute(statement)
elif isinstance(database_engine, Sqlite3Engine):
cur.execute(SQLITE_TABLE)
else:
raise Exception("Unrecognized database engine")
cur.execute("SELECT MIN(stream_ordering) FROM events")
rows = cur.fetchall()
min_stream_id = rows[0][0]
cur.execute("SELECT MAX(stream_ordering) FROM events")
rows = cur.fetchall()
max_stream_id = rows[0][0]
if min_stream_id is not None and max_stream_id is not None:
progress = {
"target_min_stream_id_inclusive": min_stream_id,
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
progress_json = ujson.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
" VALUES (?, ?)"
)
sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("event_search", progress_json))
@@ -1,25 +0,0 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* This is a manual index of guest_access content of state events,
* so that we can join on them in SELECT statements.
*/
CREATE TABLE IF NOT EXISTS guest_access(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
guest_access TEXT NOT NULL,
UNIQUE (event_id)
);
@@ -1,25 +0,0 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* This is a manual index of history_visibility content of state events,
* so that we can join on them in SELECT statements.
*/
CREATE TABLE IF NOT EXISTS history_visibility(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
history_visibility TEXT NOT NULL,
UNIQUE (event_id)
);
-38
View File
@@ -1,38 +0,0 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS room_tags(
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
tag TEXT NOT NULL, -- The name of the tag.
content TEXT NOT NULL, -- The JSON content of the tag.
CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag)
);
CREATE TABLE IF NOT EXISTS room_tags_revisions (
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
stream_id BIGINT NOT NULL, -- The current version of the room tags.
CONSTRAINT room_tag_revisions_uniqueness UNIQUE (user_id, room_id)
);
CREATE TABLE IF NOT EXISTS private_user_data_max_stream_id(
Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
stream_id BIGINT NOT NULL,
CHECK (Lock='X')
);
INSERT INTO private_user_data_max_stream_id (stream_id) VALUES (0);
-307
View File
@@ -1,307 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from .background_updates import BackgroundUpdateStore
from synapse.api.errors import SynapseError
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
import logging
logger = logging.getLogger(__name__)
class SearchStore(BackgroundUpdateStore):
EVENT_SEARCH_UPDATE_NAME = "event_search"
def __init__(self, hs):
super(SearchStore, self).__init__(hs)
self.register_background_update_handler(
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
)
@defer.inlineCallbacks
def _background_reindex_search(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
INSERT_CLUMP_SIZE = 1000
TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
def reindex_search_txn(txn):
sql = (
"SELECT stream_ordering, event_id FROM events"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
" AND (%s)"
" ORDER BY stream_ordering DESC"
" LIMIT ?"
) % (" OR ".join("type = '%s'" % (t,) for t in TYPES),)
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
rows = txn.fetchall()
if not rows:
return 0
min_stream_id = rows[-1][0]
event_ids = [row[1] for row in rows]
events = self._get_events_txn(txn, event_ids)
event_search_rows = []
for event in events:
try:
event_id = event.event_id
room_id = event.room_id
content = event.content
if event.type == "m.room.message":
key = "content.body"
value = content["body"]
elif event.type == "m.room.topic":
key = "content.topic"
value = content["topic"]
elif event.type == "m.room.name":
key = "content.name"
value = content["name"]
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
continue
event_search_rows.append((event_id, room_id, key, value))
if isinstance(self.database_engine, PostgresEngine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, vector)"
" VALUES (?,?,?,to_tsvector('english', ?))"
)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
for index in range(0, len(event_search_rows), INSERT_CLUMP_SIZE):
clump = event_search_rows[index:index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
"max_stream_id_exclusive": min_stream_id,
"rows_inserted": rows_inserted + len(event_search_rows)
}
self._background_update_progress_txn(
txn, self.EVENT_SEARCH_UPDATE_NAME, progress
)
return len(event_search_rows)
result = yield self.runInteraction(
self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
)
if not result:
yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME)
defer.returnValue(result)
@defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.
Args:
room_ids (list): List of room ids to search in
search_term (str): Search term to search for
keys (list): List of keys to search in, currently supports
"content.body", "content.name", "content.topic"
Returns:
list of dicts
"""
clauses = []
args = []
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
if len(room_ids) < 500:
clauses.append(
"room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
)
args.extend(room_ids)
local_clauses = []
for key in keys:
local_clauses.append("key = ?")
args.append(key)
clauses.append(
"(%s)" % (" OR ".join(local_clauses),)
)
if isinstance(self.database_engine, PostgresEngine):
sql = (
"SELECT ts_rank_cd(vector, query) AS rank, room_id, event_id"
" FROM plainto_tsquery('english', ?) as query, event_search"
" WHERE vector @@ query"
)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"SELECT rank(matchinfo(event_search)) as rank, room_id, event_id"
" FROM event_search"
" WHERE value MATCH ?"
)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
for clause in clauses:
sql += " AND " + clause
# We add an arbitrary limit here to ensure we don't try to pull the
# entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500"
results = yield self._execute(
"search_msgs", self.cursor_to_dict, sql, *([search_term] + args)
)
results = filter(lambda row: row["room_id"] in room_ids, results)
events = yield self._get_events([r["event_id"] for r in results])
event_map = {
ev.event_id: ev
for ev in events
}
defer.returnValue([
{
"event": event_map[r["event_id"]],
"rank": r["rank"],
}
for r in results
if r["event_id"] in event_map
])
@defer.inlineCallbacks
def search_room(self, room_id, search_term, keys, limit, pagination_token=None):
"""Performs a full text search over events with given keys.
Args:
room_id (str): The room_id to search in
search_term (str): Search term to search for
keys (list): List of keys to search in, currently supports
"content.body", "content.name", "content.topic"
pagination_token (str): A pagination token previously returned
Returns:
list of dicts
"""
clauses = []
args = [search_term, room_id]
local_clauses = []
for key in keys:
local_clauses.append("key = ?")
args.append(key)
clauses.append(
"(%s)" % (" OR ".join(local_clauses),)
)
if pagination_token:
try:
topo, stream = pagination_token.split(",")
topo = int(topo)
stream = int(stream)
except:
raise SynapseError(400, "Invalid pagination token")
clauses.append(
"(topological_ordering < ?"
" OR (topological_ordering = ? AND stream_ordering < ?))"
)
args.extend([topo, topo, stream])
if isinstance(self.database_engine, PostgresEngine):
sql = (
"SELECT ts_rank_cd(vector, query) as rank,"
" topological_ordering, stream_ordering, room_id, event_id"
" FROM plainto_tsquery('english', ?) as query, event_search"
" NATURAL JOIN events"
" WHERE vector @@ query AND room_id = ?"
)
elif isinstance(self.database_engine, Sqlite3Engine):
# We use CROSS JOIN here to ensure we use the right indexes.
# https://sqlite.org/optoverview.html#crossjoin
#
# We want to use the full text search index on event_search to
# extract all possible matches first, then lookup those matches
# in the events table to get the topological ordering. We need
# to use the indexes in this order because sqlite refuses to
# MATCH unless it uses the full text search index
sql = (
"SELECT rank(matchinfo) as rank, room_id, event_id,"
" topological_ordering, stream_ordering"
" FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo"
" FROM event_search"
" WHERE value MATCH ?"
" )"
" CROSS JOIN events USING (event_id)"
" WHERE room_id = ?"
)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
for clause in clauses:
sql += " AND " + clause
# We add an arbitrary limit here to ensure we don't try to pull the
# entire table from the database.
sql += " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
args.append(limit)
results = yield self._execute(
"search_rooms", self.cursor_to_dict, sql, *args
)
events = yield self._get_events([r["event_id"] for r in results])
event_map = {
ev.event_id: ev
for ev in events
}
defer.returnValue([
{
"event": event_map[r["event_id"]],
"rank": r["rank"],
"pagination_token": "%s,%s" % (
r["topological_ordering"], r["stream_ordering"]
),
}
for r in results
if r["event_id"] in event_map
])
+6 -19
View File
@@ -54,7 +54,7 @@ class StateStore(SQLBaseStore):
defer.returnValue({})
event_to_groups = yield self._get_state_group_for_events(
event_ids,
room_id, event_ids,
)
groups = set(event_to_groups.values())
@@ -208,12 +208,13 @@ class StateStore(SQLBaseStore):
)
@defer.inlineCallbacks
def get_state_for_events(self, event_ids, types):
def get_state_for_events(self, room_id, event_ids, types):
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event. The state dicts will only have the type/state_keys
that are in the `types` list.
Args:
room_id (str)
event_ids (list)
types (list): List of (type, state_key) tuples which are used to
filter the state fetched. `state_key` may be None, which matches
@@ -224,7 +225,7 @@ class StateStore(SQLBaseStore):
The dicts are mappings from (type, state_key) -> state_events
"""
event_to_groups = yield self._get_state_group_for_events(
event_ids,
room_id, event_ids,
)
groups = set(event_to_groups.values())
@@ -237,20 +238,6 @@ class StateStore(SQLBaseStore):
defer.returnValue({event: event_to_state[event] for event in event_ids})
@defer.inlineCallbacks
def get_state_for_event(self, event_id, types=None):
"""
Get the state dict corresponding to a particular event
:param str event_id: event whose state should be returned
:param list[(str, str)]|None types: List of (type, state_key) tuples
which are used to filter the state fetched. May be None, which
matches any key
:return: a deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_for_events([event_id], types)
defer.returnValue(state_map[event_id])
@cached(num_args=2, lru=True, max_entries=10000)
def _get_state_group_for_event(self, room_id, event_id):
return self._simple_select_one_onecol(
@@ -264,8 +251,8 @@ class StateStore(SQLBaseStore):
)
@cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids",
num_args=1)
def _get_state_group_for_events(self, event_ids):
num_args=2)
def _get_state_group_for_events(self, room_id, event_ids):
"""Returns mapping event_id -> state_group
"""
def f(txn):
+10 -174
View File
@@ -23,7 +23,7 @@ paginate bacwards.
This is implemented by keeping two ordering columns: stream_ordering and
topological_ordering. Stream ordering is basically insertion/received order
(except for events from backfill requests). The topological_ordering is a
(except for events from backfill requests). The topolgical_ordering is a
weak ordering of events based on the pdu graph.
This means that we have to have two different types of tokens, depending on
@@ -158,40 +158,14 @@ class StreamStore(SQLBaseStore):
defer.returnValue(results)
@log_function
def get_room_events_stream(
self,
user_id,
from_key,
to_key,
limit=0,
is_guest=False,
room_ids=None
):
room_ids = room_ids or []
room_ids = [r for r in room_ids]
if is_guest:
current_room_membership_sql = (
"SELECT c.room_id FROM history_visibility AS h"
" INNER JOIN current_state_events AS c"
" ON h.event_id = c.event_id"
" WHERE c.room_id IN (%s) AND h.history_visibility = 'world_readable'" % (
",".join(map(lambda _: "?", room_ids))
)
)
current_room_membership_args = room_ids
else:
current_room_membership_sql = (
"SELECT m.room_id FROM room_memberships as m "
" INNER JOIN current_state_events as c"
" ON m.event_id = c.event_id AND c.state_key = m.user_id"
" WHERE m.user_id = ? AND m.membership = 'join'"
)
current_room_membership_args = [user_id]
if room_ids:
current_room_membership_sql += " AND m.room_id in (%s)" % (
",".join(map(lambda _: "?", room_ids))
)
current_room_membership_args = [user_id] + room_ids
def get_room_events_stream(self, user_id, from_key, to_key, room_id,
limit=0):
current_room_membership_sql = (
"SELECT m.room_id FROM room_memberships as m "
" INNER JOIN current_state_events as c"
" ON m.event_id = c.event_id AND c.state_key = m.user_id"
" WHERE m.user_id = ? AND m.membership = 'join'"
)
# We also want to get any membership events about that user, e.g.
# invites or leave notifications.
@@ -200,7 +174,6 @@ class StreamStore(SQLBaseStore):
"INNER JOIN current_state_events as c ON m.event_id = c.event_id "
"WHERE m.user_id = ? "
)
membership_args = [user_id]
if limit:
limit = max(limit, MAX_STREAM_SIZE)
@@ -227,9 +200,7 @@ class StreamStore(SQLBaseStore):
}
def f(txn):
args = ([False] + current_room_membership_args + membership_args +
[from_id.stream, to_id.stream])
txn.execute(sql, args)
txn.execute(sql, (False, user_id, user_id, from_id.stream, to_id.stream,))
rows = self.cursor_to_dict(txn)
@@ -465,138 +436,3 @@ class StreamStore(SQLBaseStore):
internal = event.internal_metadata
internal.before = str(RoomStreamToken(topo, stream - 1))
internal.after = str(RoomStreamToken(topo, stream))
@defer.inlineCallbacks
def get_events_around(self, room_id, event_id, before_limit, after_limit):
"""Retrieve events and pagination tokens around a given event in a
room.
Args:
room_id (str)
event_id (str)
before_limit (int)
after_limit (int)
Returns:
dict
"""
results = yield self.runInteraction(
"get_events_around", self._get_events_around_txn,
room_id, event_id, before_limit, after_limit
)
events_before = yield self._get_events(
[e for e in results["before"]["event_ids"]],
get_prev_content=True
)
events_after = yield self._get_events(
[e for e in results["after"]["event_ids"]],
get_prev_content=True
)
defer.returnValue({
"events_before": events_before,
"events_after": events_after,
"start": results["before"]["token"],
"end": results["after"]["token"],
})
def _get_events_around_txn(self, txn, room_id, event_id, before_limit, after_limit):
"""Retrieves event_ids and pagination tokens around a given event in a
room.
Args:
room_id (str)
event_id (str)
before_limit (int)
after_limit (int)
Returns:
dict
"""
results = self._simple_select_one_txn(
txn,
"events",
keyvalues={
"event_id": event_id,
"room_id": room_id,
},
retcols=["stream_ordering", "topological_ordering"],
)
stream_ordering = results["stream_ordering"]
topological_ordering = results["topological_ordering"]
query_before = (
"SELECT topological_ordering, stream_ordering, event_id FROM events"
" WHERE room_id = ? AND (topological_ordering < ?"
" OR (topological_ordering = ? AND stream_ordering < ?))"
" ORDER BY topological_ordering DESC, stream_ordering DESC"
" LIMIT ?"
)
query_after = (
"SELECT topological_ordering, stream_ordering, event_id FROM events"
" WHERE room_id = ? AND (topological_ordering > ?"
" OR (topological_ordering = ? AND stream_ordering > ?))"
" ORDER BY topological_ordering ASC, stream_ordering ASC"
" LIMIT ?"
)
txn.execute(
query_before,
(
room_id, topological_ordering, topological_ordering,
stream_ordering, before_limit,
)
)
rows = self.cursor_to_dict(txn)
events_before = [r["event_id"] for r in rows]
if rows:
start_token = str(RoomStreamToken(
rows[0]["topological_ordering"],
rows[0]["stream_ordering"] - 1,
))
else:
start_token = str(RoomStreamToken(
topological_ordering,
stream_ordering - 1,
))
txn.execute(
query_after,
(
room_id, topological_ordering, topological_ordering,
stream_ordering, after_limit,
)
)
rows = self.cursor_to_dict(txn)
events_after = [r["event_id"] for r in rows]
if rows:
end_token = str(RoomStreamToken(
rows[-1]["topological_ordering"],
rows[-1]["stream_ordering"],
))
else:
end_token = str(RoomStreamToken(
topological_ordering,
stream_ordering,
))
return {
"before": {
"event_ids": events_before,
"token": start_token,
},
"after": {
"event_ids": events_after,
"token": end_token,
},
}
-216
View File
@@ -1,216 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from twisted.internet import defer
from .util.id_generators import StreamIdGenerator
import ujson as json
import logging
logger = logging.getLogger(__name__)
class TagsStore(SQLBaseStore):
def __init__(self, hs):
super(TagsStore, self).__init__(hs)
self._private_user_data_id_gen = StreamIdGenerator(
"private_user_data_max_stream_id", "stream_id"
)
def get_max_private_user_data_stream_id(self):
"""Get the current max stream id for the private user data stream
Returns:
A deferred int.
"""
return self._private_user_data_id_gen.get_max_token(self)
@cached()
def get_tags_for_user(self, user_id):
"""Get all the tags for a user.
Args:
user_id(str): The user to get the tags for.
Returns:
A deferred dict mapping from room_id strings to lists of tag
strings.
"""
deferred = self._simple_select_list(
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
)
@deferred.addCallback
def tags_by_room(rows):
tags_by_room = {}
for row in rows:
room_tags = tags_by_room.setdefault(row["room_id"], {})
room_tags[row["tag"]] = json.loads(row["content"])
return tags_by_room
return deferred
@defer.inlineCallbacks
def get_updated_tags(self, user_id, stream_id):
"""Get all the tags for the rooms where the tags have changed since the
given version
Args:
user_id(str): The user to get the tags for.
stream_id(int): The earliest update to get for the user.
Returns:
A deferred dict mapping from room_id strings to lists of tag
strings for all the rooms that changed since the stream_id token.
"""
def get_updated_tags_txn(txn):
sql = (
"SELECT room_id from room_tags_revisions"
" WHERE user_id = ? AND stream_id > ?"
)
txn.execute(sql, (user_id, stream_id))
room_ids = [row[0] for row in txn.fetchall()]
return room_ids
room_ids = yield self.runInteraction(
"get_updated_tags", get_updated_tags_txn
)
results = {}
if room_ids:
tags_by_room = yield self.get_tags_for_user(user_id)
for room_id in room_ids:
results[room_id] = tags_by_room.get(room_id, {})
defer.returnValue(results)
def get_tags_for_room(self, user_id, room_id):
"""Get all the tags for the given room
Args:
user_id(str): The user to get tags for
room_id(str): The room to get tags for
Returns:
A deferred list of string tags.
"""
return self._simple_select_list(
table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id},
retcols=("tag", "content"),
desc="get_tags_for_room",
).addCallback(lambda rows: {
row["tag"]: json.loads(row["content"]) for row in rows
})
@defer.inlineCallbacks
def add_tag_to_room(self, user_id, room_id, tag, content):
"""Add a tag to a room for a user.
Args:
user_id(str): The user to add a tag for.
room_id(str): The room to add a tag for.
tag(str): The tag name to add.
content(dict): A json object to associate with the tag.
Returns:
A deferred that completes once the tag has been added.
"""
content_json = json.dumps(content)
def add_tag_txn(txn, next_id):
self._simple_upsert_txn(
txn,
table="room_tags",
keyvalues={
"user_id": user_id,
"room_id": room_id,
"tag": tag,
},
values={
"content": content_json,
}
)
self._update_revision_txn(txn, user_id, room_id, next_id)
with (yield self._private_user_data_id_gen.get_next(self)) as next_id:
yield self.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
result = yield self._private_user_data_id_gen.get_max_token(self)
defer.returnValue(result)
@defer.inlineCallbacks
def remove_tag_from_room(self, user_id, room_id, tag):
"""Remove a tag from a room for a user.
Returns:
A deferred that completes once the tag has been removed
"""
def remove_tag_txn(txn, next_id):
sql = (
"DELETE FROM room_tags "
" WHERE user_id = ? AND room_id = ? AND tag = ?"
)
txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id)
with (yield self._private_user_data_id_gen.get_next(self)) as next_id:
yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
result = yield self._private_user_data_id_gen.get_max_token(self)
defer.returnValue(result)
def _update_revision_txn(self, txn, user_id, room_id, next_id):
"""Update the latest revision of the tags for the given user and room.
Args:
txn: The database cursor
user_id(str): The ID of the user.
room_id(str): The ID of the room.
next_id(int): The the revision to advance to.
"""
update_max_id_sql = (
"UPDATE private_user_data_max_stream_id"
" SET stream_id = ?"
" WHERE stream_id < ?"
)
txn.execute(update_max_id_sql, (next_id, next_id))
update_sql = (
"UPDATE room_tags_revisions"
" SET stream_id = ?"
" WHERE user_id = ?"
" AND room_id = ?"
)
txn.execute(update_sql, (next_id, user_id, room_id))
if txn.rowcount == 0:
insert_sql = (
"INSERT INTO room_tags_revisions (user_id, room_id, stream_id)"
" VALUES (?, ?, ?)"
)
try:
txn.execute(insert_sql, (user_id, room_id, next_id))
except self.database_engine.module.IntegrityError:
# Ignore insertion errors. It doesn't matter if the row wasn't
# inserted because if two updates happend concurrently the one
# with the higher stream_id will not be reported to a client
# unless the previous update has completed. It doesn't matter
# which stream_id ends up in the table, as long as it is higher
# than the id that the client has.
pass

Some files were not shown because too many files have changed in this diff Show More