1
0

Compare commits

...

20 Commits

Author SHA1 Message Date
Erik Johnston
75bf48b905 Update tracer to give more information 2015-03-13 11:38:57 +00:00
Erik Johnston
d1ae594ae5 Add a utility class that can be used to generate a twisted deferred aware call graph 2015-03-12 16:52:02 +00:00
Erik Johnston
abaf47bbb6 Fix bug in logging. 2015-03-10 10:28:29 +00:00
Erik Johnston
045afd6b61 in_thread takes no arguments 2015-03-10 10:19:03 +00:00
Erik Johnston
98b867f7b7 Fix bug in logging. 2015-03-10 10:16:09 +00:00
Erik Johnston
e84fe3599b Merge pull request #105 from matrix-org/erikj-perf
Add a Twisted Service to synapse.app.homeserver
2015-03-10 10:02:26 +00:00
Erik Johnston
c37eceeb9e Split out the 'run' from 'setup' 2015-03-10 09:58:33 +00:00
Erik Johnston
b8a6692657 Add documentation. When starting via twistd respect soft_file_limit config option. 2015-03-10 09:39:42 +00:00
Erik Johnston
019422ebba Merge pull request #101 from matrix-org/neaten-federation-servlets
Neaten federation servlets
2015-03-09 17:39:06 +00:00
Erik Johnston
9fccb0df08 Merge pull request #104 from matrix-org/get_joined_rooms_for_user
Get joined rooms for user
2015-03-09 17:22:29 +00:00
Erik Johnston
6d74e46621 Fix tests 2015-03-09 17:01:11 +00:00
Erik Johnston
8e28db5cc9 Change room handlers get_rooms_for_user to get_joined_rooms_for_user. This uses the a storage api that is cached. 2015-03-09 16:43:09 +00:00
Erik Johnston
f31e65ca8b Merge branch 'develop' of github.com:matrix-org/synapse into erikj-perf 2015-03-09 13:29:41 +00:00
Paul "LeoNerd" Evans
d79d91a4a7 Appease pep8 2015-03-05 20:55:40 +00:00
Paul "LeoNerd" Evans
5eab2549ab Append a $ on PATH at registration time, meaning each PATH attribute doesn't need it 2015-03-05 20:36:05 +00:00
Paul "LeoNerd" Evans
7644cb79b2 Slightly neater(?) arrangement of authentication wrapper for HTTP servlet methods 2015-03-05 20:33:16 +00:00
Paul "LeoNerd" Evans
ba8ac996f9 Remove the dead 'rate_limit_origin' method from TransportLayerServer 2015-03-05 19:43:17 +00:00
Paul "LeoNerd" Evans
a901ed16b5 Move federation API responding code out of weird mix of lambdas into Servlet-style methods on instances 2015-03-05 19:10:57 +00:00
Erik Johnston
7f058c5ff7 Merge branch 'develop' of github.com:matrix-org/synapse into erikj-perf
Conflicts:
	synapse/app/homeserver.py
2015-01-22 13:35:34 +00:00
Erik Johnston
82be4457de Add twisted Service interface 2015-01-07 13:46:37 +00:00
15 changed files with 749 additions and 224 deletions

213
scripts/graph_tracer.py Normal file
View File

