1
0

Compare commits

...

65 Commits

Author SHA1 Message Date
Erik Johnston bdba57edf1 Support pagination for tokens without chunk part 2018-06-05 11:03:33 +01:00
Erik Johnston 9eaf69a386 Merge pull request #3315 from matrix-org/erikj/chunk_pag_1
Implement pagination using chunks
2018-06-01 15:17:58 +01:00
Erik Johnston c33810d9cc Remove spurious conditional 2018-06-01 11:55:08 +01:00
Erik Johnston 58aadd3dd4 Remove spurious break 2018-06-01 11:54:24 +01:00
Erik Johnston e7bb34b72a Use *row 2018-06-01 11:53:43 +01:00
Erik Johnston 9e7cf48461 Reuse stream_ordering attribute instead of order
The internal metadata "order" attribute was only used in one place,
which was equivalent to using the stream ordering anyway.
2018-06-01 11:51:11 +01:00
Erik Johnston 5bf4fa0fc4 Don't drop topo ordering when there is no chunk_id 2018-06-01 11:43:03 +01:00
Erik Johnston 80a877e9d9 Comment on stream vs topological vs depth ordering in schema 2018-06-01 11:31:16 +01:00
Erik Johnston 47b36e9a02 Update docs for RoomStreamToken 2018-06-01 11:19:57 +01:00
Erik Johnston b671e57759 Implement pagination using chunks 2018-05-31 11:27:31 +01:00
Erik Johnston bf599cdba1 Use calculated topological ordering when persisting events 2018-05-31 10:18:40 +01:00
Erik Johnston 6188512b18 Add chunk ID to pagination token 2018-05-31 10:04:33 +01:00
Erik Johnston 867132f28c Merge pull request #3240 from matrix-org/erikj/events_chunks
Compute new chunks for new events
2018-05-31 09:37:52 +01:00
Erik Johnston 384731330d Rename func to _insert_into_chunk_txn 2018-05-30 11:51:03 +01:00
Erik Johnston 9e1d3f119a Remove unnecessary COALESCE 2018-05-30 11:45:58 +01:00
Erik Johnston f687d8fae2 Comments 2018-05-30 11:45:41 +01:00
Erik Johnston ecd4931ab2 Just iterate once rather than create a new set 2018-05-30 11:35:02 +01:00
Erik Johnston 1cdd0d3b0d Remove redundant conditions 2018-05-30 11:33:57 +01:00
Erik Johnston 1810cc3f7e Remove unnecessary set 2018-05-30 11:32:27 +01:00
Erik Johnston 6c1d13a15a Correctly loop over events_and_contexts 2018-05-30 11:30:33 +01:00
Erik Johnston 13dbcafb9b Compute new chunks for new events
We also calculate a consistent topological ordering within a chunk, but
it isn't used yet.
2018-05-25 10:54:23 +01:00
Erik Johnston bcc9e7f777 Merge branch 'develop' of github.com:matrix-org/synapse into erikj/room_chunks 2018-05-25 10:53:43 +01:00
Amber Brown 9c36c150e7 Merge pull request #3283 from NotAFile/py3-state
py3-ize state.py
2018-05-24 14:23:33 -05:00
Amber Brown cc1349c06a Merge pull request #3279 from NotAFile/py3-more-iteritems
more six iteritems
2018-05-24 14:23:13 -05:00
Amber Brown 5b788aba90 Merge pull request #3280 from NotAFile/py3-more-misc
More Misc. py3 fixes
2018-05-24 14:22:59 -05:00
Adrian Tschira 0e61705661 py3-ize state.py 2018-05-24 20:59:00 +02:00
Adrian Tschira 17a70cf6e9 Misc. py3 fixes
Signed-off-by: Adrian Tschira <nota@notafile.com>
2018-05-24 20:20:33 +02:00
Adrian Tschira 6c16a4ec1b more iteritems 2018-05-24 20:19:06 +02:00
Amber Brown 7ea07c7305 Merge pull request #3278 from NotAFile/py3-storage-base
Py3 storage/_base.py
2018-05-24 13:08:09 -05:00
Amber Brown 1f69693347 Merge pull request #3244 from NotAFile/py3-six-4
replace some iteritems with six
2018-05-24 13:04:07 -05:00
Amber Brown c4fb15a06c Merge pull request #3246 from NotAFile/py3-repr-string
use repr, not str
2018-05-24 13:00:20 -05:00
Amber Brown 36501068d8 Merge pull request #3247 from NotAFile/py3-misc
Misc Python3 fixes
2018-05-24 12:58:37 -05:00
Amber Brown 2aff6eab6d Merge pull request #3245 from NotAFile/batch-iter
Add batch_iter to utils
2018-05-24 12:54:12 -05:00
Adrian Tschira 095292304f Py3 storage/_base.py
Signed-off-by: Adrian Tschira <nota@notafile.com>
2018-05-24 18:24:12 +02:00
David Baker ecc4b88bd1 Merge pull request #3277 from matrix-org/dbkr/remove_from_user_dir
Remove users from user directory on deactivate
2018-05-24 16:12:12 +01:00
Erik Johnston 46345187cc Merge pull request #3243 from NotAFile/py3-six-3
Replace some more comparisons with six
2018-05-24 16:08:57 +01:00
Neil Johnson 037c6db85d Merge branch 'master' into develop 2018-05-24 16:03:44 +01:00
David Baker 7a1af504d7 Remove users from user directory on deactivate 2018-05-24 15:59:58 +01:00
Erik Johnston f72d5a44d5 Merge pull request #3261 from matrix-org/erikj/pagination_fixes
Fix federation backfill bugs
2018-05-24 14:52:03 +01:00
Erik Johnston 68399fc4de Merge pull request #3267 from matrix-org/erikj/iter_filter
Use iter* methods for _filter_events_for_server
2018-05-24 14:44:57 +01:00
Richard van der Hoff 8c98281b8d Merge branch 'release-v0.30.0' into develop 2018-05-24 10:33:12 +01:00
Richard van der Hoff 96f07cebda Merge pull request #3268 from matrix-org/rav/privacy_policy_docs
Docs on consent bits
2018-05-23 16:23:05 +01:00
Erik Johnston 6e11803ed3 Merge branch 'develop' of github.com:matrix-org/synapse into erikj/room_chunks 2018-05-23 10:54:14 +01:00
Erik Johnston 5aaa3189d5 s/values/itervalues/ 2018-05-23 10:13:05 +01:00
Erik Johnston 0a4bca4134 Use iter* methods for _filter_events_for_server 2018-05-23 10:04:23 +01:00
Erik Johnston e85b5a0ff7 Use iter* methods 2018-05-22 19:02:48 +01:00
Erik Johnston 586b66b197 Fix that states is a dict of dicts 2018-05-22 19:02:36 +01:00
Erik Johnston cb2a2ad791 get_domains_from_state returns list of tuples 2018-05-22 16:23:39 +01:00
Adrian Tschira 933bf2dd35 replace some iteritems with six
Signed-off-by: Adrian Tschira <nota@notafile.com>
2018-05-19 17:59:26 +02:00
Adrian Tschira d9fe2b2d9d Replace some more comparisons with six
plus a bonus b"" string I missed last time

Signed-off-by: Adrian Tschira <nota@notafile.com>
2018-05-19 17:56:31 +02:00
Adrian Tschira 45b55e23d3 Add batch_iter to utils
There's a frequent idiom I noticed where an iterable is split up into a
number of chunks/batches. Unfortunately that method does not work with
iterators like dict.keys() in python3. This implementation works with
iterators.

Signed-off-by: Adrian Tschira <nota@notafile.com>
2018-05-19 17:48:30 +02:00
Adrian Tschira dcc235b47d use stand-in value if maxint is not available
Signed-off-by: Adrian Tschira <nota@notafile.com>
2018-05-19 17:35:44 +02:00
Adrian Tschira 73cbdef5f7 fix py3 intern and remove unnecessary py3 encode
Signed-off-by: Adrian Tschira <nota@notafile.com>
2018-05-19 17:35:31 +02:00
Adrian Tschira aafb0f6b0d py3-ize url preview 2018-05-19 17:35:20 +02:00
Adrian Tschira b932b4ea25 use repr, not str
Signed-off-by: Adrian Tschira <nota@notafile.com>
2018-05-19 17:28:42 +02:00
Erik Johnston 0a325e5385 Merge pull request #3226 from matrix-org/erikj/chunk_base
Begin adding implementing room chunks
2018-05-18 13:54:34 +01:00
Erik Johnston b725e128f8 Comments 2018-05-18 13:43:01 +01:00
Erik Johnston 0504d809fd More comments 2018-05-17 17:08:36 +01:00
Erik Johnston 12fd6d7688 Document case of unconnected chunks 2018-05-17 16:07:20 +01:00
Erik Johnston a638649254 Make insert_* functions internal and reorder funcs
This makes it clearer what the public interface is vs what subclasses
need to implement.
2018-05-17 15:10:23 +01:00
Erik Johnston d4e4a7344f Increase range of rebalance interval
This both simplifies the code, and ensures that the target node is
roughly in the center of the range rather than at an end.
2018-05-17 15:09:31 +01:00
Erik Johnston c771c124d5 Improve documentation and comments 2018-05-17 15:09:10 +01:00
Erik Johnston 3369354b56 Add note about index in changelog 2018-05-17 14:00:54 +01:00
Erik Johnston 3b505a80dc Merge branch 'develop' of github.com:matrix-org/synapse into erikj/chunk_base 2018-05-17 14:00:41 +01:00
Erik Johnston 943f1029d6 Begin adding implementing room chunks
This commit adds the necessary tables and columns, as well as an
implementation of an online topological sorting algorithm to maintain an
absolute ordering of the room chunks.
2018-05-17 12:05:22 +01:00
46 changed files with 1631 additions and 221 deletions
+8 -1
View File
@@ -1,3 +1,11 @@
Changes in <unreleased>
=======================
This release adds an index to the events table. This means that on first
startup there will be an inceased amount of IO until the index is created, and
an increase in disk usage.
Changes in synapse v0.30.0 (2018-05-24)
==========================================
@@ -53,7 +61,6 @@ Bug Fixes:
* Fix error in handling receipts (PR #3235)
* Stop the transaction cache caching failures (PR #3255)
Changes in synapse v0.29.1 (2018-05-17)
==========================================
Changes:
+3 -1
View File
@@ -20,6 +20,8 @@ from frozendict import frozendict
import re
from six import string_types
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
# (?<!stuff) matches if the current position in the string is not preceded
# by a match for 'stuff'.
@@ -277,7 +279,7 @@ def serialize_event(e, time_now_ms, as_client_event=True,
if only_event_fields:
if (not isinstance(only_event_fields, list) or
not all(isinstance(f, basestring) for f in only_event_fields)):
not all(isinstance(f, string_types) for f in only_event_fields)):
raise TypeError("only_event_fields must be a list of strings")
d = only_fields(d, only_event_fields)
+4 -2
View File
@@ -17,6 +17,8 @@ from synapse.types import EventID, RoomID, UserID
from synapse.api.errors import SynapseError
from synapse.api.constants import EventTypes, Membership
from six import string_types
class EventValidator(object):
@@ -49,7 +51,7 @@ class EventValidator(object):
strings.append("state_key")
for s in strings:
if not isinstance(getattr(event, s), basestring):
if not isinstance(getattr(event, s), string_types):
raise SynapseError(400, "Not '%s' a string type" % (s,))
if event.type == EventTypes.Member:
@@ -88,5 +90,5 @@ class EventValidator(object):
for s in keys:
if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,))
if not isinstance(d[s], basestring):
if not isinstance(d[s], string_types):
raise SynapseError(400, "Not '%s' a string type" % (s,))
+3 -1
View File
@@ -20,6 +20,8 @@ from synapse.api.errors import SynapseError
from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
from twisted.internet import defer
from six import string_types
logger = logging.getLogger(__name__)
@@ -431,7 +433,7 @@ class GroupsServerHandler(object):
"long_description"):
if keyname in content:
value = content[keyname]
if not isinstance(value, basestring):
if not isinstance(value, string_types):
raise SynapseError(400, "%r value is not a string" % (keyname,))
profile[keyname] = value
+4
View File
@@ -30,6 +30,7 @@ class DeactivateAccountHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
self._room_member_handler = hs.get_room_member_handler()
self.user_directory_handler = hs.get_user_directory_handler()
# Flag that indicates whether the process to part users from rooms is running
self._user_parter_running = False
@@ -65,6 +66,9 @@ class DeactivateAccountHandler(BaseHandler):
# removal from all the rooms they're a member of)
yield self.store.add_user_pending_deactivation(user_id)
# delete from user directory
yield self.user_directory_handler.handle_user_deactivated(user_id)
# Now start the process that goes through that list and
# parts users from rooms (if it isn't already running)
self._start_user_parting()
+8 -6
View File
@@ -26,6 +26,8 @@ from ._base import BaseHandler
import logging
from six import itervalues, iteritems
logger = logging.getLogger(__name__)
@@ -318,7 +320,7 @@ class DeviceHandler(BaseHandler):
# The user may have left the room
# TODO: Check if they actually did or if we were just invited.
if room_id not in room_ids:
for key, event_id in current_state_ids.iteritems():
for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
@@ -338,7 +340,7 @@ class DeviceHandler(BaseHandler):
# special-case for an empty prev state: include all members
# in the changed list
if not event_ids:
for key, event_id in current_state_ids.iteritems():
for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
@@ -354,10 +356,10 @@ class DeviceHandler(BaseHandler):
# Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users.
for state_dict in prev_state_ids.itervalues():
for state_dict in itervalues(prev_state_ids):
member_event = state_dict.get((EventTypes.Member, user_id), None)
if not member_event or member_event != current_member_id:
for key, event_id in current_state_ids.iteritems():
for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
@@ -367,14 +369,14 @@ class DeviceHandler(BaseHandler):
# If there has been any change in membership, include them in the
# possibly changed list. We'll check if they are joined below,
# and we're not toooo worried about spuriously adding users.
for key, event_id in current_state_ids.iteritems():
for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
# check if this member has changed since any of the extremities
# at the stream_ordering, and add them to the list if so.
for state_dict in prev_state_ids.itervalues():
for state_dict in itervalues(prev_state_ids):
prev_event_id = state_dict.get(key, None)
if not prev_event_id or prev_event_id != event_id:
if state_key != user_id:
+7 -6
View File
@@ -19,6 +19,7 @@ import logging
from canonicaljson import encode_canonical_json
from twisted.internet import defer
from six import iteritems
from synapse.api.errors import (
SynapseError, CodeMessageException, FederationDeniedError,
@@ -92,7 +93,7 @@ class E2eKeysHandler(object):
remote_queries_not_in_cache = {}
if remote_queries:
query_list = []
for user_id, device_ids in remote_queries.iteritems():
for user_id, device_ids in iteritems(remote_queries):
if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids)
else:
@@ -103,9 +104,9 @@ class E2eKeysHandler(object):
query_list
)
)
for user_id, devices in remote_results.iteritems():
for user_id, devices in iteritems(remote_results):
user_devices = results.setdefault(user_id, {})
for device_id, device in devices.iteritems():
for device_id, device in iteritems(devices):
keys = device.get("keys", None)
device_display_name = device.get("device_display_name", None)
if keys:
@@ -250,9 +251,9 @@ class E2eKeysHandler(object):
"Claimed one-time-keys: %s",
",".join((
"%s for %s:%s" % (key_id, user_id, device_id)
for user_id, user_keys in json_result.iteritems()
for device_id, device_keys in user_keys.iteritems()
for key_id, _ in device_keys.iteritems()
for user_id, user_keys in iteritems(json_result)
for device_id, device_keys in iteritems(user_keys)
for key_id, _ in iteritems(device_keys)
)),
)
+32 -18
View File
@@ -24,6 +24,7 @@ from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
import six
from six.moves import http_client
from six import iteritems
from twisted.internet import defer
from unpaddedbase64 import decode_base64
@@ -479,18 +480,18 @@ class FederationHandler(BaseHandler):
# to get all state ids that we're interested in.
event_map = yield self.store.get_events([
e_id
for key_to_eid in event_to_state_ids.values()
for key, e_id in key_to_eid.items()
for key_to_eid in event_to_state_ids.itervalues()
for key, e_id in key_to_eid.iteritems()
if key[0] != EventTypes.Member or check_match(key[1])
])
event_to_state = {
e_id: {
key: event_map[inner_e_id]
for key, inner_e_id in key_to_eid.items()
for key, inner_e_id in key_to_eid.iteritems()
if inner_e_id in event_map
}
for e_id, key_to_eid in event_to_state_ids.items()
for e_id, key_to_eid in event_to_state_ids.iteritems()
}
def redact_disallowed(event, state):
@@ -505,7 +506,7 @@ class FederationHandler(BaseHandler):
# membership states for the requesting server to determine
# if the server is either in the room or has been invited
# into the room.
for ev in state.values():
for ev in state.itervalues():
if ev.type != EventTypes.Member:
continue
try:
@@ -751,9 +752,19 @@ class FederationHandler(BaseHandler):
curr_state = yield self.state_handler.get_current_state(room_id)
def get_domains_from_state(state):
"""Get joined domains from state
Args:
state (dict[tuple, FrozenEvent]): State map from type/state
key to event.
Returns:
list[tuple[str, int]]: Returns a list of servers with the
lowest depth of their joins. Sorted by lowest depth first.
"""
joined_users = [
(state_key, int(event.depth))
for (e_type, state_key), event in state.items()
for (e_type, state_key), event in state.iteritems()
if e_type == EventTypes.Member
and event.membership == Membership.JOIN
]
@@ -770,7 +781,7 @@ class FederationHandler(BaseHandler):
except Exception:
pass
return sorted(joined_domains.items(), key=lambda d: d[1])
return sorted(joined_domains.iteritems(), key=lambda d: d[1])
curr_domains = get_domains_from_state(curr_state)
@@ -787,7 +798,7 @@ class FederationHandler(BaseHandler):
yield self.backfill(
dom, room_id,
limit=100,
extremities=[e for e in extremities.keys()]
extremities=extremities,
)
# If this succeeded then we probably already have the
# appropriate stuff.
@@ -833,7 +844,7 @@ class FederationHandler(BaseHandler):
tried_domains = set(likely_domains)
tried_domains.add(self.server_name)
event_ids = list(extremities.keys())
event_ids = list(extremities.iterkeys())
logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = logcontext.preserve_fn(
@@ -843,31 +854,34 @@ class FederationHandler(BaseHandler):
[resolve(room_id, [e]) for e in event_ids],
consumeErrors=True,
))
# dict[str, dict[tuple, str]], a map from event_id to state map of
# event_ids.
states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events(
[e_id for ids in states.values() for e_id in ids],
[e_id for ids in states.itervalues() for e_id in ids.itervalues()],
get_prev_content=False
)
states = {
key: {
k: state_map[e_id]
for k, e_id in state_dict.items()
for k, e_id in state_dict.iteritems()
if e_id in state_map
} for key, state_dict in states.items()
} for key, state_dict in states.iteritems()
}
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
success = yield try_backfill([
dom for dom in likely_domains
dom for dom, _ in likely_domains
if dom not in tried_domains
])
if success:
defer.returnValue(True)
tried_domains.update(likely_domains)
tried_domains.update(dom for dom, _ in likely_domains)
defer.returnValue(False)
@@ -1375,7 +1389,7 @@ class FederationHandler(BaseHandler):
)
if state_groups:
_, state = state_groups.items().pop()
_, state = list(iteritems(state_groups)).pop()
results = {
(e.type, e.state_key): e for e in state
}
@@ -2021,7 +2035,7 @@ class FederationHandler(BaseHandler):
this will not be included in the current_state in the context.
"""
state_updates = {
k: a.event_id for k, a in auth_events.iteritems()
k: a.event_id for k, a in iteritems(auth_events)
if k != event_key
}
context.current_state_ids = dict(context.current_state_ids)
@@ -2031,7 +2045,7 @@ class FederationHandler(BaseHandler):
context.delta_ids.update(state_updates)
context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.iteritems()
k: a.event_id for k, a in iteritems(auth_events)
})
context.state_group = yield self.store.store_state_group(
event.event_id,
@@ -2083,7 +2097,7 @@ class FederationHandler(BaseHandler):
def get_next(it, opt=None):
try:
return it.next()
return next(it)
except Exception:
return opt
+2 -1
View File
@@ -15,6 +15,7 @@
# limitations under the License.
from twisted.internet import defer
from six import iteritems
from synapse.api.errors import SynapseError
from synapse.types import get_domain_from_id
@@ -449,7 +450,7 @@ class GroupsLocalHandler(object):
results = {}
failed_results = []
for destination, dest_user_ids in destinations.iteritems():
for destination, dest_user_ids in iteritems(destinations):
try:
r = yield self.transport_client.bulk_get_publicised_groups(
destination, list(dest_user_ids),
+6 -5
View File
@@ -19,6 +19,7 @@ import sys
from canonicaljson import encode_canonical_json
import six
from six import string_types, itervalues, iteritems
from twisted.internet import defer, reactor
from twisted.internet.defer import succeed
from twisted.python.failure import Failure
@@ -234,7 +235,7 @@ class MessageHandler(BaseHandler):
room_id, max_topo
)
events, next_key = yield self.store.paginate_room_events(
events, next_key, extremities = yield self.store.paginate_room_events(
room_id=room_id,
from_key=source_config.from_key,
to_key=source_config.to_key,
@@ -402,7 +403,7 @@ class MessageHandler(BaseHandler):
"avatar_url": profile.avatar_url,
"display_name": profile.display_name,
}
for user_id, profile in users_with_profile.iteritems()
for user_id, profile in iteritems(users_with_profile)
})
@@ -667,7 +668,7 @@ class EventCreationHandler(object):
spam_error = self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, basestring):
if not isinstance(spam_error, string_types):
spam_error = "Spam is not permitted here"
raise SynapseError(
403, spam_error, Codes.FORBIDDEN
@@ -881,7 +882,7 @@ class EventCreationHandler(object):
state_to_include_ids = [
e_id
for k, e_id in context.current_state_ids.iteritems()
for k, e_id in iteritems(context.current_state_ids)
if k[0] in self.hs.config.room_invite_state_types
or k == (EventTypes.Member, event.sender)
]
@@ -895,7 +896,7 @@ class EventCreationHandler(object):
"content": e.content,
"sender": e.sender,
}
for e in state_to_include.itervalues()
for e in itervalues(state_to_include)
]
invitee = UserID.from_string(event.state_key)
+8 -7
View File
@@ -25,6 +25,8 @@ The methods that define policy are:
from twisted.internet import defer, reactor
from contextlib import contextmanager
from six import itervalues, iteritems
from synapse.api.errors import SynapseError
from synapse.api.constants import PresenceState
from synapse.storage.presence import UserPresenceState
@@ -40,7 +42,6 @@ import synapse.metrics
import logging
logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
@@ -530,7 +531,7 @@ class PresenceHandler(object):
prev_state.copy_and_replace(
last_user_sync_ts=time_now_ms,
)
for prev_state in prev_states.itervalues()
for prev_state in itervalues(prev_states)
])
self.external_process_last_updated_ms.pop(process_id, None)
@@ -553,14 +554,14 @@ class PresenceHandler(object):
for user_id in user_ids
}
missing = [user_id for user_id, state in states.iteritems() if not state]
missing = [user_id for user_id, state in iteritems(states) if not state]
if missing:
# There are things not in our in memory cache. Lets pull them out of
# the database.
res = yield self.store.get_presence_for_users(missing)
states.update(res)
missing = [user_id for user_id, state in states.iteritems() if not state]
missing = [user_id for user_id, state in iteritems(states) if not state]
if missing:
new = {
user_id: UserPresenceState.default(user_id)
@@ -1048,7 +1049,7 @@ class PresenceEventSource(object):
defer.returnValue((updates.values(), max_token))
else:
defer.returnValue(([
s for s in updates.itervalues()
s for s in itervalues(updates)
if s.state != PresenceState.OFFLINE
], max_token))
@@ -1305,11 +1306,11 @@ def get_interested_remotes(store, states, state_handler):
# hosts in those rooms.
room_ids_to_states, users_to_states = yield get_interested_parties(store, states)
for room_id, states in room_ids_to_states.iteritems():
for room_id, states in iteritems(room_ids_to_states):
hosts = yield state_handler.get_current_hosts_in_room(room_id)
hosts_and_states.append((hosts, states))
for user_id, states in users_to_states.iteritems():
for user_id, states in iteritems(users_to_states):
host = get_domain_from_id(user_id)
hosts_and_states.append(([host], states))
+3 -2
View File
@@ -514,7 +514,8 @@ class RoomEventSource(object):
events = list(room_events)
events.extend(e for evs, _ in room_to_events.values() for e in evs)
events.sort(key=lambda e: e.internal_metadata.order)
# Order by the stream ordering of the events.
events.sort(key=lambda e: e.internal_metadata.stream_ordering)
if limit:
events[:] = events[:limit]
@@ -534,7 +535,7 @@ class RoomEventSource(object):
@defer.inlineCallbacks
def get_pagination_rows(self, user, config, key):
events, next_key = yield self.store.paginate_room_events(
events, next_key, _ = yield self.store.paginate_room_events(
room_id=key,
from_key=config.from_key,
to_key=config.to_key,
+8 -6
View File
@@ -28,6 +28,8 @@ import collections
import logging
import itertools
from six import itervalues, iteritems
logger = logging.getLogger(__name__)
@@ -275,7 +277,7 @@ class SyncHandler(object):
# result returned by the event source is poor form (it might cache
# the object)
room_id = event["room_id"]
event_copy = {k: v for (k, v) in event.iteritems()
event_copy = {k: v for (k, v) in iteritems(event)
if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
@@ -294,7 +296,7 @@ class SyncHandler(object):
for event in receipts:
room_id = event["room_id"]
# exclude room id, as above
event_copy = {k: v for (k, v) in event.iteritems()
event_copy = {k: v for (k, v) in iteritems(event)
if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
@@ -325,7 +327,7 @@ class SyncHandler(object):
current_state_ids = frozenset()
if any(e.is_state() for e in recents):
current_state_ids = yield self.state.get_current_state_ids(room_id)
current_state_ids = frozenset(current_state_ids.itervalues())
current_state_ids = frozenset(itervalues(current_state_ids))
recents = yield filter_events_for_client(
self.store,
@@ -382,7 +384,7 @@ class SyncHandler(object):
current_state_ids = frozenset()
if any(e.is_state() for e in loaded_recents):
current_state_ids = yield self.state.get_current_state_ids(room_id)
current_state_ids = frozenset(current_state_ids.itervalues())
current_state_ids = frozenset(itervalues(current_state_ids))
loaded_recents = yield filter_events_for_client(
self.store,
@@ -984,7 +986,7 @@ class SyncHandler(object):
if since_token:
for joined_sync in sync_result_builder.joined:
it = itertools.chain(
joined_sync.timeline.events, joined_sync.state.itervalues()
joined_sync.timeline.events, itervalues(joined_sync.state)
)
for event in it:
if event.type == EventTypes.Member:
@@ -1062,7 +1064,7 @@ class SyncHandler(object):
newly_left_rooms = []
room_entries = []
invited = []
for room_id, events in mem_change_events_by_room_id.iteritems():
for room_id, events in iteritems(mem_change_events_by_room_id):
non_joins = [e for e in events if e.membership != Membership.JOIN]
has_join = len(non_joins) != len(events)
+9 -1
View File
@@ -22,6 +22,7 @@ from synapse.util.metrics import Measure
from synapse.util.async import sleep
from synapse.types import get_localpart_from_id
from six import iteritems
logger = logging.getLogger(__name__)
@@ -122,6 +123,13 @@ class UserDirectoryHandler(object):
user_id, profile.display_name, profile.avatar_url, None,
)
@defer.inlineCallbacks
def handle_user_deactivated(self, user_id):
"""Called when a user ID is deactivated
"""
yield self.store.remove_from_user_dir(user_id)
yield self.store.remove_from_user_in_public_room(user_id)
@defer.inlineCallbacks
def _unsafe_process(self):
# If self.pos is None then means we haven't fetched it from DB
@@ -403,7 +411,7 @@ class UserDirectoryHandler(object):
if change:
users_with_profile = yield self.state.get_current_user_in_room(room_id)
for user_id, profile in users_with_profile.iteritems():
for user_id, profile in iteritems(users_with_profile):
yield self._handle_new_user(room_id, user_id, profile)
else:
users = yield self.store.get_users_in_public_due_to_room(room_id)
+5 -3
View File
@@ -42,6 +42,8 @@ import random
import sys
import urllib
from six.moves.urllib import parse as urlparse
from six import string_types
logger = logging.getLogger(__name__)
outbound_logger = logging.getLogger("synapse.http.outbound")
@@ -553,7 +555,7 @@ class MatrixFederationHttpClient(object):
encoded_args = {}
for k, vs in args.items():
if isinstance(vs, basestring):
if isinstance(vs, string_types):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
@@ -668,7 +670,7 @@ def check_content_type_is_json(headers):
RuntimeError if the
"""
c_type = headers.getRawHeaders("Content-Type")
c_type = headers.getRawHeaders(b"Content-Type")
if c_type is None:
raise RuntimeError(
"No Content-Type header"
@@ -685,7 +687,7 @@ def check_content_type_is_json(headers):
def encode_query_args(args):
encoded_args = {}
for k, vs in args.items():
if isinstance(vs, basestring):
if isinstance(vs, string_types):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
+1 -1
View File
@@ -56,7 +56,7 @@ class SynapseRequest(Request):
def __repr__(self):
# We overwrite this so that we don't log ``access_token``
return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % (
self.__class__.__name__,
id(self),
self.method,
+2 -1
View File
@@ -15,6 +15,7 @@
import os
from six import iteritems
TICKS_PER_SEC = 100
BYTES_PER_PAGE = 4096
@@ -55,7 +56,7 @@ def update_resource_metrics():
# line is PID (command) more stats go here ...
raw_stats = line.split(") ", 1)[1].split(" ")
for (name, index) in STAT_FIELDS.iteritems():
for (name, index) in iteritems(STAT_FIELDS):
# subtract 3 from the index, because proc(5) is 1-based, and
# we've lost the first two fields in PID and COMMAND above
stats[name] = int(raw_stats[index - 3])
+7 -6
View File
@@ -30,6 +30,7 @@ from synapse.state import POWER_KEY
from collections import namedtuple
from six import itervalues, iteritems
logger = logging.getLogger(__name__)
@@ -126,7 +127,7 @@ class BulkPushRuleEvaluator(object):
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
(e.type, e.state_key): e for e in auth_events.itervalues()
(e.type, e.state_key): e for e in itervalues(auth_events)
}
sender_level = get_user_power_level(event.sender, auth_events)
@@ -160,7 +161,7 @@ class BulkPushRuleEvaluator(object):
condition_cache = {}
for uid, rules in rules_by_user.iteritems():
for uid, rules in iteritems(rules_by_user):
if event.sender == uid:
continue
@@ -406,7 +407,7 @@ class RulesForRoom(object):
# If the event is a join event then it will be in current state evnts
# map but not in the DB, so we have to explicitly insert it.
if event.type == EventTypes.Member:
for event_id in member_event_ids.itervalues():
for event_id in itervalues(member_event_ids):
if event_id == event.event_id:
members[event_id] = (event.state_key, event.membership)
@@ -414,7 +415,7 @@ class RulesForRoom(object):
logger.debug("Found members %r: %r", self.room_id, members.values())
interested_in_user_ids = set(
user_id for user_id, membership in members.itervalues()
user_id for user_id, membership in itervalues(members)
if membership == Membership.JOIN
)
@@ -426,7 +427,7 @@ class RulesForRoom(object):
)
user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.iteritems() if have_pusher
uid for uid, have_pusher in iteritems(if_users_with_pushers) if have_pusher
)
logger.debug("With pushers: %r", user_ids)
@@ -447,7 +448,7 @@ class RulesForRoom(object):
)
ret_rules_by_user.update(
item for item in rules_by_user.iteritems() if item[0] is not None
item for item in iteritems(rules_by_user) if item[0] is not None
)
self.update_cache(sequence, members, ret_rules_by_user, state_group)
+3 -1
View File
@@ -21,6 +21,8 @@ from synapse.types import UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache
from six import string_types
logger = logging.getLogger(__name__)
@@ -238,7 +240,7 @@ def _flatten_dict(d, prefix=[], result=None):
if result is None:
result = {}
for key, value in d.items():
if isinstance(value, basestring):
if isinstance(value, string_types):
result[".".join(prefix + [key])] = value.lower()
elif hasattr(value, "items"):
_flatten_dict(value, prefix=(prefix + [key]), result=result)
+5 -4
View File
@@ -68,6 +68,7 @@ import synapse.metrics
import struct
import fcntl
from six import iterkeys, iteritems
metrics = synapse.metrics.get_metrics_for(__name__)
@@ -392,7 +393,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
if stream_name == "ALL":
# Subscribe to all streams we're publishing to.
for stream in self.streamer.streams_by_name.iterkeys():
for stream in iterkeys(self.streamer.streams_by_name):
self.subscribe_to_stream(stream, token)
else:
self.subscribe_to_stream(stream_name, token)
@@ -498,7 +499,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
BaseReplicationStreamProtocol.connectionMade(self)
# Once we've connected subscribe to the necessary streams
for stream_name, token in self.handler.get_streams_to_replicate().iteritems():
for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
self.replicate(stream_name, token)
# Tell the server if we have any users currently syncing (should only
@@ -633,7 +634,7 @@ metrics.register_callback(
lambda: {
(k[0], p.name, p.conn_id): count
for p in connected_connections
for k, count in p.inbound_commands_counter.counts.iteritems()
for k, count in iteritems(p.inbound_commands_counter.counts)
},
labels=["command", "name", "conn_id"],
)
@@ -643,7 +644,7 @@ metrics.register_callback(
lambda: {
(k[0], p.name, p.conn_id): count
for p in connected_connections
for k, count in p.outbound_commands_counter.counts.iteritems()
for k, count in iteritems(p.outbound_commands_counter.counts)
},
labels=["command", "name", "conn_id"],
)
+2 -1
View File
@@ -26,6 +26,7 @@ from synapse.util.metrics import Measure, measure_func
import logging
import synapse.metrics
from six import itervalues
metrics = synapse.metrics.get_metrics_for(__name__)
stream_updates_counter = metrics.register_counter(
@@ -80,7 +81,7 @@ class ReplicationStreamer(object):
# We only support federation stream if federation sending hase been
# disabled on the master.
self.streams = [
stream(hs) for stream in STREAMS_MAP.itervalues()
stream(hs) for stream in itervalues(STREAMS_MAP)
if stream != FederationStream or not hs.config.send_federation
]
+5 -3
View File
@@ -23,6 +23,8 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns
from six import string_types
import logging
logger = logging.getLogger(__name__)
@@ -71,7 +73,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
if "status_msg" in content:
state["status_msg"] = content.pop("status_msg")
if not isinstance(state["status_msg"], basestring):
if not isinstance(state["status_msg"], string_types):
raise SynapseError(400, "status_msg must be a string.")
if content:
@@ -129,7 +131,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
if "invite" in content:
for u in content["invite"]:
if not isinstance(u, basestring):
if not isinstance(u, string_types):
raise SynapseError(400, "Bad invite value.")
if len(u) == 0:
continue
@@ -140,7 +142,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
if "drop" in content:
for u in content["drop"]:
if not isinstance(u, basestring):
if not isinstance(u, string_types):
raise SynapseError(400, "Bad drop value.")
if len(u) == 0:
continue
+2 -1
View File
@@ -48,6 +48,7 @@ import shutil
import cgi
import logging
from six.moves.urllib import parse as urlparse
from six import iteritems
logger = logging.getLogger(__name__)
@@ -603,7 +604,7 @@ class MediaRepository(object):
thumbnails[(t_width, t_height, r_type)] = r_method
# Now we generate the thumbnails for each dimension, store it
for (t_width, t_height, t_type), t_method in thumbnails.iteritems():
for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
# Generate the thumbnail
if t_method == "crop":
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@@ -24,7 +24,9 @@ import shutil
import sys
import traceback
import simplejson as json
import urlparse
from six.moves import urllib_parse as urlparse
from six import string_types
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
@@ -590,8 +592,8 @@ def _iterate_over_text(tree, *tags_to_ignore):
# to be returned.
elements = iter([tree])
while True:
el = elements.next()
if isinstance(el, basestring):
el = next(elements)
if isinstance(el, string_types):
yield el
elif el is not None and el.tag not in tags_to_ignore:
# el.text is the text before the first child, so we can immediately
+25 -23
View File
@@ -32,6 +32,8 @@ from frozendict import frozendict
import logging
import hashlib
from six import iteritems, itervalues
logger = logging.getLogger(__name__)
@@ -132,7 +134,7 @@ class StateHandler(object):
state_map = yield self.store.get_events(state.values(), get_prev_content=False)
state = {
key: state_map[e_id] for key, e_id in state.iteritems() if e_id in state_map
key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map
}
defer.returnValue(state)
@@ -338,7 +340,7 @@ class StateHandler(object):
)
if len(state_groups_ids) == 1:
name, state_list = state_groups_ids.items().pop()
name, state_list = list(state_groups_ids.items()).pop()
prev_group, delta_ids = yield self.store.get_state_group_delta(name)
@@ -378,7 +380,7 @@ class StateHandler(object):
new_state = resolve_events_with_state_map(state_set_ids, state_map)
new_state = {
key: state_map[ev_id] for key, ev_id in new_state.iteritems()
key: state_map[ev_id] for key, ev_id in iteritems(new_state)
}
return new_state
@@ -458,15 +460,15 @@ class StateResolutionHandler(object):
# build a map from state key to the event_ids which set that state.
# dict[(str, str), set[str])
state = {}
for st in state_groups_ids.itervalues():
for key, e_id in st.iteritems():
for st in itervalues(state_groups_ids):
for key, e_id in iteritems(st):
state.setdefault(key, set()).add(e_id)
# build a map from state key to the event_ids which set that state,
# including only those where there are state keys in conflict.
conflicted_state = {
k: list(v)
for k, v in state.iteritems()
for k, v in iteritems(state)
if len(v) > 1
}
@@ -474,13 +476,13 @@ class StateResolutionHandler(object):
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_factory(
state_groups_ids.values(),
list(state_groups_ids.values()),
event_map=event_map,
state_map_factory=state_map_factory,
)
else:
new_state = {
key: e_ids.pop() for key, e_ids in state.iteritems()
key: e_ids.pop() for key, e_ids in iteritems(state)
}
with Measure(self.clock, "state.create_group_ids"):
@@ -489,8 +491,8 @@ class StateResolutionHandler(object):
# which will be used as a cache key for future resolutions, but
# not get persisted.
state_group = None
new_state_event_ids = frozenset(new_state.itervalues())
for sg, events in state_groups_ids.iteritems():
new_state_event_ids = frozenset(itervalues(new_state))
for sg, events in iteritems(state_groups_ids):
if new_state_event_ids == frozenset(e_id for e_id in events):
state_group = sg
break
@@ -501,11 +503,11 @@ class StateResolutionHandler(object):
prev_group = None
delta_ids = None
for old_group, old_ids in state_groups_ids.iteritems():
for old_group, old_ids in iteritems(state_groups_ids):
if not set(new_state) - set(old_ids):
n_delta_ids = {
k: v
for k, v in new_state.iteritems()
for k, v in iteritems(new_state)
if old_ids.get(k) != v
}
if not delta_ids or len(n_delta_ids) < len(delta_ids):
@@ -527,7 +529,7 @@ class StateResolutionHandler(object):
def _ordered_events(events):
def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest()
return sorted(events, key=key_func)
@@ -584,7 +586,7 @@ def _seperate(state_sets):
conflicted_state = {}
for state_set in state_sets[1:]:
for key, value in state_set.iteritems():
for key, value in iteritems(state_set):
# Check if there is an unconflicted entry for the state key.
unconflicted_value = unconflicted_state.get(key)
if unconflicted_value is None:
@@ -640,7 +642,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
needed_events = set(
event_id
for event_ids in conflicted_state.itervalues()
for event_ids in itervalues(conflicted_state)
for event_id in event_ids
)
if event_map is not None:
@@ -662,7 +664,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
unconflicted_state, conflicted_state, state_map
)
new_needed_events = set(auth_events.itervalues())
new_needed_events = set(itervalues(auth_events))
new_needed_events -= needed_events
if event_map is not None:
new_needed_events -= set(event_map.iterkeys())
@@ -679,7 +681,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
auth_events = {}
for event_ids in conflicted_state.itervalues():
for event_ids in itervalues(conflicted_state):
for event_id in event_ids:
if event_id in state_map:
keys = event_auth.auth_types_for_event(state_map[event_id])
@@ -694,7 +696,7 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma
def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids,
state_map):
conflicted_state = {}
for key, event_ids in conflicted_state_ds.iteritems():
for key, event_ids in iteritems(conflicted_state_ds):
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
if len(events) > 1:
conflicted_state[key] = events
@@ -703,7 +705,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
auth_events = {
key: state_map[ev_id]
for key, ev_id in auth_event_ids.iteritems()
for key, ev_id in iteritems(auth_event_ids)
if ev_id in state_map
}
@@ -716,7 +718,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
raise
new_state = unconflicted_state_ids
for key, event in resolved_state.iteritems():
for key, event in iteritems(resolved_state):
new_state[key] = event.event_id
return new_state
@@ -741,7 +743,7 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state)
for key, events in conflicted_state.iteritems():
for key, events in iteritems(conflicted_state):
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = _resolve_auth_events(
@@ -751,7 +753,7 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state)
for key, events in conflicted_state.iteritems():
for key, events in iteritems(conflicted_state):
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = _resolve_auth_events(
@@ -761,7 +763,7 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state)
for key, events in conflicted_state.iteritems():
for key, events in iteritems(conflicted_state):
if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = _resolve_normal_events(
+1
View File
@@ -131,6 +131,7 @@ class DataStore(RoomMemberStore, RoomStore,
self._group_updates_id_gen = StreamIdGenerator(
db_conn, "local_group_updates", "stream_id",
)
self._chunk_id_gen = IdGenerator(db_conn, "events", "chunk_id")
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = StreamIdGenerator(
+29 -18
View File
@@ -27,9 +27,17 @@ import sys
import time
import threading
from six import itervalues, iterkeys, iteritems
from six.moves import intern, range
logger = logging.getLogger(__name__)
try:
MAX_TXN_ID = sys.maxint - 1
except AttributeError:
# python 3 does not have a maximum int value
MAX_TXN_ID = 2**63 - 1
sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
perf_logger = logging.getLogger("synapse.storage.TIME")
@@ -137,7 +145,7 @@ class PerformanceCounters(object):
def interval(self, interval_duration, limit=3):
counters = []
for name, (count, cum_time) in self.current_counters.iteritems():
for name, (count, cum_time) in iteritems(self.current_counters):
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
counters.append((
(cum_time - prev_time) / interval_duration,
@@ -222,7 +230,7 @@ class SQLBaseStore(object):
# We don't really need these to be unique, so lets stop it from
# growing really large.
self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
name = "%s-%x" % (desc, txn_id, )
@@ -543,7 +551,7 @@ class SQLBaseStore(object):
", ".join("%s = ?" % (k,) for k in values),
" AND ".join("%s = ?" % (k,) for k in keyvalues)
)
sqlargs = values.values() + keyvalues.values()
sqlargs = list(values.values()) + list(keyvalues.values())
txn.execute(sql, sqlargs)
if txn.rowcount > 0:
@@ -561,7 +569,7 @@ class SQLBaseStore(object):
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues)
)
txn.execute(sql, allvalues.values())
txn.execute(sql, list(allvalues.values()))
# successfully inserted
return True
@@ -629,8 +637,8 @@ class SQLBaseStore(object):
}
if keyvalues:
sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
txn.execute(sql, keyvalues.values())
sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
txn.execute(sql, list(keyvalues.values()))
else:
txn.execute(sql)
@@ -694,7 +702,7 @@ class SQLBaseStore(object):
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
txn.execute(sql, keyvalues.values())
txn.execute(sql, list(keyvalues.values()))
else:
sql = "SELECT %s FROM %s" % (
", ".join(retcols),
@@ -725,9 +733,12 @@ class SQLBaseStore(object):
if not iterable:
defer.returnValue(results)
# iterables can not be sliced, so convert it to a list first
it_list = list(iterable)
chunks = [
iterable[i:i + batch_size]
for i in xrange(0, len(iterable), batch_size)
it_list[i:i + batch_size]
for i in range(0, len(it_list), batch_size)
]
for chunk in chunks:
rows = yield self.runInteraction(
@@ -767,7 +778,7 @@ class SQLBaseStore(object):
)
values.extend(iterable)
for key, value in keyvalues.iteritems():
for key, value in iteritems(keyvalues):
clauses.append("%s = ?" % (key,))
values.append(value)
@@ -790,7 +801,7 @@ class SQLBaseStore(object):
@staticmethod
def _simple_update_txn(txn, table, keyvalues, updatevalues):
if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
else:
where = ""
@@ -802,7 +813,7 @@ class SQLBaseStore(object):
txn.execute(
update_sql,
updatevalues.values() + keyvalues.values()
list(updatevalues.values()) + list(keyvalues.values())
)
return txn.rowcount
@@ -850,7 +861,7 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k,) for k in keyvalues)
)
txn.execute(select_sql, keyvalues.values())
txn.execute(select_sql, list(keyvalues.values()))
row = txn.fetchone()
if not row:
@@ -888,7 +899,7 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
txn.execute(sql, keyvalues.values())
txn.execute(sql, list(keyvalues.values()))
if txn.rowcount == 0:
raise StoreError(404, "No row found")
if txn.rowcount > 1:
@@ -906,7 +917,7 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
return txn.execute(sql, keyvalues.values())
return txn.execute(sql, list(keyvalues.values()))
def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
return self.runInteraction(
@@ -938,7 +949,7 @@ class SQLBaseStore(object):
)
values.extend(iterable)
for key, value in keyvalues.iteritems():
for key, value in iteritems(keyvalues):
clauses.append("%s = ?" % (key,))
values.append(value)
@@ -978,7 +989,7 @@ class SQLBaseStore(object):
txn.close()
if cache:
min_val = min(cache.itervalues())
min_val = min(itervalues(cache))
else:
min_val = max_value
@@ -1093,7 +1104,7 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k,) for k in keyvalues),
" ? ASC LIMIT ? OFFSET ?"
)
txn.execute(sql, keyvalues.values() + pagevalues)
txn.execute(sql, list(keyvalues.values()) + list(pagevalues))
else:
sql = "SELECT %s FROM %s ORDER BY %s" % (
", ".join(retcols),
+319
View File
@@ -0,0 +1,319 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import logging
from synapse.storage._base import SQLBaseStore
from synapse.util.katriel_bodlaender import OrderedListStore
from synapse.util.metrics import Measure
import synapse.metrics
metrics = synapse.metrics.get_metrics_for(__name__)
rebalance_counter = metrics.register_counter("rebalances")
logger = logging.getLogger(__name__)
class ChunkDBOrderedListStore(OrderedListStore):
"""Used as the list store for room chunks, efficiently maintaining them in
topological order on updates.
A room chunk is a connected portion of the room events DAG. Chunks are
constructed so that they have the additional property that for all events in
the chunk, either all of their prev_events are in that chunk or none of them
are. This ensures that no event that is subsequently received needs to be
inserted into the middle of a chunk, since it cannot both reference an event
in the chunk and be referenced by an event in the chunk (assuming no
cycles).
As such the set of chunks in a room inherits a DAG, i.e. if an event in one
chunk references an event in a second chunk, then we say that the first
chunk references the second, and thus forming a DAG. (This means that chunks
start off disconnected until an event is received that connects the two
chunks.)
We can therefore end up with multiple chunks in a room when the server
misses some events, e.g. due to the server being offline for a time.
The server may only have a subset of all events in a room, in which case
its possible for the server to have chunks that are unconnected from each
other. The ordering between unconnected chunks is arbitrary.
The class is designed for use inside transactions and so takes a
transaction object in the constructor. This means that it needs to be
re-instantiated in each transaction, so all state needs to be stored
in the database.
Internally the ordering is implemented using floats, and the average is
taken when a node is inserted between other nodes. To avoid precision
errors a minimum difference between sucessive orderings is attempted to be
kept; whenever the difference is too small we attempt to rebalance. See
the `_rebalance` function for implementation details.
Note that OrderedListStore orders nodes such that source of an edge
comes before the target. This is counter intuitive when edges represent
causality, so for the purposes of ordering algorithm we invert the edge
directions, i.e. if chunk A has a prev chunk of B then we say that the
edge is from B to A. This ensures that newer chunks get inserted at the
end (rather than the start).
Note: Calls to `add_node` and `add_edge` cannot overlap for the same room,
and so callers should perform some form of per-room locking when using
this class.
Args:
txn
room_id (str)
clock
rebalance_digits (int): When a rebalance is triggered we rebalance
in a range around the node, where the bounds are rounded to this
number of digits.
min_difference (int): A rebalance is triggered when the difference
between two successive orderings is less than the reciprocal of
this.
"""
def __init__(self,
txn, room_id, clock,
rebalance_digits=3,
min_difference=1000000):
self.txn = txn
self.room_id = room_id
self.clock = clock
self.rebalance_digits = rebalance_digits
self.min_difference = 1. / min_difference
def is_before(self, a, b):
"""Implements OrderedListStore"""
return self._get_order(a) < self._get_order(b)
def get_prev(self, node_id):
"""Implements OrderedListStore"""
order = self._get_order(node_id)
sql = """
SELECT chunk_id FROM chunk_linearized
WHERE ordering < ? AND room_id = ?
ORDER BY ordering DESC
LIMIT 1
"""
self.txn.execute(sql, (order, self.room_id,))
row = self.txn.fetchone()
if row:
return row[0]
return None
def get_next(self, node_id):
"""Implements OrderedListStore"""
order = self._get_order(node_id)
sql = """
SELECT chunk_id FROM chunk_linearized
WHERE ordering > ? AND room_id = ?
ORDER BY ordering ASC
LIMIT 1
"""
self.txn.execute(sql, (order, self.room_id,))
row = self.txn.fetchone()
if row:
return row[0]
return None
def _insert_before(self, node_id, target_id):
"""Implements OrderedListStore"""
rebalance = False # Set to true if we need to trigger a rebalance
if target_id:
target_order = self._get_order(target_id)
before_id = self.get_prev(target_id)
if before_id:
before_order = self._get_order(before_id)
new_order = (target_order + before_order) / 2.
rebalance = math.fabs(target_order - before_order) < self.min_difference
else:
new_order = math.floor(target_order) - 1
else:
# If target_id is None then we insert at the end.
self.txn.execute("""
SELECT COALESCE(MAX(ordering), 0) + 1
FROM chunk_linearized
WHERE room_id = ?
""", (self.room_id,))
new_order, = self.txn.fetchone()
self._insert(node_id, new_order)
if rebalance:
self._rebalance(node_id)
def _insert_after(self, node_id, target_id):
"""Implements OrderedListStore"""
rebalance = False # Set to true if we need to trigger a rebalance
if target_id:
target_order = self._get_order(target_id)
after_id = self.get_next(target_id)
if after_id:
after_order = self._get_order(after_id)
new_order = (target_order + after_order) / 2.
rebalance = math.fabs(target_order - after_order) < self.min_difference
else:
new_order = math.ceil(target_order) + 1
else:
# If target_id is None then we insert at the start.
self.txn.execute("""
SELECT COALESCE(MIN(ordering), 0) - 1
FROM chunk_linearized
WHERE room_id = ?
""", (self.room_id,))
new_order, = self.txn.fetchone()
self._insert(node_id, new_order)
if rebalance:
self._rebalance(node_id)
def get_nodes_with_edges_to(self, node_id):
"""Implements OrderedListStore"""
# Note that we use the inverse relation here
sql = """
SELECT l.ordering, l.chunk_id FROM chunk_graph AS g
INNER JOIN chunk_linearized AS l ON g.prev_id = l.chunk_id
WHERE g.chunk_id = ?
"""
self.txn.execute(sql, (node_id,))
return self.txn.fetchall()
def get_nodes_with_edges_from(self, node_id):
"""Implements OrderedListStore"""
# Note that we use the inverse relation here
sql = """
SELECT l.ordering, l.chunk_id FROM chunk_graph AS g
INNER JOIN chunk_linearized AS l ON g.chunk_id = l.chunk_id
WHERE g.prev_id = ?
"""
self.txn.execute(sql, (node_id,))
return self.txn.fetchall()
def _delete_ordering(self, node_id):
"""Implements OrderedListStore"""
SQLBaseStore._simple_delete_txn(
self.txn,
table="chunk_linearized",
keyvalues={"chunk_id": node_id},
)
def _add_edge_to_graph(self, source_id, target_id):
"""Implements OrderedListStore"""
# Note that we use the inverse relation
SQLBaseStore._simple_insert_txn(
self.txn,
table="chunk_graph",
values={"chunk_id": target_id, "prev_id": source_id}
)
def _insert(self, node_id, order):
"""Inserts the node with the given ordering.
"""
SQLBaseStore._simple_insert_txn(
self.txn,
table="chunk_linearized",
values={
"chunk_id": node_id,
"room_id": self.room_id,
"ordering": order,
}
)
def _get_order(self, node_id):
"""Get the ordering of the given node.
"""
return SQLBaseStore._simple_select_one_onecol_txn(
self.txn,
table="chunk_linearized",
keyvalues={"chunk_id": node_id},
retcol="ordering"
)
def _rebalance(self, node_id):
"""Rebalances the list around the given node to ensure that the
ordering floats don't get too small.
This works by finding a range that includes the given node, and
recalculating the ordering floats such that they're equidistant in
that range.
"""
logger.info("Rebalancing room %s, chunk %s", self.room_id, node_id)
with Measure(self.clock, "chunk_rebalance"):
# We pick the interval to try and minimise the number of decimal
# places, i.e. we round to nearest float with `rebalance_digits` and
# use that as one side of the interval
order = self._get_order(node_id)
a = round(order, self.rebalance_digits)
min_order = a - 10 ** -self.rebalance_digits
max_order = a + 10 ** -self.rebalance_digits
# Now we get all the nodes in the range. We add the minimum difference
# to the bounds to ensure that we don't accidentally move a node to be
# within the minimum difference of a node outside the range.
sql = """
SELECT chunk_id FROM chunk_linearized
WHERE ordering >= ? AND ordering <= ? AND room_id = ?
"""
self.txn.execute(sql, (
min_order - self.min_difference,
max_order + self.min_difference,
self.room_id,
))
chunk_ids = [c for c, in self.txn]
sql = """
UPDATE chunk_linearized
SET ordering = ?
WHERE chunk_id = ?
"""
step = (max_order - min_order) / len(chunk_ids)
self.txn.executemany(
sql,
(
((idx * step + min_order), chunk_id)
for idx, chunk_id in enumerate(chunk_ids)
)
)
rebalance_counter.inc()
+4 -2
View File
@@ -22,6 +22,8 @@ from . import background_updates
from synapse.util.caches import CACHE_SIZE_FACTOR
from six import iteritems
logger = logging.getLogger(__name__)
@@ -99,7 +101,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
def _update_client_ips_batch_txn(self, txn, to_update):
self.database_engine.lock_table(txn, "user_ips")
for entry in to_update.iteritems():
for entry in iteritems(to_update):
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
self._simple_upsert_txn(
@@ -231,5 +233,5 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"user_agent": user_agent,
"last_seen": last_seen,
}
for (access_token, ip), (user_agent, last_seen) in results.iteritems()
for (access_token, ip), (user_agent, last_seen) in iteritems(results)
))
+5 -4
View File
@@ -21,6 +21,7 @@ from synapse.api.errors import StoreError
from ._base import SQLBaseStore, Cache
from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
from six import itervalues, iteritems
logger = logging.getLogger(__name__)
@@ -360,7 +361,7 @@ class DeviceStore(SQLBaseStore):
return (now_stream_id, [])
if len(query_map) >= 20:
now_stream_id = max(stream_id for stream_id in query_map.itervalues())
now_stream_id = max(stream_id for stream_id in itervalues(query_map))
devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True
@@ -373,13 +374,13 @@ class DeviceStore(SQLBaseStore):
"""
results = []
for user_id, user_devices in devices.iteritems():
for user_id, user_devices in iteritems(devices):
# The prev_id for the first row is always the last row before
# `from_stream_id`
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
rows = txn.fetchall()
prev_id = rows[0][0]
for device_id, device in user_devices.iteritems():
for device_id, device in iteritems(user_devices):
stream_id = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
@@ -483,7 +484,7 @@ class DeviceStore(SQLBaseStore):
if devices:
user_devices = devices[user_id]
results = []
for device_id, device in user_devices.iteritems():
for device_id, device in iteritems(user_devices):
result = {
"device_id": device_id,
}
+4 -2
View File
@@ -21,6 +21,8 @@ import simplejson as json
from ._base import SQLBaseStore
from six import iteritems
class EndToEndKeyStore(SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
@@ -81,8 +83,8 @@ class EndToEndKeyStore(SQLBaseStore):
query_list, include_all_devices,
)
for user_id, device_keys in results.iteritems():
for device_id, device_info in device_keys.iteritems():
for user_id, device_keys in iteritems(results):
for device_id, device_info in iteritems(device_keys):
device_info["keys"] = json.loads(device_info.pop("key_json"))
defer.returnValue(results)
+3 -1
View File
@@ -22,6 +22,8 @@ from synapse.util.caches.descriptors import cachedInlineCallbacks
import logging
import simplejson as json
from six import iteritems
logger = logging.getLogger(__name__)
@@ -420,7 +422,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.executemany(sql, (
_gen_entry(user_id, actions)
for user_id, actions in user_id_actions.iteritems()
for user_id, actions in iteritems(user_id_actions)
))
return self.runInteraction(
+211 -16
View File
@@ -23,6 +23,7 @@ import simplejson as json
from twisted.internet import defer
from synapse.storage.events_worker import EventsWorkerStore
from synapse.storage.chunk_ordered_table import ChunkDBOrderedListStore
from synapse.util.async import ObservableDeferred
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.logcontext import (
@@ -232,6 +233,15 @@ class EventsStore(EventsWorkerStore):
psql_only=True,
)
self.register_background_index_update(
"events_chunk_index",
index_name="events_chunk_index",
table="events",
columns=["room_id", "chunk_id", "topological_ordering", "stream_ordering"],
unique=True,
psql_only=True,
)
self._event_persist_queue = _EventPeristenceQueue()
self._state_resolution_handler = hs.get_state_resolution_handler()
@@ -1010,13 +1020,20 @@ class EventsStore(EventsWorkerStore):
}
)
sql = (
"UPDATE events SET outlier = ?"
" WHERE event_id = ?"
chunk_id, topo = self._insert_into_chunk_txn(
txn, event.room_id, event.event_id,
[eid for eid, _ in event.prev_events],
)
txn.execute(
sql,
(False, event.event_id,)
self._simple_update_txn(
txn,
table="events",
keyvalues={"event_id": event.event_id},
updatevalues={
"outlier": False,
"chunk_id": chunk_id,
"topological_ordering": topo,
},
)
# Update the event_backward_extremities table now that this
@@ -1099,13 +1116,22 @@ class EventsStore(EventsWorkerStore):
],
)
self._simple_insert_many_txn(
txn,
table="events",
values=[
{
for event, _ in events_and_contexts:
if event.internal_metadata.is_outlier():
chunk_id, topo = None, 0
else:
chunk_id, topo = self._insert_into_chunk_txn(
txn, event.room_id, event.event_id,
[eid for eid, _ in event.prev_events],
)
self._simple_insert_txn(
txn,
table="events",
values={
"stream_ordering": event.internal_metadata.stream_ordering,
"topological_ordering": event.depth,
"chunk_id": chunk_id,
"topological_ordering": topo,
"depth": event.depth,
"event_id": event.event_id,
"room_id": event.room_id,
@@ -1120,10 +1146,8 @@ class EventsStore(EventsWorkerStore):
"url" in event.content
and isinstance(event.content["url"], basestring)
),
}
for event, _ in events_and_contexts
],
)
},
)
def _store_rejected_events_txn(self, txn, events_and_contexts):
"""Add rows to the 'rejections' table for received events which were
@@ -1335,6 +1359,177 @@ class EventsStore(EventsWorkerStore):
(event.event_id, event.redacts)
)
def _insert_into_chunk_txn(self, txn, room_id, event_id, prev_event_ids):
"""Computes the chunk ID and topological ordering for an event and
handles updating chunk_graph table.
Args:
txn,
room_id (str)
event_id (str)
prev_event_ids (list[str])
Returns:
tuple[int, int]: Returns the chunk_id, topological_ordering for
the event
"""
# We calculate the chunk for an event using the following rules:
#
# 1. If all prev events have the same chunk ID then use that chunk ID
# 2. If we have none of the prev events but do have events pointing to
# the event, then we use their chunk ID if:
# - They're all in the same chunk, and
# - All their prev events match the events being inserted
# 3. Otherwise, create a new chunk and use that
# Set of chunks that the event refers to. Includes None if there were
# prev events that we don't have (or don't have a chunk for)
prev_chunk_ids = set()
for eid in prev_event_ids:
chunk_id = self._simple_select_one_onecol_txn(
txn,
table="events",
keyvalues={"event_id": eid},
retcol="chunk_id",
allow_none=True,
)
prev_chunk_ids.add(chunk_id)
forward_events = self._simple_select_onecol_txn(
txn,
table="event_edges",
keyvalues={
"prev_event_id": event_id,
"is_state": False,
},
retcol="event_id",
)
# Set of chunks that refer to this event.
forward_chunk_ids = set()
# All the prev_events of events in `forward_events`.
# Note that this will include the current event_id.
sibling_events = set()
for eid in forward_events:
chunk_id = self._simple_select_one_onecol_txn(
txn,
table="events",
keyvalues={"event_id": eid},
retcol="chunk_id",
allow_none=True,
)
if chunk_id is not None:
# chunk_id can be None if it's an outlier
forward_chunk_ids.add(chunk_id)
pes = self._simple_select_onecol_txn(
txn,
table="event_edges",
keyvalues={
"event_id": eid,
"is_state": False,
},
retcol="prev_event_id",
)
sibling_events.update(pes)
table = ChunkDBOrderedListStore(
txn, room_id, self.clock,
)
# If there is only one previous chunk (and that isn't None), then this
# satisfies condition one.
if len(prev_chunk_ids) == 1 and None not in prev_chunk_ids:
chunk_id = list(prev_chunk_ids)[0]
# This event is being inserted at the end of the chunk
new_topo = self._simple_select_one_onecol_txn(
txn,
table="events",
keyvalues={
"room_id": room_id,
"chunk_id": chunk_id,
},
retcol="MAX(topological_ordering)",
)
new_topo += 1
# If there is only one forward chunk and only one sibling event (which
# would be the given event), then this satisfies condition two.
elif len(forward_chunk_ids) == 1 and len(sibling_events) == 1:
chunk_id = list(forward_chunk_ids)[0]
# This event is being inserted at the start of the chunk
new_topo = self._simple_select_one_onecol_txn(
txn,
table="events",
keyvalues={
"room_id": room_id,
"chunk_id": chunk_id,
},
retcol="MIN(topological_ordering)",
)
new_topo -= 1
else:
chunk_id = self._chunk_id_gen.get_next()
new_topo = 0
# We've generated a new chunk, so we have to tell the
# ChunkDBOrderedListStore about that.
table.add_node(chunk_id)
# We need to now update the database with any new edges between chunks
current_prev_ids = self._simple_select_onecol_txn(
txn,
table="chunk_graph",
keyvalues={
"chunk_id": chunk_id,
},
retcol="prev_id",
)
current_forward_ids = self._simple_select_onecol_txn(
txn,
table="chunk_graph",
keyvalues={
"prev_id": chunk_id,
},
retcol="chunk_id",
)
for pid in prev_chunk_ids:
if pid is not None and pid not in current_prev_ids and pid != chunk_id:
# Note that the edge direction is reversed than what you might
# expect. See ChunkDBOrderedListStore for more details.
table.add_edge(pid, chunk_id)
for fid in forward_chunk_ids:
# Note that the edge direction is reversed than what you might
# expect. See ChunkDBOrderedListStore for more details.
if fid not in current_forward_ids and fid != chunk_id:
table.add_edge(chunk_id, fid)
# We now need to update the backwards extremities for the chunks.
txn.executemany("""
INSERT INTO chunk_backwards_extremities (chunk_id, event_id)
SELECT ?, ? WHERE ? NOT IN (SELECT event_id FROM events)
""", [(chunk_id, eid, eid) for eid in prev_event_ids])
self._simple_delete_txn(
txn,
table="chunk_backwards_extremities",
keyvalues={"event_id": event_id},
)
return chunk_id, new_topo
@defer.inlineCallbacks
def have_events_in_timeline(self, event_ids):
"""Given a list of event ids, check if we have already processed and
+1 -1
View File
@@ -337,7 +337,7 @@ class EventsWorkerStore(SQLBaseStore):
def _fetch_event_rows(self, txn, events):
rows = []
N = 200
for i in range(1 + len(events) / N):
for i in range(1 + len(events) // N):
evs = events[i * N:(i + 1) * N]
if not evs:
break
+1 -1
View File
@@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore):
desc="get_user_filter",
)
defer.returnValue(json.loads(str(def_json).decode("utf-8")))
defer.returnValue(json.loads(bytes(def_json).decode("utf-8")))
def add_user_filter(self, user_localpart, user_filter):
def_json = encode_canonical_json(user_filter)
+1 -1
View File
@@ -92,7 +92,7 @@ class KeyStore(SQLBaseStore):
if verify_key_bytes:
defer.returnValue(decode_verify_key_bytes(
key_id, str(verify_key_bytes)
key_id, bytes(verify_key_bytes)
))
@defer.inlineCallbacks
+6 -4
View File
@@ -30,6 +30,8 @@ from synapse.types import get_domain_from_id
import logging
import simplejson as json
from six import itervalues, iteritems
logger = logging.getLogger(__name__)
@@ -272,7 +274,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
users_in_room = {}
member_event_ids = [
e_id
for key, e_id in current_state_ids.iteritems()
for key, e_id in iteritems(current_state_ids)
if key[0] == EventTypes.Member
]
@@ -289,7 +291,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
users_in_room = dict(prev_res)
member_event_ids = [
e_id
for key, e_id in context.delta_ids.iteritems()
for key, e_id in iteritems(context.delta_ids)
if key[0] == EventTypes.Member
]
for etype, state_key in context.delta_ids:
@@ -741,7 +743,7 @@ class _JoinedHostsCache(object):
if state_entry.state_group == self.state_group:
pass
elif state_entry.prev_group == self.state_group:
for (typ, state_key), event_id in state_entry.delta_ids.iteritems():
for (typ, state_key), event_id in iteritems(state_entry.delta_ids):
if typ != EventTypes.Member:
continue
@@ -771,7 +773,7 @@ class _JoinedHostsCache(object):
self.state_group = state_entry.state_group
else:
self.state_group = object()
self._len = sum(len(v) for v in self.hosts_to_joined_users.itervalues())
self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users))
defer.returnValue(frozenset(self.hosts_to_joined_users))
def __len__(self):
@@ -0,0 +1,49 @@
/* Copyright 2018 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE events ADD COLUMN chunk_id BIGINT;
INSERT INTO background_updates (update_name, progress_json) VALUES
('events_chunk_index', '{}');
-- Stores how chunks of graph relate to each other
CREATE TABLE chunk_graph (
chunk_id BIGINT NOT NULL,
prev_id BIGINT NOT NULL
);
CREATE UNIQUE INDEX chunk_graph_id ON chunk_graph (chunk_id, prev_id);
CREATE INDEX chunk_graph_prev_id ON chunk_graph (prev_id);
-- The extremities in each chunk. Note that these are pointing to events that
-- we don't have, rather than boundary between chunks.
CREATE TABLE chunk_backwards_extremities (
chunk_id BIGINT NOT NULL,
event_id TEXT NOT NULL
);
CREATE INDEX chunk_backwards_extremities_id ON chunk_backwards_extremities(chunk_id, event_id);
CREATE INDEX chunk_backwards_extremities_event_id ON chunk_backwards_extremities(event_id);
-- Maintains an absolute ordering of chunks. Gets updated when we see new
-- edges between chunks.
CREATE TABLE chunk_linearized (
chunk_id BIGINT NOT NULL,
room_id TEXT NOT NULL,
ordering DOUBLE PRECISION NOT NULL
);
CREATE UNIQUE INDEX chunk_linearized_id ON chunk_linearized (chunk_id);
CREATE INDEX chunk_linearized_ordering ON chunk_linearized (room_id, ordering);
@@ -14,7 +14,12 @@
*/
CREATE TABLE IF NOT EXISTS events(
-- Defines an ordering used to stream new events to clients. Events
-- fetched via backfill have negative values.
stream_ordering INTEGER PRIMARY KEY,
-- Defines a topological ordering of events within a chunk
-- (The concept of a chunk was added in later schemas, this used to
-- be set to the same value as the `depth` field in an event)
topological_ordering BIGINT NOT NULL,
event_id TEXT NOT NULL,
type TEXT NOT NULL,
+192 -54
View File
@@ -41,6 +41,7 @@ from synapse.storage.events import EventsWorkerStore
from synapse.types import RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.storage.chunk_ordered_table import ChunkDBOrderedListStore
from synapse.storage.engines import PostgresEngine
import abc
@@ -62,24 +63,25 @@ _TOPOLOGICAL_TOKEN = "topological"
# Used as return values for pagination APIs
_EventDictReturn = namedtuple("_EventDictReturn", (
"event_id", "topological_ordering", "stream_ordering",
"event_id", "chunk_id", "topological_ordering", "stream_ordering",
))
def lower_bound(token, engine, inclusive=False):
inclusive = "=" if inclusive else ""
if token.topological is None:
if token.chunk is None:
return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering")
else:
if isinstance(engine, PostgresEngine):
# Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
# as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
# use the later form when running against postgres.
return "((%d,%d) <%s (%s,%s))" % (
token.topological, token.stream, inclusive,
return "(chunk_id = %d AND (%d,%d) <%s (%s,%s))" % (
token.chunk, token.topological, token.stream, inclusive,
"topological_ordering", "stream_ordering",
)
return "(%d < %s OR (%d = %s AND %d <%s %s))" % (
return "(chunk_id = %d AND (%d < %s OR (%d = %s AND %d <%s %s)))" % (
token.chunk,
token.topological, "topological_ordering",
token.topological, "topological_ordering",
token.stream, inclusive, "stream_ordering",
@@ -88,18 +90,19 @@ def lower_bound(token, engine, inclusive=False):
def upper_bound(token, engine, inclusive=True):
inclusive = "=" if inclusive else ""
if token.topological is None:
if token.chunk is None:
return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering")
else:
if isinstance(engine, PostgresEngine):
# Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well
# as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we
# use the later form when running against postgres.
return "((%d,%d) >%s (%s,%s))" % (
token.topological, token.stream, inclusive,
return "(chunk_id = %d AND (%d,%d) >%s (%s,%s))" % (
token.chunk, token.topological, token.stream, inclusive,
"topological_ordering", "stream_ordering",
)
return "(%d > %s OR (%d = %s AND %d >%s %s))" % (
return "(chunk_id = %d AND (%d > %s OR (%d = %s AND %d >%s %s)))" % (
token.chunk,
token.topological, "topological_ordering",
token.topological, "topological_ordering",
token.stream, inclusive, "stream_ordering",
@@ -275,7 +278,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
) % (order,)
txn.execute(sql, (room_id, from_id, to_id, limit))
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
rows = [_EventDictReturn(row[0], None, None, row[1]) for row in txn]
return rows
rows = yield self.runInteraction("get_room_events_stream_for_room", f)
@@ -325,7 +328,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
txn.execute(sql, (user_id, from_id, to_id,))
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
rows = [_EventDictReturn(row[0], None, None, row[1]) for row in txn]
return rows
@@ -392,7 +395,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token = RoomStreamToken.parse(end_token)
rows, token = yield self.runInteraction(
rows, token, _ = yield self.runInteraction(
"get_recent_event_ids_for_room", self._paginate_room_events_txn,
room_id, from_token=end_token, limit=limit,
)
@@ -437,15 +440,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
`room_id` causes it to return the current room specific topological
token.
"""
token = yield self.get_room_max_stream_ordering()
if room_id is None:
defer.returnValue("s%d" % (token,))
token = yield self.get_room_max_stream_ordering()
defer.returnValue(str(RoomStreamToken(None, None, token)))
else:
topo = yield self.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn,
token = yield self.runInteraction(
"get_room_events_max_id", self._get_topological_token_for_room_txn,
room_id,
)
defer.returnValue("t%d-%d" % (topo, token))
if not token:
raise Exception("Server not in room")
defer.returnValue(str(token))
def get_stream_token_for_event(self, event_id):
"""The stream token for an event
@@ -460,7 +465,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
table="events",
keyvalues={"event_id": event_id},
retcol="stream_ordering",
).addCallback(lambda row: "s%d" % (row,))
).addCallback(lambda row: str(RoomStreamToken(None, None, row)))
def get_topological_token_for_event(self, event_id):
"""The stream token for an event
@@ -469,16 +474,34 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Raises:
StoreError if the event wasn't in the database.
Returns:
A deferred "t%d-%d" topological token.
A deferred topological token.
"""
return self._simple_select_one(
table="events",
keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"),
retcols=("stream_ordering", "topological_ordering", "chunk_id"),
desc="get_topological_token_for_event",
).addCallback(lambda row: "t%d-%d" % (
row["topological_ordering"], row["stream_ordering"],)
)
).addCallback(lambda row: str(RoomStreamToken(
row["chunk_id"],
row["topological_ordering"],
row["stream_ordering"],
)))
def _get_topological_token_for_room_txn(self, txn, room_id):
sql = """
SELECT chunk_id, topological_ordering, stream_ordering
FROM events
NATURAL JOIN event_forward_extremities
WHERE room_id = ?
ORDER BY stream_ordering DESC
LIMIT 1
"""
txn.execute(sql, (room_id,))
row = txn.fetchone()
if row:
c, t, s = row
return RoomStreamToken(c, t, s)
return None
def get_max_topological_token(self, room_id, stream_key):
sql = (
@@ -515,18 +538,20 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
null topological_ordering.
"""
for event, row in zip(events, rows):
chunk = row.chunk_id
topo = row.topological_ordering
stream = row.stream_ordering
if topo_order and row.topological_ordering:
topo = row.topological_ordering
else:
topo = None
internal = event.internal_metadata
internal.before = str(RoomStreamToken(topo, stream - 1))
internal.after = str(RoomStreamToken(topo, stream))
internal.order = (
int(topo) if topo else 0,
int(stream),
)
internal.stream_ordering = stream
if topo_order:
internal.before = str(RoomStreamToken(chunk, topo, stream - 1))
internal.after = str(RoomStreamToken(chunk, topo, stream))
else:
internal.before = str(RoomStreamToken(None, None, stream - 1))
internal.after = str(RoomStreamToken(None, None, stream))
@defer.inlineCallbacks
def get_events_around(self, room_id, event_id, before_limit, after_limit):
@@ -586,27 +611,29 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"event_id": event_id,
"room_id": room_id,
},
retcols=["stream_ordering", "topological_ordering"],
retcols=["stream_ordering", "topological_ordering", "chunk_id"],
)
# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
before_token = RoomStreamToken(
results["topological_ordering"] - 1,
results["stream_ordering"],
results["chunk_id"],
results["topological_ordering"],
results["stream_ordering"] - 1,
)
after_token = RoomStreamToken(
results["chunk_id"],
results["topological_ordering"],
results["stream_ordering"],
)
rows, start_token = self._paginate_room_events_txn(
rows, start_token, _ = self._paginate_room_events_txn(
txn, room_id, before_token, direction='b', limit=before_limit,
)
events_before = [r.event_id for r in rows]
rows, end_token = self._paginate_room_events_txn(
rows, end_token, _ = self._paginate_room_events_txn(
txn, room_id, after_token, direction='f', limit=after_limit,
)
events_after = [r.event_id for r in rows]
@@ -689,12 +716,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
those that match the filter.
Returns:
Deferred[tuple[list[_EventDictReturn], str]]: Returns the results
as a list of _EventDictReturn and a token that points to the end
of the result set.
Deferred[tuple[list[_EventDictReturn], str, list[int]]: Returns
the results as a list of _EventDictReturn, a token that points to
the end of the result set, and a list of chunks iterated over.
"""
assert int(limit) >= 0
limit = int(limit) # Sometimes we are passed a string from somewhere
assert limit >= 0
# There are two modes of fetching events: by stream order or by
# topological order. This is determined by whether the from_token is a
# stream or topological token. If stream then we can simply do a select
# ordered by stream_ordering column. If topological, then we need to
# fetch events from one chunk at a time until we hit the limit.
# For backwards compatibility we need to check if the token has a
# topological part but no chunk part. If that's the case we can use the
# stream part to generate an appropriate topological token.
if from_token.chunk is None and from_token.topological is not None:
res = self._simple_select_one_txn(
txn,
table="events",
keyvalues={
"stream_ordering": from_token.stream,
},
retcols=(
"chunk_id",
"topological_ordering",
"stream_ordering",
),
allow_none=True,
)
if res and res["chunk_id"] is not None:
from_token = RoomStreamToken(
res["chunk_id"],
res["topological_ordering"],
res["stream_ordering"],
)
# Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence
@@ -725,10 +783,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
bounds += " AND " + filter_clause
args.extend(filter_args)
args.append(int(limit))
args.append(limit)
sql = (
"SELECT event_id, topological_ordering, stream_ordering"
"SELECT event_id, chunk_id, topological_ordering, stream_ordering"
" FROM events"
" WHERE outlier = ? AND room_id = ? AND %(bounds)s"
" ORDER BY topological_ordering %(order)s,"
@@ -740,9 +798,64 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.execute(sql, args)
rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
rows = [_EventDictReturn(*row) for row in txn]
# If we are paginating topologically and we haven't hit the limit on
# number of events then we need to fetch events from the previous or
# next chunk.
iterated_chunks = []
chunk_id = None
if rows:
chunk_id = rows[-1].chunk_id
iterated_chunks = [r.chunk_id for r in rows]
elif from_token.chunk:
chunk_id = from_token.chunk
iterated_chunks = [chunk_id]
table = ChunkDBOrderedListStore(
txn, room_id, self.clock,
)
if filter_clause:
filter_clause = "AND " + filter_clause
sql = (
"SELECT event_id, chunk_id, topological_ordering, stream_ordering"
" FROM events"
" WHERE outlier = ? AND room_id = ? %(filter_clause)s"
" ORDER BY topological_ordering %(order)s,"
" stream_ordering %(order)s LIMIT ?"
) % {
"filter_clause": filter_clause,
"order": order,
}
args = [False, room_id] + filter_args + [limit]
while chunk_id and (limit <= 0 or len(rows) < limit):
if chunk_id not in iterated_chunks:
iterated_chunks.append(chunk_id)
if direction == 'b':
chunk_id = table.get_prev(chunk_id)
else:
chunk_id = table.get_next(chunk_id)
if chunk_id is None:
break
txn.execute(sql, args)
new_rows = [_EventDictReturn(*row) for row in txn]
rows.extend(new_rows)
# We may have inserted more rows than necessary in the loop above
rows = rows[:limit]
if rows:
chunk = rows[-1].chunk_id
topo = rows[-1].topological_ordering
toke = rows[-1].stream_ordering
if direction == 'b':
@@ -752,12 +865,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# when we are going backwards so we subtract one from the
# stream part.
toke -= 1
next_token = RoomStreamToken(topo, toke)
next_token = RoomStreamToken(chunk, topo, toke)
else:
# TODO (erikj): We should work out what to do here instead.
next_token = to_token if to_token else from_token
return rows, str(next_token),
return rows, str(next_token), iterated_chunks,
@defer.inlineCallbacks
def paginate_room_events(self, room_id, from_key, to_key=None,
@@ -777,18 +890,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
those that match the filter.
Returns:
tuple[list[dict], str]: Returns the results as a list of dicts and
a token that points to the end of the result set. The dicts have
the keys "event_id", "topological_ordering" and "stream_orderign".
tuple[list[dict], str, list[str]]: Returns the results as a list of
dicts, a token that points to the end of the result set, and a list
of backwards extremities. The dicts have the keys "event_id",
"topological_ordering" and "stream_ordering".
"""
from_key = RoomStreamToken.parse(from_key)
if to_key:
to_key = RoomStreamToken.parse(to_key)
rows, token = yield self.runInteraction(
"paginate_room_events", self._paginate_room_events_txn,
room_id, from_key, to_key, direction, limit, event_filter,
def _do_paginate_room_events(txn):
rows, token, chunks = self._paginate_room_events_txn(
txn, room_id, from_key, to_key, direction, limit, event_filter,
)
# We now fetch the extremities by fetching the extremities for
# each chunk we iterated over.
extremities = []
seen = set()
for chunk_id in chunks:
if chunk_id in seen:
continue
seen.add(chunk_id)
event_ids = self._simple_select_onecol_txn(
txn,
table="chunk_backwards_extremities",
keyvalues={"chunk_id": chunk_id},
retcol="event_id"
)
extremities.extend(e for e in event_ids if e not in extremities)
return rows, token, extremities
rows, token, extremities = yield self.runInteraction(
"paginate_room_events", _do_paginate_room_events,
)
events = yield self._get_events(
@@ -798,7 +936,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self._set_before_and_after(events, rows)
defer.returnValue((events, token))
defer.returnValue((events, token, extremities))
class StreamStore(StreamWorkerStore):
+29 -12
View File
@@ -306,7 +306,7 @@ StreamToken.START = StreamToken(
)
class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
class RoomStreamToken(namedtuple("_StreamToken", ("chunk", "topological", "stream"))):
"""Tokens are positions between events. The token "s1" comes after event 1.
s0 s1
@@ -319,14 +319,18 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
When traversing the live event stream events are ordered by when they
arrived at the homeserver.
When traversing historic events the events are ordered by their depth in
the event graph "topological_ordering" and then by when they arrived at the
homeserver "stream_ordering".
When traversing historic events the events are ordered by the topological
ordering of the room graph. This is done using event chunks and the
`topological_ordering` column.
Live tokens start with an "s" followed by the "stream_ordering" id of the
event it comes after. Historic tokens start with a "t" followed by the
"topological_ordering" id of the event it comes after, followed by "-",
followed by the "stream_ordering" id of the event it comes after.
Live tokens start with an 's' and include the stream_ordering of the event
it comes after. Historic tokens start with a 'c' and include the chunk ID,
topological ordering and stream ordering of the event it comes after.
(In previous versions, when chunks were not implemented, the historic tokens
started with 't' and included the topological and stream ordering. These
tokens can be roughly converted to the new format by looking up the chunk
and topological ordering of the event with the same stream ordering).
"""
__slots__ = []
@@ -334,10 +338,19 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
def parse(cls, string):
try:
if string[0] == 's':
return cls(topological=None, stream=int(string[1:]))
if string[0] == 't':
return cls(chunk=None, topological=None, stream=int(string[1:]))
if string[0] == 't': # For backwards compat with older tokens.
parts = string[1:].split('-', 1)
return cls(topological=int(parts[0]), stream=int(parts[1]))
return cls(chunk=None, topological=int(parts[0]), stream=int(parts[1]))
if string[0] == 'c':
# We use '~' as both stream ordering and topological ordering
# can be negative, so we can't use '-'
parts = string[1:].split('~', 2)
return cls(
chunk=int(parts[0]),
topological=int(parts[1]),
stream=int(parts[2]),
)
except Exception:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
@@ -346,12 +359,16 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
def parse_stream_token(cls, string):
try:
if string[0] == 's':
return cls(topological=None, stream=int(string[1:]))
return cls(chunk=None, topological=None, stream=int(string[1:]))
except Exception:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
def __str__(self):
if self.chunk is not None:
# We use '~' as both stream ordering and topological ordering
# can be negative, so we can't use '-'
return "c%d~%d~%d" % (self.chunk, self.topological, self.stream)
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
else:
+18
View File
@@ -20,6 +20,8 @@ from twisted.internet import defer, reactor, task
import time
import logging
from itertools import islice
logger = logging.getLogger(__name__)
@@ -79,3 +81,19 @@ class Clock(object):
except Exception:
if not ignore_errs:
raise
def batch_iter(iterable, size):
"""batch an iterable up into tuples with a maximum size
Args:
iterable (iterable): the iterable to slice
size (int): the maximum batch size
Returns:
an iterator over the chunks
"""
# make sure we can deal with iterables like lists too
sourceiter = iter(iterable)
# call islice until it returns an empty tuple
return iter(lambda: tuple(islice(sourceiter, size)), ())
+6 -1
View File
@@ -16,6 +16,9 @@
import synapse.metrics
import os
from six.moves import intern
import six
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5))
metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
@@ -66,7 +69,9 @@ def intern_string(string):
return None
try:
string = string.encode("ascii")
if six.PY2:
string = string.encode("ascii")
return intern(string)
except UnicodeEncodeError:
return string
+337
View File
@@ -0,0 +1,337 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains an implementation of the Katriel-Bodlaender algorithm,
which is used to do online topological ordering of graphs.
Note that the ordering derived from the graph is such that the source node of
an edge comes before the target node of the edge, i.e. a graph of A -> B -> C
would produce the ordering [A, B, C].
This ordering is therefore opposite to what one might expect when considering
the room DAG, as newer messages would be added to the start rather than the
end.
***The ChunkDBOrderedListStore therefore inverts the direction of edges***
See:
A tight analysis of the KatrielBodlaender algorithm for online topological
ordering
Hsiao-Fei Liua and Kun-Mao Chao
https://www.sciencedirect.com/science/article/pii/S0304397507006573
and:
Online Topological Ordering
Irit Katriel and Hans L. Bodlaender
http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.78.7933 )
"""
from abc import ABCMeta, abstractmethod
class OrderedListStore(object):
"""An abstract base class that is used to store a graph and maintain a
topological consistent, total ordering.
Internally this uses the Katriel-Bodlaender algorithm, which requires the
store expose an interface for the total ordering that supports:
- Insertion of the node into the ordering either immediately before or
after another node.
- Deletion of the node from the ordering
- Comparing the relative ordering of two arbitary nodes
- Get the node immediately before or after a given node in the ordering
It also needs to be able to interact with the graph in the following ways:
- Query the number of edges from a node in the graph
- Query the number of edges into a node in the graph
- Add an edge to the graph
Users of subclasses should call `add_node` and `add_edge` whenever editing
the graph. The total ordering exposed will remain constant until the next
call to one of these methods.
Note: Calls to `add_node` and `add_edge` cannot overlap, and so callers
should perform some form of locking.
"""
__metaclass__ = ABCMeta
def add_node(self, node_id):
"""Adds a node to the graph.
Args:
node_id (str)
"""
self._insert_before(node_id, None)
def add_edge(self, source, target):
"""Adds a new edge to the graph and updates the ordering.
See module level docs.
Note that both the source and target nodes must have been inserted into
the store (at an arbitrary position) already.
Args:
source (str): The source node of the new edge
target (str): The target node of the new edge
"""
# The following is the Katriel-Bodlaender algorithm.
to_s = []
from_t = []
to_s_neighbours = []
from_t_neighbours = []
to_s_indegree = 0
from_t_outdegree = 0
s = source
t = target
while s and t and not self.is_before(s, t):
m_s = to_s_indegree
m_t = from_t_outdegree
# These functions return a tuple where the first term is a float
# that can be used to order the the list of neighbours.
# These are valid until the next write
pe_s = self.get_nodes_with_edges_to(s)
fe_t = self.get_nodes_with_edges_from(t)
l_s = len(pe_s)
l_t = len(fe_t)
if m_s + l_s <= m_t + l_t:
to_s.append(s)
to_s_neighbours.extend(pe_s)
to_s_indegree += l_s
if to_s_neighbours:
to_s_neighbours.sort()
_, s = to_s_neighbours.pop()
else:
s = None
if m_s + l_s >= m_t + l_t:
from_t.append(t)
from_t_neighbours.extend(fe_t)
from_t_outdegree += l_t
if from_t_neighbours:
from_t_neighbours.sort(reverse=True)
_, t = from_t_neighbours.pop()
else:
t = None
if s is None:
s = self.get_prev(target)
if t is None:
t = self.get_next(source)
while to_s:
s1 = to_s.pop()
self._delete_ordering(s1)
self._insert_after(s1, s)
s = s1
while from_t:
t1 = from_t.pop()
self._delete_ordering(t1)
self._insert_before(t1, t)
t = t1
self._add_edge_to_graph(source, target)
@abstractmethod
def is_before(self, first_node, second_node):
"""Returns whether the first node is before the second node.
Args:
first_node (str)
second_node (str)
Returns:
bool: True if first_node is before second_node
"""
pass
@abstractmethod
def get_prev(self, node_id):
"""Gets the node immediately before the given node in the topological
ordering.
Args:
node_id (str)
Returns:
str|None: A node ID or None if no preceding node exists
"""
pass
@abstractmethod
def get_next(self, node_id):
"""Gets the node immediately after the given node in the topological
ordering.
Args:
node_id (str)
Returns:
str|None: A node ID or None if no proceding node exists
"""
pass
@abstractmethod
def get_nodes_with_edges_to(self, node_id):
"""Get all nodes with edges to the given node
Args:
node_id (str)
Returns:
list[tuple[float, str]]: Returns a list of tuple of an ordering
term and the node ID. The ordering term can be used to sort the
returned list.
The ordering is valid until subsequent calls to `add_edge`
functions
"""
pass
@abstractmethod
def get_nodes_with_edges_from(self, node_id):
"""Get all nodes with edges from the given node
Args:
node_id (str)
Returns:
list[tuple[float, str]]: Returns a list of tuple of an ordering
term and the node ID. The ordering term can be used to sort the
returned list.
The ordering is valid until subsequent calls to `add_edge`
functions
"""
pass
@abstractmethod
def _insert_before(self, node_id, target_id):
"""Inserts node immediately before target node.
If target_id is None then the node is inserted at the end of the list
Args:
node_id (str)
target_id (str|None)
"""
pass
@abstractmethod
def _insert_after(self, node_id, target_id):
"""Inserts node immediately after target node.
If target_id is None then the node is inserted at the start of the list
Args:
node_id (str)
target_id (str|None)
"""
pass
@abstractmethod
def _delete_ordering(self, node_id):
"""Deletes the given node from the ordered list (but not the graph).
Used when we want to reinsert it into a different position
Args:
node_id (str)
"""
pass
@abstractmethod
def _add_edge_to_graph(self, source_id, target_id):
"""Adds an edge to the graph from source to target.
Does not update ordering.
Args:
source_id (str)
target_id (str)
"""
pass
class InMemoryOrderedListStore(OrderedListStore):
"""An in memory OrderedListStore
"""
def __init__(self):
# The ordered list of nodes
self.list = []
# Map from node to set of nodes that it references
self.edges_from = {}
# Map from node to set of nodes that it is referenced by
self.edges_to = {}
def is_before(self, first_node, second_node):
return self.list.index(first_node) < self.list.index(second_node)
def get_prev(self, node_id):
idx = self.list.index(node_id) - 1
if idx >= 0:
return self.list[idx]
else:
return None
def get_next(self, node_id):
idx = self.list.index(node_id) + 1
if idx < len(self.list):
return self.list[idx]
else:
return None
def _insert_before(self, node_id, target_id):
if target_id is not None:
idx = self.list.index(target_id)
self.list.insert(idx, node_id)
else:
self.list.append(node_id)
def _insert_after(self, node_id, target_id):
if target_id is not None:
idx = self.list.index(target_id) + 1
self.list.insert(idx, node_id)
else:
self.list.insert(0, node_id)
def _delete_ordering(self, node_id):
self.list.remove(node_id)
def get_nodes_with_edges_to(self, node_id):
to_nodes = self.edges_to.get(node_id, [])
return [(self.list.index(nid), nid) for nid in to_nodes]
def get_nodes_with_edges_from(self, node_id):
from_nodes = self.edges_from.get(node_id, [])
return [(self.list.index(nid), nid) for nid in from_nodes]
def _add_edge_to_graph(self, source_id, target_id):
self.edges_from.setdefault(source_id, set()).add(target_id)
self.edges_to.setdefault(target_id, set()).add(source_id)
@@ -0,0 +1,185 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
import random
import tests.unittest
import tests.utils
from synapse.storage.chunk_ordered_table import ChunkDBOrderedListStore
class ChunkLinearizerStoreTestCase(tests.unittest.TestCase):
"""Tests to ensure that the ordering and rebalancing functions of
ChunkDBOrderedListStore work as expected.
"""
def __init__(self, *args, **kwargs):
super(ChunkLinearizerStoreTestCase, self).__init__(*args, **kwargs)
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@defer.inlineCallbacks
def test_simple_insert_fetch(self):
room_id = "foo_room1"
def test_txn(txn):
table = ChunkDBOrderedListStore(
txn, room_id, self.clock, 1, 100,
)
table.add_node("A")
table._insert_after("B", "A")
table._insert_before("C", "A")
sql = """
SELECT chunk_id FROM chunk_linearized
WHERE room_id = ?
ORDER BY ordering ASC
"""
txn.execute(sql, (room_id,))
ordered = [r for r, in txn]
self.assertEqual(["C", "A", "B"], ordered)
yield self.store.runInteraction("test", test_txn)
@defer.inlineCallbacks
def test_many_insert_fetch(self):
room_id = "foo_room2"
def test_txn(txn):
table = ChunkDBOrderedListStore(
txn, room_id, self.clock, 1, 20,
)
nodes = [(i, "node_%d" % (i,)) for i in xrange(1, 1000)]
expected = [n for _, n in nodes]
already_inserted = []
random.shuffle(nodes)
while nodes:
i, node_id = nodes.pop()
if not already_inserted:
table.add_node(node_id)
else:
for j, target_id in already_inserted:
if j > i:
break
if j < i:
table._insert_after(node_id, target_id)
else:
table._insert_before(node_id, target_id)
already_inserted.append((i, node_id))
already_inserted.sort()
sql = """
SELECT chunk_id FROM chunk_linearized
WHERE room_id = ?
ORDER BY ordering ASC
"""
txn.execute(sql, (room_id,))
ordered = [r for r, in txn]
self.assertEqual(expected, ordered)
yield self.store.runInteraction("test", test_txn)
@defer.inlineCallbacks
def test_prepend_and_append(self):
room_id = "foo_room3"
def test_txn(txn):
table = ChunkDBOrderedListStore(
txn, room_id, self.clock, 1, 20,
)
table.add_node("a")
expected = ["a"]
for i in xrange(1, 1000):
node_id = "node_id_before_%d" % i
table._insert_before(node_id, expected[0])
expected.insert(0, node_id)
for i in xrange(1, 1000):
node_id = "node_id_after_%d" % i
table._insert_after(node_id, expected[-1])
expected.append(node_id)
sql = """
SELECT chunk_id FROM chunk_linearized
WHERE room_id = ?
ORDER BY ordering ASC
"""
txn.execute(sql, (room_id,))
ordered = [r for r, in txn]
self.assertEqual(expected, ordered)
yield self.store.runInteraction("test", test_txn)
@defer.inlineCallbacks
def test_worst_case(self):
room_id = "foo_room3"
def test_txn(txn):
table = ChunkDBOrderedListStore(
txn, room_id, self.clock, 1, 100,
)
table.add_node("a")
prev_node = "a"
expected_prefix = ["a"]
expected_suffix = []
for i in xrange(1, 100):
node_id = "node_id_%d" % i
if i % 2 == 0:
table._insert_before(node_id, prev_node)
expected_prefix.append(node_id)
else:
table._insert_after(node_id, prev_node)
expected_suffix.append(node_id)
prev_node = node_id
sql = """
SELECT chunk_id FROM chunk_linearized
WHERE room_id = ?
ORDER BY ordering ASC
"""
txn.execute(sql, (room_id,))
ordered = [r for r, in txn]
expected = expected_prefix + list(reversed(expected_suffix))
self.assertEqual(expected, ordered)
yield self.store.runInteraction("test", test_txn)
+58
View File
@@ -0,0 +1,58 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.katriel_bodlaender import InMemoryOrderedListStore
from tests import unittest
class KatrielBodlaenderTests(unittest.TestCase):
def test_simple_graph(self):
store = InMemoryOrderedListStore()
nodes = [
"node_1",
"node_2",
"node_3",
"node_4",
]
for node in nodes:
store.add_node(node)
store.add_edge("node_2", "node_3")
store.add_edge("node_1", "node_2")
store.add_edge("node_3", "node_4")
self.assertEqual(nodes, store.list)
def test_reverse_graph(self):
store = InMemoryOrderedListStore()
nodes = [
"node_1",
"node_2",
"node_3",
"node_4",
]
for node in nodes:
store.add_node(node)
store.add_edge("node_3", "node_2")
store.add_edge("node_2", "node_1")
store.add_edge("node_4", "node_3")
self.assertEqual(list(reversed(nodes)), store.list)