@@ -0,0 +1,213 @@
import fileinput
import pydot
import sys
import itertools
import json
def pairwise(iterable):
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = itertools.tee(iterable)
next(b, None)
return itertools.izip(a, b)
nodes = {}
edges = set()
graph = pydot.Dot(graph_name="call_graph", graph_type="digraph")
names = {}
starts = {}
ends = {}
deferreds = set()
deferreds_map = {}
deferred_edges = set()
root_id = None
for line in fileinput.input():
line = line.strip()
try:
if " calls " in line:
start, end = line.split(" calls ")
start, end = start.strip(), end.strip()
edges.add((start, end))
# print start, end
if " named " in line:
node_id, name = line.split(" named ")
names[node_id.strip()] = name.strip()
if name.strip() == "synapse.rest.client.v1.room.RoomSendEventRestServlet.on_PUT":
root_id = node_id
if " in " in line:
node_id, d = line.split(" in ")
deferreds_map[node_id.strip()] = d.strip()
if " is deferred" in line:
node_id, _ = line.split(" is deferred")
deferreds.add(node_id)
if " start " in line:
node_id, ms = line.split(" start ")
starts[node_id.strip()] = int(ms.strip())
if " end " in line:
node_id, ms = line.split(" end ")
ends[node_id.strip()] = int(ms.strip())
if " waits on " in line:
start, end = line.split(" waits on ")
start, end = start.strip(), end.strip()
deferred_edges.add((start, end))
# print start, end
except Exception as e:
sys.stderr.write("failed %s to parse '%s'\n" % (e.message, line))
if not root_id:
sys.stderr.write("Could not find root")
sys.exit(1)
# deferreds_root = set(deferreds.values())
# for parent, child in deferred_edges:
# deferreds_root.discard(child)
#
# deferred_tree = {
# d: {}
# for d in deferreds_root
# }
#
# def populate(root, tree):
# for leaf in deferred_edges.get(root, []):
# populate(leaf, tree.setdefault(leaf, {}))
#
#
# for d in deferreds_root:
# tree = deferred_tree.setdefault(d, {})
# populate(d, tree)
# print deferred_edges
# print root_id
def is_in_deferred(d):
while True:
if d == root_id:
return True
for start, end in deferred_edges:
if d == end:
d = start
break
else:
return False
def walk_graph(d):
res = [d]
while d != root_id:
for start, end in edges:
if d == end:
d = start
res.append(d)
break
else:
return res
return res
def make_tree_el(node_id):
return {
"id": node_id,
"name": names[node_id],
"children": [],
"start": starts[node_id],
"end": ends[node_id],
"size": ends[node_id] - starts[node_id],
}
tree = make_tree_el(root_id)
tree_index = {
root_id: tree,
}
viz_out = {
"nodes": [],
"edges": [],
}
for node_id, name in names.items():
# if times.get(node_id, 100) < 5:
# continue
walk = walk_graph(node_id)
# print walk
if root_id not in walk:
continue
if node_id in deferreds:
if not is_in_deferred(node_id):
continue
elif node_id in deferreds_map:
if not is_in_deferred(deferreds_map[node_id]):
continue
walk_names = [
names[w].split("synapse.", 1)[1] for w in walk
]
for child, parent in reversed(list(pairwise(walk))):
if parent in tree_index and child not in tree_index:
el = make_tree_el(child)
tree_index[parent]["children"].append(el)
tree_index[child] = el
# print "-".join(reversed(["end"] + walk_names)) + ", " + str(ends[node_id] - starts[node_id])
# print "%d,%s,%s,%s" % (len(walk), walk_names[0], starts[node_id], ends[node_id])
viz_out["nodes"].append({
"id": node_id,
"label": names[node_id].split("synapse.", 1)[1],
"value": ends[node_id] - starts[node_id],
"level": len(walk),
})
node = pydot.Node(node_id, label=name)
# if node_id in deferreds:
# clusters[deferreds[node_id]].add_node(node)
# elif node_id in clusters:
# clusters[node_id].add_node(node)
# else:
# graph.add_node(node)
graph.add_node(node)
nodes[node_id] = node
# print node_id
# for el in tree_index.values():
# el["children"].sort(key=lambda e: e["start"])
#
# print json.dumps(tree)
for parent, child in edges:
if child not in nodes:
# sys.stderr.write(child + " not a node\n")
continue
if parent not in nodes:
# sys.stderr.write(parent + " not a node\n")
continue
viz_out["edges"].append({
"from": parent,
"to": child,
"value": ends[child] - starts[child],
})
edge = pydot.Edge(nodes[parent], nodes[child])
graph.add_edge(edge)
print json.dumps(viz_out)
file_prefix = "call_graph_out"
graph.write('%s.dot' % file_prefix, format='raw', prog='dot')
graph.write_svg("%s.svg" % file_prefix, prog='dot')

View File

@@ -26,6 +26,7 @@ from synapse.server import HomeServer
from synapse.python_dependencies import check_requirements
from twisted.internet import reactor
from twisted.application import service
from twisted.enterprise import adbapi
from twisted.web.resource import Resource
from twisted.web.static import File
@@ -50,6 +51,8 @@ from synapse.rest.client.v2_alpha import ClientV2AlphaRestResource
from daemonize import Daemonize
import twisted.manhole.telnet
from synapse.util.traceutil import Tracer
import synapse
import logging
@@ -60,6 +63,7 @@ import subprocess
import sqlite3
import syweb
logger = logging.getLogger(__name__)
@@ -295,10 +299,19 @@ def change_resource_limit(soft_file_no):
logger.warn("Failed to set file limit: %s", e)
def setup():
def setup(config_options):
"""
Args:
config_options_options: The options passed to Synapse. Usually
`sys.argv[1:]`.
should_run (bool): Whether to start the reactor.
Returns:
HomeServer
"""
config = HomeServerConfig.load_config(
"Synapse Homeserver",
sys.argv[1:],
config_options,
generate_section="Homeserver"
)
@@ -370,13 +383,45 @@ def setup():
hs.get_datastore().start_profiling()
hs.get_replication_layer().start_get_pdu_cache()
if config.daemonize:
print config.pid_file
return hs
class SynapseService(service.Service):
"""A twisted Service class that will start synapse. Used to run synapse
via twistd and a .tac.
"""
def __init__(self, config):
self.config = config
def startService(self):
hs = setup(self.config)
change_resource_limit(hs.config.soft_file_limit)
def stopService(self):
return self._port.stopListening()
def run(hs):
def in_thread():
try:
tracer = Tracer()
sys.settrace(tracer.process)
except Exception:
logger.exception("Failed to start tracer")
with LoggingContext("run"):
change_resource_limit(hs.config.soft_file_limit)
reactor.run()
if hs.config.daemonize:
print hs.config.pid_file
daemon = Daemonize(
app="synapse-homeserver",
pid=config.pid_file,
action=lambda: run(config),
pid=hs.config.pid_file,
action=lambda: in_thread(),
auto_close_fds=False,
verbose=True,
logger=logger,
@@ -384,20 +429,14 @@ def setup():
daemon.start()
else:
run(config)
def run(config):
with LoggingContext("run"):
change_resource_limit(config.soft_file_limit)
reactor.run()
in_thread()
def main():
with LoggingContext("main"):
check_requirements()
setup()
hs = setup(sys.argv[1:])
run(hs)
if __name__ == '__main__':

View File

@@ -115,8 +115,8 @@ class TransactionQueue(object):
if not deferred.called:
deferred.errback(failure)
def log_failure(failure):
logger.warn("Failed to send pdu", failure.value)
def log_failure(f):
logger.warn("Failed to send pdu to %s: %s", destination, f.value)
deferred.addErrback(log_failure)
@@ -143,8 +143,8 @@ class TransactionQueue(object):
if not deferred.called:
deferred.errback(failure)
def log_failure(failure):
logger.warn("Failed to send pdu", failure.value)
def log_failure(f):
logger.warn("Failed to send edu to %s: %s", destination, f.value)
deferred.addErrback(log_failure)
@@ -174,7 +174,7 @@ class TransactionQueue(object):
deferred.errback(f)
def log_failure(f):
logger.warn("Failed to send pdu", f.value)
logger.warn("Failed to send failure to %s: %s", destination, f.value)
deferred.addErrback(log_failure)

View File

@@ -19,6 +19,7 @@ from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError
from synapse.util.logutils import log_function
import functools
import logging
import simplejson as json
import re
@@ -30,8 +31,9 @@ logger = logging.getLogger(__name__)
class TransportLayerServer(object):
"""Handles incoming federation HTTP requests"""
# A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
def _authenticate_request(self, request):
def authenticate_request(self, request):
json_request = {
"method": request.method,
"uri": request.uri,
@@ -93,28 +95,6 @@ class TransportLayerServer(object):
defer.returnValue((origin, content))
def _with_authentication(self, handler):
@defer.inlineCallbacks
def new_handler(request, *args, **kwargs):
try:
(origin, content) = yield self._authenticate_request(request)
with self.ratelimiter.ratelimit(origin) as d:
yield d
response = yield handler(
origin, content, request.args, *args, **kwargs
)
except:
logger.exception("_authenticate_request failed")
raise
defer.returnValue(response)
return new_handler
def rate_limit_origin(self, handler):
def new_handler(origin, *args, **kwargs):
response = yield handler(origin, *args, **kwargs)
defer.returnValue(response)
return new_handler()
@log_function
def register_received_handler(self, handler):
""" Register a handler that will be fired when we receive data.
@@ -122,14 +102,12 @@ class TransportLayerServer(object):
Args:
handler (TransportReceivedHandler)
"""
self.received_handler = handler
# This is when someone is trying to send us a bunch of data.
self.server.register_path(
"PUT",
re.compile("^" + PREFIX + "/send/([^/]*)/$"),
self._with_authentication(self._on_send_request)
)
FederationSendServlet(
handler,
authenticator=self,
ratelimiter=self.ratelimiter,
server_name=self.server_name,
).register(self.server)
@log_function
def register_request_handler(self, handler):
@@ -138,136 +116,61 @@ class TransportLayerServer(object):
Args:
handler (TransportRequestHandler)
"""
self.request_handler = handler
for servletclass in SERVLET_CLASSES:
servletclass(
handler,
authenticator=self,
ratelimiter=self.ratelimiter,
).register(self.server)
# This is for when someone asks us for everything since version X
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/pull/$"),
self._with_authentication(
lambda origin, content, query:
handler.on_pull_request(query["origin"][0], query["v"])
)
)
# This is when someone asks for a data item for a given server
# data_id pair.
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/event/([^/]*)/$"),
self._with_authentication(
lambda origin, content, query, event_id:
handler.on_pdu_request(origin, event_id)
)
)
class BaseFederationServlet(object):
def __init__(self, handler, authenticator, ratelimiter):
self.handler = handler
self.authenticator = authenticator
self.ratelimiter = ratelimiter
# This is when someone asks for all data for a given context.
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/state/([^/]*)/$"),
self._with_authentication(
lambda origin, content, query, context:
handler.on_context_state_request(
origin,
context,
query.get("event_id", [None])[0],
)
)
)
def _wrap(self, code):
authenticator = self.authenticator
ratelimiter = self.ratelimiter
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/backfill/([^/]*)/$"),
self._with_authentication(
lambda origin, content, query, context:
self._on_backfill_request(
origin, context, query["v"], query["limit"]
)
)
)
@defer.inlineCallbacks
@functools.wraps(code)
def new_code(request, *args, **kwargs):
try:
(origin, content) = yield authenticator.authenticate_request(request)
with ratelimiter.ratelimit(origin) as d:
yield d
response = yield code(
origin, content, request.args, *args, **kwargs
)
except:
logger.exception("authenticate_request failed")
raise
defer.returnValue(response)
return new_code
# This is when we receive a server-server Query
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/query/([^/]*)$"),
self._with_authentication(
lambda origin, content, query, query_type:
handler.on_query_request(
query_type,
{k: v[0].decode("utf-8") for k, v in query.items()}
)
)
)
def register(self, server):
pattern = re.compile("^" + PREFIX + self.PATH + "$")
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"),
self._with_authentication(
lambda origin, content, query, context, user_id:
self._on_make_join_request(
origin, content, query, context, user_id
)
)
)
for method in ("GET", "PUT", "POST"):
code = getattr(self, "on_%s" % (method), None)
if code is None:
continue
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/event_auth/([^/]*)/([^/]*)$"),
self._with_authentication(
lambda origin, content, query, context, event_id:
handler.on_event_auth(
origin, context, event_id,
)
)
)
server.register_path(method, pattern, self._wrap(code))
self.server.register_path(
"PUT",
re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)$"),
self._with_authentication(
lambda origin, content, query, context, event_id:
self._on_send_join_request(
origin, content, query,
)
)
)
self.server.register_path(
"PUT",
re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)$"),
self._with_authentication(
lambda origin, content, query, context, event_id:
self._on_invite_request(
origin, content, query,
)
)
)
class FederationSendServlet(BaseFederationServlet):
PATH = "/send/([^/]*)/"
self.server.register_path(
"POST",
re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"),
self._with_authentication(
lambda origin, content, query, context, event_id:
self._on_query_auth_request(
origin, content, event_id,
)
)
)
self.server.register_path(
"POST",
re.compile("^" + PREFIX + "/get_missing_events/([^/]*)/?$"),
self._with_authentication(
lambda origin, content, query, room_id:
self._get_missing_events(
origin, content, room_id,
)
)
)
def __init__(self, handler, server_name, **kwargs):
super(FederationSendServlet, self).__init__(handler, **kwargs)
self.server_name = server_name
# This is when someone is trying to send us a bunch of data.
@defer.inlineCallbacks
@log_function
def _on_send_request(self, origin, content, query, transaction_id):
def on_PUT(self, origin, content, query, transaction_id):
""" Called on PUT /send/<transaction_id>/
Args:
@@ -305,8 +208,7 @@ class TransportLayerServer(object):
return
try:
handler = self.received_handler
code, response = yield handler.on_incoming_transaction(
code, response = yield self.handler.on_incoming_transaction(
transaction_data
)
except:
@@ -315,65 +217,123 @@ class TransportLayerServer(object):
defer.returnValue((code, response))
@log_function
def _on_backfill_request(self, origin, context, v_list, limits):
class FederationPullServlet(BaseFederationServlet):
PATH = "/pull/"
# This is for when someone asks us for everything since version X
def on_GET(self, origin, content, query):
return self.handler.on_pull_request(query["origin"][0], query["v"])
class FederationEventServlet(BaseFederationServlet):
PATH = "/event/([^/]*)/"
# This is when someone asks for a data item for a given server data_id pair.
def on_GET(self, origin, content, query, event_id):
return self.handler.on_pdu_request(origin, event_id)
class FederationStateServlet(BaseFederationServlet):
PATH = "/state/([^/]*)/"
# This is when someone asks for all data for a given context.
def on_GET(self, origin, content, query, context):
return self.handler.on_context_state_request(
origin,
context,
query.get("event_id", [None])[0],
)
class FederationBackfillServlet(BaseFederationServlet):
PATH = "/backfill/([^/]*)/"
def on_GET(self, origin, content, query, context):
versions = query["v"]
limits = query["limit"]
if not limits:
return defer.succeed(
(400, {"error": "Did not include limit param"})
)
return defer.succeed((400, {"error": "Did not include limit param"}))
limit = int(limits[-1])
versions = v_list
return self.handler.on_backfill_request(origin, context, versions, limit)
return self.request_handler.on_backfill_request(
origin, context, versions, limit
class FederationQueryServlet(BaseFederationServlet):
PATH = "/query/([^/]*)"
# This is when we receive a server-server Query
def on_GET(self, origin, content, query, query_type):
return self.handler.on_query_request(
query_type,
{k: v[0].decode("utf-8") for k, v in query.items()}
)
class FederationMakeJoinServlet(BaseFederationServlet):
PATH = "/make_join/([^/]*)/([^/]*)"
@defer.inlineCallbacks
@log_function
def _on_make_join_request(self, origin, content, query, context, user_id):
content = yield self.request_handler.on_make_join_request(
context, user_id,
)
def on_GET(self, origin, content, query, context, user_id):
content = yield self.handler.on_make_join_request(context, user_id)
defer.returnValue((200, content))
@defer.inlineCallbacks
@log_function
def _on_send_join_request(self, origin, content, query):
content = yield self.request_handler.on_send_join_request(
origin, content,
)
class FederationEventAuthServlet(BaseFederationServlet):
PATH = "/event_auth/([^/]*)/([^/]*)"
def on_GET(self, origin, content, query, context, event_id):
return self.handler.on_event_auth(origin, context, event_id)
class FederationSendJoinServlet(BaseFederationServlet):
PATH = "/send_join/([^/]*)/([^/]*)"
@defer.inlineCallbacks
def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually
# match those given in content
content = yield self.handler.on_send_join_request(origin, content)
defer.returnValue((200, content))
@defer.inlineCallbacks
@log_function
def _on_invite_request(self, origin, content, query):
content = yield self.request_handler.on_invite_request(
origin, content,
)
class FederationInviteServlet(BaseFederationServlet):
PATH = "/invite/([^/]*)/([^/]*)"
@defer.inlineCallbacks
def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually
# match those given in content
content = yield self.handler.on_invite_request(origin, content)
defer.returnValue((200, content))
class FederationQueryAuthServlet(BaseFederationServlet):
PATH = "/query_auth/([^/]*)/([^/]*)"
@defer.inlineCallbacks
@log_function
def _on_query_auth_request(self, origin, content, event_id):
new_content = yield self.request_handler.on_query_auth_request(
def on_POST(self, origin, content, query, context, event_id):
new_content = yield self.handler.on_query_auth_request(
origin, content, event_id
)
defer.returnValue((200, new_content))
class FederationGetMissingEventsServlet(BaseFederationServlet):
# TODO(paul): Why does this path alone end with "/?" optional?
PATH = "/get_missing_events/([^/]*)/?"
@defer.inlineCallbacks
@log_function
def _get_missing_events(self, origin, content, room_id):
def on_POST(self, origin, content, query, room_id):
limit = int(content.get("limit", 10))
min_depth = int(content.get("min_depth", 0))
earliest_events = content.get("earliest_events", [])
latest_events = content.get("latest_events", [])
content = yield self.request_handler.on_get_missing_events(
content = yield self.handler.on_get_missing_events(
origin,
room_id=room_id,
earliest_events=earliest_events,
@@ -383,3 +343,18 @@ class TransportLayerServer(object):
)
defer.returnValue((200, content))
SERVLET_CLASSES = (
FederationPullServlet,
FederationEventServlet,
FederationStateServlet,
FederationBackfillServlet,
FederationQueryServlet,
FederationMakeJoinServlet,
FederationEventServlet,
FederationSendJoinServlet,
FederationInviteServlet,
FederationQueryAuthServlet,
FederationGetMissingEventsServlet,
)

View File

@@ -71,7 +71,7 @@ class EventStreamHandler(BaseHandler):
self._streams_per_user[auth_user] += 1
rm_handler = self.hs.get_handlers().room_member_handler
room_ids = yield rm_handler.get_rooms_for_user(auth_user)
room_ids = yield rm_handler.get_joined_rooms_for_user(auth_user)
if timeout:
# If they've set a timeout set a minimum limit.

View File

@@ -73,7 +73,6 @@ class FederationHandler(BaseHandler):
# When joining a room we need to queue any events for that room up
self.room_queues = {}
@log_function
@defer.inlineCallbacks
def handle_new_event(self, event, destinations):
""" Takes in an event from the client to server side, that has already

View File

@@ -452,7 +452,7 @@ class PresenceHandler(BaseHandler):
# Also include people in all my rooms
rm_handler = self.homeserver.get_handlers().room_member_handler
room_ids = yield rm_handler.get_rooms_for_user(user)
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
if state is None:
state = yield self.store.get_presence_state(user.localpart)
@@ -596,7 +596,7 @@ class PresenceHandler(BaseHandler):
localusers.add(user)
rm_handler = self.homeserver.get_handlers().room_member_handler
room_ids = yield rm_handler.get_rooms_for_user(user)
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
if not localusers and not room_ids:
defer.returnValue(None)
@@ -663,7 +663,7 @@ class PresenceHandler(BaseHandler):
)
rm_handler = self.homeserver.get_handlers().room_member_handler
room_ids = yield rm_handler.get_rooms_for_user(user)
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
if room_ids:
logger.debug(" | %d interested room IDs %r", len(room_ids), room_ids)

View File

@@ -197,9 +197,8 @@ class ProfileHandler(BaseHandler):
self.ratelimit(user.to_string())
joins = yield self.store.get_rooms_for_user_where_membership_is(
joins = yield self.store.get_rooms_for_user(
user.to_string(),
[Membership.JOIN],
)
for j in joins:

View File

@@ -507,7 +507,7 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue((is_remote_invite_join, room_host))
@defer.inlineCallbacks
def get_rooms_for_user(self, user, membership_list=[Membership.JOIN]):
def get_joined_rooms_for_user(self, user):
"""Returns a list of roomids that the user has any of the given
membership states in."""
@@ -517,8 +517,8 @@ class RoomMemberHandler(BaseHandler):
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
else:
rooms = yield self.store.get_rooms_for_user_where_membership_is(
user_id=user.to_string(), membership_list=membership_list
rooms = yield self.store.get_rooms_for_user(
user.to_string(),
)
# For some reason the list of events contains duplicates

View File

@@ -96,7 +96,9 @@ class SyncHandler(BaseHandler):
return self.current_sync_for_user(sync_config, since_token)
rm_handler = self.hs.get_handlers().room_member_handler
room_ids = yield rm_handler.get_rooms_for_user(sync_config.user)
room_ids = yield rm_handler.get_joined_rooms_for_user(
sync_config.user
)
result = yield self.notifier.wait_for_events(
sync_config.user, room_ids,
sync_config.filter, timeout, current_sync_callback
@@ -227,7 +229,7 @@ class SyncHandler(BaseHandler):
logger.debug("Typing %r", typing_by_room)
rm_handler = self.hs.get_handlers().room_member_handler
room_ids = yield rm_handler.get_rooms_for_user(sync_config.user)
room_ids = yield rm_handler.get_joined_rooms_for_user(sync_config.user)
# TODO (mjark): Does public mean "published"?
published_rooms = yield self.store.get_rooms(is_public=True)

View File

@@ -165,6 +165,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
defer.returnValue((200, {}))
def trace(f):
f.should_trace = True
f.root_trace = True
return f
# TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(ClientV1RestServlet):
@@ -175,7 +181,11 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_type, txn_id=None):
import inspect
frame = inspect.currentframe()
logger.info("Frame: %s", id(frame))
user, client = yield self.auth.get_user_by_req(request)
logger.info("Frame: %s", id(inspect.currentframe()))
content = _parse_json(request)
msg_handler = self.handlers.message_handler
@@ -189,12 +199,14 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
client=client,
txn_id=txn_id,
)
logger.info("Frame: %s", id(inspect.currentframe()))
defer.returnValue((200, {"event_id": event.event_id}))
def on_GET(self, request, room_id, event_type, txn_id):
return (200, "Not implemented")
@trace
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, txn_id):
try:

286
synapse/util/traceutil.py Normal file
View File

@@ -0,0 +1,286 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import inspect
import logging
logger = logging.getLogger("Tracer")
class Tracer(object):
def __init__(self):
self.interested_deferreds = set()
self.next_id = 1
self.deferred_frames = {}
self.deferred_to_current_frames = {}
def process(self, frame, event, arg):
if event == 'call':
return self.process_call(frame)
def handle_inline_callbacks(self, frm):
argvalues = inspect.getargvalues(frm)
generator = argvalues.locals["g"]
deferred = argvalues.locals["deferred"]
if not hasattr(deferred, "syn_trace_defer_id"):
trace_id = self.get_next_id()
deferred.syn_trace_defer_id = trace_id
logger.info(
"%s named %s",
trace_id,
self.get_name_for_frame(generator.gi_frame)
)
logger.info(
"%s is deferred",
trace_id,
)
logger.info("%s start %d", trace_id, int(time.time() * 1000))
def do(res):
logger.info("%s end %d", trace_id, int(time.time() * 1000))
return res
deferred.addBoth(do)
back = frm.f_back
while back:
try:
name = self.get_name_for_frame(back)
if name == "twisted.internet.defer._inlineCallbacks":
argvalues = inspect.getargvalues(back)
deferred = argvalues.locals["deferred"]
d_id = getattr(deferred, "syn_trace_defer_id", None)
if d_id:
logger.info("%s in %s", trace_id, d_id)
curr_stack = self.deferred_to_current_frames.setdefault(
d_id, []
)
if curr_stack:
logger.info("%s calls %s", curr_stack[-1], trace_id)
else:
logger.info("%s calls %s", d_id, trace_id)
break
except:
pass
back = back.f_back
def are_interested(self, name):
if not name.startswith("synapse"):
return False
if name.startswith("synapse.util.logcontext"):
return False
if name.startswith("synapse.util.logutils"):
return False
if name.startswith("synapse.util.traceutil"):
return False
if name.startswith("synapse.events.FrozenEvent.get"):
return False
if name.startswith("synapse.events.EventBuilder.get"):
return False
if name.startswith("synapse.types"):
return False
if name.startswith("synapse.util.frozenutils.freeze"):
return False
if name.startswith("synapse.util.frozenutils.<dictcomp>"):
return False
if name.startswith("synapse.util.Clock"):
return False
if name.endswith("__repr__") or name.endswith("__str__"):
return False
if name.endswith("<genexpr>"):
return False
return True
def process_call(self, frame):
should_trace = False
try:
name = self.get_name_for_frame(frame)
if name == "twisted.internet.defer._inlineCallbacks":
self.handle_inline_callbacks(frame)
return
if not self.are_interested(name):
return
back_name = self.get_name_for_frame(frame.f_back)
if name == "synapse.api.auth.Auth.get_user_by_req":
logger.info(
"synapse.api.auth.Auth.get_user_by_req %s",
back_name
)
try:
if back_name == "twisted.internet.defer._inlineCallbacks":
def ret(f, event, result):
if event != "return":
return
argvalues = inspect.getargvalues(frame.f_back)
deferred = argvalues.locals["deferred"]
try:
logger.info(
"%s waits on %s",
deferred.syn_trace_defer_id,
result.syn_trace_defer_id
)
except:
pass
return ret
if back_name == "twisted.internet.defer.unwindGenerator":
return
except:
pass
try:
func = getattr(frame.f_locals["self"], frame.f_code.co_name)
if inspect.isgeneratorfunction(func):
return
except:
pass
try:
func = frame.f_globals[frame.f_code.co_name]
if inspect.isgeneratorfunction(func):
return
except:
pass
except:
return
back = frame
names = []
seen_deferreds = []
bottom_deferred = None
while back:
try:
name = self.get_name_for_frame(back)
if name.startswith("synapse"):
names.append(name)
# if name.startswith("twisted.internet.defer"):
# logger.info("Name: %s", name)
if name == "twisted.internet.defer._inlineCallbacks":
argvalues = inspect.getargvalues(back)
deferred = argvalues.locals["deferred"]
d_id = getattr(deferred, "syn_trace_defer_id", None)
if d_id:
seen_deferreds.append(d_id)
if not bottom_deferred:
bottom_deferred = deferred
if d_id in self.interested_deferreds:
should_trace = True
break
func = getattr(back.f_locals["self"], back.f_code.co_name)
if hasattr(func, "should_trace") or hasattr(func.im_func, "should_trace"):
should_trace = True
break
func.root_trace
should_trace = True
break
except:
pass
back = back.f_back
if not should_trace:
return
frame_id = self.get_next_id()
name = self.get_name_for_frame(frame)
logger.info("%s named %s", frame_id, name)
self.interested_deferreds.update(seen_deferreds)
names.reverse()
if bottom_deferred:
self.deferred_frames.setdefault(
bottom_deferred.syn_trace_defer_id, []
).append(names)
logger.info("%s in %s", frame_id, bottom_deferred.syn_trace_defer_id)
if not hasattr(bottom_deferred, "syn_trace_registered_cb"):
bottom_deferred.syn_trace_registered_cb = True
def do(res):
return res
bottom_deferred.addBoth(do)
curr_stack = self.deferred_to_current_frames.setdefault(
bottom_deferred.syn_trace_defer_id, []
)
if curr_stack:
logger.info("%s calls %s", curr_stack[-1], frame_id)
else:
logger.info("%s calls %s", bottom_deferred.syn_trace_defer_id, frame_id)
curr_stack.append(frame_id)
logger.info("%s start %d", frame_id, int(time.time() * 1000))
def p(frame, event, arg):
if event == "return":
curr_stack.pop()
logger.info("%s end %d", frame_id, int(time.time() * 1000))
return p
def get_name_for_frame(self, frame):
module_name = frame.f_globals["__name__"]
cls_instance = frame.f_locals.get("self", None)
if cls_instance:
cls_name = cls_instance.__class__.__name__
name = "%s.%s.%s" % (
module_name, cls_name, frame.f_code.co_name
)
else:
name = "%s.%s" % (
module_name, frame.f_code.co_name
)
return name
def get_next_id(self):
i = self.next_id
self.next_id += 1
return i

View File

@@ -100,7 +100,7 @@ class PresenceTestCase(unittest.TestCase):
self.room_members = []
room_member_handler = handlers.room_member_handler = Mock(spec=[
"get_rooms_for_user",
"get_joined_rooms_for_user",
"get_room_members",
"fetch_room_distributions_into",
])
@@ -111,7 +111,7 @@ class PresenceTestCase(unittest.TestCase):
return defer.succeed([self.room_id])
else:
return defer.succeed([])
room_member_handler.get_rooms_for_user = get_rooms_for_user
room_member_handler.get_joined_rooms_for_user = get_rooms_for_user
def get_room_members(room_id):
if room_id == self.room_id:

View File

@@ -64,7 +64,7 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
"set_presence_state",
"is_presence_visible",
"set_profile_displayname",
"get_rooms_for_user_where_membership_is",
"get_rooms_for_user",
]),
handlers=None,
resource_for_federation=Mock(),
@@ -124,9 +124,9 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
self.mock_update_client)
hs.handlers.room_member_handler = Mock(spec=[
"get_rooms_for_user",
"get_joined_rooms_for_user",
])
hs.handlers.room_member_handler.get_rooms_for_user = (
hs.handlers.room_member_handler.get_joined_rooms_for_user = (
lambda u: defer.succeed([]))
# Some local users to test with
@@ -138,7 +138,7 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
self.u_potato = UserID.from_string("@potato:remote")
self.mock_get_joined = (
self.datastore.get_rooms_for_user_where_membership_is
self.datastore.get_rooms_for_user
)
@defer.inlineCallbacks

View File

@@ -79,13 +79,13 @@ class PresenceStateTestCase(unittest.TestCase):
room_member_handler = hs.handlers.room_member_handler = Mock(
spec=[
"get_rooms_for_user",
"get_joined_rooms_for_user",
]
)
def get_rooms_for_user(user):
return defer.succeed([])
room_member_handler.get_rooms_for_user = get_rooms_for_user
room_member_handler.get_joined_rooms_for_user = get_rooms_for_user
presence.register_servlets(hs, self.mock_resource)
@@ -166,7 +166,7 @@ class PresenceListTestCase(unittest.TestCase):
hs.handlers.room_member_handler = Mock(
spec=[
"get_rooms_for_user",
"get_joined_rooms_for_user",
]
)
@@ -291,7 +291,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
return ["a-room"]
else:
return []
hs.handlers.room_member_handler.get_rooms_for_user = get_rooms_for_user
hs.handlers.room_member_handler.get_joined_rooms_for_user = get_rooms_for_user
self.mock_datastore = hs.get_datastore()
self.mock_datastore.get_app_service_by_token = Mock(return_value=None)