1
0

Compare commits

..

2 Commits

149 changed files with 2896 additions and 8993 deletions

17
.gitignore vendored
View File

@@ -26,18 +26,17 @@ htmlcov
demo/*.db
demo/*.log
demo/*.log.*
demo/*.pid
demo/media_store.*
demo/etc
graph/*.svg
graph/*.png
graph/*.dot
**/webclient/config.js
**/webclient/test/coverage/
**/webclient/test/environment-protractor.js
uploads
.idea/
media_store/
*.tac
build/
localhost-800*/

View File

@@ -1,78 +1,54 @@
Changes in synapse v0.7.0 (2015-02-12)
======================================
* Add initial implementation of the query auth federation API, allowing
servers to agree on whether an event should be allowed or rejected.
* Persist events we have rejected from federation, fixing the bug where
servers would keep requesting the same events.
* Various federation performance improvements, including:
- Add in memory caches on queries such as:
* Computing the state of a room at a point in time, used for
authorization on federation requests.
* Fetching events from the database.
* User's room membership, used for authorizing presence updates.
- Upgraded JSON library to improve parsing and serialisation speeds.
* Add default avatars to new user accounts using pydenticon library.
* Correctly time out federation requests.
* Retry federation requests against different servers.
* Add support for push and push rules.
* Add alpha versions of proposed new CSv2 APIs, including ``/sync`` API.
Changes in synapse 0.6.1 (2015-01-07)
=====================================
* Major optimizations to improve performance of initial sync and event sending
in large rooms (by up to 10x)
* Media repository now includes a Content-Length header on media downloads.
* Improve quality of thumbnails by changing resizing algorithm.
* Major optimizations to improve performance of initial sync and event sending
in large rooms (by up to 10x)
* Media repository now includes a Content-Length header on media downloads.
* Improve quality of thumbnails by changing resizing algorithm.
Changes in synapse 0.6.0 (2014-12-16)
=====================================
* Add new API for media upload and download that supports thumbnailing.
* Replicate media uploads over multiple homeservers so media is always served
to clients from their local homeserver. This obsoletes the
--content-addr parameter and confusion over accessing content directly
from remote homeservers.
* Implement exponential backoff when retrying federation requests when
sending to remote homeservers which are offline.
* Implement typing notifications.
* Fix bugs where we sent events with invalid signatures due to bugs where
we incorrectly persisted events.
* Improve performance of database queries involving retrieving events.
* Add new API for media upload and download that supports thumbnailing.
* Replicate media uploads over multiple homeservers so media is always served
to clients from their local homeserver. This obsoletes the
--content-addr parameter and confusion over accessing content directly
from remote homeservers.
* Implement exponential backoff when retrying federation requests when
sending to remote homeservers which are offline.
* Implement typing notifications.
* Fix bugs where we sent events with invalid signatures due to bugs where
we incorrectly persisted events.
* Improve performance of database queries involving retrieving events.
Changes in synapse 0.5.4a (2014-12-13)
======================================
* Fix bug while generating the error message when a file path specified in
the config doesn't exist.
* Fix bug while generating the error message when a file path specified in
the config doesn't exist.
Changes in synapse 0.5.4 (2014-12-03)
=====================================
* Fix presence bug where some rooms did not display presence updates for
remote users.
* Do not log SQL timing log lines when started with "-v"
* Fix potential memory leak.
* Fix presence bug where some rooms did not display presence updates for
remote users.
* Do not log SQL timing log lines when started with "-v"
* Fix potential memory leak.
Changes in synapse 0.5.3c (2014-12-02)
======================================
* Change the default value for the `content_addr` option to use the HTTP
listener, as by default the HTTPS listener will be using a self-signed
certificate.
* Change the default value for the `content_addr` option to use the HTTP
listener, as by default the HTTPS listener will be using a self-signed
certificate.
Changes in synapse 0.5.3 (2014-11-27)
=====================================
* Fix bug that caused joining a remote room to fail if a single event was not
signed correctly.
* Fix bug which caused servers to continuously try and fetch events from other
servers.
* Fix bug that caused joining a remote room to fail if a single event was not
signed correctly.
* Fix bug which caused servers to continuously try and fetch events from other
servers.
Changes in synapse 0.5.2 (2014-11-26)
=====================================

View File

@@ -1,14 +1,4 @@
include synctl
include LICENSE
include VERSION
include *.rst
include demo/README
recursive-include synapse/storage/schema *.sql
recursive-include demo *.dh
recursive-include demo *.py
recursive-include demo *.sh
recursive-include docs *
recursive-include scripts *
recursive-include tests *.py
recursive-include synapse/storage/schema *.sql
recursive-include syweb/webclient *

View File

@@ -95,35 +95,27 @@ Installing prerequisites on Ubuntu or Debian::
$ sudo apt-get install build-essential python2.7-dev libffi-dev \
python-pip python-setuptools sqlite3 \
libssl-dev python-virtualenv libjpeg-dev
Installing prerequisites on ArchLinux::
$ sudo pacman -S base-devel python2 python-pip \
python-setuptools python-virtualenv sqlite3
libssl-dev
Installing prerequisites on Mac OS X::
$ xcode-select --install
$ sudo pip install virtualenv
To install the synapse homeserver run::
$ virtualenv ~/.synapse
$ source ~/.synapse/bin/activate
$ pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
$ pip install --user --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
This installs synapse, along with the libraries it uses, into a virtual
environment under ``~/.synapse``.
This installs synapse, along with the libraries it uses, into
``$HOME/.local/lib/`` on Linux or ``$HOME/Library/Python/2.7/lib/`` on OSX.
To set up your homeserver, run (in your virtualenv, as before)::
Your python may not give priority to locally installed libraries over system
libraries, in which case you must add your local packages to your python path::
$ python -m synapse.app.homeserver \
--server-name machine.my.domain.name \
--config-path homeserver.yaml \
--generate-config
$ # on Linux:
$ export PYTHONPATH=$HOME/.local/lib/python2.7/site-packages:$PYTHONPATH
Substituting your host and domain name as appropriate.
$ # on OSX:
$ export PYTHONPATH=$HOME/Library/Python/2.7/lib/python/site-packages:$PYTHONPATH
For reliable VoIP calls to be routed via this homeserver, you MUST configure
a TURN server. See docs/turn-howto.rst for details.
@@ -136,56 +128,23 @@ you get errors about ``error: no such option: --process-dependency-links`` you
may need to manually upgrade it::
$ sudo pip install --upgrade pip
If pip crashes mid-installation for reason (e.g. lost terminal), pip may
refuse to run until you remove the temporary installation directory it
created. To reset the installation::
$ rm -rf /tmp/pip_install_matrix
pip seems to leak *lots* of memory during installation. For instance, a Linux
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
happens, you will have to individually install the dependencies which are
failing, e.g.::
$ pip install twisted
$ pip install --user twisted
On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
will need to export CFLAGS=-Qunused-arguments.
ArchLinux
---------
Installation on ArchLinux may encounter a few hiccups as Arch defaults to
python 3, but synapse currently assumes python 2.7 by default.
pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 )::
$ sudo pip2.7 install --upgrade pip
You also may need to explicitly specify python 2.7 again during the install
request::
$ pip2.7 install --process-dependency-links \
https://github.com/matrix-org/synapse/tarball/master
If you encounter an error with lib bcrypt causing an Wrong ELF Class:
ELFCLASS32 (x64 Systems), you may need to reinstall py-bcrypt to correctly
compile it under the right architecture. (This should not be needed if
installing under virtualenv)::
$ sudo pip2.7 uninstall py-bcrypt
$ sudo pip2.7 install py-bcrypt
During setup of homeserver you need to call python2.7 directly again::
$ python2.7 -m synapse.app.homeserver \
--server-name machine.my.domain.name \
--config-path homeserver.yaml \
--generate-config
...substituting your host and domain name as appropriate.
Windows Install
---------------
Synapse can be installed on Cygwin. It requires the following Cygwin packages:
@@ -196,7 +155,7 @@ Synapse can be installed on Cygwin. It requires the following Cygwin packages:
- openssl (and openssl-devel, python-openssl)
- python
- python-setuptools
The content repository requires additional packages and will be unable to process
uploads without them:
- libjpeg8
@@ -223,13 +182,23 @@ Running Your Homeserver
To actually run your new homeserver, pick a working directory for Synapse to run
(e.g. ``~/.synapse``), and::
$ mkdir ~/.synapse
$ cd ~/.synapse
$ source ./bin/activate
$ synctl start
$ # on Linux
$ ~/.local/bin/synctl start
$ # on OSX
$ ~/Library/Python/2.7/bin/synctl start
Troubleshooting Running
-----------------------
If ``synctl`` fails with ``pkg_resources.DistributionNotFound`` errors you may
need a newer version of setuptools than that provided by your OS.::
$ sudo pip install setuptools --upgrade
If synapse fails with ``missing "sodium.h"`` crypto errors, you may need
to manually upgrade PyNaCL, as synapse uses NaCl (http://nacl.cr.yp.to/) for
encryption and digital signatures.
@@ -245,14 +214,6 @@ fix try re-installing from PyPI or directly from
$ # Install from github
$ pip install --user https://github.com/pyca/pynacl/tarball/master
ArchLinux
---------
If running `$ synctl start` fails wit 'returned non-zero exit status 1', you will need to explicitly call Python2.7 - either running as::
$ python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml --pid-file homeserver.pid
...or by editing synctl with the correct python executable.
Homeserver Development
======================
@@ -264,15 +225,13 @@ directory of your choice::
$ cd synapse
The homeserver has a number of external dependencies, that are easiest
to install using pip and a virtualenv::
to install by making setup.py do so, in --user mode::
$ virtualenv env
$ source env/bin/activate
$ python synapse/python_dependencies.py | xargs -n1 pip install
$ pip install setuptools_trial mock
$ python setup.py develop --user
This will run a process of downloading and installing all the needed
dependencies into a virtual env.
This will run a process of downloading and installing into your
user's .local/lib directory all of the required dependencies that are
missing.
Once this is done, you may wish to run the homeserver's unit tests, to
check that everything is installed as it should be::
@@ -293,7 +252,7 @@ IMPORTANT: Before upgrading an existing homeserver to a new version, please
refer to UPGRADE.rst for any additional instructions.
Otherwise, simply re-install the new codebase over the current one - e.g.
by ``pip install --process-dependency-links
by ``pip install --user --process-dependency-links
https://github.com/matrix-org/synapse/tarball/master``
if using pip, or by ``git pull`` if running off a git working copy.
@@ -320,9 +279,9 @@ For the first form, simply pass the required hostname (of the machine) as the
$ python -m synapse.app.homeserver \
--server-name machine.my.domain.name \
--config-path homeserver.yaml \
--config-path homeserver.config \
--generate-config
$ python -m synapse.app.homeserver --config-path homeserver.yaml
$ python -m synapse.app.homeserver --config-path homeserver.config
Alternatively, you can run ``synctl start`` to guide you through the process.
@@ -342,9 +301,9 @@ SRV record, as that is the name other machines will expect it to have::
$ python -m synapse.app.homeserver \
--server-name YOURDOMAIN \
--bind-port 8448 \
--config-path homeserver.yaml \
--config-path homeserver.config \
--generate-config
$ python -m synapse.app.homeserver --config-path homeserver.yaml
$ python -m synapse.app.homeserver --config-path homeserver.config
You may additionally want to pass one or more "-v" options, in order to

View File

@@ -1,17 +1,3 @@
Upgrading to v0.7.0
===================
New dependencies are:
- pydenticon
- simplejson
- syutil
- matrix-angular-sdk
To pull in these dependencies in a virtual env, run::
python synapse/python_dependencies.py | xargs -n 1 pip install
Upgrading to v0.6.0
===================

1
VERSION Normal file
View File

@@ -0,0 +1 @@
0.6.1b

View File

@@ -39,43 +39,43 @@ ROOMDOMAIN="meet.jit.si"
#ROOMDOMAIN="conference.jitsi.vuc.me"
class TrivialMatrixClient:
def __init__(self, access_token):
self.token = None
self.access_token = access_token
def __init__(self, access_token):
self.token = None
self.access_token = access_token
def getEvent(self):
while True:
url = MATRIXBASE+'events?access_token='+self.access_token+"&timeout=60000"
if self.token:
url += "&from="+self.token
req = grequests.get(url)
resps = grequests.map([req])
obj = json.loads(resps[0].content)
print "incoming from matrix",obj
if 'end' not in obj:
continue
self.token = obj['end']
if len(obj['chunk']):
return obj['chunk'][0]
def getEvent(self):
while True:
url = MATRIXBASE+'events?access_token='+self.access_token+"&timeout=60000"
if self.token:
url += "&from="+self.token
req = grequests.get(url)
resps = grequests.map([req])
obj = json.loads(resps[0].content)
print "incoming from matrix",obj
if 'end' not in obj:
continue
self.token = obj['end']
if len(obj['chunk']):
return obj['chunk'][0]
def joinRoom(self, roomId):
url = MATRIXBASE+'rooms/'+roomId+'/join?access_token='+self.access_token
print url
headers={ 'Content-Type': 'application/json' }
req = grequests.post(url, headers=headers, data='{}')
resps = grequests.map([req])
obj = json.loads(resps[0].content)
print "response: ",obj
def joinRoom(self, roomId):
url = MATRIXBASE+'rooms/'+roomId+'/join?access_token='+self.access_token
print url
headers={ 'Content-Type': 'application/json' }
req = grequests.post(url, headers=headers, data='{}')
resps = grequests.map([req])
obj = json.loads(resps[0].content)
print "response: ",obj
def sendEvent(self, roomId, evType, event):
url = MATRIXBASE+'rooms/'+roomId+'/send/'+evType+'?access_token='+self.access_token
print url
print json.dumps(event)
headers={ 'Content-Type': 'application/json' }
req = grequests.post(url, headers=headers, data=json.dumps(event))
resps = grequests.map([req])
obj = json.loads(resps[0].content)
print "response: ",obj
def sendEvent(self, roomId, evType, event):
url = MATRIXBASE+'rooms/'+roomId+'/send/'+evType+'?access_token='+self.access_token
print url
print json.dumps(event)
headers={ 'Content-Type': 'application/json' }
req = grequests.post(url, headers=headers, data=json.dumps(event))
resps = grequests.map([req])
obj = json.loads(resps[0].content)
print "response: ",obj
@@ -83,178 +83,178 @@ xmppClients = {}
def matrixLoop():
while True:
ev = matrixCli.getEvent()
print ev
if ev['type'] == 'm.room.member':
print 'membership event'
if ev['membership'] == 'invite' and ev['state_key'] == MYUSERNAME:
roomId = ev['room_id']
print "joining room %s" % (roomId)
matrixCli.joinRoom(roomId)
elif ev['type'] == 'm.room.message':
if ev['room_id'] in xmppClients:
print "already have a bridge for that user, ignoring"
continue
print "got message, connecting"
xmppClients[ev['room_id']] = TrivialXmppClient(ev['room_id'], ev['user_id'])
gevent.spawn(xmppClients[ev['room_id']].xmppLoop)
elif ev['type'] == 'm.call.invite':
print "Incoming call"
#sdp = ev['content']['offer']['sdp']
#print "sdp: %s" % (sdp)
#xmppClients[ev['room_id']] = TrivialXmppClient(ev['room_id'], ev['user_id'])
#gevent.spawn(xmppClients[ev['room_id']].xmppLoop)
elif ev['type'] == 'm.call.answer':
print "Call answered"
sdp = ev['content']['answer']['sdp']
if ev['room_id'] not in xmppClients:
print "We didn't have a call for that room"
continue
# should probably check call ID too
xmppCli = xmppClients[ev['room_id']]
xmppCli.sendAnswer(sdp)
elif ev['type'] == 'm.call.hangup':
if ev['room_id'] in xmppClients:
xmppClients[ev['room_id']].stop()
del xmppClients[ev['room_id']]
while True:
ev = matrixCli.getEvent()
print ev
if ev['type'] == 'm.room.member':
print 'membership event'
if ev['membership'] == 'invite' and ev['state_key'] == MYUSERNAME:
roomId = ev['room_id']
print "joining room %s" % (roomId)
matrixCli.joinRoom(roomId)
elif ev['type'] == 'm.room.message':
if ev['room_id'] in xmppClients:
print "already have a bridge for that user, ignoring"
continue
print "got message, connecting"
xmppClients[ev['room_id']] = TrivialXmppClient(ev['room_id'], ev['user_id'])
gevent.spawn(xmppClients[ev['room_id']].xmppLoop)
elif ev['type'] == 'm.call.invite':
print "Incoming call"
#sdp = ev['content']['offer']['sdp']
#print "sdp: %s" % (sdp)
#xmppClients[ev['room_id']] = TrivialXmppClient(ev['room_id'], ev['user_id'])
#gevent.spawn(xmppClients[ev['room_id']].xmppLoop)
elif ev['type'] == 'm.call.answer':
print "Call answered"
sdp = ev['content']['answer']['sdp']
if ev['room_id'] not in xmppClients:
print "We didn't have a call for that room"
continue
# should probably check call ID too
xmppCli = xmppClients[ev['room_id']]
xmppCli.sendAnswer(sdp)
elif ev['type'] == 'm.call.hangup':
if ev['room_id'] in xmppClients:
xmppClients[ev['room_id']].stop()
del xmppClients[ev['room_id']]
class TrivialXmppClient:
def __init__(self, matrixRoom, userId):
self.rid = 0
self.matrixRoom = matrixRoom
self.userId = userId
self.running = True
def __init__(self, matrixRoom, userId):
self.rid = 0
self.matrixRoom = matrixRoom
self.userId = userId
self.running = True
def stop(self):
self.running = False
def stop(self):
self.running = False
def nextRid(self):
self.rid += 1
return '%d' % (self.rid)
def nextRid(self):
self.rid += 1
return '%d' % (self.rid)
def sendIq(self, xml):
fullXml = "<body rid='%s' xmlns='http://jabber.org/protocol/httpbind' sid='%s'>%s</body>" % (self.nextRid(), self.sid, xml)
#print "\t>>>%s" % (fullXml)
return self.xmppPoke(fullXml)
def sendIq(self, xml):
fullXml = "<body rid='%s' xmlns='http://jabber.org/protocol/httpbind' sid='%s'>%s</body>" % (self.nextRid(), self.sid, xml)
#print "\t>>>%s" % (fullXml)
return self.xmppPoke(fullXml)
def xmppPoke(self, xml):
headers = {'Content-Type': 'application/xml'}
req = grequests.post(HTTPBIND, verify=False, headers=headers, data=xml)
resps = grequests.map([req])
obj = BeautifulSoup(resps[0].content)
return obj
def xmppPoke(self, xml):
headers = {'Content-Type': 'application/xml'}
req = grequests.post(HTTPBIND, verify=False, headers=headers, data=xml)
resps = grequests.map([req])
obj = BeautifulSoup(resps[0].content)
return obj
def sendAnswer(self, answer):
print "sdp from matrix client",answer
p = subprocess.Popen(['node', 'unjingle/unjingle.js', '--sdp'], stdin=subprocess.PIPE, stdout=subprocess.PIPE)
jingle, out_err = p.communicate(answer)
jingle = jingle % {
'tojid': self.callfrom,
'action': 'session-accept',
'initiator': self.callfrom,
'responder': self.jid,
'sid': self.callsid
}
print "answer jingle from sdp",jingle
res = self.sendIq(jingle)
print "reply from answer: ",res
self.ssrcs = {}
jingleSoup = BeautifulSoup(jingle)
for cont in jingleSoup.iq.jingle.findAll('content'):
if cont.description:
self.ssrcs[cont['name']] = cont.description['ssrc']
print "my ssrcs:",self.ssrcs
def sendAnswer(self, answer):
print "sdp from matrix client",answer
p = subprocess.Popen(['node', 'unjingle/unjingle.js', '--sdp'], stdin=subprocess.PIPE, stdout=subprocess.PIPE)
jingle, out_err = p.communicate(answer)
jingle = jingle % {
'tojid': self.callfrom,
'action': 'session-accept',
'initiator': self.callfrom,
'responder': self.jid,
'sid': self.callsid
}
print "answer jingle from sdp",jingle
res = self.sendIq(jingle)
print "reply from answer: ",res
self.ssrcs = {}
jingleSoup = BeautifulSoup(jingle)
for cont in jingleSoup.iq.jingle.findAll('content'):
if cont.description:
self.ssrcs[cont['name']] = cont.description['ssrc']
print "my ssrcs:",self.ssrcs
gevent.joinall([
gevent.spawn(self.advertiseSsrcs)
])
def advertiseSsrcs(self):
gevent.joinall([
gevent.spawn(self.advertiseSsrcs)
])
def advertiseSsrcs(self):
time.sleep(7)
print "SSRC spammer started"
while self.running:
ssrcMsg = "<presence to='%(tojid)s' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%(nick)s</nick><stats xmlns='http://jitsi.org/jitmeet/stats'><stat name='bitrate_download' value='175'/><stat name='bitrate_upload' value='176'/><stat name='packetLoss_total' value='0'/><stat name='packetLoss_download' value='0'/><stat name='packetLoss_upload' value='0'/></stats><media xmlns='http://estos.de/ns/mjs'><source type='audio' ssrc='%(assrc)s' direction='sendre'/><source type='video' ssrc='%(vssrc)s' direction='sendre'/></media></presence>" % { 'tojid': "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid), 'nick': self.userId, 'assrc': self.ssrcs['audio'], 'vssrc': self.ssrcs['video'] }
res = self.sendIq(ssrcMsg)
print "reply from ssrc announce: ",res
time.sleep(10)
print "SSRC spammer started"
while self.running:
ssrcMsg = "<presence to='%(tojid)s' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%(nick)s</nick><stats xmlns='http://jitsi.org/jitmeet/stats'><stat name='bitrate_download' value='175'/><stat name='bitrate_upload' value='176'/><stat name='packetLoss_total' value='0'/><stat name='packetLoss_download' value='0'/><stat name='packetLoss_upload' value='0'/></stats><media xmlns='http://estos.de/ns/mjs'><source type='audio' ssrc='%(assrc)s' direction='sendre'/><source type='video' ssrc='%(vssrc)s' direction='sendre'/></media></presence>" % { 'tojid': "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid), 'nick': self.userId, 'assrc': self.ssrcs['audio'], 'vssrc': self.ssrcs['video'] }
res = self.sendIq(ssrcMsg)
print "reply from ssrc announce: ",res
time.sleep(10)
def xmppLoop(self):
self.matrixCallId = time.time()
res = self.xmppPoke("<body rid='%s' xmlns='http://jabber.org/protocol/httpbind' to='%s' xml:lang='en' wait='60' hold='1' content='text/xml; charset=utf-8' ver='1.6' xmpp:version='1.0' xmlns:xmpp='urn:xmpp:xbosh'/>" % (self.nextRid(), HOST))
print res
self.sid = res.body['sid']
print "sid %s" % (self.sid)
def xmppLoop(self):
self.matrixCallId = time.time()
res = self.xmppPoke("<body rid='%s' xmlns='http://jabber.org/protocol/httpbind' to='%s' xml:lang='en' wait='60' hold='1' content='text/xml; charset=utf-8' ver='1.6' xmpp:version='1.0' xmlns:xmpp='urn:xmpp:xbosh'/>" % (self.nextRid(), HOST))
res = self.sendIq("<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='ANONYMOUS'/>")
print res
self.sid = res.body['sid']
print "sid %s" % (self.sid)
res = self.xmppPoke("<body rid='%s' xmlns='http://jabber.org/protocol/httpbind' sid='%s' to='%s' xml:lang='en' xmpp:restart='true' xmlns:xmpp='urn:xmpp:xbosh'/>" % (self.nextRid(), self.sid, HOST))
res = self.sendIq("<iq type='set' id='_bind_auth_2' xmlns='jabber:client'><bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'/></iq>")
print res
res = self.sendIq("<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='ANONYMOUS'/>")
self.jid = res.body.iq.bind.jid.string
print "jid: %s" % (self.jid)
self.shortJid = self.jid.split('-')[0]
res = self.xmppPoke("<body rid='%s' xmlns='http://jabber.org/protocol/httpbind' sid='%s' to='%s' xml:lang='en' xmpp:restart='true' xmlns:xmpp='urn:xmpp:xbosh'/>" % (self.nextRid(), self.sid, HOST))
res = self.sendIq("<iq type='set' id='_session_auth_2' xmlns='jabber:client'><session xmlns='urn:ietf:params:xml:ns:xmpp-session'/></iq>")
res = self.sendIq("<iq type='set' id='_bind_auth_2' xmlns='jabber:client'><bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'/></iq>")
print res
#randomthing = res.body.iq['to']
#whatsitpart = randomthing.split('-')[0]
self.jid = res.body.iq.bind.jid.string
print "jid: %s" % (self.jid)
self.shortJid = self.jid.split('-')[0]
#print "other random bind thing: %s" % (randomthing)
res = self.sendIq("<iq type='set' id='_session_auth_2' xmlns='jabber:client'><session xmlns='urn:ietf:params:xml:ns:xmpp-session'/></iq>")
# advertise preence to the jitsi room, with our nick
res = self.sendIq("<iq type='get' to='%s' xmlns='jabber:client' id='1:sendIQ'><services xmlns='urn:xmpp:extdisco:1'><service host='%s'/></services></iq><presence to='%s@%s/d98f6c40' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%s</nick></presence>" % (HOST, TURNSERVER, ROOMNAME, ROOMDOMAIN, self.userId))
self.muc = {'users': []}
for p in res.body.findAll('presence'):
u = {}
u['shortJid'] = p['from'].split('/')[1]
if p.c and p.c.nick:
u['nick'] = p.c.nick.string
self.muc['users'].append(u)
print "muc: ",self.muc
#randomthing = res.body.iq['to']
#whatsitpart = randomthing.split('-')[0]
#print "other random bind thing: %s" % (randomthing)
# advertise preence to the jitsi room, with our nick
res = self.sendIq("<iq type='get' to='%s' xmlns='jabber:client' id='1:sendIQ'><services xmlns='urn:xmpp:extdisco:1'><service host='%s'/></services></iq><presence to='%s@%s/d98f6c40' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%s</nick></presence>" % (HOST, TURNSERVER, ROOMNAME, ROOMDOMAIN, self.userId))
self.muc = {'users': []}
for p in res.body.findAll('presence'):
u = {}
u['shortJid'] = p['from'].split('/')[1]
if p.c and p.c.nick:
u['nick'] = p.c.nick.string
self.muc['users'].append(u)
print "muc: ",self.muc
# wait for stuff
while True:
print "waiting..."
res = self.sendIq("")
print "got from stream: ",res
if res.body.iq:
jingles = res.body.iq.findAll('jingle')
if len(jingles):
self.callfrom = res.body.iq['from']
self.handleInvite(jingles[0])
elif 'type' in res.body and res.body['type'] == 'terminate':
self.running = False
del xmppClients[self.matrixRoom]
return
def handleInvite(self, jingle):
self.initiator = jingle['initiator']
self.callsid = jingle['sid']
p = subprocess.Popen(['node', 'unjingle/unjingle.js', '--jingle'], stdin=subprocess.PIPE, stdout=subprocess.PIPE)
print "raw jingle invite",str(jingle)
sdp, out_err = p.communicate(str(jingle))
print "transformed remote offer sdp",sdp
inviteEvent = {
'offer': {
'type': 'offer',
'sdp': sdp
},
'call_id': self.matrixCallId,
'version': 0,
'lifetime': 30000
}
matrixCli.sendEvent(self.matrixRoom, 'm.call.invite', inviteEvent)
# wait for stuff
while True:
print "waiting..."
res = self.sendIq("")
print "got from stream: ",res
if res.body.iq:
jingles = res.body.iq.findAll('jingle')
if len(jingles):
self.callfrom = res.body.iq['from']
self.handleInvite(jingles[0])
elif 'type' in res.body and res.body['type'] == 'terminate':
self.running = False
del xmppClients[self.matrixRoom]
return
def handleInvite(self, jingle):
self.initiator = jingle['initiator']
self.callsid = jingle['sid']
p = subprocess.Popen(['node', 'unjingle/unjingle.js', '--jingle'], stdin=subprocess.PIPE, stdout=subprocess.PIPE)
print "raw jingle invite",str(jingle)
sdp, out_err = p.communicate(str(jingle))
print "transformed remote offer sdp",sdp
inviteEvent = {
'offer': {
'type': 'offer',
'sdp': sdp
},
'call_id': self.matrixCallId,
'version': 0,
'lifetime': 30000
}
matrixCli.sendEvent(self.matrixRoom, 'm.call.invite', inviteEvent)
matrixCli = TrivialMatrixClient(ACCESS_TOKEN)
gevent.joinall([
gevent.spawn(matrixLoop)
gevent.spawn(matrixLoop)
])

View File

@@ -32,8 +32,7 @@ for port in 8080 8081 8082; do
-D --pid-file "$DIR/$port.pid" \
--manhole $((port + 1000)) \
--tls-dh-params-path "demo/demo.tls.dh" \
--media-store-path "demo/media_store.$port" \
$PARAMS $SYNAPSE_PARAMS \
$PARAMS $SYNAPSE_PARAMS
python -m synapse.app.homeserver \
--config-path "demo/etc/$port.config" \

View File

@@ -1,65 +0,0 @@
from synapse.events import FrozenEvent
from synapse.api.auth import Auth
from mock import Mock
import argparse
import itertools
import json
import sys
def check_auth(auth, auth_chain, events):
auth_chain.sort(key=lambda e: e.depth)
auth_map = {
e.event_id: e
for e in auth_chain
}
create_events = {}
for e in auth_chain:
if e.type == "m.room.create":
create_events[e.room_id] = e
for e in itertools.chain(auth_chain, events):
auth_events_list = [auth_map[i] for i, _ in e.auth_events]
auth_events = {
(e.type, e.state_key): e
for e in auth_events_list
}
auth_events[("m.room.create", "")] = create_events[e.room_id]
try:
auth.check(e, auth_events=auth_events)
except Exception as ex:
print "Failed:", e.event_id, e.type, e.state_key
print "Auth_events:", auth_events
print ex
print json.dumps(e.get_dict(), sort_keys=True, indent=4)
# raise
print "Success:", e.event_id, e.type, e.state_key
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'json',
nargs='?',
type=argparse.FileType('r'),
default=sys.stdin,
)
args = parser.parse_args()
js = json.load(args.json)
auth = Auth(Mock())
check_auth(
auth,
[FrozenEvent(d) for d in js["auth_chain"]],
[FrozenEvent(d) for d in js["pdus"]],
)

View File

@@ -97,11 +97,8 @@ def lookup(destination, path):
if ":" in destination:
return "https://%s%s" % (destination, path)
else:
try:
srv = srvlookup.lookup("matrix", "tcp", destination)[0]
return "https://%s:%d%s" % (srv.host, srv.port, path)
except:
return "https://%s:%d%s" % (destination, 8448, path)
srv = srvlookup.lookup("matrix", "tcp", destination)[0]
return "https://%s:%d%s" % (srv.host, srv.port, path)
def get_json(origin_name, origin_key, destination, path):
request_json = {

View File

@@ -1,39 +0,0 @@
#!/usr/bin/env perl
use strict;
use warnings;
use DBI;
use DBD::SQLite;
use JSON;
use Getopt::Long;
my $db; # = "homeserver.db";
my $server = "http://localhost:8008";
my $size = 320;
GetOptions("db|d=s", \$db,
"server|s=s", \$server,
"width|w=i", \$size) or usage();
usage() unless $db;
my $dbh = DBI->connect("dbi:SQLite:dbname=$db","","") || die $DBI::errstr;
my $res = $dbh->selectall_arrayref("select token, name from access_tokens, users where access_tokens.user_id = users.id group by user_id") || die $DBI::errstr;
foreach (@$res) {
my ($token, $mxid) = ($_->[0], $_->[1]);
my ($user_id) = ($mxid =~ m/@(.*):/);
my ($url) = $dbh->selectrow_array("select avatar_url from profiles where user_id=?", undef, $user_id);
if (!$url || $url =~ /#auto$/) {
`curl -s -o tmp.png "$server/_matrix/media/v1/identicon?name=${mxid}&width=$size&height=$size"`;
my $json = `curl -s -X POST -H "Content-Type: image/png" -T "tmp.png" $server/_matrix/media/v1/upload?access_token=$token`;
my $content_uri = from_json($json)->{content_uri};
`curl -X PUT -H "Content-Type: application/json" --data '{ "avatar_url": "${content_uri}#auto"}' $server/_matrix/client/api/v1/profile/${mxid}/avatar_url?access_token=$token`;
}
}
sub usage {
die "usage: ./make-identicons.pl\n\t-d database [e.g. homeserver.db]\n\t-s homeserver (default: http://localhost:8008)\n\t-w identicon size in pixels (default 320)";
}

View File

@@ -8,11 +8,3 @@ test = trial
[trial]
test_suite = tests
[check-manifest]
ignore =
contrib
contrib/*
docs/*
pylint.cfg
tox.ini

View File

@@ -18,42 +18,49 @@ import os
from setuptools import setup, find_packages
here = os.path.abspath(os.path.dirname(__file__))
def read_file(path_segments):
"""Read a file from the package. Takes a list of strings to join to
make the path"""
file_path = os.path.join(here, *path_segments)
with open(file_path) as f:
return f.read()
def exec_file(path_segments):
"""Execute a single python file to get the variables defined in it"""
result = {}
code = read_file(path_segments)
exec(code, result)
return result
version = exec_file(("synapse", "__init__.py"))["__version__"]
dependencies = exec_file(("synapse", "python_dependencies.py"))
long_description = read_file(("README.rst",))
# Utility function to read the README file.
# Used for the long_description. It's nice, because now 1) we have a top level
# README file and 2) it's easier to type in the README file than to put a raw
# string in below ...
def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
setup(
name="matrix-synapse",
version=version,
version=read("VERSION").strip(),
packages=find_packages(exclude=["tests", "tests.*"]),
description="Reference Synapse Home Server",
install_requires=dependencies["REQUIREMENTS"].keys(),
install_requires=[
"syutil==0.0.2",
"matrix_angular_sdk==0.6.0",
"Twisted>=14.0.0",
"service_identity>=1.0.0",
"pyopenssl>=0.14",
"pyyaml",
"pyasn1",
"pynacl",
"daemonize",
"py-bcrypt",
"frozendict>=0.4",
"pillow",
],
dependency_links=[
"https://github.com/matrix-org/syutil/tarball/v0.0.2#egg=syutil-0.0.2",
"https://github.com/pyca/pynacl/tarball/d4d3175589b892f6ea7c22f466e0e223853516fa#egg=pynacl-0.3.0",
"https://github.com/matrix-org/matrix-angular-sdk/tarball/v0.6.0/#egg=matrix_angular_sdk-0.6.0",
],
setup_requires=[
"Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
"setuptools_trial",
"setuptools>=1.0.0", # Needs setuptools that supports git+ssh.
# TODO: Do we need this now? we don't use git+ssh.
"mock"
],
dependency_links=dependencies["DEPENDENCY_LINKS"],
include_package_data=True,
zip_safe=False,
long_description=long_description,
scripts=["synctl"],
long_description=read("README.rst"),
entry_points="""
[console_scripts]
synctl=synapse.app.synctl:main
synapse-homeserver=synapse.app.homeserver:main
"""
)

View File

@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
__version__ = "0.7.0d"
__version__ = "0.6.1b"

View File

@@ -21,7 +21,6 @@ from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.types import UserID, ClientInfo
import logging
@@ -89,19 +88,12 @@ class Auth(object):
raise
@defer.inlineCallbacks
def check_joined_room(self, room_id, user_id, current_state=None):
if current_state:
member = current_state.get(
(EventTypes.Member, user_id),
None
)
else:
member = yield self.state.get_current_state(
room_id=room_id,
event_type=EventTypes.Member,
state_key=user_id
)
def check_joined_room(self, room_id, user_id):
member = yield self.state.get_current_state(
room_id=room_id,
event_type=EventTypes.Member,
state_key=user_id
)
self._check_joined_room(member, user_id, room_id)
defer.returnValue(member)
@@ -109,10 +101,10 @@ class Auth(object):
def check_host_in_room(self, room_id, host):
curr_state = yield self.state.get_current_state(room_id)
for event in curr_state.values():
for event in curr_state:
if event.type == EventTypes.Member:
try:
if UserID.from_string(event.state_key).domain != host:
if self.hs.parse_userid(event.state_key).domain != host:
continue
except:
logger.warn("state_key not user_id: %s", event.state_key)
@@ -297,9 +289,7 @@ class Auth(object):
Args:
request - An HTTP request with an access_token query parameter.
Returns:
tuple : of UserID and device string:
User ID object of the user making the request
Client ID object of the client instance the user is using
UserID : User ID object of the user making the request
Raises:
AuthError if no user by that token exists or the token is invalid.
"""
@@ -308,8 +298,6 @@ class Auth(object):
access_token = request.args["access_token"][0]
user_info = yield self.get_user_by_token(access_token)
user = user_info["user"]
device_id = user_info["device_id"]
token_id = user_info["token_id"]
ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders(
@@ -325,7 +313,7 @@ class Auth(object):
user_agent=user_agent
)
defer.returnValue((user, ClientInfo(device_id, token_id)))
defer.returnValue(user)
except KeyError:
raise AuthError(403, "Missing access token.")
@@ -349,8 +337,7 @@ class Auth(object):
user_info = {
"admin": bool(ret.get("admin", False)),
"device_id": ret.get("device_id"),
"user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
"user": self.hs.parse_userid(ret.get("name")),
}
defer.returnValue(user_info)
@@ -365,40 +352,26 @@ class Auth(object):
def add_auth_events(self, builder, context):
yield run_on_reactor()
auth_ids = self.compute_auth_events(builder, context.current_state)
auth_events_entries = yield self.store.add_event_hashes(
auth_ids
)
builder.auth_events = auth_events_entries
context.auth_events = {
k: v
for k, v in context.current_state.items()
if v.event_id in auth_ids
}
def compute_auth_events(self, event, current_state):
if event.type == EventTypes.Create:
return []
if builder.type == EventTypes.Create:
builder.auth_events = []
return
auth_ids = []
key = (EventTypes.PowerLevels, "", )
power_level_event = current_state.get(key)
power_level_event = context.current_state.get(key)
if power_level_event:
auth_ids.append(power_level_event.event_id)
key = (EventTypes.JoinRules, "", )
join_rule_event = current_state.get(key)
join_rule_event = context.current_state.get(key)
key = (EventTypes.Member, event.user_id, )
member_event = current_state.get(key)
key = (EventTypes.Member, builder.user_id, )
member_event = context.current_state.get(key)
key = (EventTypes.Create, "", )
create_event = current_state.get(key)
create_event = context.current_state.get(key)
if create_event:
auth_ids.append(create_event.event_id)
@@ -408,8 +381,8 @@ class Auth(object):
else:
is_public = False
if event.type == EventTypes.Member:
e_type = event.content["membership"]
if builder.type == EventTypes.Member:
e_type = builder.content["membership"]
if e_type in [Membership.JOIN, Membership.INVITE]:
if join_rule_event:
auth_ids.append(join_rule_event.event_id)
@@ -424,7 +397,17 @@ class Auth(object):
if member_event.content["membership"] == Membership.JOIN:
auth_ids.append(member_event.event_id)
return auth_ids
auth_events_entries = yield self.store.add_event_hashes(
auth_ids
)
builder.auth_events = auth_events_entries
context.auth_events = {
k: v
for k, v in context.current_state.items()
if v.event_id in auth_ids
}
@log_function
def _can_send_event(self, event, auth_events):
@@ -478,7 +461,7 @@ class Auth(object):
"You are not allowed to set others state"
)
else:
sender_domain = UserID.from_string(
sender_domain = self.hs.parse_userid(
event.user_id
).domain
@@ -513,7 +496,7 @@ class Auth(object):
# Validate users
for k, v in user_list.items():
try:
UserID.from_string(k)
self.hs.parse_userid(k)
except:
raise SynapseError(400, "Not a valid user_id: %s" % (k,))

View File

@@ -74,9 +74,3 @@ class EventTypes(object):
Message = "m.room.message"
Topic = "m.room.topic"
Name = "m.room.name"
class RejectedReason(object):
AUTH_ERROR = "auth_error"
REPLACED = "replaced"
NOT_ANCESTOR = "not_ancestor"

View File

@@ -21,7 +21,6 @@ logger = logging.getLogger(__name__)
class Codes(object):
UNRECOGNIZED = "M_UNRECOGNIZED"
UNAUTHORIZED = "M_UNAUTHORIZED"
FORBIDDEN = "M_FORBIDDEN"
BAD_JSON = "M_BAD_JSON"
@@ -35,11 +34,10 @@ class Codes(object):
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
MISSING_PARAM = "M_MISSING_PARAM",
TOO_LARGE = "M_TOO_LARGE"
class CodeMessageException(RuntimeError):
class CodeMessageException(Exception):
"""An exception with integer code and message string attributes."""
def __init__(self, code, msg):
@@ -83,35 +81,6 @@ class RegistrationError(SynapseError):
pass
class UnrecognizedRequestError(SynapseError):
"""An error indicating we don't understand the request you're trying to make"""
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.UNRECOGNIZED
message = None
if len(args) == 0:
message = "Unrecognized request"
else:
message = args[0]
super(UnrecognizedRequestError, self).__init__(
400,
message,
**kwargs
)
class NotFoundError(SynapseError):
"""An error indicating we can't find the thing you asked for"""
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.NOT_FOUND
super(NotFoundError, self).__init__(
404,
"Not found",
**kwargs
)
class AuthError(SynapseError):
"""An error raised when there was a problem authorising an event."""
@@ -227,9 +196,3 @@ class FederationError(RuntimeError):
"affected": self.affected,
"source": self.source if self.source else self.affected,
}
class HttpResponseException(CodeMessageException):
def __init__(self, code, msg, response):
self.response = response
super(HttpResponseException, self).__init__(code, msg)

View File

@@ -1,229 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.errors import SynapseError
from synapse.types import UserID, RoomID
class Filtering(object):
def __init__(self, hs):
super(Filtering, self).__init__()
self.store = hs.get_datastore()
def get_user_filter(self, user_localpart, filter_id):
result = self.store.get_user_filter(user_localpart, filter_id)
result.addCallback(Filter)
return result
def add_user_filter(self, user_localpart, user_filter):
self._check_valid_filter(user_filter)
return self.store.add_user_filter(user_localpart, user_filter)
# TODO(paul): surely we should probably add a delete_user_filter or
# replace_user_filter at some point? There's no REST API specified for
# them however
def _check_valid_filter(self, user_filter_json):
"""Check if the provided filter is valid.
This inspects all definitions contained within the filter.
Args:
user_filter_json(dict): The filter
Raises:
SynapseError: If the filter is not valid.
"""
# NB: Filters are the complete json blobs. "Definitions" are an
# individual top-level key e.g. public_user_data. Filters are made of
# many definitions.
top_level_definitions = [
"public_user_data", "private_user_data", "server_data"
]
room_level_definitions = [
"state", "events", "ephemeral"
]
for key in top_level_definitions:
if key in user_filter_json:
self._check_definition(user_filter_json[key])
if "room" in user_filter_json:
for key in room_level_definitions:
if key in user_filter_json["room"]:
self._check_definition(user_filter_json["room"][key])
def _check_definition(self, definition):
"""Check if the provided definition is valid.
This inspects not only the types but also the values to make sure they
make sense.
Args:
definition(dict): The filter definition
Raises:
SynapseError: If there was a problem with this definition.
"""
# NB: Filters are the complete json blobs. "Definitions" are an
# individual top-level key e.g. public_user_data. Filters are made of
# many definitions.
if type(definition) != dict:
raise SynapseError(
400, "Expected JSON object, not %s" % (definition,)
)
# check rooms are valid room IDs
room_id_keys = ["rooms", "not_rooms"]
for key in room_id_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for room_id in definition[key]:
RoomID.from_string(room_id)
# check senders are valid user IDs
user_id_keys = ["senders", "not_senders"]
for key in user_id_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for user_id in definition[key]:
UserID.from_string(user_id)
# TODO: We don't limit event type values but we probably should...
# check types are valid event types
event_keys = ["types", "not_types"]
for key in event_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for event_type in definition[key]:
if not isinstance(event_type, basestring):
raise SynapseError(400, "Event type should be a string")
if "format" in definition:
event_format = definition["format"]
if event_format not in ["federation", "events"]:
raise SynapseError(400, "Invalid format: %s" % (event_format,))
if "select" in definition:
event_select_list = definition["select"]
for select_key in event_select_list:
if select_key not in ["event_id", "origin_server_ts",
"thread_id", "content", "content.body"]:
raise SynapseError(400, "Bad select: %s" % (select_key,))
if ("bundle_updates" in definition and
type(definition["bundle_updates"]) != bool):
raise SynapseError(400, "Bad bundle_updates: expected bool.")
class Filter(object):
def __init__(self, filter_json):
self.filter_json = filter_json
def filter_public_user_data(self, events):
return self._filter_on_key(events, ["public_user_data"])
def filter_private_user_data(self, events):
return self._filter_on_key(events, ["private_user_data"])
def filter_room_state(self, events):
return self._filter_on_key(events, ["room", "state"])
def filter_room_events(self, events):
return self._filter_on_key(events, ["room", "events"])
def filter_room_ephemeral(self, events):
return self._filter_on_key(events, ["room", "ephemeral"])
def _filter_on_key(self, events, keys):
filter_json = self.filter_json
if not filter_json:
return events
try:
# extract the right definition from the filter
definition = filter_json
for key in keys:
definition = definition[key]
return self._filter_with_definition(events, definition)
except KeyError:
# return all events if definition isn't specified.
return events
def _filter_with_definition(self, events, definition):
return [e for e in events if self._passes_definition(definition, e)]
def _passes_definition(self, definition, event):
"""Check if the event passes through the given definition.
Args:
definition(dict): The definition to check against.
event(Event): The event to check.
Returns:
True if the event passes through the filter.
"""
# Algorithm notes:
# For each key in the definition, check the event meets the criteria:
# * For types: Literal match or prefix match (if ends with wildcard)
# * For senders/rooms: Literal match only
# * "not_" checks take presedence (e.g. if "m.*" is in both 'types'
# and 'not_types' then it is treated as only being in 'not_types')
# room checks
if hasattr(event, "room_id"):
room_id = event.room_id
allow_rooms = definition.get("rooms", None)
reject_rooms = definition.get("not_rooms", None)
if reject_rooms and room_id in reject_rooms:
return False
if allow_rooms and room_id not in allow_rooms:
return False
# sender checks
if hasattr(event, "sender"):
# Should we be including event.state_key for some event types?
sender = event.sender
allow_senders = definition.get("senders", None)
reject_senders = definition.get("not_senders", None)
if reject_senders and sender in reject_senders:
return False
if allow_senders and sender not in allow_senders:
return False
# type checks
if "not_types" in definition:
for def_type in definition["not_types"]:
if self._event_matches_type(event, def_type):
return False
if "types" in definition:
included = False
for def_type in definition["types"]:
if self._event_matches_type(event, def_type):
included = True
break
if not included:
return False
return True
def _event_matches_type(self, event, def_type):
if def_type.endswith("*"):
type_prefix = def_type[:-1]
return event.type.startswith(type_prefix)
else:
return event.type == def_type

View File

@@ -16,7 +16,6 @@
"""Contains the URL paths to prefix various aspects of the server with. """
CLIENT_PREFIX = "/_matrix/client/api/v1"
CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha"
FEDERATION_PREFIX = "/_matrix/federation/v1"
WEB_CLIENT_PREFIX = "/_matrix/client"
CONTENT_REPO_PREFIX = "/_matrix/content"

View File

@@ -26,19 +26,17 @@ from twisted.web.resource import Resource
from twisted.web.static import File
from twisted.web.server import Site
from synapse.http.server import JsonResource, RootRedirect
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.media.v0.content_repository import ContentRepoResource
from synapse.media.v1.media_repository import MediaRepositoryResource
from synapse.http.server_key_resource import LocalKey
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.api.urls import (
CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX,
SERVER_KEY_PREFIX, MEDIA_PREFIX
)
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext
from synapse.rest.client.v1 import ClientV1RestResource
from synapse.rest.client.v2_alpha import ClientV2AlphaRestResource
from daemonize import Daemonize
import twisted.manhole.telnet
@@ -61,13 +59,10 @@ class SynapseHomeServer(HomeServer):
return MatrixFederationHttpClient(self)
def build_resource_for_client(self):
return ClientV1RestResource(self)
def build_resource_for_client_v2_alpha(self):
return ClientV2AlphaRestResource(self)
return JsonResource()
def build_resource_for_federation(self):
return JsonResource(self)
return JsonResource()
def build_resource_for_web_client(self):
syweb_path = os.path.dirname(syweb.__file__)
@@ -109,7 +104,6 @@ class SynapseHomeServer(HomeServer):
# [ ("/aaa/bbb/cc", Resource1), ("/aaa/dummy", Resource2) ]
desired_tree = [
(CLIENT_PREFIX, self.get_resource_for_client()),
(CLIENT_V2_ALPHA_PREFIX, self.get_resource_for_client_v2_alpha()),
(FEDERATION_PREFIX, self.get_resource_for_federation()),
(CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()),
(SERVER_KEY_PREFIX, self.get_resource_for_server_key()),
@@ -134,7 +128,7 @@ class SynapseHomeServer(HomeServer):
logger.info("Attaching %s to path %s", resource, full_path)
last_resource = self.root_resource
for path_seg in full_path.split('/')[1:-1]:
if path_seg not in last_resource.listNames():
if not path_seg in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource"
child_resource = Resource()
last_resource.putChild(path_seg, child_resource)
@@ -230,6 +224,8 @@ def setup():
content_addr=config.content_addr,
)
hs.register_servlets()
hs.create_resource_tree(
web_client=config.webclient,
redirect_root_to_web_client=True,
@@ -272,11 +268,6 @@ def setup():
bind_port = None
hs.start_listening(bind_port, config.unsecure_port)
hs.get_pusherpool().start()
hs.get_state_handler().start_caching()
hs.get_datastore().start_profiling()
if config.daemonize:
print config.pid_file
daemon = Daemonize(

View File

@@ -27,16 +27,6 @@ class Config(object):
def __init__(self, args):
pass
@staticmethod
def parse_size(string):
sizes = {"K": 1024, "M": 1024 * 1024}
size = 1
suffix = string[-1]
if suffix in sizes:
string = string[:-1]
size = sizes[suffix]
return int(string) * size
@staticmethod
def abspath(file_path):
return os.path.abspath(file_path) if file_path else file_path
@@ -60,9 +50,8 @@ class Config(object):
)
return cls.abspath(file_path)
@classmethod
def ensure_directory(cls, dir_path):
dir_path = cls.abspath(dir_path)
@staticmethod
def ensure_directory(dir_path):
if not os.path.exists(dir_path):
os.makedirs(dir_path)
if not os.path.isdir(dir_path):

View File

@@ -24,7 +24,6 @@ class DatabaseConfig(Config):
self.database_path = ":memory:"
else:
self.database_path = self.abspath(args.database_path)
self.event_cache_size = self.parse_size(args.event_cache_size)
@classmethod
def add_arguments(cls, parser):
@@ -34,10 +33,6 @@ class DatabaseConfig(Config):
"-d", "--database-path", default="homeserver.db",
help="The database name."
)
db_group.add_argument(
"--event-cache-size", default="100K",
help="Number of events to cache in memory."
)
@classmethod
def generate_config(cls, args, config_dir_path):

View File

@@ -18,7 +18,6 @@ from synapse.util.logcontext import LoggingContextFilter
from twisted.python.log import PythonLoggingObserver
import logging
import logging.config
import yaml
class LoggingConfig(Config):
@@ -80,8 +79,7 @@ class LoggingConfig(Config):
logger.addHandler(handler)
logger.info("Test")
else:
with open(self.log_config, 'r') as f:
logging.config.dictConfig(yaml.load(f))
logging.config.fileConfig(self.log_config)
observer = PythonLoggingObserver()
observer.start()

View File

@@ -19,7 +19,7 @@ from twisted.internet.protocol import Factory
from twisted.internet import defer, reactor
from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.logcontext import PreserveLoggingContext
import simplejson as json
import json
import logging
@@ -61,11 +61,9 @@ class SynapseKeyClientProtocol(HTTPClient):
def __init__(self):
self.remote_key = defer.Deferred()
self.host = None
def connectionMade(self):
self.host = self.transport.getHost()
logger.debug("Connected to %s", self.host)
logger.debug("Connected to %s", self.transport.getHost())
self.sendCommand(b"GET", b"/_matrix/key/v1/")
self.endHeaders()
self.timer = reactor.callLater(
@@ -75,7 +73,7 @@ class SynapseKeyClientProtocol(HTTPClient):
def handleStatus(self, version, status, message):
if status != b"200":
# logger.info("Non-200 response from %s: %s %s",
#logger.info("Non-200 response from %s: %s %s",
# self.transport.getHost(), status, message)
self.transport.abortConnection()
@@ -83,7 +81,7 @@ class SynapseKeyClientProtocol(HTTPClient):
try:
json_response = json.loads(response_body_bytes)
except ValueError:
# logger.info("Invalid JSON response from %s",
#logger.info("Invalid JSON response from %s",
# self.transport.getHost())
self.transport.abortConnection()
return
@@ -94,7 +92,8 @@ class SynapseKeyClientProtocol(HTTPClient):
self.timer.cancel()
def on_timeout(self):
logger.debug("Timeout waiting for response from %s", self.host)
logger.debug("Timeout waiting for response from %s",
self.transport.getHost())
self.remote_key.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection()

View File

@@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.frozenutils import freeze
from synapse.util.frozenutils import freeze, unfreeze
class _EventInternalMetadata(object):
def __init__(self, internal_metadata_dict):
self.__dict__ = dict(internal_metadata_dict)
self.__dict__ = internal_metadata_dict
def get_dict(self):
return dict(self.__dict__)
@@ -77,7 +77,7 @@ class EventBase(object):
return self.content["membership"]
def is_state(self):
return hasattr(self, "state_key") and self.state_key is not None
return hasattr(self, "state_key")
def get_dict(self):
d = dict(self._event_dict)
@@ -140,6 +140,10 @@ class FrozenEvent(EventBase):
return e
def get_dict(self):
# We need to unfreeze what we return
return unfreeze(super(FrozenEvent, self).get_dict())
def __str__(self):
return self.__repr__()

View File

@@ -23,15 +23,14 @@ import copy
class EventBuilder(EventBase):
def __init__(self, key_values={}, internal_metadata_dict={}):
def __init__(self, key_values={}):
signatures = copy.deepcopy(key_values.pop("signatures", {}))
unsigned = copy.deepcopy(key_values.pop("unsigned", {}))
super(EventBuilder, self).__init__(
key_values,
signatures=signatures,
unsigned=unsigned,
internal_metadata_dict=internal_metadata_dict,
unsigned=unsigned
)
def build(self):

View File

@@ -20,4 +20,3 @@ class EventContext(object):
self.current_state = current_state
self.auth_events = auth_events
self.state_group = None
self.rejected = False

View File

@@ -45,14 +45,12 @@ def prune_event(event):
"membership",
]
event_dict = event.get_dict()
new_content = {}
def add_fields(*fields):
for field in fields:
if field in event.content:
new_content[field] = event_dict["content"][field]
new_content[field] = event.content[field]
if event_type == EventTypes.Member:
add_fields("membership")
@@ -77,7 +75,7 @@ def prune_event(event):
allowed_fields = {
k: v
for k, v in event_dict.items()
for k, v in event.get_dict().items()
if k in allowed_keys
}
@@ -88,78 +86,56 @@ def prune_event(event):
if "age_ts" in event.unsigned:
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
return type(event)(
allowed_fields,
internal_metadata_dict=event.internal_metadata.get_dict()
)
return type(event)(allowed_fields)
def format_event_raw(d):
return d
def format_event_for_client_v1(d):
d["user_id"] = d.pop("sender", None)
move_keys = ("age", "redacted_because", "replaces_state", "prev_content")
for key in move_keys:
if key in d["unsigned"]:
d[key] = d["unsigned"][key]
drop_keys = (
"auth_events", "prev_events", "hashes", "signatures", "depth",
"unsigned", "origin", "prev_state"
)
for key in drop_keys:
d.pop(key, None)
return d
def format_event_for_client_v2(d):
drop_keys = (
"auth_events", "prev_events", "hashes", "signatures", "depth",
"origin", "prev_state",
)
for key in drop_keys:
d.pop(key, None)
return d
def format_event_for_client_v2_without_event_id(d):
d = format_event_for_client_v2(d)
d.pop("room_id", None)
d.pop("event_id", None)
return d
def serialize_event(e, time_now_ms, as_client_event=True,
event_format=format_event_for_client_v1,
token_id=None):
def serialize_event(hs, e, client_event=True):
# FIXME(erikj): To handle the case of presence events and the like
if not isinstance(e, EventBase):
return e
time_now_ms = int(time_now_ms)
# Should this strip out None's?
d = {k: v for k, v in e.get_dict().items()}
if not client_event:
# set the age and keep all other keys
if "age_ts" in d["unsigned"]:
now = int(hs.get_clock().time_msec())
d["unsigned"]["age"] = now - d["unsigned"]["age_ts"]
return d
if "age_ts" in d["unsigned"]:
d["unsigned"]["age"] = time_now_ms - d["unsigned"]["age_ts"]
now = int(hs.get_clock().time_msec())
d["age"] = now - d["unsigned"]["age_ts"]
del d["unsigned"]["age_ts"]
d["user_id"] = d.pop("sender", None)
if "redacted_because" in e.unsigned:
d["unsigned"]["redacted_because"] = serialize_event(
e.unsigned["redacted_because"], time_now_ms
d["redacted_because"] = serialize_event(
hs, e.unsigned["redacted_because"]
)
if token_id is not None:
if token_id == getattr(e.internal_metadata, "token_id", None):
txn_id = getattr(e.internal_metadata, "txn_id", None)
if txn_id is not None:
d["unsigned"]["transaction_id"] = txn_id
del d["unsigned"]["redacted_because"]
if as_client_event:
return event_format(d)
else:
return d
if "redacted_by" in e.unsigned:
d["redacted_by"] = e.unsigned["redacted_by"]
del d["unsigned"]["redacted_by"]
if "replaces_state" in e.unsigned:
d["replaces_state"] = e.unsigned["replaces_state"]
del d["unsigned"]["replaces_state"]
if "prev_content" in e.unsigned:
d["prev_content"] = e.unsigned["prev_content"]
del d["unsigned"]["prev_content"]
del d["auth_events"]
del d["prev_events"]
del d["hashes"]
del d["signatures"]
d.pop("depth", None)
d.pop("unsigned", None)
d.pop("origin", None)
return d

View File

@@ -1,133 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.events.utils import prune_event
from syutil.jsonutil import encode_canonical_json
from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError
import logging
logger = logging.getLogger(__name__)
class FederationBase(object):
@defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False):
"""Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of
that PDU.
If a PDU fails its content hash check then it is redacted.
The given list of PDUs are not modified, instead the function returns
a new list.
Args:
pdu (list)
outlier (bool)
Returns:
Deferred : A list of PDUs that have valid signatures and hashes.
"""
signed_pdus = []
@defer.inlineCallbacks
def do(pdu):
try:
new_pdu = yield self._check_sigs_and_hash(pdu)
signed_pdus.append(new_pdu)
except SynapseError:
# FIXME: We should handle signature failures more gracefully.
# Check local db.
new_pdu = yield self.store.get_event(
pdu.event_id,
allow_rejected=True,
allow_none=True,
)
if new_pdu:
signed_pdus.append(new_pdu)
return
# Check pdu.origin
if pdu.origin != origin:
try:
new_pdu = yield self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
outlier=outlier,
)
if new_pdu:
signed_pdus.append(new_pdu)
return
except:
pass
logger.warn(
"Failed to find copy of %s with valid signature",
pdu.event_id,
)
yield defer.gatherResults(
[do(pdu) for pdu in pdus],
consumeErrors=True
)
defer.returnValue(signed_pdus)
@defer.inlineCallbacks
def _check_sigs_and_hash(self, pdu):
"""Throws a SynapseError if the PDU does not have the correct
signatures.
Returns:
FrozenEvent: Either the given event or it redacted if it failed the
content hash check.
"""
# Check signatures are correct.
redacted_event = prune_event(pdu)
redacted_pdu_json = redacted_event.get_pdu_json()
try:
yield self.keyring.verify_json_for_server(
pdu.origin, redacted_pdu_json
)
except SynapseError:
logger.warn(
"Signature check failed for %s redacted to %s",
encode_canonical_json(pdu.get_pdu_json()),
encode_canonical_json(redacted_pdu_json),
)
raise
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting %s, %s",
pdu.event_id, encode_canonical_json(pdu.get_dict())
)
defer.returnValue(redacted_event)
defer.returnValue(pdu)

View File

@@ -1,400 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from .federation_base import FederationBase
from .units import Edu
from synapse.api.errors import CodeMessageException
from synapse.util.logutils import log_function
from synapse.events import FrozenEvent
import logging
logger = logging.getLogger(__name__)
class FederationClient(FederationBase):
@log_function
def send_pdu(self, pdu, destinations):
"""Informs the replication layer about a new PDU generated within the
home server that should be transmitted to others.
TODO: Figure out when we should actually resolve the deferred.
Args:
pdu (Pdu): The new Pdu.
Returns:
Deferred: Completes when we have successfully processed the PDU
and replicated it to any interested remote home servers.
"""
order = self._order
self._order += 1
logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
# TODO, add errback, etc.
self._transaction_queue.enqueue_pdu(pdu, destinations, order)
logger.debug(
"[%s] transaction_layer.enqueue_pdu... done",
pdu.event_id
)
@log_function
def send_edu(self, destination, edu_type, content):
edu = Edu(
origin=self.server_name,
destination=destination,
edu_type=edu_type,
content=content,
)
# TODO, add errback, etc.
self._transaction_queue.enqueue_edu(edu)
return defer.succeed(None)
@log_function
def send_failure(self, failure, destination):
self._transaction_queue.enqueue_failure(failure, destination)
return defer.succeed(None)
@log_function
def make_query(self, destination, query_type, args,
retry_on_dns_fail=True):
"""Sends a federation Query to a remote homeserver of the given type
and arguments.
Args:
destination (str): Domain name of the remote homeserver
query_type (str): Category of the query type; should match the
handler name used in register_query_handler().
args (dict): Mapping of strings to strings containing the details
of the query request.
Returns:
a Deferred which will eventually yield a JSON object from the
response
"""
return self.transport_layer.make_query(
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
)
@defer.inlineCallbacks
@log_function
def backfill(self, dest, context, limit, extremities):
"""Requests some more historic PDUs for the given context from the
given destination server.
Args:
dest (str): The remote home server to ask.
context (str): The context to backfill.
limit (int): The maximum number of PDUs to return.
extremities (list): List of PDU id and origins of the first pdus
we have seen from the context
Returns:
Deferred: Results in the received PDUs.
"""
logger.debug("backfill extrem=%s", extremities)
# If there are no extremeties then we've (probably) reached the start.
if not extremities:
return
transaction_data = yield self.transport_layer.backfill(
dest, context, extremities, limit)
logger.debug("backfill transaction_data=%s", repr(transaction_data))
pdus = [
self.event_from_pdu_json(p, outlier=False)
for p in transaction_data["pdus"]
]
for i, pdu in enumerate(pdus):
pdus[i] = yield self._check_sigs_and_hash(pdu)
# FIXME: We should handle signature failures more gracefully.
defer.returnValue(pdus)
@defer.inlineCallbacks
@log_function
def get_pdu(self, destinations, event_id, outlier=False):
"""Requests the PDU with given origin and ID from the remote home
servers.
Will attempt to get the PDU from each destination in the list until
one succeeds.
This will persist the PDU locally upon receipt.
Args:
destinations (list): Which home servers to query
pdu_origin (str): The home server that originally sent the pdu.
event_id (str)
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
Returns:
Deferred: Results in the requested PDU.
"""
# TODO: Rate limit the number of times we try and get the same event.
pdu = None
for destination in destinations:
try:
transaction_data = yield self.transport_layer.get_event(
destination, event_id
)
logger.debug("transaction_data %r", transaction_data)
pdu_list = [
self.event_from_pdu_json(p, outlier=outlier)
for p in transaction_data["pdus"]
]
if pdu_list:
pdu = pdu_list[0]
# Check signatures are correct.
pdu = yield self._check_sigs_and_hash(pdu)
break
except CodeMessageException:
raise
except Exception as e:
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
continue
defer.returnValue(pdu)
@defer.inlineCallbacks
@log_function
def get_state_for_room(self, destination, room_id, event_id):
"""Requests all of the `current` state PDUs for a given room from
a remote home server.
Args:
destination (str): The remote homeserver to query for the state.
room_id (str): The id of the room we're interested in.
event_id (str): The id of the event we want the state at.
Returns:
Deferred: Results in a list of PDUs.
"""
result = yield self.transport_layer.get_room_state(
destination, room_id, event_id=event_id,
)
pdus = [
self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
]
auth_chain = [
self.event_from_pdu_json(p, outlier=True)
for p in result.get("auth_chain", [])
]
signed_pdus = yield self._check_sigs_and_hash_and_fetch(
destination, pdus, outlier=True
)
signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True
)
signed_auth.sort(key=lambda e: e.depth)
defer.returnValue((signed_pdus, signed_auth))
@defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
res = yield self.transport_layer.get_event_auth(
destination, room_id, event_id,
)
auth_chain = [
self.event_from_pdu_json(p, outlier=True)
for p in res["auth_chain"]
]
signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True
)
signed_auth.sort(key=lambda e: e.depth)
defer.returnValue(signed_auth)
@defer.inlineCallbacks
def make_join(self, destinations, room_id, user_id):
for destination in destinations:
try:
ret = yield self.transport_layer.make_join(
destination, room_id, user_id
)
pdu_dict = ret["event"]
logger.debug("Got response to make_join: %s", pdu_dict)
defer.returnValue(
(destination, self.event_from_pdu_json(pdu_dict))
)
break
except CodeMessageException:
raise
except Exception as e:
logger.warn(
"Failed to make_join via %s: %s",
destination, e.message
)
raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks
def send_join(self, destinations, pdu):
for destination in destinations:
try:
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
logger.debug("Got content: %s", content)
state = [
self.event_from_pdu_json(p, outlier=True)
for p in content.get("state", [])
]
auth_chain = [
self.event_from_pdu_json(p, outlier=True)
for p in content.get("auth_chain", [])
]
signed_state = yield self._check_sigs_and_hash_and_fetch(
destination, state, outlier=True
)
signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True
)
auth_chain.sort(key=lambda e: e.depth)
defer.returnValue({
"state": signed_state,
"auth_chain": signed_auth,
"origin": destination,
})
except CodeMessageException:
raise
except Exception as e:
logger.warn(
"Failed to send_join via %s: %s",
destination, e.message
)
raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks
def send_invite(self, destination, room_id, event_id, pdu):
time_now = self._clock.time_msec()
code, content = yield self.transport_layer.send_invite(
destination=destination,
room_id=room_id,
event_id=event_id,
content=pdu.get_pdu_json(time_now),
)
pdu_dict = content["event"]
logger.debug("Got response to send_invite: %s", pdu_dict)
pdu = self.event_from_pdu_json(pdu_dict)
# Check signatures are correct.
pdu = yield self._check_sigs_and_hash(pdu)
# FIXME: We should handle signature failures more gracefully.
defer.returnValue(pdu)
@defer.inlineCallbacks
def query_auth(self, destination, room_id, event_id, local_auth):
"""
Params:
destination (str)
event_it (str)
local_auth (list)
"""
time_now = self._clock.time_msec()
send_content = {
"auth_chain": [e.get_pdu_json(time_now) for e in local_auth],
}
code, content = yield self.transport_layer.send_query_auth(
destination=destination,
room_id=room_id,
event_id=event_id,
content=send_content,
)
auth_chain = [
self.event_from_pdu_json(e)
for e in content["auth_chain"]
]
signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True
)
signed_auth.sort(key=lambda e: e.depth)
ret = {
"auth_chain": signed_auth,
"rejects": content.get("rejects", []),
"missing": content.get("missing", []),
}
defer.returnValue(ret)
def event_from_pdu_json(self, pdu_json, outlier=False):
event = FrozenEvent(
pdu_json
)
event.internal_metadata.outlier = outlier
return event

View File

@@ -1,441 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from .federation_base import FederationBase
from .units import Transaction, Edu
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.events import FrozenEvent
from synapse.api.errors import FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature
import logging
logger = logging.getLogger(__name__)
class FederationServer(FederationBase):
def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate
receipt of new PDUs from other home servers. The required methods are
documented on :py:class:`.ReplicationHandler`.
"""
self.handler = handler
def register_edu_handler(self, edu_type, handler):
if edu_type in self.edu_handlers:
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
self.edu_handlers[edu_type] = handler
def register_query_handler(self, query_type, handler):
"""Sets the handler callable that will be used to handle an incoming
federation Query of the given type.
Args:
query_type (str): Category name of the query, which should match
the string used by make_query.
handler (callable): Invoked to handle incoming queries of this type
handler is invoked as:
result = handler(args)
where 'args' is a dict mapping strings to strings of the query
arguments. It should return a Deferred that will eventually yield an
object to encode as JSON.
"""
if query_type in self.query_handlers:
raise KeyError(
"Already have a Query handler for %s" % (query_type,)
)
self.query_handlers[query_type] = handler
@defer.inlineCallbacks
@log_function
def on_backfill_request(self, origin, room_id, versions, limit):
pdus = yield self.handler.on_backfill_request(
origin, room_id, versions, limit
)
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@defer.inlineCallbacks
@log_function
def on_incoming_transaction(self, transaction_data):
transaction = Transaction(**transaction_data)
for p in transaction.pdus:
if "unsigned" in p:
unsigned = p["unsigned"]
if "age" in unsigned:
p["age"] = unsigned["age"]
if "age" in p:
p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
del p["age"]
pdu_list = [
self.event_from_pdu_json(p) for p in transaction.pdus
]
logger.debug("[%s] Got transaction", transaction.transaction_id)
response = yield self.transaction_actions.have_responded(transaction)
if response:
logger.debug(
"[%s] We've already responed to this request",
transaction.transaction_id
)
defer.returnValue(response)
return
logger.debug("[%s] Transaction is new", transaction.transaction_id)
with PreserveLoggingContext():
dl = []
for pdu in pdu_list:
dl.append(self._handle_new_pdu(transaction.origin, pdu))
if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]:
self.received_edu(
transaction.origin,
edu.edu_type,
edu.content
)
results = yield defer.DeferredList(dl)
ret = []
for r in results:
if r[0]:
ret.append({})
else:
logger.exception(r[1])
ret.append({"error": str(r[1])})
logger.debug("Returning: %s", str(ret))
yield self.transaction_actions.set_response(
transaction,
200, response
)
defer.returnValue((200, response))
def received_edu(self, origin, edu_type, content):
if edu_type in self.edu_handlers:
self.edu_handlers[edu_type](origin, content)
else:
logger.warn("Received EDU of type %s with no handler", edu_type)
@defer.inlineCallbacks
@log_function
def on_context_state_request(self, origin, room_id, event_id):
if event_id:
pdus = yield self.handler.get_state_for_pdu(
origin, room_id, event_id,
)
auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus]
)
for event in auth_chain:
event.signatures.update(
compute_event_signature(
event,
self.hs.hostname,
self.hs.config.signing_key[0]
)
)
else:
raise NotImplementedError("Specify an event")
defer.returnValue((200, {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
}))
@defer.inlineCallbacks
@log_function
def on_pdu_request(self, origin, event_id):
pdu = yield self._get_persisted_pdu(origin, event_id)
if pdu:
defer.returnValue(
(200, self._transaction_from_pdus([pdu]).get_dict())
)
else:
defer.returnValue((404, ""))
@defer.inlineCallbacks
@log_function
def on_pull_request(self, origin, versions):
raise NotImplementedError("Pull transactions not implemented")
@defer.inlineCallbacks
def on_query_request(self, query_type, args):
if query_type in self.query_handlers:
response = yield self.query_handlers[query_type](args)
defer.returnValue((200, response))
else:
defer.returnValue(
(404, "No handler for Query type '%s'" % (query_type,))
)
@defer.inlineCallbacks
def on_make_join_request(self, room_id, user_id):
pdu = yield self.handler.on_make_join_request(room_id, user_id)
time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@defer.inlineCallbacks
def on_invite_request(self, origin, content):
pdu = self.event_from_pdu_json(content)
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
time_now = self._clock.time_msec()
defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
@defer.inlineCallbacks
def on_send_join_request(self, origin, content):
logger.debug("on_send_join_request: content: %s", content)
pdu = self.event_from_pdu_json(content)
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
defer.returnValue((200, {
"state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
"auth_chain": [
p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
],
}))
@defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id):
time_now = self._clock.time_msec()
auth_pdus = yield self.handler.on_event_auth(event_id)
defer.returnValue((200, {
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
}))
@defer.inlineCallbacks
def on_query_auth_request(self, origin, content, event_id):
"""
Content is a dict with keys::
auth_chain (list): A list of events that give the auth chain.
missing (list): A list of event_ids indicating what the other
side (`origin`) think we're missing.
rejects (dict): A mapping from event_id to a 2-tuple of reason
string and a proof (or None) of why the event was rejected.
The keys of this dict give the list of events the `origin` has
rejected.
Args:
origin (str)
content (dict)
event_id (str)
Returns:
Deferred: Results in `dict` with the same format as `content`
"""
auth_chain = [
self.event_from_pdu_json(e)
for e in content["auth_chain"]
]
signed_auth = yield self._check_sigs_and_hash_and_fetch(
origin, auth_chain, outlier=True
)
ret = yield self.handler.on_query_auth(
origin,
event_id,
signed_auth,
content.get("rejects", []),
content.get("missing", []),
)
time_now = self._clock.time_msec()
send_content = {
"auth_chain": [
e.get_pdu_json(time_now)
for e in ret["auth_chain"]
],
"rejects": ret.get("rejects", []),
"missing": ret.get("missing", []),
}
defer.returnValue(
(200, send_content)
)
@log_function
def _get_persisted_pdu(self, origin, event_id, do_auth=True):
""" Get a PDU from the database with given origin and id.
Returns:
Deferred: Results in a `Pdu`.
"""
return self.handler.get_persisted_pdu(
origin, event_id, do_auth=do_auth
)
def _transaction_from_pdus(self, pdu_list):
"""Returns a new Transaction containing the given PDUs suitable for
transmission.
"""
time_now = self._clock.time_msec()
pdus = [p.get_pdu_json(time_now) for p in pdu_list]
return Transaction(
origin=self.server_name,
pdus=pdus,
origin_server_ts=int(time_now),
destination=None,
)
@defer.inlineCallbacks
@log_function
def _handle_new_pdu(self, origin, pdu, max_recursion=10):
# We reprocess pdus when we have seen them only as outliers
existing = yield self._get_persisted_pdu(
origin, pdu.event_id, do_auth=False
)
# FIXME: Currently we fetch an event again when we already have it
# if it has been marked as an outlier.
already_seen = (
existing and (
not existing.internal_metadata.is_outlier()
or pdu.internal_metadata.is_outlier()
)
)
if already_seen:
logger.debug("Already seen pdu %s", pdu.event_id)
defer.returnValue({})
return
# Check signature.
try:
pdu = yield self._check_sigs_and_hash(pdu)
except SynapseError as e:
raise FederationError(
"ERROR",
e.code,
e.msg,
affected=pdu.event_id,
)
state = None
auth_chain = []
have_seen = yield self.store.have_events(
[ev for ev, _ in pdu.prev_events]
)
fetch_state = False
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
min_depth = yield self.handler.get_min_depth_for_context(
pdu.room_id
)
logger.debug(
"_handle_new_pdu min_depth for %s: %d",
pdu.room_id, min_depth
)
if min_depth and pdu.depth > min_depth and max_recursion > 0:
for event_id, hashes in pdu.prev_events:
if event_id not in have_seen:
logger.debug(
"_handle_new_pdu requesting pdu %s",
event_id
)
try:
new_pdu = yield self.federation_client.get_pdu(
[origin, pdu.origin],
event_id=event_id,
)
if new_pdu:
yield self._handle_new_pdu(
origin,
new_pdu,
max_recursion=max_recursion-1
)
logger.debug("Processed pdu %s", event_id)
else:
logger.warn("Failed to get PDU %s", event_id)
fetch_state = True
except:
# TODO(erikj): Do some more intelligent retries.
logger.exception("Failed to get PDU")
fetch_state = True
else:
prevs = {e_id for e_id, _ in pdu.prev_events}
seen = set(have_seen.keys())
if prevs - seen:
fetch_state = True
else:
fetch_state = True
if fetch_state:
# We need to get the state at this event, since we haven't
# processed all the prev events.
logger.debug(
"_handle_new_pdu getting state for %s",
pdu.room_id
)
try:
state, auth_chain = yield self.get_state_for_room(
origin, pdu.room_id, pdu.event_id,
)
except:
logger.warn("Failed to get state for event: %s", pdu.event_id)
ret = yield self.handler.on_receive_pdu(
origin,
pdu,
backfilled=False,
state=state,
auth_chain=auth_chain,
)
defer.returnValue(ret)
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name
def event_from_pdu_json(self, pdu_json, outlier=False):
event = FrozenEvent(
pdu_json
)
event.internal_metadata.outlier = outlier
return event

View File

@@ -23,8 +23,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function
from syutil.jsonutil import encode_canonical_json
import json
import logging
@@ -71,7 +70,7 @@ class TransactionActions(object):
transaction.transaction_id,
transaction.origin,
code,
encode_canonical_json(response)
json.dumps(response)
)
@defer.inlineCallbacks
@@ -101,5 +100,5 @@ class TransactionActions(object):
transaction.transaction_id,
transaction.destination,
response_code,
encode_canonical_json(response_dict)
json.dumps(response_dict)
)

View File

@@ -17,20 +17,23 @@
a given transport.
"""
from .federation_client import FederationClient
from .federation_server import FederationServer
from twisted.internet import defer
from .transaction_queue import TransactionQueue
from .units import Transaction, Edu
from .persistence import TransactionActions
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.events import FrozenEvent
import logging
logger = logging.getLogger(__name__)
class ReplicationLayer(FederationClient, FederationServer):
class ReplicationLayer(object):
"""This layer is responsible for replicating with remote home servers over
the given transport. I.e., does the sending and receiving of PDUs to
remote home servers.
@@ -51,28 +54,898 @@ class ReplicationLayer(FederationClient, FederationServer):
def __init__(self, hs, transport_layer):
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
self.transport_layer = transport_layer
self.transport_layer.register_received_handler(self)
self.transport_layer.register_request_handler(self)
self.federation_client = self
self.store = hs.get_datastore()
# self.pdu_actions = PduActions(self.store)
self.transaction_actions = TransactionActions(self.store)
self._transaction_queue = _TransactionQueue(
hs, self.transaction_actions, transport_layer
)
self.handler = None
self.edu_handlers = {}
self.query_handlers = {}
self._clock = hs.get_clock()
self.transaction_actions = TransactionActions(self.store)
self._transaction_queue = TransactionQueue(hs, transport_layer)
self._order = 0
self.hs = hs
self._clock = hs.get_clock()
self.event_builder_factory = hs.get_event_builder_factory()
def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate
receipt of new PDUs from other home servers. The required methods are
documented on :py:class:`.ReplicationHandler`.
"""
self.handler = handler
def register_edu_handler(self, edu_type, handler):
if edu_type in self.edu_handlers:
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
self.edu_handlers[edu_type] = handler
def register_query_handler(self, query_type, handler):
"""Sets the handler callable that will be used to handle an incoming
federation Query of the given type.
Args:
query_type (str): Category name of the query, which should match
the string used by make_query.
handler (callable): Invoked to handle incoming queries of this type
handler is invoked as:
result = handler(args)
where 'args' is a dict mapping strings to strings of the query
arguments. It should return a Deferred that will eventually yield an
object to encode as JSON.
"""
if query_type in self.query_handlers:
raise KeyError(
"Already have a Query handler for %s" % (query_type,)
)
self.query_handlers[query_type] = handler
@log_function
def send_pdu(self, pdu, destinations):
"""Informs the replication layer about a new PDU generated within the
home server that should be transmitted to others.
TODO: Figure out when we should actually resolve the deferred.
Args:
pdu (Pdu): The new Pdu.
Returns:
Deferred: Completes when we have successfully processed the PDU
and replicated it to any interested remote home servers.
"""
order = self._order
self._order += 1
logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
# TODO, add errback, etc.
self._transaction_queue.enqueue_pdu(pdu, destinations, order)
logger.debug(
"[%s] transaction_layer.enqueue_pdu... done",
pdu.event_id
)
@log_function
def send_edu(self, destination, edu_type, content):
edu = Edu(
origin=self.server_name,
destination=destination,
edu_type=edu_type,
content=content,
)
# TODO, add errback, etc.
self._transaction_queue.enqueue_edu(edu)
return defer.succeed(None)
@log_function
def send_failure(self, failure, destination):
self._transaction_queue.enqueue_failure(failure, destination)
return defer.succeed(None)
@log_function
def make_query(self, destination, query_type, args,
retry_on_dns_fail=True):
"""Sends a federation Query to a remote homeserver of the given type
and arguments.
Args:
destination (str): Domain name of the remote homeserver
query_type (str): Category of the query type; should match the
handler name used in register_query_handler().
args (dict): Mapping of strings to strings containing the details
of the query request.
Returns:
a Deferred which will eventually yield a JSON object from the
response
"""
return self.transport_layer.make_query(
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
)
@defer.inlineCallbacks
@log_function
def backfill(self, dest, context, limit, extremities):
"""Requests some more historic PDUs for the given context from the
given destination server.
Args:
dest (str): The remote home server to ask.
context (str): The context to backfill.
limit (int): The maximum number of PDUs to return.
extremities (list): List of PDU id and origins of the first pdus
we have seen from the context
Returns:
Deferred: Results in the received PDUs.
"""
logger.debug("backfill extrem=%s", extremities)
# If there are no extremeties then we've (probably) reached the start.
if not extremities:
return
transaction_data = yield self.transport_layer.backfill(
dest, context, extremities, limit)
logger.debug("backfill transaction_data=%s", repr(transaction_data))
transaction = Transaction(**transaction_data)
pdus = [
self.event_from_pdu_json(p, outlier=False)
for p in transaction.pdus
]
for pdu in pdus:
yield self._handle_new_pdu(dest, pdu, backfilled=True)
defer.returnValue(pdus)
@defer.inlineCallbacks
@log_function
def get_pdu(self, destination, event_id, outlier=False):
"""Requests the PDU with given origin and ID from the remote home
server.
This will persist the PDU locally upon receipt.
Args:
destination (str): Which home server to query
pdu_origin (str): The home server that originally sent the pdu.
event_id (str)
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
Returns:
Deferred: Results in the requested PDU.
"""
transaction_data = yield self.transport_layer.get_event(
destination, event_id
)
transaction = Transaction(**transaction_data)
pdu_list = [
self.event_from_pdu_json(p, outlier=outlier)
for p in transaction.pdus
]
pdu = None
if pdu_list:
pdu = pdu_list[0]
yield self._handle_new_pdu(destination, pdu)
defer.returnValue(pdu)
@defer.inlineCallbacks
@log_function
def get_state_for_room(self, destination, room_id, event_id):
"""Requests all of the `current` state PDUs for a given room from
a remote home server.
Args:
destination (str): The remote homeserver to query for the state.
room_id (str): The id of the room we're interested in.
event_id (str): The id of the event we want the state at.
Returns:
Deferred: Results in a list of PDUs.
"""
result = yield self.transport_layer.get_room_state(
destination, room_id, event_id=event_id,
)
pdus = [
self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
]
auth_chain = [
self.event_from_pdu_json(p, outlier=True)
for p in result.get("auth_chain", [])
]
defer.returnValue((pdus, auth_chain))
@defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
res = yield self.transport_layer.get_event_auth(
destination, room_id, event_id,
)
auth_chain = [
self.event_from_pdu_json(p, outlier=True)
for p in res["auth_chain"]
]
auth_chain.sort(key=lambda e: e.depth)
defer.returnValue(auth_chain)
@defer.inlineCallbacks
@log_function
def on_backfill_request(self, origin, room_id, versions, limit):
pdus = yield self.handler.on_backfill_request(
origin, room_id, versions, limit
)
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@defer.inlineCallbacks
@log_function
def on_incoming_transaction(self, transaction_data):
transaction = Transaction(**transaction_data)
for p in transaction.pdus:
if "unsigned" in p:
unsigned = p["unsigned"]
if "age" in unsigned:
p["age"] = unsigned["age"]
if "age" in p:
p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
del p["age"]
pdu_list = [
self.event_from_pdu_json(p) for p in transaction.pdus
]
logger.debug("[%s] Got transaction", transaction.transaction_id)
response = yield self.transaction_actions.have_responded(transaction)
if response:
logger.debug("[%s] We've already responed to this request",
transaction.transaction_id)
defer.returnValue(response)
return
logger.debug("[%s] Transaction is new", transaction.transaction_id)
with PreserveLoggingContext():
dl = []
for pdu in pdu_list:
dl.append(self._handle_new_pdu(transaction.origin, pdu))
if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]:
self.received_edu(
transaction.origin,
edu.edu_type,
edu.content
)
results = yield defer.DeferredList(dl)
ret = []
for r in results:
if r[0]:
ret.append({})
else:
logger.exception(r[1])
ret.append({"error": str(r[1])})
logger.debug("Returning: %s", str(ret))
yield self.transaction_actions.set_response(
transaction,
200, response
)
defer.returnValue((200, response))
def received_edu(self, origin, edu_type, content):
if edu_type in self.edu_handlers:
self.edu_handlers[edu_type](origin, content)
else:
logger.warn("Received EDU of type %s with no handler", edu_type)
@defer.inlineCallbacks
@log_function
def on_context_state_request(self, origin, room_id, event_id):
if event_id:
pdus = yield self.handler.get_state_for_pdu(
origin, room_id, event_id,
)
auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus]
)
else:
raise NotImplementedError("Specify an event")
defer.returnValue((200, {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
}))
@defer.inlineCallbacks
@log_function
def on_pdu_request(self, origin, event_id):
pdu = yield self._get_persisted_pdu(origin, event_id)
if pdu:
defer.returnValue(
(200, self._transaction_from_pdus([pdu]).get_dict())
)
else:
defer.returnValue((404, ""))
@defer.inlineCallbacks
@log_function
def on_pull_request(self, origin, versions):
raise NotImplementedError("Pull transactions not implemented")
@defer.inlineCallbacks
def on_query_request(self, query_type, args):
if query_type in self.query_handlers:
response = yield self.query_handlers[query_type](args)
defer.returnValue((200, response))
else:
defer.returnValue(
(404, "No handler for Query type '%s'" % (query_type,))
)
@defer.inlineCallbacks
def on_make_join_request(self, room_id, user_id):
pdu = yield self.handler.on_make_join_request(room_id, user_id)
time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@defer.inlineCallbacks
def on_invite_request(self, origin, content):
pdu = self.event_from_pdu_json(content)
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
time_now = self._clock.time_msec()
defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
@defer.inlineCallbacks
def on_send_join_request(self, origin, content):
logger.debug("on_send_join_request: content: %s", content)
pdu = self.event_from_pdu_json(content)
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
defer.returnValue((200, {
"state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
"auth_chain": [
p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
],
}))
@defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id):
time_now = self._clock.time_msec()
auth_pdus = yield self.handler.on_event_auth(event_id)
defer.returnValue((200, {
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
}))
@defer.inlineCallbacks
def make_join(self, destination, room_id, user_id):
ret = yield self.transport_layer.make_join(
destination, room_id, user_id
)
pdu_dict = ret["event"]
logger.debug("Got response to make_join: %s", pdu_dict)
defer.returnValue(self.event_from_pdu_json(pdu_dict))
@defer.inlineCallbacks
def send_join(self, destination, pdu):
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
logger.debug("Got content: %s", content)
state = [
self.event_from_pdu_json(p, outlier=True)
for p in content.get("state", [])
]
auth_chain = [
self.event_from_pdu_json(p, outlier=True)
for p in content.get("auth_chain", [])
]
auth_chain.sort(key=lambda e: e.depth)
defer.returnValue({
"state": state,
"auth_chain": auth_chain,
})
@defer.inlineCallbacks
def send_invite(self, destination, room_id, event_id, pdu):
time_now = self._clock.time_msec()
code, content = yield self.transport_layer.send_invite(
destination=destination,
room_id=room_id,
event_id=event_id,
content=pdu.get_pdu_json(time_now),
)
pdu_dict = content["event"]
logger.debug("Got response to send_invite: %s", pdu_dict)
defer.returnValue(self.event_from_pdu_json(pdu_dict))
@log_function
def _get_persisted_pdu(self, origin, event_id, do_auth=True):
""" Get a PDU from the database with given origin and id.
Returns:
Deferred: Results in a `Pdu`.
"""
return self.handler.get_persisted_pdu(
origin, event_id, do_auth=do_auth
)
def _transaction_from_pdus(self, pdu_list):
"""Returns a new Transaction containing the given PDUs suitable for
transmission.
"""
time_now = self._clock.time_msec()
pdus = [p.get_pdu_json(time_now) for p in pdu_list]
return Transaction(
origin=self.server_name,
pdus=pdus,
origin_server_ts=int(time_now),
destination=None,
)
@defer.inlineCallbacks
@log_function
def _handle_new_pdu(self, origin, pdu, backfilled=False):
# We reprocess pdus when we have seen them only as outliers
existing = yield self._get_persisted_pdu(
origin, pdu.event_id, do_auth=False
)
already_seen = (
existing and (
not existing.internal_metadata.is_outlier()
or pdu.internal_metadata.is_outlier()
)
)
if already_seen:
logger.debug("Already seen pdu %s", pdu.event_id)
defer.returnValue({})
return
state = None
auth_chain = []
# We need to make sure we have all the auth events.
# for e_id, _ in pdu.auth_events:
# exists = yield self._get_persisted_pdu(
# origin,
# e_id,
# do_auth=False
# )
#
# if not exists:
# try:
# logger.debug(
# "_handle_new_pdu fetch missing auth event %s from %s",
# e_id,
# origin,
# )
#
# yield self.get_pdu(
# origin,
# event_id=e_id,
# outlier=True,
# )
#
# logger.debug("Processed pdu %s", e_id)
# except:
# logger.warn(
# "Failed to get auth event %s from %s",
# e_id,
# origin
# )
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
min_depth = yield self.handler.get_min_depth_for_context(
pdu.room_id
)
logger.debug(
"_handle_new_pdu min_depth for %s: %d",
pdu.room_id, min_depth
)
if min_depth and pdu.depth > min_depth:
for event_id, hashes in pdu.prev_events:
exists = yield self._get_persisted_pdu(
origin,
event_id,
do_auth=False
)
if not exists:
logger.debug(
"_handle_new_pdu requesting pdu %s",
event_id
)
try:
yield self.get_pdu(
origin,
event_id=event_id,
)
logger.debug("Processed pdu %s", event_id)
except:
# TODO(erikj): Do some more intelligent retries.
logger.exception("Failed to get PDU")
else:
# We need to get the state at this event, since we have reached
# a backward extremity edge.
logger.debug(
"_handle_new_pdu getting state for %s",
pdu.room_id
)
state, auth_chain = yield self.get_state_for_room(
origin, pdu.room_id, pdu.event_id,
)
if not backfilled:
ret = yield self.handler.on_receive_pdu(
origin,
pdu,
backfilled=backfilled,
state=state,
auth_chain=auth_chain,
)
else:
ret = None
# yield self.pdu_actions.mark_as_processed(pdu)
defer.returnValue(ret)
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name
def event_from_pdu_json(self, pdu_json, outlier=False):
event = FrozenEvent(
pdu_json
)
event.internal_metadata.outlier = outlier
return event
class _TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at
a time for a given destination.
It batches pending PDUs into single transactions.
"""
def __init__(self, hs, transaction_actions, transport_layer):
self.server_name = hs.hostname
self.transaction_actions = transaction_actions
self.transport_layer = transport_layer
self._clock = hs.get_clock()
self.store = hs.get_datastore()
# Is a mapping from destinations -> deferreds. Used to keep track
# of which destinations have transactions in flight and when they are
# done
self.pending_transactions = {}
# Is a mapping from destination -> list of
# tuple(pending pdus, deferred, order)
self.pending_pdus_by_dest = {}
# destination -> list of tuple(edu, deferred)
self.pending_edus_by_dest = {}
# destination -> list of tuple(failure, deferred)
self.pending_failures_by_dest = {}
# HACK to get unique tx id
self._next_txn_id = int(self._clock.time_msec())
@defer.inlineCallbacks
@log_function
def enqueue_pdu(self, pdu, destinations, order):
# We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus
# table and we'll get back to it later.
destinations = set(destinations)
destinations.discard(self.server_name)
destinations.discard("localhost")
logger.debug("Sending to: %s", str(destinations))
if not destinations:
return
deferreds = []
for destination in destinations:
deferred = defer.Deferred()
self.pending_pdus_by_dest.setdefault(destination, []).append(
(pdu, deferred, order)
)
def eb(failure):
if not deferred.called:
deferred.errback(failure)
else:
logger.warn("Failed to send pdu", failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(eb)
deferreds.append(deferred)
yield defer.DeferredList(deferreds)
# NO inlineCallbacks
def enqueue_edu(self, edu):
destination = edu.destination
if destination == self.server_name:
return
deferred = defer.Deferred()
self.pending_edus_by_dest.setdefault(destination, []).append(
(edu, deferred)
)
def eb(failure):
if not deferred.called:
deferred.errback(failure)
else:
logger.warn("Failed to send edu", failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(eb)
return deferred
@defer.inlineCallbacks
def enqueue_failure(self, failure, destination):
deferred = defer.Deferred()
self.pending_failures_by_dest.setdefault(
destination, []
).append(
(failure, deferred)
)
yield deferred
@defer.inlineCallbacks
@log_function
def _attempt_new_transaction(self, destination):
(retry_last_ts, retry_interval) = (0, 0)
retry_timings = yield self.store.get_destination_retry_timings(
destination
)
if retry_timings:
(retry_last_ts, retry_interval) = (
retry_timings.retry_last_ts, retry_timings.retry_interval
)
if retry_last_ts + retry_interval > int(self._clock.time_msec()):
logger.info(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,
)
return
else:
logger.info("TX [%s] is ready for retry", destination)
logger.info("TX [%s] _attempt_new_transaction", destination)
if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending
# request at which point pending_pdus_by_dest just keeps growing.
# we need application-layer timeouts of some flavour of these
# requests
return
# list of (pending_pdu, deferred, order)
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_failures = self.pending_failures_by_dest.pop(destination, [])
if pending_pdus:
logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
return
logger.debug(
"TX [%s] Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)",
destination,
len(pending_pdus),
len(pending_edus),
len(pending_failures)
)
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[2])
pdus = [x[0] for x in pending_pdus]
edus = [x[0] for x in pending_edus]
failures = [x[0].get_dict() for x in pending_failures]
deferreds = [
x[1]
for x in pending_pdus + pending_edus + pending_failures
]
try:
self.pending_transactions[destination] = 1
logger.debug("TX [%s] Persisting transaction...", destination)
transaction = Transaction.create_new(
origin_server_ts=int(self._clock.time_msec()),
transaction_id=str(self._next_txn_id),
origin=self.server_name,
destination=destination,
pdus=pdus,
edus=edus,
pdu_failures=failures,
)
self._next_txn_id += 1
yield self.transaction_actions.prepare_to_send(transaction)
logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
"TX [%s] Sending transaction [%s]",
destination,
transaction.transaction_id,
)
# Actually send the transaction
# FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work
def json_data_cb():
data = transaction.get_dict()
now = int(self._clock.time_msec())
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
unsigned = p.setdefault("unsigned", {})
unsigned["age"] = now - int(p["age_ts"])
del p["age_ts"]
return data
code, response = yield self.transport_layer.send_transaction(
transaction, json_data_cb
)
logger.info("TX [%s] got %d response", destination, code)
logger.debug("TX [%s] Sent transaction", destination)
logger.debug("TX [%s] Marking as delivered...", destination)
yield self.transaction_actions.delivered(
transaction, code, response
)
logger.debug("TX [%s] Marked as delivered", destination)
logger.debug("TX [%s] Yielding to callbacks...", destination)
for deferred in deferreds:
if code == 200:
if retry_last_ts:
# this host is alive! reset retry schedule
yield self.store.set_destination_retry_timings(
destination, 0, 0
)
deferred.callback(None)
else:
self.set_retrying(destination, retry_interval)
deferred.errback(RuntimeError("Got status %d" % code))
# Ensures we don't continue until all callbacks on that
# deferred have fired
try:
yield deferred
except:
pass
logger.debug("TX [%s] Yielded to callbacks", destination)
except Exception as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
self.set_retrying(destination, retry_interval)
for deferred in deferreds:
if not deferred.called:
deferred.errback(e)
finally:
# We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None)
# Check to see if there is anything else to send.
self._attempt_new_transaction(destination)
@defer.inlineCallbacks
def set_retrying(self, destination, retry_interval):
# track that this destination is having problems and we should
# give it a chance to recover before trying it again
if retry_interval:
retry_interval *= 2
# plateau at hourly retries for now
if retry_interval >= 60 * 60 * 1000:
retry_interval = 60 * 60 * 1000
else:
retry_interval = 2000 # try again at first after 2 seconds
yield self.store.set_destination_retry_timings(
destination,
int(self._clock.time_msec()),
retry_interval
)

View File

@@ -1,335 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from .persistence import TransactionActions
from .units import Transaction
from synapse.api.errors import HttpResponseException
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
import logging
logger = logging.getLogger(__name__)
class TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at
a time for a given destination.
It batches pending PDUs into single transactions.
"""
def __init__(self, hs, transport_layer):
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.transaction_actions = TransactionActions(self.store)
self.transport_layer = transport_layer
self._clock = hs.get_clock()
# Is a mapping from destinations -> deferreds. Used to keep track
# of which destinations have transactions in flight and when they are
# done
self.pending_transactions = {}
# Is a mapping from destination -> list of
# tuple(pending pdus, deferred, order)
self.pending_pdus_by_dest = {}
# destination -> list of tuple(edu, deferred)
self.pending_edus_by_dest = {}
# destination -> list of tuple(failure, deferred)
self.pending_failures_by_dest = {}
# HACK to get unique tx id
self._next_txn_id = int(self._clock.time_msec())
@defer.inlineCallbacks
@log_function
def enqueue_pdu(self, pdu, destinations, order):
# We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus
# table and we'll get back to it later.
destinations = set(destinations)
destinations.discard(self.server_name)
destinations.discard("localhost")
logger.debug("Sending to: %s", str(destinations))
if not destinations:
return
deferreds = []
for destination in destinations:
deferred = defer.Deferred()
self.pending_pdus_by_dest.setdefault(destination, []).append(
(pdu, deferred, order)
)
def eb(failure):
if not deferred.called:
deferred.errback(failure)
else:
logger.warn("Failed to send pdu", failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(eb)
deferreds.append(deferred)
yield defer.DeferredList(deferreds)
# NO inlineCallbacks
def enqueue_edu(self, edu):
destination = edu.destination
if destination == self.server_name:
return
deferred = defer.Deferred()
self.pending_edus_by_dest.setdefault(destination, []).append(
(edu, deferred)
)
def eb(failure):
if not deferred.called:
deferred.errback(failure)
else:
logger.warn("Failed to send edu", failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(eb)
return deferred
@defer.inlineCallbacks
def enqueue_failure(self, failure, destination):
deferred = defer.Deferred()
self.pending_failures_by_dest.setdefault(
destination, []
).append(
(failure, deferred)
)
yield deferred
@defer.inlineCallbacks
@log_function
def _attempt_new_transaction(self, destination):
(retry_last_ts, retry_interval) = (0, 0)
retry_timings = yield self.store.get_destination_retry_timings(
destination
)
if retry_timings:
(retry_last_ts, retry_interval) = (
retry_timings.retry_last_ts, retry_timings.retry_interval
)
if retry_last_ts + retry_interval > int(self._clock.time_msec()):
logger.info(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,
)
return
else:
logger.info("TX [%s] is ready for retry", destination)
if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending
# request at which point pending_pdus_by_dest just keeps growing.
# we need application-layer timeouts of some flavour of these
# requests
logger.info(
"TX [%s] Transaction already in progress",
destination
)
return
logger.info("TX [%s] _attempt_new_transaction", destination)
# list of (pending_pdu, deferred, order)
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_failures = self.pending_failures_by_dest.pop(destination, [])
if pending_pdus:
logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
logger.info("TX [%s] Nothing to send", destination)
return
logger.debug(
"TX [%s] Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)",
destination,
len(pending_pdus),
len(pending_edus),
len(pending_failures)
)
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[2])
pdus = [x[0] for x in pending_pdus]
edus = [x[0] for x in pending_edus]
failures = [x[0].get_dict() for x in pending_failures]
deferreds = [
x[1]
for x in pending_pdus + pending_edus + pending_failures
]
try:
self.pending_transactions[destination] = 1
logger.debug("TX [%s] Persisting transaction...", destination)
transaction = Transaction.create_new(
origin_server_ts=int(self._clock.time_msec()),
transaction_id=str(self._next_txn_id),
origin=self.server_name,
destination=destination,
pdus=pdus,
edus=edus,
pdu_failures=failures,
)
self._next_txn_id += 1
yield self.transaction_actions.prepare_to_send(transaction)
logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
"TX [%s] Sending transaction [%s]",
destination,
transaction.transaction_id,
)
# Actually send the transaction
# FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work
def json_data_cb():
data = transaction.get_dict()
now = int(self._clock.time_msec())
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
unsigned = p.setdefault("unsigned", {})
unsigned["age"] = now - int(p["age_ts"])
del p["age_ts"]
return data
try:
response = yield self.transport_layer.send_transaction(
transaction, json_data_cb
)
code = 200
except HttpResponseException as e:
code = e.code
response = e.response
logger.info("TX [%s] got %d response", destination, code)
logger.debug("TX [%s] Sent transaction", destination)
logger.debug("TX [%s] Marking as delivered...", destination)
yield self.transaction_actions.delivered(
transaction, code, response
)
logger.debug("TX [%s] Marked as delivered", destination)
logger.debug("TX [%s] Yielding to callbacks...", destination)
for deferred in deferreds:
if code == 200:
if retry_last_ts:
# this host is alive! reset retry schedule
yield self.store.set_destination_retry_timings(
destination, 0, 0
)
deferred.callback(None)
else:
self.set_retrying(destination, retry_interval)
deferred.errback(RuntimeError("Got status %d" % code))
# Ensures we don't continue until all callbacks on that
# deferred have fired
try:
yield deferred
except:
pass
logger.debug("TX [%s] Yielded to callbacks", destination)
except RuntimeError as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
except Exception as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.exception(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
self.set_retrying(destination, retry_interval)
for deferred in deferreds:
if not deferred.called:
deferred.errback(e)
finally:
# We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None)
# Check to see if there is anything else to send.
self._attempt_new_transaction(destination)
@defer.inlineCallbacks
def set_retrying(self, destination, retry_interval):
# track that this destination is having problems and we should
# give it a chance to recover before trying it again
if retry_interval:
retry_interval *= 2
# plateau at hourly retries for now
if retry_interval >= 60 * 60 * 1000:
retry_interval = 60 * 60 * 1000
else:
retry_interval = 2000 # try again at first after 2 seconds
yield self.store.set_destination_retry_timings(
destination,
int(self._clock.time_msec()),
retry_interval
)

View File

@@ -19,6 +19,7 @@ from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.util.logutils import log_function
import logging
import json
logger = logging.getLogger(__name__)
@@ -128,7 +129,7 @@ class TransportLayerClient(object):
# generated by the json_data_callback.
json_data = transaction.get_dict()
response = yield self.client.put_json(
code, response = yield self.client.put_json(
transaction.destination,
path=PREFIX + "/send/%s/" % transaction.transaction_id,
data=json_data,
@@ -136,86 +137,79 @@ class TransportLayerClient(object):
)
logger.debug(
"send_data dest=%s, txid=%s, got response: 200",
transaction.destination, transaction.transaction_id,
"send_data dest=%s, txid=%s, got response: %d",
transaction.destination, transaction.transaction_id, code
)
defer.returnValue(response)
defer.returnValue((code, response))
@defer.inlineCallbacks
@log_function
def make_query(self, destination, query_type, args, retry_on_dns_fail):
path = PREFIX + "/query/%s" % query_type
content = yield self.client.get_json(
response = yield self.client.get_json(
destination=destination,
path=path,
args=args,
retry_on_dns_fail=retry_on_dns_fail,
)
defer.returnValue(content)
defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def make_join(self, destination, room_id, user_id, retry_on_dns_fail=True):
path = PREFIX + "/make_join/%s/%s" % (room_id, user_id)
content = yield self.client.get_json(
response = yield self.client.get_json(
destination=destination,
path=path,
retry_on_dns_fail=retry_on_dns_fail,
)
defer.returnValue(content)
defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def send_join(self, destination, room_id, event_id, content):
path = PREFIX + "/send_join/%s/%s" % (room_id, event_id)
response = yield self.client.put_json(
code, content = yield self.client.put_json(
destination=destination,
path=path,
data=content,
)
defer.returnValue(response)
if not 200 <= code < 300:
raise RuntimeError("Got %d from send_join", code)
defer.returnValue(json.loads(content))
@defer.inlineCallbacks
@log_function
def send_invite(self, destination, room_id, event_id, content):
path = PREFIX + "/invite/%s/%s" % (room_id, event_id)
response = yield self.client.put_json(
code, content = yield self.client.put_json(
destination=destination,
path=path,
data=content,
)
defer.returnValue(response)
if not 200 <= code < 300:
raise RuntimeError("Got %d from send_invite", code)
defer.returnValue(json.loads(content))
@defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
path = PREFIX + "/event_auth/%s/%s" % (room_id, event_id)
content = yield self.client.get_json(
response = yield self.client.get_json(
destination=destination,
path=path,
)
defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def send_query_auth(self, destination, room_id, event_id, content):
path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id)
content = yield self.client.post_json(
destination=destination,
path=path,
data=content,
)
defer.returnValue(content)
defer.returnValue(response)

View File

@@ -20,7 +20,7 @@ from synapse.api.errors import Codes, SynapseError
from synapse.util.logutils import log_function
import logging
import simplejson as json
import json
import re
@@ -42,7 +42,7 @@ class TransportLayerServer(object):
content = None
origin = None
if request.method in ["PUT", "POST"]:
if request.method == "PUT":
# TODO: Handle other method types? other content types?
try:
content_bytes = request.content.read()
@@ -234,16 +234,6 @@ class TransportLayerServer(object):
)
)
)
self.server.register_path(
"POST",
re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"),
self._with_authentication(
lambda origin, content, query, context, event_id:
self._on_query_auth_request(
origin, content, event_id,
)
)
)
@defer.inlineCallbacks
@log_function
@@ -335,12 +325,3 @@ class TransportLayerServer(object):
)
defer.returnValue((200, content))
@defer.inlineCallbacks
@log_function
def _on_query_auth_request(self, origin, content, event_id):
new_content = yield self.request_handler.on_query_auth_request(
origin, content, event_id
)
defer.returnValue((200, new_content))

View File

@@ -26,7 +26,6 @@ from .presence import PresenceHandler
from .directory import DirectoryHandler
from .typing import TypingNotificationHandler
from .admin import AdminHandler
from .sync import SyncHandler
class Handlers(object):
@@ -52,4 +51,3 @@ class Handlers(object):
self.directory_handler = DirectoryHandler(hs)
self.typing_notification_handler = TypingNotificationHandler(hs)
self.admin_handler = AdminHandler(hs)
self.sync_handler = SyncHandler(hs)

View File

@@ -19,7 +19,6 @@ from synapse.api.errors import LimitExceededError, SynapseError
from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID
import logging
@@ -114,7 +113,7 @@ class BaseHandler(object):
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
invitee = UserID.from_string(event.state_key)
invitee = self.hs.parse_userid(event.state_key)
if not self.hs.is_mine(invitee):
# TODO: Can we add signature from remote server in a nicer
# way? If we have been invited by a remote server, we need
@@ -135,7 +134,7 @@ class BaseHandler(object):
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN:
destinations.add(
UserID.from_string(s.state_key).domain
self.hs.parse_userid(s.state_key).domain
)
except SynapseError:
logger.warn(

View File

@@ -19,7 +19,6 @@ from ._base import BaseHandler
from synapse.api.errors import SynapseError, Codes, CodeMessageException
from synapse.api.constants import EventTypes
from synapse.types import RoomAlias
import logging
@@ -113,16 +112,7 @@ class DirectoryHandler(BaseHandler):
)
extra_servers = yield self.store.get_joined_hosts_for_room(room_id)
servers = set(extra_servers) | set(servers)
# If this server is in the list of servers, return it first.
if self.server_name in servers:
servers = (
[self.server_name]
+ [s for s in servers if s != self.server_name]
)
else:
servers = list(servers)
servers = list(set(extra_servers) | set(servers))
defer.returnValue({
"room_id": room_id,
@@ -132,7 +122,7 @@ class DirectoryHandler(BaseHandler):
@defer.inlineCallbacks
def on_directory_query(self, args):
room_alias = RoomAlias.from_string(args["room_alias"])
room_alias = self.hs.parse_roomalias(args["room_alias"])
if not self.hs.is_mine(room_alias):
raise SynapseError(
400, "Room Alias is not hosted on this Home Server"

View File

@@ -17,8 +17,6 @@ from twisted.internet import defer
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function
from synapse.types import UserID
from synapse.events.utils import serialize_event
from ._base import BaseHandler
@@ -49,25 +47,24 @@ class EventStreamHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
def get_stream(self, auth_user_id, pagin_config, timeout=0,
as_client_event=True, affect_presence=True):
auth_user = UserID.from_string(auth_user_id)
as_client_event=True):
auth_user = self.hs.parse_userid(auth_user_id)
try:
if affect_presence:
if auth_user not in self._streams_per_user:
self._streams_per_user[auth_user] = 0
if auth_user in self._stop_timer_per_user:
try:
self.clock.cancel_call_later(
self._stop_timer_per_user.pop(auth_user)
)
except:
logger.exception("Failed to cancel event timer")
else:
yield self.distributor.fire(
"started_user_eventstream", auth_user
if auth_user not in self._streams_per_user:
self._streams_per_user[auth_user] = 0
if auth_user in self._stop_timer_per_user:
try:
self.clock.cancel_call_later(
self._stop_timer_per_user.pop(auth_user)
)
self._streams_per_user[auth_user] += 1
except:
logger.exception("Failed to cancel event timer")
else:
yield self.distributor.fire(
"started_user_eventstream", auth_user
)
self._streams_per_user[auth_user] += 1
if pagin_config.from_token is None:
pagin_config.from_token = None
@@ -80,10 +77,8 @@ class EventStreamHandler(BaseHandler):
auth_user, room_ids, pagin_config, timeout
)
time_now = self.clock.time_msec()
chunks = [
serialize_event(e, time_now, as_client_event) for e in events
self.hs.serialize_event(e, as_client_event) for e in events
]
chunk = {
@@ -95,29 +90,28 @@ class EventStreamHandler(BaseHandler):
defer.returnValue(chunk)
finally:
if affect_presence:
self._streams_per_user[auth_user] -= 1
if not self._streams_per_user[auth_user]:
del self._streams_per_user[auth_user]
self._streams_per_user[auth_user] -= 1
if not self._streams_per_user[auth_user]:
del self._streams_per_user[auth_user]
# 10 seconds of grace to allow the client to reconnect again
# before we think they're gone
def _later():
logger.debug(
"_later stopped_user_eventstream %s", auth_user
)
self._stop_timer_per_user.pop(auth_user, None)
return self.distributor.fire(
"stopped_user_eventstream", auth_user
)
logger.debug("Scheduling _later: for %s", auth_user)
self._stop_timer_per_user[auth_user] = (
self.clock.call_later(30, _later)
# 10 seconds of grace to allow the client to reconnect again
# before we think they're gone
def _later():
logger.debug(
"_later stopped_user_eventstream %s", auth_user
)
self._stop_timer_per_user.pop(auth_user, None)
yield self.distributor.fire(
"stopped_user_eventstream", auth_user
)
logger.debug("Scheduling _later: for %s", auth_user)
self._stop_timer_per_user[auth_user] = (
self.clock.call_later(30, _later)
)
class EventHandler(BaseHandler):

View File

@@ -17,21 +17,21 @@
from ._base import BaseHandler
from synapse.events.utils import prune_event
from synapse.api.errors import (
AuthError, FederationError, StoreError,
AuthError, FederationError, SynapseError, StoreError,
)
from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.api.constants import EventTypes, Membership
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.util.frozenutils import unfreeze
from synapse.crypto.event_signing import (
compute_event_signature, add_hashes_and_signatures,
compute_event_signature, check_event_content_hash,
add_hashes_and_signatures,
)
from synapse.types import UserID
from syutil.jsonutil import encode_canonical_json
from twisted.internet import defer
import itertools
import logging
@@ -112,6 +112,33 @@ class FederationHandler(BaseHandler):
logger.debug("Processing event: %s", event.event_id)
redacted_event = prune_event(event)
redacted_pdu_json = redacted_event.get_pdu_json()
try:
yield self.keyring.verify_json_for_server(
event.origin, redacted_pdu_json
)
except SynapseError as e:
logger.warn(
"Signature check failed for %s redacted to %s",
encode_canonical_json(pdu.get_pdu_json()),
encode_canonical_json(redacted_pdu_json),
)
raise FederationError(
"ERROR",
e.code,
e.msg,
affected=event.event_id,
)
if not check_event_content_hash(event):
logger.warn(
"Event content has been tampered, redacting %s, %s",
event.event_id, encode_canonical_json(event.get_dict())
)
event = redacted_event
logger.debug("Event: %s", event)
# FIXME (erikj): Awful hack to make the case where we are not currently
@@ -121,38 +148,41 @@ class FederationHandler(BaseHandler):
event.room_id,
self.server_name
)
if not is_in_room and not event.internal_metadata.is_outlier():
if not is_in_room and not event.internal_metadata.outlier:
logger.debug("Got event for room we're not in.")
current_state = state
event_ids = set()
if state:
event_ids |= {e.event_id for e in state}
if auth_chain:
event_ids |= {e.event_id for e in auth_chain}
replication = self.replication_layer
seen_ids = set(
(yield self.store.have_events(event_ids)).keys()
)
if not state:
state, auth_chain = yield replication.get_state_for_room(
origin, context=event.room_id, event_id=event.event_id,
)
if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication
# layer, then we should handle them (if we haven't before.)
for e in itertools.chain(auth_chain, state):
if e.event_id in seen_ids:
continue
if not auth_chain:
auth_chain = yield replication.get_event_auth(
origin,
context=event.room_id,
event_id=event.event_id,
)
for e in auth_chain:
e.internal_metadata.outlier = True
try:
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
yield self._handle_new_event(
origin, e, auth_events=auth
yield self._handle_new_event(e, fetch_auth_from=origin)
except:
logger.exception(
"Failed to handle auth event %s",
e.event_id,
)
seen_ids.add(e.event_id)
current_state = state
if state:
for e in state:
logging.info("A :) %r", e)
e.internal_metadata.outlier = True
try:
yield self._handle_new_event(e)
except:
logger.exception(
"Failed to handle state event %s",
@@ -161,7 +191,6 @@ class FederationHandler(BaseHandler):
try:
yield self._handle_new_event(
origin,
event,
state=state,
backfilled=backfilled,
@@ -198,7 +227,7 @@ class FederationHandler(BaseHandler):
extra_users = []
if event.type == EventTypes.Member:
target_user_id = event.state_key
target_user = UserID.from_string(target_user_id)
target_user = self.hs.parse_userid(target_user_id)
extra_users.append(target_user)
yield self.notifier.on_new_room_event(
@@ -207,7 +236,7 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
user = UserID.from_string(event.state_key)
user = self.hs.parse_userid(event.state_key)
yield self.distributor.fire(
"user_joined_room", user=user, room_id=event.room_id
)
@@ -276,7 +305,7 @@ class FederationHandler(BaseHandler):
@log_function
@defer.inlineCallbacks
def do_invite_join(self, target_hosts, room_id, joinee, content, snapshot):
def do_invite_join(self, target_host, room_id, joinee, content, snapshot):
""" Attempts to join the `joinee` to the room `room_id` via the
server `target_host`.
@@ -290,8 +319,8 @@ class FederationHandler(BaseHandler):
"""
logger.debug("Joining %s to %s", joinee, room_id)
origin, pdu = yield self.replication_layer.make_join(
target_hosts,
pdu = yield self.replication_layer.make_join(
target_host,
room_id,
joinee
)
@@ -312,7 +341,7 @@ class FederationHandler(BaseHandler):
self.room_queues[room_id] = []
builder = self.event_builder_factory.new(
unfreeze(event.get_pdu_json())
event.get_pdu_json()
)
handled_events = set()
@@ -333,20 +362,11 @@ class FederationHandler(BaseHandler):
new_event = builder.build()
# Try the host we successfully got a response to /make_join/
# request first.
try:
target_hosts.remove(origin)
target_hosts.insert(0, origin)
except ValueError:
pass
ret = yield self.replication_layer.send_join(
target_hosts,
target_host,
new_event
)
origin = ret["origin"]
state = ret["state"]
auth_chain = ret["auth_chain"]
auth_chain.sort(key=lambda e: e.depth)
@@ -372,19 +392,8 @@ class FederationHandler(BaseHandler):
for e in auth_chain:
e.internal_metadata.outlier = True
if e.event_id == event.event_id:
continue
try:
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
yield self._handle_new_event(
origin, e, auth_events=auth
)
yield self._handle_new_event(e)
except:
logger.exception(
"Failed to handle auth event %s",
@@ -392,18 +401,11 @@ class FederationHandler(BaseHandler):
)
for e in state:
if e.event_id == event.event_id:
continue
# FIXME: Auth these.
e.internal_metadata.outlier = True
try:
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
yield self._handle_new_event(
origin, e, auth_events=auth
e, fetch_auth_from=target_host
)
except:
logger.exception(
@@ -411,18 +413,10 @@ class FederationHandler(BaseHandler):
e.event_id,
)
auth_ids = [e_id for e_id, _ in event.auth_events]
auth_events = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
yield self._handle_new_event(
origin,
new_event,
state=state,
current_state=state,
auth_events=auth_events,
)
yield self.notifier.on_new_room_event(
@@ -486,7 +480,7 @@ class FederationHandler(BaseHandler):
event.internal_metadata.outlier = False
context = yield self._handle_new_event(origin, event)
context = yield self._handle_new_event(event)
logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
@@ -497,7 +491,7 @@ class FederationHandler(BaseHandler):
extra_users = []
if event.type == EventTypes.Member:
target_user_id = event.state_key
target_user = UserID.from_string(target_user_id)
target_user = self.hs.parse_userid(target_user_id)
extra_users.append(target_user)
yield self.notifier.on_new_room_event(
@@ -506,7 +500,7 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.JOIN:
user = UserID.from_string(event.state_key)
user = self.hs.parse_userid(event.state_key)
yield self.distributor.fire(
"user_joined_room", user=user, room_id=event.room_id
)
@@ -520,15 +514,13 @@ class FederationHandler(BaseHandler):
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN:
destinations.add(
UserID.from_string(s.state_key).domain
self.hs.parse_userid(s.state_key).domain
)
except:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
destinations.discard(origin)
logger.debug(
"on_send_join_request: Sending event: %s, signatures: %s",
event.event_id,
@@ -573,7 +565,7 @@ class FederationHandler(BaseHandler):
backfilled=False,
)
target_user = UserID.from_string(event.state_key)
target_user = self.hs.parse_userid(event.state_key)
yield self.notifier.on_new_room_event(
event, extra_users=[target_user],
)
@@ -649,7 +641,6 @@ class FederationHandler(BaseHandler):
event = yield self.store.get_event(
event_id,
allow_none=True,
allow_rejected=True,
)
if event:
@@ -690,12 +681,11 @@ class FederationHandler(BaseHandler):
waiters.pop().callback(None)
@defer.inlineCallbacks
@log_function
def _handle_new_event(self, origin, event, state=None, backfilled=False,
current_state=None, auth_events=None):
def _handle_new_event(self, event, state=None, backfilled=False,
current_state=None, fetch_auth_from=None):
logger.debug(
"_handle_new_event: %s, sigs: %s",
"_handle_new_event: Before annotate: %s, sigs: %s",
event.event_id, event.signatures,
)
@@ -703,46 +693,65 @@ class FederationHandler(BaseHandler):
event, old_state=state
)
if not auth_events:
auth_events = context.auth_events
logger.debug(
"_handle_new_event: %s, auth_events: %s",
event.event_id, auth_events,
"_handle_new_event: Before auth fetch: %s, sigs: %s",
event.event_id, event.signatures,
)
is_new_state = not event.internal_metadata.is_outlier()
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
known_ids = set(
[s.event_id for s in context.auth_events.values()]
)
for e_id, _ in event.auth_events:
if e_id not in known_ids:
e = yield self.store.get_event(e_id, allow_none=True)
if not e and fetch_auth_from is not None:
# Grab the auth_chain over federation if we are missing
# auth events.
auth_chain = yield self.replication_layer.get_event_auth(
fetch_auth_from, event.event_id, event.room_id
)
for auth_event in auth_chain:
yield self._handle_new_event(auth_event)
e = yield self.store.get_event(e_id, allow_none=True)
if not e:
# TODO: Do some conflict res to make sure that we're
# not the ones who are wrong.
logger.info(
"Rejecting %s as %s not in db or %s",
event.event_id, e_id, known_ids,
)
# FIXME: How does raising AuthError work with federation?
raise AuthError(403, "Cannot find auth event")
context.auth_events[(e.type, e.state_key)] = e
logger.debug(
"_handle_new_event: Before hack: %s, sigs: %s",
event.event_id, event.signatures,
)
if event.type == EventTypes.Member and not event.auth_events:
if len(event.prev_events) == 1:
c = yield self.store.get_event(event.prev_events[0][0])
if c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c
context.auth_events[(c.type, c.state_key)] = c
try:
yield self.do_auth(
origin, event, context, auth_events=auth_events
)
except AuthError as e:
logger.warn(
"Rejecting %s because %s",
event.event_id, e.msg
)
logger.debug(
"_handle_new_event: Before auth check: %s, sigs: %s",
event.event_id, event.signatures,
)
context.rejected = RejectedReason.AUTH_ERROR
self.auth.check(event, auth_events=context.auth_events)
# FIXME: Don't store as rejected with AUTH_ERROR if we haven't
# seen all the auth events.
yield self.store.persist_event(
event,
context=context,
backfilled=backfilled,
is_new_state=False,
current_state=current_state,
)
raise
logger.debug(
"_handle_new_event: Before persist_event: %s, sigs: %s",
event.event_id, event.signatures,
)
yield self.store.persist_event(
event,
@@ -752,363 +761,9 @@ class FederationHandler(BaseHandler):
current_state=current_state,
)
logger.debug(
"_handle_new_event: After persist_event: %s, sigs: %s",
event.event_id, event.signatures,
)
defer.returnValue(context)
@defer.inlineCallbacks
def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
missing):
# Just go through and process each event in `remote_auth_chain`. We
# don't want to fall into the trap of `missing` being wrong.
for e in remote_auth_chain:
try:
yield self._handle_new_event(origin, e)
except AuthError:
pass
# Now get the current auth_chain for the event.
local_auth_chain = yield self.store.get_auth_chain([event_id])
# TODO: Check if we would now reject event_id. If so we need to tell
# everyone.
ret = yield self.construct_auth_difference(
local_auth_chain, remote_auth_chain
)
for event in ret["auth_chain"]:
event.signatures.update(
compute_event_signature(
event,
self.hs.hostname,
self.hs.config.signing_key[0]
)
)
logger.debug("on_query_auth returning: %s", ret)
defer.returnValue(ret)
@defer.inlineCallbacks
@log_function
def do_auth(self, origin, event, context, auth_events):
# Check if we have all the auth events.
have_events = yield self.store.have_events(
[e_id for e_id, _ in event.auth_events]
)
event_auth_events = set(e_id for e_id, _ in event.auth_events)
seen_events = set(have_events.keys())
missing_auth = event_auth_events - seen_events
if missing_auth:
logger.debug("Missing auth: %s", missing_auth)
# If we don't have all the auth events, we need to get them.
try:
remote_auth_chain = yield self.replication_layer.get_event_auth(
origin, event.room_id, event.event_id
)
seen_remotes = yield self.store.have_events(
[e.event_id for e in remote_auth_chain]
)
for e in remote_auth_chain:
if e.event_id in seen_remotes.keys():
continue
if e.event_id == event.event_id:
continue
try:
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in remote_auth_chain
if e.event_id in auth_ids
}
e.internal_metadata.outlier = True
logger.debug(
"do_auth %s missing_auth: %s",
event.event_id, e.event_id
)
yield self._handle_new_event(
origin, e, auth_events=auth
)
if e.event_id in event_auth_events:
auth_events[(e.type, e.state_key)] = e
except AuthError:
pass
have_events = yield self.store.have_events(
[e_id for e_id, _ in event.auth_events]
)
seen_events = set(have_events.keys())
except:
# FIXME:
logger.exception("Failed to get auth chain")
# FIXME: Assumes we have and stored all the state for all the
# prev_events
current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state
if different_auth and not event.internal_metadata.is_outlier():
# Do auth conflict res.
logger.debug("Different auth: %s", different_auth)
different_events = yield defer.gatherResults(
[
self.store.get_event(
d,
allow_none=True,
allow_rejected=False,
)
for d in different_auth
if d in have_events and not have_events[d]
],
consumeErrors=True
)
if different_events:
local_view = dict(auth_events)
remote_view = dict(auth_events)
remote_view.update({
(d.type, d.state_key) for d in different_events
})
new_state, prev_state = self.state_handler.resolve_events(
[local_view, remote_view],
event
)
auth_events.update(new_state)
current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state
context.current_state.update(auth_events)
context.state_group = None
if different_auth and not event.internal_metadata.is_outlier():
# Only do auth resolution if we have something new to say.
# We can't rove an auth failure.
do_resolution = False
provable = [
RejectedReason.NOT_ANCESTOR, RejectedReason.NOT_ANCESTOR,
]
for e_id in different_auth:
if e_id in have_events:
if have_events[e_id] in provable:
do_resolution = True
break
if do_resolution:
# 1. Get what we think is the auth chain.
auth_ids = self.auth.compute_auth_events(
event, context.current_state
)
local_auth_chain = yield self.store.get_auth_chain(auth_ids)
try:
# 2. Get remote difference.
result = yield self.replication_layer.query_auth(
origin,
event.room_id,
event.event_id,
local_auth_chain,
)
seen_remotes = yield self.store.have_events(
[e.event_id for e in result["auth_chain"]]
)
# 3. Process any remote auth chain events we haven't seen.
for ev in result["auth_chain"]:
if ev.event_id in seen_remotes.keys():
continue
if ev.event_id == event.event_id:
continue
try:
auth_ids = [e_id for e_id, _ in ev.auth_events]
auth = {
(e.type, e.state_key): e
for e in result["auth_chain"]
if e.event_id in auth_ids
}
ev.internal_metadata.outlier = True
logger.debug(
"do_auth %s different_auth: %s",
event.event_id, e.event_id
)
yield self._handle_new_event(
origin, ev, auth_events=auth
)
if ev.event_id in event_auth_events:
auth_events[(ev.type, ev.state_key)] = ev
except AuthError:
pass
except:
# FIXME:
logger.exception("Failed to query auth chain")
# 4. Look at rejects and their proofs.
# TODO.
context.current_state.update(auth_events)
context.state_group = None
try:
self.auth.check(event, auth_events=auth_events)
except AuthError:
raise
@defer.inlineCallbacks
def construct_auth_difference(self, local_auth, remote_auth):
""" Given a local and remote auth chain, find the differences. This
assumes that we have already processed all events in remote_auth
Params:
local_auth (list)
remote_auth (list)
Returns:
dict
"""
logger.debug("construct_auth_difference Start!")
# TODO: Make sure we are OK with local_auth or remote_auth having more
# auth events in them than strictly necessary.
def sort_fun(ev):
return ev.depth, ev.event_id
logger.debug("construct_auth_difference after sort_fun!")
# We find the differences by starting at the "bottom" of each list
# and iterating up on both lists. The lists are ordered by depth and
# then event_id, we iterate up both lists until we find the event ids
# don't match. Then we look at depth/event_id to see which side is
# missing that event, and iterate only up that list. Repeat.
remote_list = list(remote_auth)
remote_list.sort(key=sort_fun)
local_list = list(local_auth)
local_list.sort(key=sort_fun)
local_iter = iter(local_list)
remote_iter = iter(remote_list)
logger.debug("construct_auth_difference before get_next!")
def get_next(it, opt=None):
try:
return it.next()
except:
return opt
current_local = get_next(local_iter)
current_remote = get_next(remote_iter)
logger.debug("construct_auth_difference before while")
missing_remotes = []
missing_locals = []
while current_local or current_remote:
if current_remote is None:
missing_locals.append(current_local)
current_local = get_next(local_iter)
continue
if current_local is None:
missing_remotes.append(current_remote)
current_remote = get_next(remote_iter)
continue
if current_local.event_id == current_remote.event_id:
current_local = get_next(local_iter)
current_remote = get_next(remote_iter)
continue
if current_local.depth < current_remote.depth:
missing_locals.append(current_local)
current_local = get_next(local_iter)
continue
if current_local.depth > current_remote.depth:
missing_remotes.append(current_remote)
current_remote = get_next(remote_iter)
continue
# They have the same depth, so we fall back to the event_id order
if current_local.event_id < current_remote.event_id:
missing_locals.append(current_local)
current_local = get_next(local_iter)
if current_local.event_id > current_remote.event_id:
missing_remotes.append(current_remote)
current_remote = get_next(remote_iter)
continue
logger.debug("construct_auth_difference after while")
# missing locals should be sent to the server
# We should find why we are missing remotes, as they will have been
# rejected.
# Remove events from missing_remotes if they are referencing a missing
# remote. We only care about the "root" rejected ones.
missing_remote_ids = [e.event_id for e in missing_remotes]
base_remote_rejected = list(missing_remotes)
for e in missing_remotes:
for e_id, _ in e.auth_events:
if e_id in missing_remote_ids:
try:
base_remote_rejected.remove(e)
except ValueError:
pass
reason_map = {}
for e in base_remote_rejected:
reason = yield self.store.get_rejection_reason(e.event_id)
if reason is None:
# TODO: e is not in the current state, so we should
# construct some proof of that.
continue
reason_map[e.event_id] = reason
if reason == RejectedReason.AUTH_ERROR:
pass
elif reason == RejectedReason.REPLACED:
# TODO: Get proof
pass
elif reason == RejectedReason.NOT_ANCESTOR:
# TODO: Get proof.
pass
logger.debug("construct_auth_difference returning")
defer.returnValue({
"auth_chain": local_auth,
"rejects": {
e.event_id: {
"reason": reason_map[e.event_id],
"proof": None,
}
for e in base_remote_rejected
},
"missing": [e.event_id for e in missing_locals],
})

View File

@@ -16,12 +16,10 @@
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import RoomError, SynapseError
from synapse.api.errors import RoomError
from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID
from ._base import BaseHandler
@@ -35,7 +33,6 @@ class MessageHandler(BaseHandler):
def __init__(self, hs):
super(MessageHandler, self).__init__(hs)
self.hs = hs
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
@@ -92,7 +89,7 @@ class MessageHandler(BaseHandler):
yield self.hs.get_event_sources().get_current_token()
)
user = UserID.from_string(user_id)
user = self.hs.parse_userid(user_id)
events, next_key = yield data_source.get_pagination_rows(
user, pagin_config.get_source_config("room"), room_id
@@ -102,11 +99,9 @@ class MessageHandler(BaseHandler):
"room_key", next_key
)
time_now = self.clock.time_msec()
chunk = {
"chunk": [
serialize_event(e, time_now, as_client_event) for e in events
self.hs.serialize_event(e, as_client_event) for e in events
],
"start": pagin_config.from_token.to_string(),
"end": next_token.to_string(),
@@ -115,8 +110,7 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk)
@defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True,
client=None, txn_id=None):
def create_and_send_event(self, event_dict, ratelimit=True):
""" Given a dict from a client, create and handle a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events,
@@ -136,13 +130,13 @@ class MessageHandler(BaseHandler):
if ratelimit:
self.ratelimit(builder.user_id)
# TODO(paul): Why does 'event' not have a 'user' object?
user = UserID.from_string(builder.user_id)
user = self.hs.parse_userid(builder.user_id)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None)
if membership == Membership.JOIN:
joinee = UserID.from_string(builder.state_key)
joinee = self.hs.parse_userid(builder.state_key)
# If event doesn't include a display name, add one.
yield self.distributor.fire(
"collect_presencelike_data",
@@ -150,15 +144,6 @@ class MessageHandler(BaseHandler):
builder.content
)
if client is not None:
if client.token_id is not None:
builder.internal_metadata.token_id = client.token_id
if client.device_id is not None:
builder.internal_metadata.device_id = client.device_id
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id
event, context = yield self._create_new_client_event(
builder=builder,
)
@@ -225,10 +210,7 @@ class MessageHandler(BaseHandler):
# TODO: This is duplicating logic from snapshot_all_rooms
current_state = yield self.state_handler.get_current_state(room_id)
now = self.clock.time_msec()
defer.returnValue(
[serialize_event(c, now) for c in current_state.values()]
)
defer.returnValue([self.hs.serialize_event(c) for c in current_state])
@defer.inlineCallbacks
def snapshot_all_rooms(self, user_id=None, pagin_config=None,
@@ -255,7 +237,7 @@ class MessageHandler(BaseHandler):
membership_list=[Membership.INVITE, Membership.JOIN]
)
user = UserID.from_string(user_id)
user = self.hs.parse_userid(user_id)
rooms_ret = []
@@ -300,11 +282,10 @@ class MessageHandler(BaseHandler):
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
d["messages"] = {
"chunk": [
serialize_event(m, time_now, as_client_event)
self.hs.serialize_event(m, as_client_event)
for m in messages
],
"start": start_token.to_string(),
@@ -315,8 +296,7 @@ class MessageHandler(BaseHandler):
event.room_id
)
d["state"] = [
serialize_event(c, time_now, as_client_event)
for c in current_state.values()
self.hs.serialize_event(c) for c in current_state
]
except:
logger.exception("Failed to get snapshot")
@@ -332,27 +312,20 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks
def room_initial_sync(self, user_id, room_id, pagin_config=None,
feedback=False):
current_state = yield self.state.get_current_state(
room_id=room_id,
)
yield self.auth.check_joined_room(
room_id, user_id,
current_state=current_state
)
yield self.auth.check_joined_room(room_id, user_id)
# TODO(paul): I wish I was called with user objects not user_id
# strings...
auth_user = UserID.from_string(user_id)
auth_user = self.hs.parse_userid(user_id)
# TODO: These concurrently
time_now = self.clock.time_msec()
state = [
serialize_event(x, time_now)
for x in current_state.values()
]
state_tuples = yield self.state_handler.get_current_state(room_id)
state = [self.hs.serialize_event(x) for x in state_tuples]
member_event = current_state.get((EventTypes.Member, user_id,))
member_event = (yield self.store.get_room_member(
user_id=user_id,
room_id=room_id
))
now_token = yield self.hs.get_event_sources().get_current_token()
@@ -369,34 +342,28 @@ class MessageHandler(BaseHandler):
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
room_members = [
m for m in current_state.values()
if m.type == EventTypes.Member
and m.content["membership"] == Membership.JOIN
]
room_members = yield self.store.get_room_members(room_id)
presence_handler = self.hs.get_handlers().presence_handler
presence = []
for m in room_members:
try:
member_presence = yield presence_handler.get_state(
target_user=UserID.from_string(m.user_id),
target_user=self.hs.parse_userid(m.user_id),
auth_user=auth_user,
as_event=True,
)
presence.append(member_presence)
except SynapseError:
except Exception:
logger.exception(
"Failed to get member presence of %r", m.user_id
)
time_now = self.clock.time_msec()
defer.returnValue({
"membership": member_event.membership,
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
"chunk": [self.hs.serialize_event(m) for m in messages],
"start": start_token.to_string(),
"end": end_token.to_string(),
},

View File

@@ -20,7 +20,6 @@ from synapse.api.constants import PresenceState
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID
from ._base import BaseHandler
@@ -87,10 +86,6 @@ class PresenceHandler(BaseHandler):
"changed_presencelike_data", self.changed_presencelike_data
)
# outbound signal from the presence module to advertise when a user's
# presence has changed
distributor.declare("user_presence_changed")
self.distributor = distributor
self.federation = hs.get_replication_layer()
@@ -101,22 +96,22 @@ class PresenceHandler(BaseHandler):
self.federation.register_edu_handler(
"m.presence_invite",
lambda origin, content: self.invite_presence(
observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]),
observed_user=hs.parse_userid(content["observed_user"]),
observer_user=hs.parse_userid(content["observer_user"]),
)
)
self.federation.register_edu_handler(
"m.presence_accept",
lambda origin, content: self.accept_presence(
observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]),
observed_user=hs.parse_userid(content["observed_user"]),
observer_user=hs.parse_userid(content["observer_user"]),
)
)
self.federation.register_edu_handler(
"m.presence_deny",
lambda origin, content: self.deny_presence(
observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]),
observed_user=hs.parse_userid(content["observed_user"]),
observer_user=hs.parse_userid(content["observer_user"]),
)
)
@@ -423,7 +418,7 @@ class PresenceHandler(BaseHandler):
)
for p in presence:
observed_user = UserID.from_string(p.pop("observed_user_id"))
observed_user = self.hs.parse_userid(p.pop("observed_user_id"))
p["observed_user"] = observed_user
p.update(self._get_or_offline_usercache(observed_user).get_state())
if "last_active" in p:
@@ -446,7 +441,7 @@ class PresenceHandler(BaseHandler):
user.localpart, accepted=True
)
target_users = set([
UserID.from_string(x["observed_user_id"]) for x in presence
self.hs.parse_userid(x["observed_user_id"]) for x in presence
])
# Also include people in all my rooms
@@ -457,9 +452,9 @@ class PresenceHandler(BaseHandler):
if state is None:
state = yield self.store.get_presence_state(user.localpart)
else:
# statuscache = self._get_or_make_usercache(user)
# self._user_cachemap_latest_serial += 1
# statuscache.update(state, self._user_cachemap_latest_serial)
# statuscache = self._get_or_make_usercache(user)
# self._user_cachemap_latest_serial += 1
# statuscache.update(state, self._user_cachemap_latest_serial)
pass
yield self.push_update_to_local_and_remote(
@@ -608,7 +603,6 @@ class PresenceHandler(BaseHandler):
room_ids=room_ids,
statuscache=statuscache,
)
yield self.distributor.fire("user_presence_changed", user, statuscache)
@defer.inlineCallbacks
def _push_presence_remote(self, user, destination, state=None):
@@ -652,15 +646,13 @@ class PresenceHandler(BaseHandler):
deferreds = []
for push in content.get("push", []):
user = UserID.from_string(push["user_id"])
user = self.hs.parse_userid(push["user_id"])
logger.debug("Incoming presence update from %s", user)
observers = set(self._remote_recvmap.get(user, set()))
if observers:
logger.debug(
" | %d interested local observers %r", len(observers), observers
)
logger.debug(" | %d interested local observers %r", len(observers), observers)
rm_handler = self.homeserver.get_handlers().room_member_handler
room_ids = yield rm_handler.get_rooms_for_user(user)
@@ -702,14 +694,14 @@ class PresenceHandler(BaseHandler):
del self._user_cachemap[user]
for poll in content.get("poll", []):
user = UserID.from_string(poll)
user = self.hs.parse_userid(poll)
if not self.hs.is_mine(user):
continue
# TODO(paul) permissions checks
if user not in self._remote_sendmap:
if not user in self._remote_sendmap:
self._remote_sendmap[user] = set()
self._remote_sendmap[user].add(origin)
@@ -717,7 +709,7 @@ class PresenceHandler(BaseHandler):
deferreds.append(self._push_presence_remote(user, origin))
for unpoll in content.get("unpoll", []):
user = UserID.from_string(unpoll)
user = self.hs.parse_userid(unpoll)
if not self.hs.is_mine(user):
continue

View File

@@ -18,7 +18,6 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.api.constants import EventTypes, Membership
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID
from ._base import BaseHandler
@@ -170,7 +169,7 @@ class ProfileHandler(BaseHandler):
@defer.inlineCallbacks
def on_profile_query(self, args):
user = UserID.from_string(args["user_id"])
user = self.hs.parse_userid(args["user_id"])
if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this Home Server")

View File

@@ -99,26 +99,6 @@ class RegistrationHandler(BaseHandler):
raise RegistrationError(
500, "Cannot generate user ID.")
# create a default avatar for the user
# XXX: ideally clients would explicitly specify one, but given they don't
# and we want consistent and pretty identicons for random users, we'll
# do it here.
try:
auth_user = UserID.from_string(user_id)
media_repository = self.hs.get_resource_for_media_repository()
identicon_resource = media_repository.getChildWithDefault("identicon", None)
upload_resource = media_repository.getChildWithDefault("upload", None)
identicon_bytes = identicon_resource.generate_identicon(user_id, 320, 320)
content_uri = yield upload_resource.create_content(
"image/png", None, identicon_bytes, len(identicon_bytes), auth_user
)
profile_handler = self.hs.get_handlers().profile_handler
profile_handler.set_avatar_url(
auth_user, auth_user, ("%s#auto" % (content_uri,))
)
except NotImplementedError:
pass # make tests pass without messing around creating default avatars
defer.returnValue((user_id, token))
@defer.inlineCallbacks
@@ -183,7 +163,7 @@ class RegistrationHandler(BaseHandler):
# each request
httpCli = SimpleHttpClient(self.hs)
# XXX: make this configurable!
trustedIdServers = ['matrix.org:8090', 'matrix.org']
trustedIdServers = ['matrix.org:8090']
if not creds['idServer'] in trustedIdServers:
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
'credentials', creds['idServer'])

View File

@@ -16,14 +16,12 @@
"""Contains functions for performing events on rooms."""
from twisted.internet import defer
from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import StoreError, SynapseError
from synapse.util import stringutils
from synapse.util.async import run_on_reactor
from synapse.events.utils import serialize_event
from ._base import BaseHandler
import logging
@@ -66,7 +64,7 @@ class RoomCreationHandler(BaseHandler):
invite_list = config.get("invite", [])
for i in invite_list:
try:
UserID.from_string(i)
self.hs.parse_userid(i)
except:
raise SynapseError(400, "Invalid user_id: %s" % (i,))
@@ -116,7 +114,7 @@ class RoomCreationHandler(BaseHandler):
servers=[self.hs.hostname],
)
user = UserID.from_string(user_id)
user = self.hs.parse_userid(user_id)
creation_events = self._create_events_for_new_room(
user, room_id, is_public=is_public
)
@@ -248,9 +246,11 @@ class RoomMemberHandler(BaseHandler):
@defer.inlineCallbacks
def get_room_members(self, room_id):
hs = self.hs
users = yield self.store.get_users_in_room(room_id)
defer.returnValue([UserID.from_string(u) for u in users])
defer.returnValue([hs.parse_userid(u) for u in users])
@defer.inlineCallbacks
def fetch_room_distributions_into(self, room_id, localusers=None,
@@ -295,9 +295,8 @@ class RoomMemberHandler(BaseHandler):
yield self.auth.check_joined_room(room_id, user_id)
member_list = yield self.store.get_room_members(room_id=room_id)
time_now = self.clock.time_msec()
event_list = [
serialize_event(entry, time_now)
self.hs.serialize_event(entry)
for entry in member_list
]
chunk_data = {
@@ -369,7 +368,7 @@ class RoomMemberHandler(BaseHandler):
)
if prev_state and prev_state.membership == Membership.JOIN:
user = UserID.from_string(event.user_id)
user = self.hs.parse_userid(event.user_id)
self.distributor.fire(
"user_left_room", user=user, room_id=event.room_id
)
@@ -389,6 +388,8 @@ class RoomMemberHandler(BaseHandler):
if not hosts:
raise SynapseError(404, "No known servers")
host = hosts[0]
# If event doesn't include a display name, add one.
yield self.distributor.fire(
"collect_presencelike_data", joinee, content
@@ -405,13 +406,13 @@ class RoomMemberHandler(BaseHandler):
})
event, context = yield self._create_new_client_event(builder)
yield self._do_join(event, context, room_hosts=hosts, do_auth=True)
yield self._do_join(event, context, room_host=host, do_auth=True)
defer.returnValue({"room_id": room_id})
@defer.inlineCallbacks
def _do_join(self, event, context, room_hosts=None, do_auth=True):
joinee = UserID.from_string(event.state_key)
def _do_join(self, event, context, room_host=None, do_auth=True):
joinee = self.hs.parse_userid(event.state_key)
# room_id = RoomID.from_string(event.room_id, self.hs)
room_id = event.room_id
@@ -439,7 +440,7 @@ class RoomMemberHandler(BaseHandler):
if is_host_in_room:
should_do_dance = False
elif room_hosts: # TODO: Shouldn't this be remote_room_host?
elif room_host: # TODO: Shouldn't this be remote_room_host?
should_do_dance = True
else:
# TODO(markjh): get prev_state from snapshot
@@ -451,7 +452,7 @@ class RoomMemberHandler(BaseHandler):
inviter = UserID.from_string(prev_state.user_id)
should_do_dance = not self.hs.is_mine(inviter)
room_hosts = [inviter.domain]
room_host = inviter.domain
else:
# return the same error as join_room_alias does
raise SynapseError(404, "No known servers")
@@ -459,10 +460,10 @@ class RoomMemberHandler(BaseHandler):
if should_do_dance:
handler = self.hs.get_handlers().federation_handler
yield handler.do_invite_join(
room_hosts,
room_host,
room_id,
event.user_id,
event.content, # FIXME To get a non-frozen dict
event.get_dict()["content"], # FIXME To get a non-frozen dict
context
)
else:
@@ -475,7 +476,7 @@ class RoomMemberHandler(BaseHandler):
do_auth=do_auth,
)
user = UserID.from_string(event.user_id)
user = self.hs.parse_userid(event.user_id)
yield self.distributor.fire(
"user_joined_room", user=user, room_id=room_id
)
@@ -525,7 +526,7 @@ class RoomMemberHandler(BaseHandler):
do_auth):
yield run_on_reactor()
target_user = UserID.from_string(event.state_key)
target_user = self.hs.parse_userid(event.state_key)
yield self.handle_new_client_event(
event,

View File

@@ -1,439 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseHandler
from synapse.streams.config import PaginationConfig
from synapse.api.constants import Membership, EventTypes
from twisted.internet import defer
import collections
import logging
logger = logging.getLogger(__name__)
SyncConfig = collections.namedtuple("SyncConfig", [
"user",
"client_info",
"limit",
"gap",
"sort",
"backfill",
"filter",
])
class RoomSyncResult(collections.namedtuple("RoomSyncResult", [
"room_id",
"limited",
"published",
"events",
"state",
"prev_batch",
"ephemeral",
])):
__slots__ = []
def __nonzero__(self):
"""Make the result appear empty if there are no updates. This is used
to tell if room needs to be part of the sync result.
"""
return bool(self.events or self.state or self.ephemeral)
class SyncResult(collections.namedtuple("SyncResult", [
"next_batch", # Token for the next sync
"private_user_data", # List of private events for the user.
"public_user_data", # List of public events for all users.
"rooms", # RoomSyncResult for each room.
])):
__slots__ = []
def __nonzero__(self):
"""Make the result appear empty if there are no updates. This is used
to tell if the notifier needs to wait for more events when polling for
events.
"""
return bool(
self.private_user_data or self.public_user_data or self.rooms
)
class SyncHandler(BaseHandler):
def __init__(self, hs):
super(SyncHandler, self).__init__(hs)
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
@defer.inlineCallbacks
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0):
"""Get the sync for a client if we have new data for it now. Otherwise
wait for new data to arrive on the server. If the timeout expires, then
return an empty sync result.
Returns:
A Deferred SyncResult.
"""
if timeout == 0 or since_token is None:
result = yield self.current_sync_for_user(sync_config, since_token)
defer.returnValue(result)
else:
def current_sync_callback():
return self.current_sync_for_user(sync_config, since_token)
rm_handler = self.hs.get_handlers().room_member_handler
room_ids = yield rm_handler.get_rooms_for_user(sync_config.user)
result = yield self.notifier.wait_for_events(
sync_config.user, room_ids,
sync_config.filter, timeout, current_sync_callback
)
defer.returnValue(result)
def current_sync_for_user(self, sync_config, since_token=None):
"""Get the sync for client needed to match what the server has now.
Returns:
A Deferred SyncResult.
"""
if since_token is None:
return self.initial_sync(sync_config)
else:
if sync_config.gap:
return self.incremental_sync_with_gap(sync_config, since_token)
else:
# TODO(mjark): Handle gapless sync
raise NotImplementedError()
@defer.inlineCallbacks
def initial_sync(self, sync_config):
"""Get a sync for a client which is starting without any state
Returns:
A Deferred SyncResult.
"""
if sync_config.sort == "timeline,desc":
# TODO(mjark): Handle going through events in reverse order?.
# What does "most recent events" mean when applying the limits mean
# in this case?
raise NotImplementedError()
now_token = yield self.event_sources.get_current_token()
presence_stream = self.event_sources.sources["presence"]
# TODO (mjark): This looks wrong, shouldn't we be getting the presence
# UP to the present rather than after the present?
pagination_config = PaginationConfig(from_token=now_token)
presence, _ = yield presence_stream.get_pagination_rows(
user=sync_config.user,
pagination_config=pagination_config.get_source_config("presence"),
key=None
)
room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=sync_config.user.to_string(),
membership_list=[Membership.INVITE, Membership.JOIN]
)
# TODO (mjark): Does public mean "published"?
published_rooms = yield self.store.get_rooms(is_public=True)
published_room_ids = set(r["room_id"] for r in published_rooms)
rooms = []
for event in room_list:
room_sync = yield self.initial_sync_for_room(
event.room_id, sync_config, now_token, published_room_ids
)
rooms.append(room_sync)
defer.returnValue(SyncResult(
public_user_data=presence,
private_user_data=[],
rooms=rooms,
next_batch=now_token,
))
@defer.inlineCallbacks
def initial_sync_for_room(self, room_id, sync_config, now_token,
published_room_ids):
"""Sync a room for a client which is starting without any state
Returns:
A Deferred RoomSyncResult.
"""
recents, prev_batch_token, limited = yield self.load_filtered_recents(
room_id, sync_config, now_token,
)
current_state = yield self.state_handler.get_current_state(
room_id
)
current_state_events = current_state.values()
defer.returnValue(RoomSyncResult(
room_id=room_id,
published=room_id in published_room_ids,
events=recents,
prev_batch=prev_batch_token,
state=current_state_events,
limited=limited,
ephemeral=[],
))
@defer.inlineCallbacks
def incremental_sync_with_gap(self, sync_config, since_token):
""" Get the incremental delta needed to bring the client up to
date with the server.
Returns:
A Deferred SyncResult.
"""
if sync_config.sort == "timeline,desc":
# TODO(mjark): Handle going through events in reverse order?.
# What does "most recent events" mean when applying the limits mean
# in this case?
raise NotImplementedError()
now_token = yield self.event_sources.get_current_token()
presence_source = self.event_sources.sources["presence"]
presence, presence_key = yield presence_source.get_new_events_for_user(
user=sync_config.user,
from_key=since_token.presence_key,
limit=sync_config.limit,
)
now_token = now_token.copy_and_replace("presence_key", presence_key)
typing_source = self.event_sources.sources["typing"]
typing, typing_key = yield typing_source.get_new_events_for_user(
user=sync_config.user,
from_key=since_token.typing_key,
limit=sync_config.limit,
)
now_token = now_token.copy_and_replace("typing_key", typing_key)
typing_by_room = {event["room_id"]: [event] for event in typing}
for event in typing:
event.pop("room_id")
logger.debug("Typing %r", typing_by_room)
rm_handler = self.hs.get_handlers().room_member_handler
room_ids = yield rm_handler.get_rooms_for_user(sync_config.user)
# TODO (mjark): Does public mean "published"?
published_rooms = yield self.store.get_rooms(is_public=True)
published_room_ids = set(r["room_id"] for r in published_rooms)
room_events, _ = yield self.store.get_room_events_stream(
sync_config.user.to_string(),
from_key=since_token.room_key,
to_key=now_token.room_key,
room_id=None,
limit=sync_config.limit + 1,
)
rooms = []
if len(room_events) <= sync_config.limit:
# There is no gap in any of the rooms. Therefore we can just
# partition the new events by room and return them.
events_by_room_id = {}
for event in room_events:
events_by_room_id.setdefault(event.room_id, []).append(event)
for room_id in room_ids:
recents = events_by_room_id.get(room_id, [])
state = [event for event in recents if event.is_state()]
if recents:
prev_batch = now_token.copy_and_replace(
"room_key", recents[0].internal_metadata.before
)
else:
prev_batch = now_token
state = yield self.check_joined_room(
sync_config, room_id, state
)
room_sync = RoomSyncResult(
room_id=room_id,
published=room_id in published_room_ids,
events=recents,
prev_batch=prev_batch,
state=state,
limited=False,
ephemeral=typing_by_room.get(room_id, [])
)
if room_sync:
rooms.append(room_sync)
else:
for room_id in room_ids:
room_sync = yield self.incremental_sync_with_gap_for_room(
room_id, sync_config, since_token, now_token,
published_room_ids, typing_by_room
)
if room_sync:
rooms.append(room_sync)
defer.returnValue(SyncResult(
public_user_data=presence,
private_user_data=[],
rooms=rooms,
next_batch=now_token,
))
@defer.inlineCallbacks
def load_filtered_recents(self, room_id, sync_config, now_token,
since_token=None):
limited = True
recents = []
filtering_factor = 2
load_limit = max(sync_config.limit * filtering_factor, 100)
max_repeat = 3 # Only try a few times per room, otherwise
room_key = now_token.room_key
end_key = room_key
while limited and len(recents) < sync_config.limit and max_repeat:
events, keys = yield self.store.get_recent_events_for_room(
room_id,
limit=load_limit + 1,
from_token=since_token.room_key if since_token else None,
end_token=end_key,
)
(room_key, _) = keys
end_key = "s" + room_key.split('-')[-1]
loaded_recents = sync_config.filter.filter_room_events(events)
loaded_recents.extend(recents)
recents = loaded_recents
if len(events) <= load_limit:
limited = False
max_repeat -= 1
if len(recents) > sync_config.limit:
recents = recents[-sync_config.limit:]
room_key = recents[0].internal_metadata.before
prev_batch_token = now_token.copy_and_replace(
"room_key", room_key
)
defer.returnValue((recents, prev_batch_token, limited))
@defer.inlineCallbacks
def incremental_sync_with_gap_for_room(self, room_id, sync_config,
since_token, now_token,
published_room_ids, typing_by_room):
""" Get the incremental delta needed to bring the client up to date for
the room. Gives the client the most recent events and the changes to
state.
Returns:
A Deferred RoomSyncResult
"""
# TODO(mjark): Check for redactions we might have missed.
recents, prev_batch_token, limited = yield self.load_filtered_recents(
room_id, sync_config, now_token, since_token,
)
logging.debug("Recents %r", recents)
# TODO(mjark): This seems racy since this isn't being passed a
# token to indicate what point in the stream this is
current_state = yield self.state_handler.get_current_state(
room_id
)
current_state_events = current_state.values()
state_at_previous_sync = yield self.get_state_at_previous_sync(
room_id, since_token=since_token
)
state_events_delta = yield self.compute_state_delta(
since_token=since_token,
previous_state=state_at_previous_sync,
current_state=current_state_events,
)
state_events_delta = yield self.check_joined_room(
sync_config, room_id, state_events_delta
)
room_sync = RoomSyncResult(
room_id=room_id,
published=room_id in published_room_ids,
events=recents,
prev_batch=prev_batch_token,
state=state_events_delta,
limited=limited,
ephemeral=typing_by_room.get(room_id, [])
)
logging.debug("Room sync: %r", room_sync)
defer.returnValue(room_sync)
@defer.inlineCallbacks
def get_state_at_previous_sync(self, room_id, since_token):
""" Get the room state at the previous sync the client made.
Returns:
A Deferred list of Events.
"""
last_events, token = yield self.store.get_recent_events_for_room(
room_id, end_token=since_token.room_key, limit=1,
)
if last_events:
last_event = last_events[0]
last_context = yield self.state_handler.compute_event_context(
last_event
)
if last_event.is_state():
state = [last_event] + last_context.current_state.values()
else:
state = last_context.current_state.values()
else:
state = ()
defer.returnValue(state)
def compute_state_delta(self, since_token, previous_state, current_state):
""" Works out the differnce in state between the current state and the
state the client got when it last performed a sync.
Returns:
A list of events.
"""
# TODO(mjark) Check if the state events were received by the server
# after the previous sync, since we need to include those state
# updates even if they occured logically before the previous event.
# TODO(mjark) Check for new redactions in the state events.
previous_dict = {event.event_id: event for event in previous_state}
state_delta = []
for event in current_state:
if event.event_id not in previous_dict:
state_delta.append(event)
return state_delta
@defer.inlineCallbacks
def check_joined_room(self, sync_config, room_id, state_delta):
joined = False
for event in state_delta:
if (
event.type == EventTypes.Member
and event.state_key == sync_config.user.to_string()
):
if event.content["membership"] == Membership.JOIN:
joined = True
if joined:
res = yield self.state_handler.get_current_state(room_id)
state_delta = res.values()
defer.returnValue(state_delta)

View File

@@ -18,7 +18,6 @@ from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.errors import SynapseError, AuthError
from synapse.types import UserID
import logging
@@ -186,7 +185,7 @@ class TypingNotificationHandler(BaseHandler):
@defer.inlineCallbacks
def _recv_edu(self, origin, content):
room_id = content["room_id"]
user = UserID.from_string(content["user_id"])
user = self.homeserver.parse_userid(content["user_id"])
localusers = set()

View File

@@ -15,8 +15,6 @@
from synapse.http.agent_name import AGENT_NAME
from syutil.jsonutil import encode_canonical_json
from twisted.internet import defer, reactor
from twisted.web.client import (
Agent, readBody, FileBodyProducer, PartialDownloadError
@@ -25,7 +23,7 @@ from twisted.web.http_headers import Headers
from StringIO import StringIO
import simplejson as json
import json
import logging
import urllib
@@ -64,25 +62,6 @@ class SimpleHttpClient(object):
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def post_json_get_json(self, uri, post_json):
json_str = encode_canonical_json(post_json)
logger.info("HTTP POST %s -> %s", json_str, uri)
response = yield self.agent.request(
"POST",
uri.encode("ascii"),
headers=Headers({
"Content-Type": ["application/json"]
}),
bodyProducer=FileBodyProducer(StringIO(json_str))
)
body = yield readBody(response)
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def get_json(self, uri, args={}):
""" Get's some json from the given host and path

View File

@@ -27,13 +27,11 @@ from synapse.util.logcontext import PreserveLoggingContext
from syutil.jsonutil import encode_canonical_json
from synapse.api.errors import (
SynapseError, Codes, HttpResponseException,
)
from synapse.api.errors import CodeMessageException, SynapseError, Codes
from syutil.crypto.jsonsign import sign_json
import simplejson as json
import json
import logging
import urllib
import urlparse
@@ -79,7 +77,6 @@ class MatrixFederationHttpClient(object):
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
self.agent = MatrixFederationHttpAgent(reactor)
self.clock = hs.get_clock()
@defer.inlineCallbacks
def _create_request(self, destination, method, path_bytes,
@@ -119,7 +116,7 @@ class MatrixFederationHttpClient(object):
try:
with PreserveLoggingContext():
request_deferred = self.agent.request(
response = yield self.agent.request(
destination,
endpoint,
method,
@@ -130,11 +127,6 @@ class MatrixFederationHttpClient(object):
producer
)
response = yield self.clock.time_bound_deferred(
request_deferred,
time_out=60,
)
logger.debug("Got response to %s", method)
break
except Exception as e:
@@ -171,13 +163,13 @@ class MatrixFederationHttpClient(object):
)
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
pass
else:
# :'(
# Update transactions table?
body = yield readBody(response)
raise HttpResponseException(
response.code, response.phrase, body
raise CodeMessageException(
response.code, response.phrase
)
defer.returnValue(response)
@@ -246,66 +238,11 @@ class MatrixFederationHttpClient(object):
headers_dict={"Content-Type": ["application/json"]},
)
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
c_type = response.headers.getRawHeaders("Content-Type")
if "application/json" not in c_type:
raise RuntimeError(
"Content-Type not application/json"
)
logger.debug("Getting resp body")
body = yield readBody(response)
logger.debug("Got resp body")
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def post_json(self, destination, path, data={}):
""" Sends the specifed json data using POST
Args:
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
data (dict): A dict containing the data that will be used as
the request body. This will be encoded as JSON.
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body. On a 4xx or 5xx error response a
CodeMessageException is raised.
"""
def body_callback(method, url_bytes, headers_dict):
self.sign_request(
destination, method, url_bytes, headers_dict, data
)
return _JsonProducer(data)
response = yield self._create_request(
destination.encode("ascii"),
"POST",
path.encode("ascii"),
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
)
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
c_type = response.headers.getRawHeaders("Content-Type")
if "application/json" not in c_type:
raise RuntimeError(
"Content-Type not application/json"
)
logger.debug("Getting resp body")
body = yield readBody(response)
logger.debug("Got resp body")
defer.returnValue(json.loads(body))
defer.returnValue((response.code, body))
@defer.inlineCallbacks
def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
@@ -347,18 +284,7 @@ class MatrixFederationHttpClient(object):
retry_on_dns_fail=retry_on_dns_fail
)
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
c_type = response.headers.getRawHeaders("Content-Type")
if "application/json" not in c_type:
raise RuntimeError(
"Content-Type not application/json"
)
logger.debug("Getting resp body")
body = yield readBody(response)
logger.debug("Got resp body")
defer.returnValue(json.loads(body))

View File

@@ -16,7 +16,7 @@
from synapse.http.agent_name import AGENT_NAME
from synapse.api.errors import (
cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError
cs_exception, SynapseError, CodeMessageException
)
from synapse.util.logcontext import LoggingContext
@@ -69,10 +69,9 @@ class JsonResource(HttpServer, resource.Resource):
_PathEntry = collections.namedtuple("_PathEntry", ["pattern", "callback"])
def __init__(self, hs):
def __init__(self):
resource.Resource.__init__(self)
self.clock = hs.get_clock()
self.path_regexs = {}
def register_path(self, method, path_pattern, callback):
@@ -112,8 +111,6 @@ class JsonResource(HttpServer, resource.Resource):
This checks if anyone has registered a callback for that method and
path.
"""
code = None
start = self.clock.time_msec()
try:
# Just say yes to OPTIONS.
if request.method == "OPTIONS":
@@ -133,11 +130,6 @@ class JsonResource(HttpServer, resource.Resource):
urllib.unquote(u).decode("UTF-8") for u in m.groups()
]
logger.info(
"Received request: %s %s",
request.method, request.path
)
code, response = yield path_entry.callback(
request,
*args
@@ -147,17 +139,19 @@ class JsonResource(HttpServer, resource.Resource):
return
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
raise UnrecognizedRequestError()
self._send_response(
request,
400,
{"error": "Unrecognized request"}
)
except CodeMessageException as e:
if isinstance(e, SynapseError):
logger.info("%s SynapseError: %s - %s", request, e.code, e.msg)
else:
logger.exception(e)
code = e.code
self._send_response(
request,
code,
e.code,
cs_exception(e),
response_code_message=e.response_code_message
)
@@ -168,14 +162,6 @@ class JsonResource(HttpServer, resource.Resource):
500,
{"error": "Internal server error"}
)
finally:
code = str(code) if code else "-"
end = self.clock.time_msec()
logger.info(
"Processed request: %dms %s %s %s",
end-start, code, request.method, request.path
)
def _send_response(self, request, code, response_json_object,
response_code_message=None):

View File

@@ -1,113 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" This module contains base REST classes for constructing REST servlets. """
from synapse.api.errors import SynapseError
import logging
logger = logging.getLogger(__name__)
class RestServlet(object):
""" A Synapse REST Servlet.
An implementing class can either provide its own custom 'register' method,
or use the automatic pattern handling provided by the base class.
To use this latter, the implementing class instead provides a `PATTERN`
class attribute containing a pre-compiled regular expression. The automatic
register method will then use this method to register any of the following
instance methods associated with the corresponding HTTP method:
on_GET
on_PUT
on_POST
on_DELETE
on_OPTIONS
Automatically handles turning CodeMessageExceptions thrown by these methods
into the appropriate HTTP response.
"""
def register(self, http_server):
""" Register this servlet with the given HTTP server. """
if hasattr(self, "PATTERN"):
pattern = self.PATTERN
for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
if hasattr(self, "on_%s" % (method)):
method_handler = getattr(self, "on_%s" % (method))
http_server.register_path(method, pattern, method_handler)
else:
raise NotImplementedError("RestServlet must register something.")
@staticmethod
def parse_integer(request, name, default=None, required=False):
if name in request.args:
try:
return int(request.args[name][0])
except:
message = "Query parameter %r must be an integer" % (name,)
raise SynapseError(400, message)
else:
if required:
message = "Missing integer query parameter %r" % (name,)
raise SynapseError(400, message)
else:
return default
@staticmethod
def parse_boolean(request, name, default=None, required=False):
if name in request.args:
try:
return {
"true": True,
"false": False,
}[request.args[name][0]]
except:
message = (
"Boolean query parameter %r must be one of"
" ['true', 'false']"
) % (name,)
raise SynapseError(400, message)
else:
if required:
message = "Missing boolean query parameter %r" % (name,)
raise SynapseError(400, message)
else:
return default
@staticmethod
def parse_string(request, name, default=None, required=False,
allowed_values=None, param_type="string"):
if name in request.args:
value = request.args[name][0]
if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % (
name, ", ".join(repr(v) for v in allowed_values)
)
raise SynapseError(message)
else:
return value
else:
if required:
message = "Missing %s query parameter %r" % (param_type, name)
raise SynapseError(400, message)
else:
return default

View File

@@ -25,7 +25,7 @@ from twisted.web import server, resource
from twisted.internet import defer
import base64
import simplejson as json
import json
import logging
import os
import re
@@ -66,7 +66,7 @@ class ContentRepoResource(resource.Resource):
@defer.inlineCallbacks
def map_request_to_name(self, request):
# auth the user
auth_user, client = yield self.auth.get_user_by_req(request)
auth_user = yield self.auth.get_user_by_req(request)
# namespace all file uploads on the user
prefix = base64.urlsafe_b64encode(

View File

@@ -82,7 +82,7 @@ class BaseMediaResource(Resource):
raise SynapseError(
404,
"Invalid media id token %r" % (request.postpath,),
Codes.UNKNOWN,
Codes.UNKKOWN,
)
@staticmethod

View File

@@ -16,7 +16,6 @@
from .upload_resource import UploadResource
from .download_resource import DownloadResource
from .thumbnail_resource import ThumbnailResource
from .identicon_resource import IdenticonResource
from .filepath import MediaFilePaths
from twisted.web.resource import Resource
@@ -34,7 +33,6 @@ class MediaRepositoryResource(Resource):
=> POST /_matrix/media/v1/upload HTTP/1.1
Content-Type: <media-type>
Content-Length: <content-length>
<media>
@@ -77,4 +75,3 @@ class MediaRepositoryResource(Resource):
self.putChild("upload", UploadResource(hs, filepaths))
self.putChild("download", DownloadResource(hs, filepaths))
self.putChild("thumbnail", ThumbnailResource(hs, filepaths))
self.putChild("identicon", IdenticonResource())

View File

@@ -39,40 +39,10 @@ class UploadResource(BaseMediaResource):
respond_with_json(request, 200, {}, send_cors=True)
return NOT_DONE_YET
@defer.inlineCallbacks
def create_content(self, media_type, upload_name, content, content_length,
auth_user):
media_id = random_string(24)
fname = self.filepaths.local_media_filepath(media_id)
self._makedirs(fname)
# This shouldn't block for very long because the content will have
# already been uploaded at this point.
with open(fname, "wb") as f:
f.write(content)
yield self.store.store_local_media(
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=upload_name,
media_length=content_length,
user_id=auth_user,
)
media_info = {
"media_type": media_type,
"media_length": content_length,
}
yield self._generate_local_thumbnails(media_id, media_info)
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
@defer.inlineCallbacks
def _async_render_POST(self, request):
try:
auth_user, client = yield self.auth.get_user_by_req(request)
auth_user = yield self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
content_length = request.getHeader("Content-Length")
@@ -96,14 +66,36 @@ class UploadResource(BaseMediaResource):
code=400,
)
# if headers.hasHeader("Content-Disposition"):
# disposition = headers.getRawHeaders("Content-Disposition")[0]
#if headers.hasHeader("Content-Disposition"):
# disposition = headers.getRawHeaders("Content-Disposition")[0]
# TODO(markjh): parse content-dispostion
content_uri = yield self.create_content(
media_type, None, request.content.read(),
content_length, auth_user
media_id = random_string(24)
fname = self.filepaths.local_media_filepath(media_id)
self._makedirs(fname)
# This shouldn't block for very long because the content will have
# already been uploaded at this point.
with open(fname, "wb") as f:
f.write(request.content.read())
yield self.store.store_local_media(
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=None,
media_length=content_length,
user_id=auth_user,
)
media_info = {
"media_type": media_type,
"media_length": content_length,
}
yield self._generate_local_thumbnails(media_id, media_info)
content_uri = "mxc://%s/%s" % (self.server_name, media_id)
respond_with_json(
request, 200, {"content_uri": content_uri}, send_cors=True

View File

@@ -18,7 +18,6 @@ from twisted.internet import defer
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.async import run_on_reactor
from synapse.types import StreamToken
import logging
@@ -206,53 +205,6 @@ class Notifier(object):
[notify(l).addErrback(eb) for l in listeners]
)
@defer.inlineCallbacks
def wait_for_events(self, user, rooms, filter, timeout, callback):
"""Wait until the callback returns a non empty response or the
timeout fires.
"""
deferred = defer.Deferred()
from_token = StreamToken("s0", "0", "0")
listener = [_NotificationListener(
user=user,
rooms=rooms,
from_token=from_token,
limit=1,
timeout=timeout,
deferred=deferred,
)]
if timeout:
self._register_with_keys(listener[0])
result = yield callback()
if timeout:
timed_out = [False]
def _timeout_listener():
timed_out[0] = True
listener[0].notify(self, [], from_token, from_token)
self.clock.call_later(timeout/1000., _timeout_listener)
while not result and not timed_out[0]:
yield deferred
deferred = defer.Deferred()
listener[0] = _NotificationListener(
user=user,
rooms=rooms,
from_token=from_token,
limit=1,
timeout=timeout,
deferred=deferred,
)
self._register_with_keys(listener[0])
result = yield callback()
defer.returnValue(result)
def get_events_for(self, user, rooms, pagination_config, timeout):
""" For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any

View File

@@ -1,423 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.streams.config import PaginationConfig
from synapse.types import StreamToken, UserID
import synapse.util.async
import baserules
import logging
import simplejson as json
import re
logger = logging.getLogger(__name__)
class Pusher(object):
INITIAL_BACKOFF = 1000
MAX_BACKOFF = 60 * 60 * 1000
GIVE_UP_AFTER = 24 * 60 * 60 * 1000
DEFAULT_ACTIONS = ['notify']
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
def __init__(self, _hs, profile_tag, user_name, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since):
self.hs = _hs
self.evStreamHandler = self.hs.get_handlers().event_stream_handler
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.profile_tag = profile_tag
self.user_name = user_name
self.app_id = app_id
self.app_display_name = app_display_name
self.device_display_name = device_display_name
self.pushkey = pushkey
self.pushkey_ts = pushkey_ts
self.data = data
self.last_token = last_token
self.last_success = last_success # not actually used
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.failing_since = failing_since
self.alive = True
# The last value of last_active_time that we saw
self.last_last_active_time = 0
self.has_unread = True
@defer.inlineCallbacks
def _actions_for_event(self, ev):
"""
This should take into account notification settings that the user
has configured both globally and per-room when we have the ability
to do such things.
"""
if ev['user_id'] == self.user_name:
# let's assume you probably know about messages you sent yourself
defer.returnValue(['dont_notify'])
if ev['type'] == 'm.room.member':
if ev['state_key'] != self.user_name:
defer.returnValue(['dont_notify'])
rawrules = yield self.store.get_push_rules_for_user_name(self.user_name)
for r in rawrules:
r['conditions'] = json.loads(r['conditions'])
r['actions'] = json.loads(r['actions'])
user = UserID.from_string(self.user_name)
rules = baserules.list_with_base_rules(rawrules, user)
# get *our* member event for display name matching
member_events_for_room = yield self.store.get_current_state(
room_id=ev['room_id'],
event_type='m.room.member',
state_key=None
)
my_display_name = None
room_member_count = 0
for mev in member_events_for_room:
if mev.content['membership'] != 'join':
continue
# This loop does two things:
# 1) Find our current display name
if mev.state_key == self.user_name and 'displayname' in mev.content:
my_display_name = mev.content['displayname']
# and 2) Get the number of people in that room
room_member_count += 1
for r in rules:
matches = True
conditions = r['conditions']
actions = r['actions']
for c in conditions:
matches &= self._event_fulfills_condition(
ev, c, display_name=my_display_name,
room_member_count=room_member_count
)
# ignore rules with no actions (we have an explict 'dont_notify'
if len(actions) == 0:
logger.warn(
"Ignoring rule id %s with no actions for user %s" %
(r['rule_id'], r['user_name'])
)
continue
if matches:
defer.returnValue(actions)
defer.returnValue(Pusher.DEFAULT_ACTIONS)
@staticmethod
def _glob_to_regexp(glob):
r = re.escape(glob)
r = re.sub(r'\\\*', r'.*?', r)
r = re.sub(r'\\\?', r'.', r)
# handle [abc], [a-z] and [!a-z] style ranges.
r = re.sub(r'\\\[(\\\!|)(.*)\\\]',
lambda x: ('[%s%s]' % (x.group(1) and '^' or '',
re.sub(r'\\\-', '-', x.group(2)))), r)
return r
def _event_fulfills_condition(self, ev, condition, display_name, room_member_count):
if condition['kind'] == 'event_match':
if 'pattern' not in condition:
logger.warn("event_match condition with no pattern")
return False
# XXX: optimisation: cache our pattern regexps
if condition['key'] == 'content.body':
r = r'\b%s\b' % self._glob_to_regexp(condition['pattern'])
else:
r = r'^%s$' % self._glob_to_regexp(condition['pattern'])
val = _value_for_dotted_key(condition['key'], ev)
if val is None:
return False
return re.search(r, val, flags=re.IGNORECASE) is not None
elif condition['kind'] == 'device':
if 'profile_tag' not in condition:
return True
return condition['profile_tag'] == self.profile_tag
elif condition['kind'] == 'contains_display_name':
# This is special because display names can be different
# between rooms and so you can't really hard code it in a rule.
# Optimisation: we should cache these names and update them from
# the event stream.
if 'content' not in ev or 'body' not in ev['content']:
return False
if not display_name:
return False
return re.search(
"\b%s\b" % re.escape(display_name), ev['content']['body'],
flags=re.IGNORECASE
) is not None
elif condition['kind'] == 'room_member_count':
if 'is' not in condition:
return False
m = Pusher.INEQUALITY_EXPR.match(condition['is'])
if not m:
return False
ineq = m.group(1)
rhs = m.group(2)
if not rhs.isdigit():
return False
rhs = int(rhs)
if ineq == '' or ineq == '==':
return room_member_count == rhs
elif ineq == '<':
return room_member_count < rhs
elif ineq == '>':
return room_member_count > rhs
elif ineq == '>=':
return room_member_count >= rhs
elif ineq == '<=':
return room_member_count <= rhs
else:
return False
else:
return True
@defer.inlineCallbacks
def get_context_for_event(self, ev):
name_aliases = yield self.store.get_room_name_and_aliases(
ev['room_id']
)
ctx = {'aliases': name_aliases[1]}
if name_aliases[0] is not None:
ctx['name'] = name_aliases[0]
their_member_events_for_room = yield self.store.get_current_state(
room_id=ev['room_id'],
event_type='m.room.member',
state_key=ev['user_id']
)
for mev in their_member_events_for_room:
if mev.content['membership'] == 'join' and 'displayname' in mev.content:
dn = mev.content['displayname']
if dn is not None:
ctx['sender_display_name'] = dn
defer.returnValue(ctx)
@defer.inlineCallbacks
def start(self):
if not self.last_token:
# First-time setup: get a token to start from (we can't
# just start from no token, ie. 'now'
# because we need the result to be reproduceable in case
# we fail to dispatch the push)
config = PaginationConfig(from_token=None, limit='1')
chunk = yield self.evStreamHandler.get_stream(
self.user_name, config, timeout=0)
self.last_token = chunk['end']
self.store.update_pusher_last_token(
self.user_name, self.pushkey, self.last_token)
logger.info("Pusher %s for user %s starting from token %s",
self.pushkey, self.user_name, self.last_token)
while self.alive:
from_tok = StreamToken.from_string(self.last_token)
config = PaginationConfig(from_token=from_tok, limit='1')
chunk = yield self.evStreamHandler.get_stream(
self.user_name, config,
timeout=100*365*24*60*60*1000, affect_presence=False
)
# limiting to 1 may get 1 event plus 1 presence event, so
# pick out the actual event
single_event = None
for c in chunk['chunk']:
if 'event_id' in c: # Hmmm...
single_event = c
break
if not single_event:
self.last_token = chunk['end']
continue
if not self.alive:
continue
processed = False
actions = yield self._actions_for_event(single_event)
tweaks = _tweaks_for_actions(actions)
if len(actions) == 0:
logger.warn("Empty actions! Using default action.")
actions = Pusher.DEFAULT_ACTIONS
if 'notify' not in actions and 'dont_notify' not in actions:
logger.warn("Neither notify nor dont_notify in actions: adding default")
actions.extend(Pusher.DEFAULT_ACTIONS)
if 'dont_notify' in actions:
logger.debug(
"%s for %s: dont_notify",
single_event['event_id'], self.user_name
)
processed = True
else:
rejected = yield self.dispatch_push(single_event, tweaks)
self.has_unread = True
if isinstance(rejected, list) or isinstance(rejected, tuple):
processed = True
for pk in rejected:
if pk != self.pushkey:
# for sanity, we only remove the pushkey if it
# was the one we actually sent...
logger.warn(
("Ignoring rejected pushkey %s because we"
" didn't send it"), pk
)
else:
logger.info(
"Pushkey %s was rejected: removing",
pk
)
yield self.hs.get_pusherpool().remove_pusher(
self.app_id, pk
)
if not self.alive:
continue
if processed:
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end']
self.store.update_pusher_last_token_and_success(
self.user_name,
self.pushkey,
self.last_token,
self.clock.time_msec()
)
if self.failing_since:
self.failing_since = None
self.store.update_pusher_failing_since(
self.user_name,
self.pushkey,
self.failing_since)
else:
if not self.failing_since:
self.failing_since = self.clock.time_msec()
self.store.update_pusher_failing_since(
self.user_name,
self.pushkey,
self.failing_since
)
if (self.failing_since and
self.failing_since <
self.clock.time_msec() - Pusher.GIVE_UP_AFTER):
# we really only give up so that if the URL gets
# fixed, we don't suddenly deliver a load
# of old notifications.
logger.warn("Giving up on a notification to user %s, "
"pushkey %s",
self.user_name, self.pushkey)
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end']
self.store.update_pusher_last_token(
self.user_name,
self.pushkey,
self.last_token
)
self.failing_since = None
self.store.update_pusher_failing_since(
self.user_name,
self.pushkey,
self.failing_since
)
else:
logger.warn("Failed to dispatch push for user %s "
"(failing for %dms)."
"Trying again in %dms",
self.user_name,
self.clock.time_msec() - self.failing_since,
self.backoff_delay)
yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
self.backoff_delay *= 2
if self.backoff_delay > Pusher.MAX_BACKOFF:
self.backoff_delay = Pusher.MAX_BACKOFF
def stop(self):
self.alive = False
def dispatch_push(self, p, tweaks):
"""
Overridden by implementing classes to actually deliver the notification
Args:
p: The event to notify for as a single event from the event stream
Returns: If the notification was delivered, an array containing any
pushkeys that were rejected by the push gateway.
False if the notification could not be delivered (ie.
should be retried).
"""
pass
def reset_badge_count(self):
pass
def presence_changed(self, state):
"""
We clear badge counts whenever a user's last_active time is bumped
This is by no means perfect but I think it's the best we can do
without read receipts.
"""
if 'last_active' in state.state:
last_active = state.state['last_active']
if last_active > self.last_last_active_time:
self.last_last_active_time = last_active
if self.has_unread:
logger.info("Resetting badge count for %s", self.user_name)
self.reset_badge_count()
self.has_unread = False
def _value_for_dotted_key(dotted_key, event):
parts = dotted_key.split(".")
val = event
while len(parts) > 0:
if parts[0] not in val:
return None
val = val[parts[0]]
parts = parts[1:]
return val
def _tweaks_for_actions(actions):
tweaks = {}
for a in actions:
if not isinstance(a, dict):
continue
if 'set_tweak' in a and 'value' in a:
tweaks[a['set_tweak']] = a['value']
return tweaks
class PusherConfigException(Exception):
def __init__(self, msg):
super(PusherConfigException, self).__init__(msg)

View File

@@ -1,97 +0,0 @@
from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
def list_with_base_rules(rawrules, user_name):
ruleslist = []
# shove the server default rules for each kind onto the end of each
current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1]
for r in rawrules:
if r['priority_class'] < current_prio_class:
while r['priority_class'] < current_prio_class:
ruleslist.extend(make_base_rules(
user_name,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
))
current_prio_class -= 1
ruleslist.append(r)
while current_prio_class > 0:
ruleslist.extend(make_base_rules(
user_name,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
))
current_prio_class -= 1
return ruleslist
def make_base_rules(user, kind):
rules = []
if kind == 'override':
rules = make_base_override_rules()
elif kind == 'content':
rules = make_base_content_rules(user)
for r in rules:
r['priority_class'] = PRIORITY_CLASS_MAP[kind]
r['default'] = True
return rules
def make_base_content_rules(user):
return [
{
'conditions': [
{
'kind': 'event_match',
'key': 'content.body',
'pattern': user.localpart, # Matrix ID match
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default',
}
]
},
]
def make_base_override_rules():
return [
{
'conditions': [
{
'kind': 'contains_display_name'
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}
]
},
{
'conditions': [
{
'kind': 'room_member_count',
'is': '2'
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}
]
}
]

View File

@@ -1,146 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.push import Pusher, PusherConfigException
from synapse.http.client import SimpleHttpClient
from twisted.internet import defer
import logging
logger = logging.getLogger(__name__)
class HttpPusher(Pusher):
def __init__(self, _hs, profile_tag, user_name, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since):
super(HttpPusher, self).__init__(
_hs,
profile_tag,
user_name,
app_id,
app_display_name,
device_display_name,
pushkey,
pushkey_ts,
data,
last_token,
last_success,
failing_since
)
if 'url' not in data:
raise PusherConfigException(
"'url' required in data for HTTP pusher"
)
self.url = data['url']
self.httpCli = SimpleHttpClient(self.hs)
self.data_minus_url = {}
self.data_minus_url.update(self.data)
del self.data_minus_url['url']
@defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks):
# we probably do not want to push for every presence update
# (we may want to be able to set up notifications when specific
# people sign in, but we'd want to only deliver the pertinent ones)
# Actually, presence events will not get this far now because we
# need to filter them out in the main Pusher code.
if 'event_id' not in event:
defer.returnValue(None)
ctx = yield self.get_context_for_event(event)
d = {
'notification': {
'id': event['event_id'],
'type': event['type'],
'sender': event['user_id'],
'counts': { # -- we don't mark messages as read yet so
# we have no way of knowing
# Just set the badge to 1 until we have read receipts
'unread': 1,
# 'missed_calls': 2
},
'devices': [
{
'app_id': self.app_id,
'pushkey': self.pushkey,
'pushkey_ts': long(self.pushkey_ts / 1000),
'data': self.data_minus_url,
'tweaks': tweaks
}
]
}
}
if event['type'] == 'm.room.member':
d['notification']['membership'] = event['content']['membership']
if 'content' in event:
d['notification']['content'] = event['content']
if len(ctx['aliases']):
d['notification']['room_alias'] = ctx['aliases'][0]
if 'sender_display_name' in ctx and len(ctx['sender_display_name']) > 0:
d['notification']['sender_display_name'] = ctx['sender_display_name']
if 'name' in ctx and len(ctx['name']) > 0:
d['notification']['room_name'] = ctx['name']
defer.returnValue(d)
@defer.inlineCallbacks
def dispatch_push(self, event, tweaks):
notification_dict = yield self._build_notification_dict(event, tweaks)
if not notification_dict:
defer.returnValue([])
try:
resp = yield self.httpCli.post_json_get_json(self.url, notification_dict)
except:
logger.exception("Failed to push %s ", self.url)
defer.returnValue(False)
rejected = []
if 'rejected' in resp:
rejected = resp['rejected']
defer.returnValue(rejected)
@defer.inlineCallbacks
def reset_badge_count(self):
d = {
'notification': {
'id': '',
'type': None,
'sender': '',
'counts': {
'unread': 0,
'missed_calls': 0
},
'devices': [
{
'app_id': self.app_id,
'pushkey': self.pushkey,
'pushkey_ts': long(self.pushkey_ts / 1000),
'data': self.data_minus_url,
}
]
}
}
try:
resp = yield self.httpCli.post_json_get_json(self.url, d)
except:
logger.exception("Failed to push %s ", self.url)
defer.returnValue(False)
rejected = []
if 'rejected' in resp:
rejected = resp['rejected']
defer.returnValue(rejected)

View File

@@ -1,154 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from httppusher import HttpPusher
from synapse.push import PusherConfigException
from syutil.jsonutil import encode_canonical_json
import logging
import simplejson as json
logger = logging.getLogger(__name__)
class PusherPool:
def __init__(self, _hs):
self.hs = _hs
self.store = self.hs.get_datastore()
self.pushers = {}
self.last_pusher_started = -1
distributor = self.hs.get_distributor()
distributor.observe(
"user_presence_changed", self.user_presence_changed
)
@defer.inlineCallbacks
def user_presence_changed(self, user, state):
user_name = user.to_string()
# until we have read receipts, pushers use this to reset a user's
# badge counters to zero
for p in self.pushers.values():
if p.user_name == user_name:
yield p.presence_changed(state)
@defer.inlineCallbacks
def start(self):
pushers = yield self.store.get_all_pushers()
for p in pushers:
p['data'] = json.loads(p['data'])
self._start_pushers(pushers)
@defer.inlineCallbacks
def add_pusher(self, user_name, profile_tag, kind, app_id,
app_display_name, device_display_name, pushkey, lang, data):
# we try to create the pusher just to validate the config: it
# will then get pulled out of the database,
# recreated, added and started: this means we have only one
# code path adding pushers.
self._create_pusher({
"user_name": user_name,
"kind": kind,
"profile_tag": profile_tag,
"app_id": app_id,
"app_display_name": app_display_name,
"device_display_name": device_display_name,
"pushkey": pushkey,
"pushkey_ts": self.hs.get_clock().time_msec(),
"lang": lang,
"data": data,
"last_token": None,
"last_success": None,
"failing_since": None
})
yield self._add_pusher_to_store(
user_name, profile_tag, kind, app_id,
app_display_name, device_display_name,
pushkey, lang, data
)
@defer.inlineCallbacks
def _add_pusher_to_store(self, user_name, profile_tag, kind, app_id,
app_display_name, device_display_name,
pushkey, lang, data):
yield self.store.add_pusher(
user_name=user_name,
profile_tag=profile_tag,
kind=kind,
app_id=app_id,
app_display_name=app_display_name,
device_display_name=device_display_name,
pushkey=pushkey,
pushkey_ts=self.hs.get_clock().time_msec(),
lang=lang,
data=encode_canonical_json(data).decode("UTF-8"),
)
self._refresh_pusher((app_id, pushkey))
def _create_pusher(self, pusherdict):
if pusherdict['kind'] == 'http':
return HttpPusher(
self.hs,
profile_tag=pusherdict['profile_tag'],
user_name=pusherdict['user_name'],
app_id=pusherdict['app_id'],
app_display_name=pusherdict['app_display_name'],
device_display_name=pusherdict['device_display_name'],
pushkey=pusherdict['pushkey'],
pushkey_ts=pusherdict['pushkey_ts'],
data=pusherdict['data'],
last_token=pusherdict['last_token'],
last_success=pusherdict['last_success'],
failing_since=pusherdict['failing_since']
)
else:
raise PusherConfigException(
"Unknown pusher type '%s' for user %s" %
(pusherdict['kind'], pusherdict['user_name'])
)
@defer.inlineCallbacks
def _refresh_pusher(self, app_id_pushkey):
p = yield self.store.get_pushers_by_app_id_and_pushkey(
app_id_pushkey
)
p['data'] = json.loads(p['data'])
self._start_pushers([p])
def _start_pushers(self, pushers):
logger.info("Starting %d pushers", len(pushers))
for pusherdict in pushers:
p = self._create_pusher(pusherdict)
if p:
fullid = "%s:%s" % (pusherdict['app_id'], pusherdict['pushkey'])
if fullid in self.pushers:
self.pushers[fullid].stop()
self.pushers[fullid] = p
p.start()
@defer.inlineCallbacks
def remove_pusher(self, app_id, pushkey):
fullid = "%s:%s" % (app_id, pushkey)
if fullid in self.pushers:
logger.info("Stopping pusher %s", fullid)
self.pushers[fullid].stop()
del self.pushers[fullid]
yield self.store.delete_pusher_by_app_id_pushkey(app_id, pushkey)

View File

@@ -1,8 +0,0 @@
PRIORITY_CLASS_MAP = {
'underride': 1,
'sender': 2,
'room': 3,
'content': 4,
'override': 5,
}
PRIORITY_CLASS_INVERSE_MAP = {v: k for k, v in PRIORITY_CLASS_MAP.items()}

View File

@@ -4,9 +4,9 @@ from distutils.version import LooseVersion
logger = logging.getLogger(__name__)
REQUIREMENTS = {
"syutil>=0.0.3": ["syutil"],
"matrix_angular_sdk>=0.6.2": ["syweb>=0.6.2"],
"Twisted==14.0.2": ["twisted==14.0.2"],
"syutil==0.0.2": ["syutil"],
"matrix_angular_sdk==0.6.0": ["syweb==0.6.0"],
"Twisted>=14.0.0": ["twisted>=14.0.0"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
"pyyaml": ["yaml"],
@@ -16,32 +16,9 @@ REQUIREMENTS = {
"py-bcrypt": ["bcrypt"],
"frozendict>=0.4": ["frozendict"],
"pillow": ["PIL"],
"pydenticon": ["pydenticon"],
}
def github_link(project, version, egg):
return "https://github.com/%s/tarball/%s/#egg=%s" % (project, version, egg)
DEPENDENCY_LINKS = [
github_link(
project="matrix-org/syutil",
version="v0.0.3",
egg="syutil-0.0.3",
),
github_link(
project="matrix-org/matrix-angular-sdk",
version="v0.6.2",
egg="matrix_angular_sdk-0.6.2",
),
github_link(
project="pyca/pynacl",
version="d4d3175589b892f6ea7c22f466e0e223853516fa",
egg="pynacl-0.3.0",
)
]
class MissingRequirementError(Exception):
pass
@@ -101,24 +78,3 @@ def check_requirements():
"Unexpected version of %r in %r. %r != %r"
% (dependency, file_path, version, required_version)
)
def list_requirements():
result = []
linked = []
for link in DEPENDENCY_LINKS:
egg = link.split("#egg=")[1]
linked.append(egg.split('-')[0])
result.append(link)
for requirement in REQUIREMENTS:
is_linked = False
for link in linked:
if requirement.replace('-', '_').startswith(link):
is_linked = True
if not is_linked:
result.append(requirement)
return result
if __name__ == "__main__":
import sys
sys.stdout.writelines(req + "\n" for req in list_requirements())

View File

@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,3 +12,36 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import (
room, events, register, login, profile, presence, initial_sync, directory,
voip, admin,
)
class RestServletFactory(object):
""" A factory for creating REST servlets.
These REST servlets represent the entire client-server REST API. Generally
speaking, they serve as wrappers around events and the handlers that
process them.
See synapse.events for information on synapse events.
"""
def __init__(self, hs):
client_resource = hs.get_resource_for_client()
# TODO(erikj): There *must* be a better way of doing this.
room.register_servlets(hs, client_resource)
events.register_servlets(hs, client_resource)
register.register_servlets(hs, client_resource)
login.register_servlets(hs, client_resource)
profile.register_servlets(hs, client_resource)
presence.register_servlets(hs, client_resource)
initial_sync.register_servlets(hs, client_resource)
directory.register_servlets(hs, client_resource)
voip.register_servlets(hs, client_resource)
admin.register_servlets(hs, client_resource)

View File

@@ -16,22 +16,20 @@
from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError
from synapse.types import UserID
from base import ClientV1RestServlet, client_path_pattern
from base import RestServlet, client_path_pattern
import logging
logger = logging.getLogger(__name__)
class WhoisRestServlet(ClientV1RestServlet):
class WhoisRestServlet(RestServlet):
PATTERN = client_path_pattern("/admin/whois/(?P<user_id>[^/]*)")
@defer.inlineCallbacks
def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id)
auth_user, client = yield self.auth.get_user_by_req(request)
target_user = self.hs.parse_userid(user_id)
auth_user = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(auth_user)
if not is_admin and target_user != auth_user:

80
synapse/rest/base.py Normal file
View File

@@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" This module contains base REST classes for constructing REST servlets. """
from synapse.api.urls import CLIENT_PREFIX
from synapse.rest.transactions import HttpTransactionStore
import re
import logging
logger = logging.getLogger(__name__)
def client_path_pattern(path_regex):
"""Creates a regex compiled client path with the correct client path
prefix.
Args:
path_regex (str): The regex string to match. This should NOT have a ^
as this will be prefixed.
Returns:
SRE_Pattern
"""
return re.compile("^" + CLIENT_PREFIX + path_regex)
class RestServlet(object):
""" A Synapse REST Servlet.
An implementing class can either provide its own custom 'register' method,
or use the automatic pattern handling provided by the base class.
To use this latter, the implementing class instead provides a `PATTERN`
class attribute containing a pre-compiled regular expression. The automatic
register method will then use this method to register any of the following
instance methods associated with the corresponding HTTP method:
on_GET
on_PUT
on_POST
on_DELETE
on_OPTIONS
Automatically handles turning CodeMessageExceptions thrown by these methods
into the appropriate HTTP response.
"""
def __init__(self, hs):
self.hs = hs
self.handlers = hs.get_handlers()
self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_auth()
self.txns = HttpTransactionStore()
def register(self, http_server):
""" Register this servlet with the given HTTP server. """
if hasattr(self, "PATTERN"):
pattern = self.PATTERN
for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
if hasattr(self, "on_%s" % (method)):
method_handler = getattr(self, "on_%s" % (method))
http_server.register_path(method, pattern, method_handler)
else:
raise NotImplementedError("RestServlet must register something.")

View File

@@ -1,14 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -1,44 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import (
room, events, register, login, profile, presence, initial_sync, directory,
voip, admin, pusher, push_rule
)
from synapse.http.server import JsonResource
class ClientV1RestResource(JsonResource):
"""A resource for version 1 of the matrix client API."""
def __init__(self, hs):
JsonResource.__init__(self, hs)
self.register_servlets(self, hs)
@staticmethod
def register_servlets(client_resource, hs):
room.register_servlets(hs, client_resource)
events.register_servlets(hs, client_resource)
register.register_servlets(hs, client_resource)
login.register_servlets(hs, client_resource)
profile.register_servlets(hs, client_resource)
presence.register_servlets(hs, client_resource)
initial_sync.register_servlets(hs, client_resource)
directory.register_servlets(hs, client_resource)
voip.register_servlets(hs, client_resource)
admin.register_servlets(hs, client_resource)
pusher.register_servlets(hs, client_resource)
push_rule.register_servlets(hs, client_resource)

View File

@@ -1,52 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains base REST classes for constructing client v1 servlets.
"""
from synapse.http.servlet import RestServlet
from synapse.api.urls import CLIENT_PREFIX
from .transactions import HttpTransactionStore
import re
import logging
logger = logging.getLogger(__name__)
def client_path_pattern(path_regex):
"""Creates a regex compiled client path with the correct client path
prefix.
Args:
path_regex (str): The regex string to match. This should NOT have a ^
as this will be prefixed.
Returns:
SRE_Pattern
"""
return re.compile("^" + CLIENT_PREFIX + path_regex)
class ClientV1RestServlet(RestServlet):
"""A base Synapse REST Servlet for the client version 1 API.
"""
def __init__(self, hs):
self.hs = hs
self.handlers = hs.get_handlers()
self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_auth()
self.txns = HttpTransactionStore()

View File

@@ -1,411 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import (
SynapseError, Codes, UnrecognizedRequestError, NotFoundError, StoreError
)
from .base import ClientV1RestServlet, client_path_pattern
from synapse.storage.push_rule import (
InconsistentRuleException, RuleNotFoundException
)
import synapse.push.baserules as baserules
from synapse.push.rulekinds import (
PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
)
import simplejson as json
class PushRuleRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/pushrules/.*$")
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
"Unrecognised request: You probably wanted a trailing slash")
@defer.inlineCallbacks
def on_PUT(self, request):
spec = _rule_spec_from_path(request.postpath)
try:
priority_class = _priority_class_from_spec(spec)
except InvalidRuleException as e:
raise SynapseError(400, e.message)
user, _ = yield self.auth.get_user_by_req(request)
if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
raise SynapseError(400, "rule_id may not contain slashes")
content = _parse_json(request)
try:
(conditions, actions) = _rule_tuple_from_request_object(
spec['template'],
spec['rule_id'],
content,
device=spec['device'] if 'device' in spec else None
)
except InvalidRuleException as e:
raise SynapseError(400, e.message)
before = request.args.get("before", None)
if before and len(before):
before = before[0]
after = request.args.get("after", None)
if after and len(after):
after = after[0]
try:
yield self.hs.get_datastore().add_push_rule(
user_name=user.to_string(),
rule_id=_namespaced_rule_id_from_spec(spec),
priority_class=priority_class,
conditions=conditions,
actions=actions,
before=before,
after=after
)
except InconsistentRuleException as e:
raise SynapseError(400, e.message)
except RuleNotFoundException as e:
raise SynapseError(400, e.message)
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_DELETE(self, request):
spec = _rule_spec_from_path(request.postpath)
user, _ = yield self.auth.get_user_by_req(request)
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
try:
yield self.hs.get_datastore().delete_push_rule(
user.to_string(), namespaced_rule_id
)
defer.returnValue((200, {}))
except StoreError as e:
if e.code == 404:
raise NotFoundError()
else:
raise
@defer.inlineCallbacks
def on_GET(self, request):
user, _ = yield self.auth.get_user_by_req(request)
# we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference
rawrules = yield self.hs.get_datastore().get_push_rules_for_user_name(
user.to_string()
)
for r in rawrules:
r["conditions"] = json.loads(r["conditions"])
r["actions"] = json.loads(r["actions"])
ruleslist = baserules.list_with_base_rules(rawrules, user)
rules = {'global': {}, 'device': {}}
rules['global'] = _add_empty_priority_class_arrays(rules['global'])
for r in ruleslist:
rulearray = None
template_name = _priority_class_to_template_name(r['priority_class'])
if r['priority_class'] > PRIORITY_CLASS_MAP['override']:
# per-device rule
profile_tag = _profile_tag_from_conditions(r["conditions"])
r = _strip_device_condition(r)
if not profile_tag:
continue
if profile_tag not in rules['device']:
rules['device'][profile_tag] = {}
rules['device'][profile_tag] = (
_add_empty_priority_class_arrays(
rules['device'][profile_tag]
)
)
rulearray = rules['device'][profile_tag][template_name]
else:
rulearray = rules['global'][template_name]
template_rule = _rule_to_template(r)
if template_rule:
rulearray.append(template_rule)
path = request.postpath[1:]
if path == []:
# we're a reference impl: pedantry is our job.
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
if path[0] == '':
defer.returnValue((200, rules))
elif path[0] == 'global':
path = path[1:]
result = _filter_ruleset_with_path(rules['global'], path)
defer.returnValue((200, result))
elif path[0] == 'device':
path = path[1:]
if path == []:
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
if path[0] == '':
defer.returnValue((200, rules['device']))
profile_tag = path[0]
path = path[1:]
if profile_tag not in rules['device']:
ret = {}
ret = _add_empty_priority_class_arrays(ret)
defer.returnValue((200, ret))
ruleset = rules['device'][profile_tag]
result = _filter_ruleset_with_path(ruleset, path)
defer.returnValue((200, result))
else:
raise UnrecognizedRequestError()
def on_OPTIONS(self, _):
return 200, {}
def _rule_spec_from_path(path):
if len(path) < 2:
raise UnrecognizedRequestError()
if path[0] != 'pushrules':
raise UnrecognizedRequestError()
scope = path[1]
path = path[2:]
if scope not in ['global', 'device']:
raise UnrecognizedRequestError()
device = None
if scope == 'device':
if len(path) == 0:
raise UnrecognizedRequestError()
device = path[0]
path = path[1:]
if len(path) == 0:
raise UnrecognizedRequestError()
template = path[0]
path = path[1:]
if len(path) == 0:
raise UnrecognizedRequestError()
rule_id = path[0]
spec = {
'scope': scope,
'template': template,
'rule_id': rule_id
}
if device:
spec['profile_tag'] = device
return spec
def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None):
if rule_template in ['override', 'underride']:
if 'conditions' not in req_obj:
raise InvalidRuleException("Missing 'conditions'")
conditions = req_obj['conditions']
for c in conditions:
if 'kind' not in c:
raise InvalidRuleException("Condition without 'kind'")
elif rule_template == 'room':
conditions = [{
'kind': 'event_match',
'key': 'room_id',
'pattern': rule_id
}]
elif rule_template == 'sender':
conditions = [{
'kind': 'event_match',
'key': 'user_id',
'pattern': rule_id
}]
elif rule_template == 'content':
if 'pattern' not in req_obj:
raise InvalidRuleException("Content rule missing 'pattern'")
pat = req_obj['pattern']
conditions = [{
'kind': 'event_match',
'key': 'content.body',
'pattern': pat
}]
else:
raise InvalidRuleException("Unknown rule template: %s" % (rule_template,))
if device:
conditions.append({
'kind': 'device',
'profile_tag': device
})
if 'actions' not in req_obj:
raise InvalidRuleException("No actions found")
actions = req_obj['actions']
for a in actions:
if a in ['notify', 'dont_notify', 'coalesce']:
pass
elif isinstance(a, dict) and 'set_sound' in a:
pass
else:
raise InvalidRuleException("Unrecognised action")
return conditions, actions
def _add_empty_priority_class_arrays(d):
for pc in PRIORITY_CLASS_MAP.keys():
d[pc] = []
return d
def _profile_tag_from_conditions(conditions):
"""
Given a list of conditions, return the profile tag of the
device rule if there is one
"""
for c in conditions:
if c['kind'] == 'device':
return c['profile_tag']
return None
def _filter_ruleset_with_path(ruleset, path):
if path == []:
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
if path[0] == '':
return ruleset
template_kind = path[0]
if template_kind not in ruleset:
raise UnrecognizedRequestError()
path = path[1:]
if path == []:
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
if path[0] == '':
return ruleset[template_kind]
rule_id = path[0]
for r in ruleset[template_kind]:
if r['rule_id'] == rule_id:
return r
raise NotFoundError
def _priority_class_from_spec(spec):
if spec['template'] not in PRIORITY_CLASS_MAP.keys():
raise InvalidRuleException("Unknown template: %s" % (spec['kind']))
pc = PRIORITY_CLASS_MAP[spec['template']]
if spec['scope'] == 'device':
pc += len(PRIORITY_CLASS_MAP)
return pc
def _priority_class_to_template_name(pc):
if pc > PRIORITY_CLASS_MAP['override']:
# per-device
prio_class_index = pc - len(PushRuleRestServlet.PRIORITY_CLASS_MAP)
return PRIORITY_CLASS_INVERSE_MAP[prio_class_index]
else:
return PRIORITY_CLASS_INVERSE_MAP[pc]
def _rule_to_template(rule):
unscoped_rule_id = None
if 'rule_id' in rule:
unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id'])
template_name = _priority_class_to_template_name(rule['priority_class'])
if template_name in ['override', 'underride']:
templaterule = {k: rule[k] for k in ["conditions", "actions"]}
elif template_name in ["sender", "room"]:
templaterule = {'actions': rule['actions']}
unscoped_rule_id = rule['conditions'][0]['pattern']
elif template_name == 'content':
if len(rule["conditions"]) != 1:
return None
thecond = rule["conditions"][0]
if "pattern" not in thecond:
return None
templaterule = {'actions': rule['actions']}
templaterule["pattern"] = thecond["pattern"]
if unscoped_rule_id:
templaterule['rule_id'] = unscoped_rule_id
if 'default' in rule:
templaterule['default'] = rule['default']
return templaterule
def _strip_device_condition(rule):
for i, c in enumerate(rule['conditions']):
if c['kind'] == 'device':
del rule['conditions'][i]
return rule
def _namespaced_rule_id_from_spec(spec):
if spec['scope'] == 'global':
scope = 'global'
else:
scope = 'device/%s' % (spec['profile_tag'])
return "%s/%s/%s" % (scope, spec['template'], spec['rule_id'])
def _rule_id_from_namespaced(in_rule_id):
return in_rule_id.split('/')[-1]
class InvalidRuleException(Exception):
pass
# XXX: C+ped from rest/room.py - surely this should be common?
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.",
errcode=Codes.NOT_JSON)
return content
except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
def register_servlets(hs, http_server):
PushRuleRestServlet(hs).register(http_server)

View File

@@ -1,89 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes
from synapse.push import PusherConfigException
from .base import ClientV1RestServlet, client_path_pattern
import simplejson as json
class PusherRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/pushers/set$")
@defer.inlineCallbacks
def on_POST(self, request):
user, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
pusher_pool = self.hs.get_pusherpool()
if ('pushkey' in content and 'app_id' in content
and 'kind' in content and
content['kind'] is None):
yield pusher_pool.remove_pusher(
content['app_id'], content['pushkey']
)
defer.returnValue((200, {}))
reqd = ['profile_tag', 'kind', 'app_id', 'app_display_name',
'device_display_name', 'pushkey', 'lang', 'data']
missing = []
for i in reqd:
if i not in content:
missing.append(i)
if len(missing):
raise SynapseError(400, "Missing parameters: "+','.join(missing),
errcode=Codes.MISSING_PARAM)
try:
yield pusher_pool.add_pusher(
user_name=user.to_string(),
profile_tag=content['profile_tag'],
kind=content['kind'],
app_id=content['app_id'],
app_display_name=content['app_display_name'],
device_display_name=content['device_display_name'],
pushkey=content['pushkey'],
lang=content['lang'],
data=content['data']
)
except PusherConfigException as pce:
raise SynapseError(400, "Config Error: "+pce.message,
errcode=Codes.MISSING_PARAM)
defer.returnValue((200, {}))
def on_OPTIONS(self, _):
return 200, {}
# XXX: C+ped from rest/room.py - surely this should be common?
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.",
errcode=Codes.NOT_JSON)
return content
except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
def register_servlets(hs, http_server):
PusherRestServlet(hs).register(http_server)

View File

@@ -1,34 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import (
sync,
filter
)
from synapse.http.server import JsonResource
class ClientV2AlphaRestResource(JsonResource):
"""A resource for version 2 alpha of the matrix client API."""
def __init__(self, hs):
JsonResource.__init__(self, hs)
self.register_servlets(self, hs)
@staticmethod
def register_servlets(client_resource, hs):
sync.register_servlets(hs, client_resource)
filter.register_servlets(hs, client_resource)

View File

@@ -1,38 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains base REST classes for constructing client v1 servlets.
"""
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
import re
import logging
logger = logging.getLogger(__name__)
def client_v2_pattern(path_regex):
"""Creates a regex compiled client path with the correct client path
prefix.
Args:
path_regex (str): The regex string to match. This should NOT have a ^
as this will be prefixed.
Returns:
SRE_Pattern
"""
return re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)

View File

@@ -1,104 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError
from synapse.http.servlet import RestServlet
from synapse.types import UserID
from ._base import client_v2_pattern
import simplejson as json
import logging
logger = logging.getLogger(__name__)
class GetFilterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
def __init__(self, hs):
super(GetFilterRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
@defer.inlineCallbacks
def on_GET(self, request, user_id, filter_id):
target_user = UserID.from_string(user_id)
auth_user, client = yield self.auth.get_user_by_req(request)
if target_user != auth_user:
raise AuthError(403, "Cannot get filters for other users")
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only get filters for local users")
try:
filter_id = int(filter_id)
except:
raise SynapseError(400, "Invalid filter_id")
try:
filter = yield self.filtering.get_user_filter(
user_localpart=target_user.localpart,
filter_id=filter_id,
)
defer.returnValue((200, filter.filter_json))
except KeyError:
raise SynapseError(400, "No such filter")
class CreateFilterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter")
def __init__(self, hs):
super(CreateFilterRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
@defer.inlineCallbacks
def on_POST(self, request, user_id):
target_user = UserID.from_string(user_id)
auth_user, client = yield self.auth.get_user_by_req(request)
if target_user != auth_user:
raise AuthError(403, "Cannot create filters for other users")
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only create filters for local users")
try:
content = json.loads(request.content.read())
# TODO(paul): check for required keys and invalid keys
except:
raise SynapseError(400, "Invalid filter definition")
filter_id = yield self.filtering.add_user_filter(
user_localpart=target_user.localpart,
user_filter=content,
)
defer.returnValue((200, {"filter_id": str(filter_id)}))
def register_servlets(hs, http_server):
GetFilterRestServlet(hs).register(http_server)
CreateFilterRestServlet(hs).register(http_server)

View File

@@ -1,207 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.http.servlet import RestServlet
from synapse.handlers.sync import SyncConfig
from synapse.types import StreamToken
from synapse.events.utils import (
serialize_event, format_event_for_client_v2_without_event_id,
)
from synapse.api.filtering import Filter
from ._base import client_v2_pattern
import logging
logger = logging.getLogger(__name__)
class SyncRestServlet(RestServlet):
"""
GET parameters::
timeout(int): How long to wait for new events in milliseconds.
limit(int): Maxiumum number of events per room to return.
gap(bool): Create gaps the message history if limit is exceeded to
ensure that the client has the most recent messages. Defaults to
"true".
sort(str,str): tuple of sort key (e.g. "timeline") and direction
(e.g. "asc", "desc"). Defaults to "timeline,asc".
since(batch_token): Batch token when asking for incremental deltas.
set_presence(str): What state the device presence should be set to.
default is "online".
backfill(bool): Should the HS request message history from other
servers. This may take a long time making it unsuitable for clients
expecting a prompt response. Defaults to "true".
filter(filter_id): A filter to apply to the events returned.
filter_*: Filter override parameters.
Response JSON::
{
"next_batch": // batch token for the next /sync
"private_user_data": // private events for this user.
"public_user_data": // public events for all users including the
// public events for this user.
"rooms": [{ // List of rooms with updates.
"room_id": // Id of the room being updated
"limited": // Was the per-room event limit exceeded?
"published": // Is the room published by our HS?
"event_map": // Map of EventID -> event JSON.
"events": { // The recent events in the room if gap is "true"
// otherwise the next events in the room.
"batch": [] // list of EventIDs in the "event_map".
"prev_batch": // back token for getting previous events.
}
"state": [] // list of EventIDs updating the current state to
// be what it should be at the end of the batch.
"ephemeral": []
}]
}
"""
PATTERN = client_v2_pattern("/sync$")
ALLOWED_SORT = set(["timeline,asc", "timeline,desc"])
ALLOWED_PRESENCE = set(["online", "offline", "idle"])
def __init__(self, hs):
super(SyncRestServlet, self).__init__()
self.auth = hs.get_auth()
self.sync_handler = hs.get_handlers().sync_handler
self.clock = hs.get_clock()
self.filtering = hs.get_filtering()
@defer.inlineCallbacks
def on_GET(self, request):
user, client = yield self.auth.get_user_by_req(request)
timeout = self.parse_integer(request, "timeout", default=0)
limit = self.parse_integer(request, "limit", required=True)
gap = self.parse_boolean(request, "gap", default=True)
sort = self.parse_string(
request, "sort", default="timeline,asc",
allowed_values=self.ALLOWED_SORT
)
since = self.parse_string(request, "since")
set_presence = self.parse_string(
request, "set_presence", default="online",
allowed_values=self.ALLOWED_PRESENCE
)
backfill = self.parse_boolean(request, "backfill", default=False)
filter_id = self.parse_string(request, "filter", default=None)
logger.info(
"/sync: user=%r, timeout=%r, limit=%r, gap=%r, sort=%r, since=%r,"
" set_presence=%r, backfill=%r, filter_id=%r" % (
user, timeout, limit, gap, sort, since, set_presence,
backfill, filter_id
)
)
# TODO(mjark): Load filter and apply overrides.
try:
filter = yield self.filtering.get_user_filter(
user.localpart, filter_id
)
except:
filter = Filter({})
# filter = filter.apply_overrides(http_request)
# if filter.matches(event):
# # stuff
sync_config = SyncConfig(
user=user,
client_info=client,
gap=gap,
limit=limit,
sort=sort,
backfill=backfill,
filter=filter,
)
if since is not None:
since_token = StreamToken.from_string(since)
else:
since_token = None
sync_result = yield self.sync_handler.wait_for_sync_for_user(
sync_config, since_token=since_token, timeout=timeout
)
time_now = self.clock.time_msec()
response_content = {
"public_user_data": self.encode_user_data(
sync_result.public_user_data, filter, time_now
),
"private_user_data": self.encode_user_data(
sync_result.private_user_data, filter, time_now
),
"rooms": self.encode_rooms(
sync_result.rooms, filter, time_now, client.token_id
),
"next_batch": sync_result.next_batch.to_string(),
}
defer.returnValue((200, response_content))
def encode_user_data(self, events, filter, time_now):
return events
def encode_rooms(self, rooms, filter, time_now, token_id):
return [
self.encode_room(room, filter, time_now, token_id)
for room in rooms
]
@staticmethod
def encode_room(room, filter, time_now, token_id):
event_map = {}
state_events = filter.filter_room_state(room.state)
recent_events = filter.filter_room_events(room.events)
state_event_ids = []
recent_event_ids = []
for event in state_events:
# TODO(mjark): Respect formatting requirements in the filter.
event_map[event.event_id] = serialize_event(
event, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_event_id,
)
state_event_ids.append(event.event_id)
for event in recent_events:
# TODO(mjark): Respect formatting requirements in the filter.
event_map[event.event_id] = serialize_event(
event, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_event_id,
)
recent_event_ids.append(event.event_id)
result = {
"room_id": room.room_id,
"event_map": event_map,
"events": {
"batch": recent_event_ids,
"prev_batch": room.prev_batch.to_string(),
},
"state": state_event_ids,
"limited": room.limited,
"published": room.published,
"ephemeral": room.ephemeral,
}
return result
def register_servlets(hs, http_server):
SyncRestServlet(hs).register(http_server)

View File

@@ -17,10 +17,9 @@
from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.types import RoomAlias
from .base import ClientV1RestServlet, client_path_pattern
from base import RestServlet, client_path_pattern
import simplejson as json
import json
import logging
@@ -31,12 +30,12 @@ def register_servlets(hs, http_server):
ClientDirectoryServer(hs).register(http_server)
class ClientDirectoryServer(ClientV1RestServlet):
class ClientDirectoryServer(RestServlet):
PATTERN = client_path_pattern("/directory/room/(?P<room_alias>[^/]*)$")
@defer.inlineCallbacks
def on_GET(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias)
room_alias = self.hs.parse_roomalias(room_alias)
dir_handler = self.handlers.directory_handler
res = yield dir_handler.get_association(room_alias)
@@ -45,16 +44,16 @@ class ClientDirectoryServer(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_alias):
user, client = yield self.auth.get_user_by_req(request)
user = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
if "room_id" not in content:
if not "room_id" in content:
raise SynapseError(400, "Missing room_id key",
errcode=Codes.BAD_JSON)
logger.debug("Got content: %s", content)
room_alias = RoomAlias.from_string(room_alias)
room_alias = self.hs.parse_roomalias(room_alias)
logger.debug("Got room name: %s", room_alias.to_string())
@@ -85,7 +84,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
@defer.inlineCallbacks
def on_DELETE(self, request, room_alias):
user, client = yield self.auth.get_user_by_req(request)
user = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(user)
if not is_admin:
@@ -93,7 +92,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
dir_handler = self.handlers.directory_handler
room_alias = RoomAlias.from_string(room_alias)
room_alias = self.hs.parse_roomalias(room_alias)
yield dir_handler.delete_association(
user.to_string(), room_alias

View File

@@ -18,8 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.streams.config import PaginationConfig
from .base import ClientV1RestServlet, client_path_pattern
from synapse.events.utils import serialize_event
from synapse.rest.base import RestServlet, client_path_pattern
import logging
@@ -27,14 +26,14 @@ import logging
logger = logging.getLogger(__name__)
class EventStreamRestServlet(ClientV1RestServlet):
class EventStreamRestServlet(RestServlet):
PATTERN = client_path_pattern("/events$")
DEFAULT_LONGPOLL_TIME_MS = 30000
@defer.inlineCallbacks
def on_GET(self, request):
auth_user, client = yield self.auth.get_user_by_req(request)
auth_user = yield self.auth.get_user_by_req(request)
try:
handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request)
@@ -62,22 +61,17 @@ class EventStreamRestServlet(ClientV1RestServlet):
# TODO: Unit test gets, with and without auth, with different kinds of events.
class EventRestServlet(ClientV1RestServlet):
class EventRestServlet(RestServlet):
PATTERN = client_path_pattern("/events/(?P<event_id>[^/]*)$")
def __init__(self, hs):
super(EventRestServlet, self).__init__(hs)
self.clock = hs.get_clock()
@defer.inlineCallbacks
def on_GET(self, request, event_id):
auth_user, client = yield self.auth.get_user_by_req(request)
auth_user = yield self.auth.get_user_by_req(request)
handler = self.handlers.event_handler
event = yield handler.get_event(auth_user, event_id)
time_now = self.clock.time_msec()
if event:
defer.returnValue((200, serialize_event(event, time_now)))
defer.returnValue((200, self.hs.serialize_event(event)))
else:
defer.returnValue((404, "Event not found."))

View File

@@ -16,16 +16,16 @@
from twisted.internet import defer
from synapse.streams.config import PaginationConfig
from base import ClientV1RestServlet, client_path_pattern
from base import RestServlet, client_path_pattern
# TODO: Needs unit testing
class InitialSyncRestServlet(ClientV1RestServlet):
class InitialSyncRestServlet(RestServlet):
PATTERN = client_path_pattern("/initialSync$")
@defer.inlineCallbacks
def on_GET(self, request):
user, client = yield self.auth.get_user_by_req(request)
user = yield self.auth.get_user_by_req(request)
with_feedback = "feedback" in request.args
as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request)

View File

@@ -17,12 +17,12 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.types import UserID
from base import ClientV1RestServlet, client_path_pattern
from base import RestServlet, client_path_pattern
import simplejson as json
import json
class LoginRestServlet(ClientV1RestServlet):
class LoginRestServlet(RestServlet):
PATTERN = client_path_pattern("/login$")
PASS_TYPE = "m.login.password"
@@ -64,7 +64,7 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result))
class LoginFallbackRestServlet(ClientV1RestServlet):
class LoginFallbackRestServlet(RestServlet):
PATTERN = client_path_pattern("/login/fallback$")
def on_GET(self, request):
@@ -73,7 +73,7 @@ class LoginFallbackRestServlet(ClientV1RestServlet):
return (200, {})
class PasswordResetRestServlet(ClientV1RestServlet):
class PasswordResetRestServlet(RestServlet):
PATTERN = client_path_pattern("/login/reset")
@defer.inlineCallbacks

View File

@@ -1,51 +0,0 @@
from pydenticon import Generator
from twisted.web.resource import Resource
FOREGROUND = [
"rgb(45,79,255)",
"rgb(254,180,44)",
"rgb(226,121,234)",
"rgb(30,179,253)",
"rgb(232,77,65)",
"rgb(49,203,115)",
"rgb(141,69,170)"
]
BACKGROUND = "rgb(224,224,224)"
SIZE = 5
class IdenticonResource(Resource):
isLeaf = True
def __init__(self):
Resource.__init__(self)
self.generator = Generator(
SIZE, SIZE, foreground=FOREGROUND, background=BACKGROUND,
)
def generate_identicon(self, name, width, height):
v_padding = width % SIZE
h_padding = height % SIZE
top_padding = v_padding // 2
left_padding = h_padding // 2
bottom_padding = v_padding - top_padding
right_padding = h_padding - left_padding
width -= v_padding
height -= h_padding
padding = (top_padding, bottom_padding, left_padding, right_padding)
identicon = self.generator.generate(
name, width, height, padding=padding
)
return identicon
def render_GET(self, request):
name = "/".join(request.postpath)
width = int(request.args.get("width", [96])[0])
height = int(request.args.get("height", [96])[0])
identicon_bytes = self.generate_identicon(name, width, height)
request.setHeader(b"Content-Type", b"image/png")
request.setHeader(
b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
)
return identicon_bytes

View File

@@ -18,22 +18,21 @@
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.types import UserID
from .base import ClientV1RestServlet, client_path_pattern
from base import RestServlet, client_path_pattern
import simplejson as json
import json
import logging
logger = logging.getLogger(__name__)
class PresenceStatusRestServlet(ClientV1RestServlet):
class PresenceStatusRestServlet(RestServlet):
PATTERN = client_path_pattern("/presence/(?P<user_id>[^/]*)/status")
@defer.inlineCallbacks
def on_GET(self, request, user_id):
auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
auth_user = yield self.auth.get_user_by_req(request)
user = self.hs.parse_userid(user_id)
state = yield self.handlers.presence_handler.get_state(
target_user=user, auth_user=auth_user)
@@ -42,8 +41,8 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
auth_user = yield self.auth.get_user_by_req(request)
user = self.hs.parse_userid(user_id)
state = {}
try:
@@ -72,13 +71,13 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
return (200, {})
class PresenceListRestServlet(ClientV1RestServlet):
class PresenceListRestServlet(RestServlet):
PATTERN = client_path_pattern("/presence/list/(?P<user_id>[^/]*)")
@defer.inlineCallbacks
def on_GET(self, request, user_id):
auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
auth_user = yield self.auth.get_user_by_req(request)
user = self.hs.parse_userid(user_id)
if not self.hs.is_mine(user):
raise SynapseError(400, "User not hosted on this Home Server")
@@ -97,8 +96,8 @@ class PresenceListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, user_id):
auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
auth_user = yield self.auth.get_user_by_req(request)
user = self.hs.parse_userid(user_id)
if not self.hs.is_mine(user):
raise SynapseError(400, "User not hosted on this Home Server")
@@ -119,7 +118,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Bad invite value.")
if len(u) == 0:
continue
invited_user = UserID.from_string(u)
invited_user = self.hs.parse_userid(u)
yield self.handlers.presence_handler.send_invite(
observer_user=user, observed_user=invited_user
)
@@ -130,7 +129,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Bad drop value.")
if len(u) == 0:
continue
dropped_user = UserID.from_string(u)
dropped_user = self.hs.parse_userid(u)
yield self.handlers.presence_handler.drop(
observer_user=user, observed_user=dropped_user
)

View File

@@ -16,18 +16,17 @@
""" This module contains REST servlets to do with profile: /profile/<paths> """
from twisted.internet import defer
from .base import ClientV1RestServlet, client_path_pattern
from synapse.types import UserID
from base import RestServlet, client_path_pattern
import simplejson as json
import json
class ProfileDisplaynameRestServlet(ClientV1RestServlet):
class ProfileDisplaynameRestServlet(RestServlet):
PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/displayname")
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
user = self.hs.parse_userid(user_id)
displayname = yield self.handlers.profile_handler.get_displayname(
user,
@@ -37,8 +36,8 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
auth_user = yield self.auth.get_user_by_req(request)
user = self.hs.parse_userid(user_id)
try:
content = json.loads(request.content.read())
@@ -55,12 +54,12 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
return (200, {})
class ProfileAvatarURLRestServlet(ClientV1RestServlet):
class ProfileAvatarURLRestServlet(RestServlet):
PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/avatar_url")
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
user = self.hs.parse_userid(user_id)
avatar_url = yield self.handlers.profile_handler.get_avatar_url(
user,
@@ -70,8 +69,8 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
auth_user = yield self.auth.get_user_by_req(request)
user = self.hs.parse_userid(user_id)
try:
content = json.loads(request.content.read())
@@ -88,12 +87,12 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
return (200, {})
class ProfileRestServlet(ClientV1RestServlet):
class ProfileRestServlet(RestServlet):
PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)")
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
user = self.hs.parse_userid(user_id)
displayname = yield self.handlers.profile_handler.get_displayname(
user,

View File

@@ -18,14 +18,14 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes
from synapse.api.constants import LoginType
from base import ClientV1RestServlet, client_path_pattern
from base import RestServlet, client_path_pattern
import synapse.util.stringutils as stringutils
from synapse.util.async import run_on_reactor
from hashlib import sha1
import hmac
import simplejson as json
import json
import logging
import urllib
@@ -42,7 +42,7 @@ else:
compare_digest = lambda a, b: a == b
class RegisterRestServlet(ClientV1RestServlet):
class RegisterRestServlet(RestServlet):
"""Handles registration with the home server.
This servlet is in control of the registration flow; the registration

View File

@@ -16,14 +16,12 @@
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
from twisted.internet import defer
from base import ClientV1RestServlet, client_path_pattern
from base import RestServlet, client_path_pattern
from synapse.api.errors import SynapseError, Codes
from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership
from synapse.types import UserID, RoomID, RoomAlias
from synapse.events.utils import serialize_event
import simplejson as json
import json
import logging
import urllib
@@ -31,7 +29,7 @@ import urllib
logger = logging.getLogger(__name__)
class RoomCreateRestServlet(ClientV1RestServlet):
class RoomCreateRestServlet(RestServlet):
# No PATTERN; we have custom dispatch rules here
def register(self, http_server):
@@ -62,7 +60,7 @@ class RoomCreateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
auth_user, client = yield self.auth.get_user_by_req(request)
auth_user = yield self.auth.get_user_by_req(request)
room_config = self.get_room_config(request)
info = yield self.make_room(room_config, auth_user, None)
@@ -95,7 +93,7 @@ class RoomCreateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(ClientV1RestServlet):
class RoomStateEventRestServlet(RestServlet):
def register(self, http_server):
# /room/$roomid/state/$eventtype
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
@@ -125,7 +123,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_type, state_key):
user, client = yield self.auth.get_user_by_req(request)
user = yield self.auth.get_user_by_req(request)
msg_handler = self.handlers.message_handler
data = yield msg_handler.get_room_data(
@@ -142,8 +140,8 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
defer.returnValue((200, data.get_dict()["content"]))
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
user, client = yield self.auth.get_user_by_req(request)
def on_PUT(self, request, room_id, event_type, state_key):
user = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
@@ -158,15 +156,13 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
event_dict["state_key"] = state_key
msg_handler = self.handlers.message_handler
yield msg_handler.create_and_send_event(
event_dict, client=client, txn_id=txn_id,
)
yield msg_handler.create_and_send_event(event_dict)
defer.returnValue((200, {}))
# TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(ClientV1RestServlet):
class RoomSendEventRestServlet(RestServlet):
def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id]
@@ -174,8 +170,8 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
register_txn_path(self, PATTERN, http_server, with_get=True)
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_type, txn_id=None):
user, client = yield self.auth.get_user_by_req(request)
def on_POST(self, request, room_id, event_type):
user = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
msg_handler = self.handlers.message_handler
@@ -185,9 +181,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
"content": content,
"room_id": room_id,
"sender": user.to_string(),
},
client=client,
txn_id=txn_id,
}
)
defer.returnValue((200, {"event_id": event.event_id}))
@@ -204,14 +198,14 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
except KeyError:
pass
response = yield self.on_POST(request, room_id, event_type, txn_id)
response = yield self.on_POST(request, room_id, event_type)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
# TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ClientV1RestServlet):
class JoinRoomAliasServlet(RestServlet):
def register(self, http_server):
# /join/$room_identifier[/$txn_id]
@@ -219,8 +213,8 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
register_txn_path(self, PATTERN, http_server)
@defer.inlineCallbacks
def on_POST(self, request, room_identifier, txn_id=None):
user, client = yield self.auth.get_user_by_req(request)
def on_POST(self, request, room_identifier):
user = yield self.auth.get_user_by_req(request)
# the identifier could be a room alias or a room id. Try one then the
# other if it fails to parse, without swallowing other valid
@@ -229,10 +223,10 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
identifier = None
is_room_alias = False
try:
identifier = RoomAlias.from_string(room_identifier)
identifier = self.hs.parse_roomalias(room_identifier)
is_room_alias = True
except SynapseError:
identifier = RoomID.from_string(room_identifier)
identifier = self.hs.parse_roomid(room_identifier)
# TODO: Support for specifying the home server to join with?
@@ -249,9 +243,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
"room_id": identifier.to_string(),
"sender": user.to_string(),
"state_key": user.to_string(),
},
client=client,
txn_id=txn_id,
}
)
defer.returnValue((200, {"room_id": identifier.to_string()}))
@@ -265,14 +257,14 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
except KeyError:
pass
response = yield self.on_POST(request, room_identifier, txn_id)
response = yield self.on_POST(request, room_identifier)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
# TODO: Needs unit testing
class PublicRoomListRestServlet(ClientV1RestServlet):
class PublicRoomListRestServlet(RestServlet):
PATTERN = client_path_pattern("/publicRooms$")
@defer.inlineCallbacks
@@ -283,13 +275,13 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing
class RoomMemberListRestServlet(ClientV1RestServlet):
class RoomMemberListRestServlet(RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/members$")
@defer.inlineCallbacks
def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
user, client = yield self.auth.get_user_by_req(request)
user = yield self.auth.get_user_by_req(request)
handler = self.handlers.room_member_handler
members = yield handler.get_room_members_as_pagination_chunk(
room_id=room_id,
@@ -297,7 +289,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
for event in members["chunk"]:
# FIXME: should probably be state_key here, not user_id
target_user = UserID.from_string(event["user_id"])
target_user = self.hs.parse_userid(event["user_id"])
# Presence is an optional cache; don't fail if we can't fetch it
try:
presence_handler = self.handlers.presence_handler
@@ -312,12 +304,12 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing
class RoomMessageListRestServlet(ClientV1RestServlet):
class RoomMessageListRestServlet(RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/messages$")
@defer.inlineCallbacks
def on_GET(self, request, room_id):
user, client = yield self.auth.get_user_by_req(request)
user = yield self.auth.get_user_by_req(request)
pagination_config = PaginationConfig.from_request(
request, default_limit=10,
)
@@ -336,12 +328,12 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing
class RoomStateRestServlet(ClientV1RestServlet):
class RoomStateRestServlet(RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/state$")
@defer.inlineCallbacks
def on_GET(self, request, room_id):
user, client = yield self.auth.get_user_by_req(request)
user = yield self.auth.get_user_by_req(request)
handler = self.handlers.message_handler
# Get all the current state for this room
events = yield handler.get_state_events(
@@ -352,12 +344,12 @@ class RoomStateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing
class RoomInitialSyncRestServlet(ClientV1RestServlet):
class RoomInitialSyncRestServlet(RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/initialSync$")
@defer.inlineCallbacks
def on_GET(self, request, room_id):
user, client = yield self.auth.get_user_by_req(request)
user = yield self.auth.get_user_by_req(request)
pagination_config = PaginationConfig.from_request(request)
content = yield self.handlers.message_handler.room_initial_sync(
room_id=room_id,
@@ -367,13 +359,9 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
defer.returnValue((200, content))
class RoomTriggerBackfill(ClientV1RestServlet):
class RoomTriggerBackfill(RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/backfill$")
def __init__(self, hs):
super(RoomTriggerBackfill, self).__init__(hs)
self.clock = hs.get_clock()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
remote_server = urllib.unquote(
@@ -385,14 +373,12 @@ class RoomTriggerBackfill(ClientV1RestServlet):
handler = self.handlers.federation_handler
events = yield handler.backfill(remote_server, room_id, limit)
time_now = self.clock.time_msec()
res = [serialize_event(event, time_now) for event in events]
res = [self.hs.serialize_event(event) for event in events]
defer.returnValue((200, res))
# TODO: Needs unit testing
class RoomMembershipRestServlet(ClientV1RestServlet):
class RoomMembershipRestServlet(RestServlet):
def register(self, http_server):
# /rooms/$roomid/[invite|join|leave]
@@ -401,8 +387,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
register_txn_path(self, PATTERN, http_server)
@defer.inlineCallbacks
def on_POST(self, request, room_id, membership_action, txn_id=None):
user, client = yield self.auth.get_user_by_req(request)
def on_POST(self, request, room_id, membership_action):
user = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
@@ -424,9 +410,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
"room_id": room_id,
"sender": user.to_string(),
"state_key": state_key,
},
client=client,
txn_id=txn_id,
}
)
defer.returnValue((200, {}))
@@ -440,22 +424,20 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
except KeyError:
pass
response = yield self.on_POST(
request, room_id, membership_action, txn_id
)
response = yield self.on_POST(request, room_id, membership_action)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
class RoomRedactEventRestServlet(ClientV1RestServlet):
class RoomRedactEventRestServlet(RestServlet):
def register(self, http_server):
PATTERN = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
register_txn_path(self, PATTERN, http_server)
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_id, txn_id=None):
user, client = yield self.auth.get_user_by_req(request)
def on_POST(self, request, room_id, event_id):
user = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
msg_handler = self.handlers.message_handler
@@ -466,9 +448,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
"room_id": room_id,
"sender": user.to_string(),
"redacts": event_id,
},
client=client,
txn_id=txn_id,
}
)
defer.returnValue((200, {"event_id": event.event_id}))
@@ -482,23 +462,23 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
except KeyError:
pass
response = yield self.on_POST(request, room_id, event_id, txn_id)
response = yield self.on_POST(request, room_id, event_id)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
class RoomTypingRestServlet(ClientV1RestServlet):
class RoomTypingRestServlet(RestServlet):
PATTERN = client_path_pattern(
"/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$"
)
@defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id):
auth_user, client = yield self.auth.get_user_by_req(request)
auth_user = yield self.auth.get_user_by_req(request)
room_id = urllib.unquote(room_id)
target_user = UserID.from_string(urllib.unquote(user_id))
target_user = self.hs.parse_userid(urllib.unquote(user_id))
content = _parse_json(request)

View File

@@ -15,7 +15,7 @@
from twisted.internet import defer
from base import ClientV1RestServlet, client_path_pattern
from base import RestServlet, client_path_pattern
import hmac
@@ -23,12 +23,12 @@ import hashlib
import base64
class VoipRestServlet(ClientV1RestServlet):
class VoipRestServlet(RestServlet):
PATTERN = client_path_pattern("/voip/turnServer$")
@defer.inlineCallbacks
def on_GET(self, request):
auth_user, client = yield self.auth.get_user_by_req(request)
auth_user = yield self.auth.get_user_by_req(request)
turnUris = self.hs.config.turn_uris
turnSecret = self.hs.config.turn_shared_secret

View File

@@ -20,20 +20,21 @@
# Imports required for the default HomeServer() implementation
from synapse.federation import initialize_http_replication
from synapse.events.utils import serialize_event
from synapse.notifier import Notifier
from synapse.api.auth import Auth
from synapse.handlers import Handlers
from synapse.rest import RestServletFactory
from synapse.state import StateHandler
from synapse.storage import DataStore
from synapse.types import UserID, RoomAlias, RoomID, EventID
from synapse.util import Clock
from synapse.util.distributor import Distributor
from synapse.util.lockutils import LockManager
from synapse.streams.events import EventSources
from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.keyring import Keyring
from synapse.push.pusherpool import PusherPool
from synapse.events.builder import EventBuilderFactory
from synapse.api.filtering import Filtering
class BaseHomeServer(object):
@@ -71,7 +72,6 @@ class BaseHomeServer(object):
'notifier',
'distributor',
'resource_for_client',
'resource_for_client_v2_alpha',
'resource_for_federation',
'resource_for_web_client',
'resource_for_content_repo',
@@ -80,9 +80,7 @@ class BaseHomeServer(object):
'event_sources',
'ratelimiter',
'keyring',
'pusherpool',
'event_builder_factory',
'filtering',
]
def __init__(self, hostname, **kwargs):
@@ -127,6 +125,33 @@ class BaseHomeServer(object):
setattr(BaseHomeServer, "get_%s" % (depname), _get)
# TODO: Why are these parse_ methods so high up along with other globals?
# Surely these should be in a util package or in the api package?
# Other utility methods
def parse_userid(self, s):
"""Parse the string given by 's' as a User ID and return a UserID
object."""
return UserID.from_string(s)
def parse_roomalias(self, s):
"""Parse the string given by 's' as a Room Alias and return a RoomAlias
object."""
return RoomAlias.from_string(s)
def parse_roomid(self, s):
"""Parse the string given by 's' as a Room ID and return a RoomID
object."""
return RoomID.from_string(s)
def parse_eventid(self, s):
"""Parse the string given by 's' as a Event ID and return a EventID
object."""
return EventID.from_string(s)
def serialize_event(self, e, as_client_event=True):
return serialize_event(self, e, as_client_event)
def get_ip_from_request(self, request):
# May be an X-Forwarding-For header depending on config
ip_addr = request.getClientIP()
@@ -178,6 +203,9 @@ class HomeServer(BaseHomeServer):
def build_auth(self):
return Auth(self)
def build_rest_servlet_factory(self):
return RestServletFactory(self)
def build_state_handler(self):
return StateHandler(self)
@@ -202,8 +230,8 @@ class HomeServer(BaseHomeServer):
hostname=self.hostname,
)
def build_filtering(self):
return Filtering(self)
def build_pusherpool(self):
return PusherPool(self)
def register_servlets(self):
""" Register all servlets associated with this HomeServer.
"""
# Simply building the ServletFactory is sufficient to have it register
self.get_rest_servlet_factory()

View File

@@ -19,7 +19,6 @@ from twisted.internet import defer
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.events.snapshot import EventContext
from collections import namedtuple
@@ -37,44 +36,12 @@ def _get_state_key_from_event(event):
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
AuthEventTypes = (
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
EventTypes.JoinRules,
)
SIZE_OF_CACHE = 1000
EVICTION_TIMEOUT_SECONDS = 20
class _StateCacheEntry(object):
def __init__(self, state, state_group, ts):
self.state = state
self.state_group = state_group
self.ts = ts
class StateHandler(object):
""" Responsible for doing state conflict resolution.
"""
def __init__(self, hs):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.hs = hs
# dict of set of event_ids -> _StateCacheEntry.
self._state_cache = None
def start_caching(self):
logger.debug("start_caching")
self._state_cache = {}
def f():
self._prune_cache()
self.clock.looping_call(f, 5*1000)
@defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""):
@@ -95,22 +62,13 @@ class StateHandler(object):
for e_id, _, _ in events
]
cache = None
if self._state_cache is not None:
cache = self._state_cache.get(frozenset(event_ids), None)
if cache:
cache.ts = self.clock.time_msec()
state = cache.state
else:
res = yield self.resolve_state_groups(event_ids)
state = res[1]
res = yield self.resolve_state_groups(event_ids)
if event_type:
defer.returnValue(state.get((event_type, state_key)))
defer.returnValue(res[1].get((event_type, state_key)))
return
defer.returnValue(state)
defer.returnValue(res[1].values())
@defer.inlineCallbacks
def compute_event_context(self, event, old_state=None):
@@ -137,9 +95,7 @@ class StateHandler(object):
context.state_group = None
if hasattr(event, "auth_events") and event.auth_events:
auth_ids = self.hs.get_auth().compute_auth_events(
event, context.current_state
)
auth_ids = zip(*event.auth_events)[0]
context.auth_events = {
k: v
for k, v in context.current_state.items()
@@ -185,9 +141,7 @@ class StateHandler(object):
event.unsigned["replaces_state"] = replaces.event_id
if hasattr(event, "auth_events") and event.auth_events:
auth_ids = self.hs.get_auth().compute_auth_events(
event, context.current_state
)
auth_ids = zip(*event.auth_events)[0]
context.auth_events = {
k: v
for k, v in context.current_state.items()
@@ -209,31 +163,10 @@ class StateHandler(object):
first is the name of a state group if one and only one is involved,
otherwise `None`.
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
if self._state_cache is not None:
cache = self._state_cache.get(frozenset(event_ids), None)
if cache and cache.state_group:
cache.ts = self.clock.time_msec()
prev_state = cache.state.get((event_type, state_key), None)
if prev_state:
prev_state = prev_state.event_id
prev_states = [prev_state]
else:
prev_states = []
defer.returnValue(
(cache.state_group, cache.state, prev_states)
)
state_groups = yield self.store.get_state_groups(
event_ids
)
logger.debug(
"resolve_state_groups state_groups %s",
state_groups.keys()
)
group_names = set(state_groups.keys())
if len(group_names) == 1:
name, state_list = state_groups.items().pop()
@@ -248,48 +181,15 @@ class StateHandler(object):
else:
prev_states = []
if self._state_cache is not None:
cache = _StateCacheEntry(
state=state,
state_group=name,
ts=self.clock.time_msec()
)
self._state_cache[frozenset(event_ids)] = cache
defer.returnValue((name, state, prev_states))
new_state, prev_states = self._resolve_events(
state_groups.values(), event_type, state_key
)
if self._state_cache is not None:
cache = _StateCacheEntry(
state=new_state,
state_group=None,
ts=self.clock.time_msec()
)
self._state_cache[frozenset(event_ids)] = cache
defer.returnValue((None, new_state, prev_states))
def resolve_events(self, state_sets, event):
if event.is_state():
return self._resolve_events(
state_sets, event.type, event.state_key
)
else:
return self._resolve_events(state_sets)
def _resolve_events(self, state_sets, event_type=None, state_key=""):
state = {}
for st in state_sets:
for e in st:
for group, g_state in state_groups.items():
for s in g_state:
state.setdefault(
(e.type, e.state_key),
(s.type, s.state_key),
{}
)[e.event_id] = e
)[s.event_id] = s
unconflicted_state = {
k: v.values()[0] for k, v in state.items()
@@ -310,133 +210,64 @@ class StateHandler(object):
else:
prev_states = []
auth_events = {
k: e for k, e in unconflicted_state.items()
if k[0] in AuthEventTypes
}
try:
resolved_state = self._resolve_state_events(
conflicted_state, auth_events
)
new_state = {}
new_state.update(unconflicted_state)
for key, events in conflicted_state.items():
new_state[key] = self._resolve_state_events(events)
except:
logger.exception("Failed to resolve state")
raise
new_state = unconflicted_state
new_state.update(resolved_state)
defer.returnValue((None, new_state, prev_states))
return new_state, prev_states
def _get_power_level_from_event_state(self, event, user_id):
if hasattr(event, "old_state_events") and event.old_state_events:
key = (EventTypes.PowerLevels, "", )
power_level_event = event.old_state_events.get(key)
level = None
if power_level_event:
level = power_level_event.content.get("users", {}).get(
user_id
)
if not level:
level = power_level_event.content.get("users_default", 0)
return level
else:
return 0
@log_function
def _resolve_state_events(self, conflicted_state, auth_events):
""" This is where we actually decide which of the conflicted state to
use.
def _resolve_state_events(self, events):
curr_events = events
We resolve conflicts in the following order:
1. power levels
2. memberships
3. other events.
"""
resolved_state = {}
power_key = (EventTypes.PowerLevels, "")
if power_key in conflicted_state.items():
power_levels = conflicted_state[power_key]
resolved_state[power_key] = self._resolve_auth_events(power_levels)
new_powers = [
self._get_power_level_from_event_state(e, e.user_id)
for e in curr_events
]
auth_events.update(resolved_state)
new_powers = [
int(p) if p else 0 for p in new_powers
]
for key, events in conflicted_state.items():
if key[0] == EventTypes.JoinRules:
resolved_state[key] = self._resolve_auth_events(
events,
auth_events
)
max_power = max(new_powers)
auth_events.update(resolved_state)
curr_events = [
z[0] for z in zip(curr_events, new_powers)
if z[1] == max_power
]
for key, events in conflicted_state.items():
if key[0] == EventTypes.Member:
resolved_state[key] = self._resolve_auth_events(
events,
auth_events
)
if not curr_events:
raise RuntimeError("Max didn't get a max?")
elif len(curr_events) == 1:
return curr_events[0]
auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key not in resolved_state:
resolved_state[key] = self._resolve_normal_events(
events, auth_events
)
return resolved_state
def _resolve_auth_events(self, events, auth_events):
reverse = [i for i in reversed(self._ordered_events(events))]
auth_events = dict(auth_events)
prev_event = reverse[0]
for event in reverse[1:]:
auth_events[(prev_event.type, prev_event.state_key)] = prev_event
try:
# FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps.
self.hs.get_auth().check(event, auth_events)
prev_event = event
except AuthError:
return prev_event
return event
def _resolve_normal_events(self, events, auth_events):
for event in self._ordered_events(events):
try:
# FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps.
self.hs.get_auth().check(event, auth_events)
return event
except AuthError:
pass
# Use the last event (the one with the least depth) if they all fail
# the auth check.
return event
def _ordered_events(self, events):
def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
return sorted(events, key=key_func)
def _prune_cache(self):
logger.debug(
"_prune_cache. before len: %d",
len(self._state_cache.keys())
)
now = self.clock.time_msec()
if len(self._state_cache.keys()) > SIZE_OF_CACHE:
sorted_entries = sorted(
self._state_cache.items(),
key=lambda k, v: v.ts,
)
for k, _ in sorted_entries[SIZE_OF_CACHE:]:
self._state_cache.pop(k)
keys_to_delete = set()
for key, cache_entry in self._state_cache.items():
if now - cache_entry.ts > EVICTION_TIMEOUT_SECONDS*1000:
keys_to_delete.add(key)
for k in keys_to_delete:
self._state_cache.pop(k)
logger.debug(
"_prune_cache. after len: %d",
len(self._state_cache.keys())
# TODO: For now, just choose the one with the largest event_id.
return (
sorted(
curr_events,
key=lambda e: hashlib.sha1(
e.event_id + e.user_id + e.room_id + e.type
).hexdigest()
)[0]
)

View File

@@ -29,14 +29,10 @@ from .stream import StreamStore
from .transactions import TransactionStore
from .keys import KeyStore
from .event_federation import EventFederationStore
from .pusher import PusherStore
from .push_rule import PushRuleStore
from .media_repository import MediaRepositoryStore
from .rejections import RejectionsStore
from .state import StateStore
from .signatures import SignatureStore
from .filtering import FilteringStore
from syutil.base64util import decode_base64
from syutil.jsonutil import encode_canonical_json
@@ -44,6 +40,7 @@ from syutil.jsonutil import encode_canonical_json
from synapse.crypto.event_signing import compute_event_reference_hash
import json
import logging
import os
@@ -63,16 +60,13 @@ SCHEMAS = [
"state",
"event_edges",
"event_signatures",
"pusher",
"media_repository",
"filtering",
"rejections",
]
# Remember to update this number every time an incompatible change is made to
# database schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 12
SCHEMA_VERSION = 11
class _RollbackButIsFineException(Exception):
@@ -88,10 +82,6 @@ class DataStore(RoomMemberStore, RoomStore,
DirectoryStore, KeyStore, StateStore, SignatureStore,
EventFederationStore,
MediaRepositoryStore,
RejectionsStore,
FilteringStore,
PusherStore,
PushRuleStore
):
def __init__(self, hs):
@@ -127,147 +117,21 @@ class DataStore(RoomMemberStore, RoomStore,
pass
@defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False,
allow_none=False):
"""Get an event from the database by event_id.
def get_event(self, event_id, allow_none=False):
events = yield self._get_events([event_id])
Args:
event_id (str): The event_id of the event to fetch
check_redacted (bool): If True, check if event has been redacted
and redact it.
get_prev_content (bool): If True and event is a state event,
include the previous states content in the unsigned field.
allow_rejected (bool): If True return rejected events.
allow_none (bool): If True, return None if no event found, if
False throw an exception.
if not events:
if allow_none:
defer.returnValue(None)
else:
raise RuntimeError("Could not find event %s" % (event_id,))
Returns:
Deferred : A FrozenEvent.
"""
event = yield self.runInteraction(
"get_event", self._get_event_txn,
event_id,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
if not event and not allow_none:
raise RuntimeError("Could not find event %s" % (event_id,))
defer.returnValue(event)
defer.returnValue(events[0])
@log_function
def _persist_event_txn(self, txn, event, context, backfilled,
stream_ordering=None, is_new_state=True,
current_state=None):
# Remove the any existing cache entries for the event_id
self._get_event_cache.pop(event.event_id)
# We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table
if current_state:
txn.execute(
"DELETE FROM current_state_events WHERE room_id = ?",
(event.room_id,)
)
for s in current_state:
self._simple_insert_txn(
txn,
"current_state_events",
{
"event_id": s.event_id,
"room_id": s.room_id,
"type": s.type,
"state_key": s.state_key,
},
or_replace=True,
)
if event.is_state() and is_new_state:
if not backfilled and not context.rejected:
self._simple_insert_txn(
txn,
table="state_forward_extremities",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
or_replace=True,
)
for prev_state_id, _ in event.prev_state:
self._simple_delete_txn(
txn,
table="state_forward_extremities",
keyvalues={
"event_id": prev_state_id,
}
)
outlier = event.internal_metadata.is_outlier()
if not outlier:
self._store_state_groups_txn(txn, event, context)
self._update_min_depth_for_room_txn(
txn,
event.room_id,
event.depth
)
self._handle_prev_events(
txn,
outlier=outlier,
event_id=event.event_id,
prev_events=event.prev_events,
room_id=event.room_id,
)
have_persisted = self._simple_select_one_onecol_txn(
txn,
table="event_json",
keyvalues={"event_id": event.event_id},
retcol="event_id",
allow_none=True,
)
metadata_json = encode_canonical_json(
event.internal_metadata.get_dict()
)
# If we have already persisted this event, we don't need to do any
# more processing.
# The processing above must be done on every call to persist event,
# since they might not have happened on previous calls. For example,
# if we are persisting an event that we had persisted as an outlier,
# but is no longer one.
if have_persisted:
if not outlier:
sql = (
"UPDATE event_json SET internal_metadata = ?"
" WHERE event_id = ?"
)
txn.execute(
sql,
(metadata_json.decode("UTF-8"), event.event_id,)
)
sql = (
"UPDATE events SET outlier = 0"
" WHERE event_id = ?"
)
txn.execute(
sql,
(event.event_id,)
)
return
if event.type == EventTypes.Member:
self._store_room_member_txn(txn, event)
elif event.type == EventTypes.Feedback:
@@ -279,6 +143,8 @@ class DataStore(RoomMemberStore, RoomStore,
elif event.type == EventTypes.Redaction:
self._store_redaction(txn, event)
outlier = event.internal_metadata.is_outlier()
event_dict = {
k: v
for k, v in event.get_dict().items()
@@ -288,6 +154,10 @@ class DataStore(RoomMemberStore, RoomStore,
]
}
metadata_json = encode_canonical_json(
event.internal_metadata.get_dict()
)
self._simple_insert_txn(
txn,
table="event_json",
@@ -300,16 +170,12 @@ class DataStore(RoomMemberStore, RoomStore,
or_replace=True,
)
content = encode_canonical_json(
event.content
).decode("UTF-8")
vals = {
"topological_ordering": event.depth,
"event_id": event.event_id,
"type": event.type,
"room_id": event.room_id,
"content": content,
"content": json.dumps(event.get_dict()["content"]),
"processed": True,
"outlier": outlier,
"depth": event.depth,
@@ -329,10 +195,7 @@ class DataStore(RoomMemberStore, RoomStore,
"prev_events",
]
}
vals["unrecognized_keys"] = encode_canonical_json(
unrec
).decode("UTF-8")
vals["unrecognized_keys"] = json.dumps(unrec)
try:
self._simple_insert_txn(
@@ -350,10 +213,38 @@ class DataStore(RoomMemberStore, RoomStore,
)
raise _RollbackButIsFineException("_persist_event")
if context.rejected:
self._store_rejections_txn(txn, event.event_id, context.rejected)
self._handle_prev_events(
txn,
outlier=outlier,
event_id=event.event_id,
prev_events=event.prev_events,
room_id=event.room_id,
)
if event.is_state():
if not outlier:
self._store_state_groups_txn(txn, event, context)
if current_state:
txn.execute(
"DELETE FROM current_state_events WHERE room_id = ?",
(event.room_id,)
)
for s in current_state:
self._simple_insert_txn(
txn,
"current_state_events",
{
"event_id": s.event_id,
"room_id": s.room_id,
"type": s.type,
"state_key": s.state_key,
},
or_replace=True,
)
is_state = hasattr(event, "state_key") and event.state_key is not None
if is_state:
vals = {
"event_id": event.event_id,
"room_id": event.room_id,
@@ -361,7 +252,6 @@ class DataStore(RoomMemberStore, RoomStore,
"state_key": event.state_key,
}
# TODO: How does this work with backfilling?
if hasattr(event, "replaces_state"):
vals["prev_state"] = event.replaces_state
@@ -372,7 +262,7 @@ class DataStore(RoomMemberStore, RoomStore,
or_replace=True,
)
if is_new_state and not context.rejected:
if is_new_state:
self._simple_insert_txn(
txn,
"current_state_events",
@@ -398,6 +288,28 @@ class DataStore(RoomMemberStore, RoomStore,
or_ignore=True,
)
if not backfilled:
self._simple_insert_txn(
txn,
table="state_forward_extremities",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
or_replace=True,
)
for prev_state_id, _ in event.prev_state:
self._simple_delete_txn(
txn,
table="state_forward_extremities",
keyvalues={
"event_id": prev_state_id,
}
)
for hash_alg, hash_base64 in event.hashes.items():
hash_bytes = decode_base64(hash_base64)
self._store_event_content_hash_txn(
@@ -428,9 +340,14 @@ class DataStore(RoomMemberStore, RoomStore,
txn, event.event_id, ref_alg, ref_hash_bytes
)
if not outlier:
self._update_min_depth_for_room_txn(
txn,
event.room_id,
event.depth
)
def _store_redaction(self, txn, event):
# invalidate the cache for the redacted event
self._get_event_cache.pop(event.redacts)
txn.execute(
"INSERT OR IGNORE INTO redactions "
"(event_id, redacts) VALUES (?,?)",
@@ -453,12 +370,9 @@ class DataStore(RoomMemberStore, RoomStore,
"redacted": del_sql,
}
if event_type and state_key is not None:
if event_type:
sql += " AND s.type = ? AND s.state_key = ? "
args = (room_id, event_type, state_key)
elif event_type:
sql += " AND s.type = ?"
args = (room_id, event_type)
else:
args = (room_id, )
@@ -467,41 +381,6 @@ class DataStore(RoomMemberStore, RoomStore,
events = yield self._parse_events(results)
defer.returnValue(events)
@defer.inlineCallbacks
def get_room_name_and_aliases(self, room_id):
del_sql = (
"SELECT event_id FROM redactions WHERE redacts = e.event_id "
"LIMIT 1"
)
sql = (
"SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
"INNER JOIN current_state_events as c ON e.event_id = c.event_id "
"INNER JOIN state_events as s ON e.event_id = s.event_id "
"WHERE c.room_id = ? "
) % {
"redacted": del_sql,
}
sql += " AND ((s.type = 'm.room.name' AND s.state_key = '')"
sql += " OR s.type = 'm.room.aliases')"
args = (room_id,)
results = yield self._execute_and_decode(sql, *args)
events = yield self._parse_events(results)
name = None
aliases = []
for e in events:
if e.type == 'm.room.name':
name = e.content['name']
elif e.type == 'm.room.aliases':
aliases.extend(e.content['aliases'])
defer.returnValue((name, aliases))
@defer.inlineCallbacks
def _get_min_token(self):
row = yield self._execute(
@@ -538,38 +417,6 @@ class DataStore(RoomMemberStore, RoomStore,
],
)
def have_events(self, event_ids):
"""Given a list of event ids, check if we have already processed them.
Returns:
dict: Has an entry for each event id we already have seen. Maps to
the rejected reason string if we rejected the event, else maps to
None.
"""
if not event_ids:
return defer.succeed({})
def f(txn):
sql = (
"SELECT e.event_id, reason FROM events as e "
"LEFT JOIN rejections as r ON e.event_id = r.event_id "
"WHERE e.event_id = ?"
)
res = {}
for event_id in event_ids:
txn.execute(sql, (event_id,))
row = txn.fetchone()
if row:
_, rejected = row
res[event_id] = rejected
return res
return self.runInteraction(
"have_events", f,
)
def schema_path(schema):
""" Get a filesystem path for the named database schema

View File

@@ -19,12 +19,11 @@ from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
from synapse.util.lrucache import LruCache
from twisted.internet import defer
import collections
import simplejson as json
import json
import sys
import time
@@ -78,43 +77,6 @@ class LoggingTransaction(object):
sql_logger.debug("[SQL time] {%s} %f", self.name, end - start)
class PerformanceCounters(object):
def __init__(self):
self.current_counters = {}
self.previous_counters = {}
def update(self, key, start_time, end_time=None):
if end_time is None:
end_time = time.time() * 1000
duration = end_time - start_time
count, cum_time = self.current_counters.get(key, (0, 0))
count += 1
cum_time += duration
self.current_counters[key] = (count, cum_time)
return end_time
def interval(self, interval_duration, limit=3):
counters = []
for name, (count, cum_time) in self.current_counters.items():
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
counters.append((
(cum_time - prev_time) / interval_duration,
count - prev_count,
name
))
self.previous_counters = dict(self.current_counters)
counters.sort(reverse=True)
top_n_counters = ", ".join(
"%s(%d): %.3f%%" % (name, count, 100 * ratio)
for ratio, count, name in counters[:limit]
)
return top_n_counters
class SQLBaseStore(object):
_TXN_ID = 0
@@ -123,43 +85,6 @@ class SQLBaseStore(object):
self._db_pool = hs.get_db_pool()
self._clock = hs.get_clock()
self._previous_txn_total_time = 0
self._current_txn_total_time = 0
self._previous_loop_ts = 0
self._txn_perf_counters = PerformanceCounters()
self._get_event_counters = PerformanceCounters()
self._get_event_cache = LruCache(hs.config.event_cache_size)
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
def loop():
curr = self._current_txn_total_time
prev = self._previous_txn_total_time
self._previous_txn_total_time = curr
time_now = self._clock.time_msec()
time_then = self._previous_loop_ts
self._previous_loop_ts = time_now
ratio = (curr - prev)/(time_now - time_then)
top_three_counters = self._txn_perf_counters.interval(
time_now - time_then, limit=3
)
top_3_event_counters = self._get_event_counters.interval(
time_now - time_then, limit=3
)
logger.info(
"Total database time: %.3f%% {%s} {%s}",
ratio * 100, top_three_counters, top_3_event_counters
)
self._clock.looping_call(loop, 10000)
@defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
@@ -169,7 +94,7 @@ class SQLBaseStore(object):
with LoggingContext("runInteraction") as context:
current_context.copy_to(context)
start = time.time() * 1000
txn_id = self._TXN_ID
txn_id = SQLBaseStore._TXN_ID
# We don't really need these to be unique, so lets stop it from
# growing really large.
@@ -189,10 +114,6 @@ class SQLBaseStore(object):
"[TXN END] {%s} %f",
name, end - start
)
self._current_txn_total_time += end - start
self._txn_perf_counters.update(desc, start, end)
with PreserveLoggingContext():
result = yield self._db_pool.runInteraction(
inner_func, *args, **kwargs
@@ -272,50 +193,6 @@ class SQLBaseStore(object):
txn.execute(sql, values.values())
return txn.lastrowid
def _simple_upsert(self, table, keyvalues, values):
"""
Args:
table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values
values (dict): The nonunique columns and their new values
Returns: A deferred
"""
return self.runInteraction(
"_simple_upsert",
self._simple_upsert_txn, table, keyvalues, values
)
def _simple_upsert_txn(self, txn, table, keyvalues, values):
# Try to update
sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in values),
" AND ".join("%s = ?" % (k,) for k in keyvalues)
)
sqlargs = values.values() + keyvalues.values()
logger.debug(
"[SQL] %s Args=%s",
sql, sqlargs,
)
txn.execute(sql, sqlargs)
if txn.rowcount == 0:
# We didn't update and rows so insert a new one
allvalues = {}
allvalues.update(keyvalues)
allvalues.update(values)
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues)
)
logger.debug(
"[SQL] %s Args=%s",
sql, keyvalues.values(),
)
txn.execute(sql, allvalues.values())
def _simple_select_one(self, table, keyvalues, retcols,
allow_none=False):
"""Executes a SELECT query on the named table, which is expected to
@@ -467,8 +344,8 @@ class SQLBaseStore(object):
if updatevalues:
update_sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in updatevalues),
" AND ".join("%s = ?" % (k,) for k in keyvalues)
", ".join("%s = ?" % (k) for k in updatevalues),
" AND ".join("%s = ?" % (k) for k in keyvalues)
)
def func(txn):
@@ -581,25 +458,10 @@ class SQLBaseStore(object):
return [e for e in events if e]
def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False):
start_time = time.time() * 1000
update_counter = self._get_event_counters.update
try:
cache = self._get_event_cache.setdefault(event_id, {})
# Separate cache entries for each way to invoke _get_event_txn
return cache[(check_redacted, get_prev_content, allow_rejected)]
except KeyError:
pass
finally:
start_time = update_counter("event_cache", start_time)
get_prev_content=False):
sql = (
"SELECT e.internal_metadata, e.json, r.event_id, rej.reason "
"FROM event_json as e "
"SELECT internal_metadata, json, r.event_id FROM event_json as e "
"LEFT JOIN redactions as r ON e.event_id = r.redacts "
"LEFT JOIN rejections as rej on rej.event_id = e.event_id "
"WHERE e.event_id = ? "
"LIMIT 1 "
)
@@ -611,35 +473,20 @@ class SQLBaseStore(object):
if not res:
return None
internal_metadata, js, redacted, rejected_reason = res
internal_metadata, js, redacted = res
start_time = update_counter("select_event", start_time)
if allow_rejected or not rejected_reason:
result = self._get_event_from_row_txn(
txn, internal_metadata, js, redacted,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
)
cache[(check_redacted, get_prev_content, allow_rejected)] = result
return result
else:
return None
return self._get_event_from_row_txn(
txn, internal_metadata, js, redacted,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
)
def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
check_redacted=True, get_prev_content=False):
start_time = time.time() * 1000
update_counter = self._get_event_counters.update
d = json.loads(js)
start_time = update_counter("decode_json", start_time)
internal_metadata = json.loads(internal_metadata)
start_time = update_counter("decode_internal", start_time)
ev = FrozenEvent(d, internal_metadata_dict=internal_metadata)
start_time = update_counter("build_frozen_event", start_time)
if check_redacted and redacted:
ev = prune_event(ev)
@@ -655,7 +502,6 @@ class SQLBaseStore(object):
if because:
ev.unsigned["redacted_because"] = because
start_time = update_counter("redact_event", start_time)
if get_prev_content and "replaces_state" in ev.unsigned:
prev = self._get_event_txn(
@@ -665,7 +511,6 @@ class SQLBaseStore(object):
)
if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"]
start_time = update_counter("get_prev_content", start_time)
return ev

View File

@@ -55,16 +55,17 @@ class EventFederationStore(SQLBaseStore):
results = set()
base_sql = (
"SELECT auth_id FROM event_auth WHERE event_id = ?"
"SELECT auth_id FROM event_auth WHERE %s"
)
front = set(event_ids)
while front:
new_front = set()
for f in front:
txn.execute(base_sql, (f,))
new_front.update([r[0] for r in txn.fetchall()])
front = new_front
sql = base_sql % (
" OR ".join(["event_id=?"] * len(front)),
)
txn.execute(sql, list(front))
front = [r[0] for r in txn.fetchall()]
results.update(front)
return list(results)

View File

@@ -1,63 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import SQLBaseStore
import simplejson as json
class FilteringStore(SQLBaseStore):
@defer.inlineCallbacks
def get_user_filter(self, user_localpart, filter_id):
def_json = yield self._simple_select_one_onecol(
table="user_filters",
keyvalues={
"user_id": user_localpart,
"filter_id": filter_id,
},
retcol="filter_json",
allow_none=False,
)
defer.returnValue(json.loads(def_json))
def add_user_filter(self, user_localpart, user_filter):
def_json = json.dumps(user_filter)
# Need an atomic transaction to SELECT the maximal ID so far then
# INSERT a new one
def _do_txn(txn):
sql = (
"SELECT MAX(filter_id) FROM user_filters "
"WHERE user_id = ?"
)
txn.execute(sql, (user_localpart,))
max_id = txn.fetchone()[0]
if max_id is None:
filter_id = 0
else:
filter_id = max_id + 1
sql = (
"INSERT INTO user_filters (user_id, filter_id, filter_json)"
"VALUES(?, ?, ?)"
)
txn.execute(sql, (user_localpart, filter_id, def_json))
return filter_id
return self.runInteraction("add_user_filter", _do_txn)

View File

@@ -1,218 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from ._base import SQLBaseStore, Table
from twisted.internet import defer
import logging
import copy
import simplejson as json
logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore):
@defer.inlineCallbacks
def get_push_rules_for_user_name(self, user_name):
sql = (
"SELECT "+",".join(PushRuleTable.fields)+" "
"FROM "+PushRuleTable.table_name+" "
"WHERE user_name = ? "
"ORDER BY priority_class DESC, priority DESC"
)
rows = yield self._execute(None, sql, user_name)
dicts = []
for r in rows:
d = {}
for i, f in enumerate(PushRuleTable.fields):
d[f] = r[i]
dicts.append(d)
defer.returnValue(dicts)
@defer.inlineCallbacks
def add_push_rule(self, before, after, **kwargs):
vals = copy.copy(kwargs)
if 'conditions' in vals:
vals['conditions'] = json.dumps(vals['conditions'])
if 'actions' in vals:
vals['actions'] = json.dumps(vals['actions'])
# we could check the rest of the keys are valid column names
# but sqlite will do that anyway so I think it's just pointless.
if 'id' in vals:
del vals['id']
if before or after:
ret = yield self.runInteraction(
"_add_push_rule_relative_txn",
self._add_push_rule_relative_txn,
before=before,
after=after,
**vals
)
defer.returnValue(ret)
else:
ret = yield self.runInteraction(
"_add_push_rule_highest_priority_txn",
self._add_push_rule_highest_priority_txn,
**vals
)
defer.returnValue(ret)
def _add_push_rule_relative_txn(self, txn, user_name, **kwargs):
after = None
relative_to_rule = None
if 'after' in kwargs and kwargs['after']:
after = kwargs['after']
relative_to_rule = after
if 'before' in kwargs and kwargs['before']:
relative_to_rule = kwargs['before']
# get the priority of the rule we're inserting after/before
sql = (
"SELECT priority_class, priority FROM ? "
"WHERE user_name = ? and rule_id = ?" % (PushRuleTable.table_name,)
)
txn.execute(sql, (user_name, relative_to_rule))
res = txn.fetchall()
if not res:
raise RuleNotFoundException(
"before/after rule not found: %s" % (relative_to_rule,)
)
priority_class, base_rule_priority = res[0]
if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class:
raise InconsistentRuleException(
"Given priority class does not match class of relative rule"
)
new_rule = copy.copy(kwargs)
if 'before' in new_rule:
del new_rule['before']
if 'after' in new_rule:
del new_rule['after']
new_rule['priority_class'] = priority_class
new_rule['user_name'] = user_name
# check if the priority before/after is free
new_rule_priority = base_rule_priority
if after:
new_rule_priority -= 1
else:
new_rule_priority += 1
new_rule['priority'] = new_rule_priority
sql = (
"SELECT COUNT(*) FROM " + PushRuleTable.table_name +
" WHERE user_name = ? AND priority_class = ? AND priority = ?"
)
txn.execute(sql, (user_name, priority_class, new_rule_priority))
res = txn.fetchall()
num_conflicting = res[0][0]
# if there are conflicting rules, bump everything
if num_conflicting:
sql = "UPDATE "+PushRuleTable.table_name+" SET priority = priority "
if after:
sql += "-1"
else:
sql += "+1"
sql += " WHERE user_name = ? AND priority_class = ? AND priority "
if after:
sql += "<= ?"
else:
sql += ">= ?"
txn.execute(sql, (user_name, priority_class, new_rule_priority))
# now insert the new rule
sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" ("
sql += ",".join(new_rule.keys())+") VALUES ("
sql += ", ".join(["?" for _ in new_rule.keys()])+")"
txn.execute(sql, new_rule.values())
def _add_push_rule_highest_priority_txn(self, txn, user_name,
priority_class, **kwargs):
# find the highest priority rule in that class
sql = (
"SELECT COUNT(*), MAX(priority) FROM " + PushRuleTable.table_name +
" WHERE user_name = ? and priority_class = ?"
)
txn.execute(sql, (user_name, priority_class))
res = txn.fetchall()
(how_many, highest_prio) = res[0]
new_prio = 0
if how_many > 0:
new_prio = highest_prio + 1
# and insert the new rule
new_rule = copy.copy(kwargs)
if 'id' in new_rule:
del new_rule['id']
new_rule['user_name'] = user_name
new_rule['priority_class'] = priority_class
new_rule['priority'] = new_prio
sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" ("
sql += ",".join(new_rule.keys())+") VALUES ("
sql += ", ".join(["?" for _ in new_rule.keys()])+")"
txn.execute(sql, new_rule.values())
@defer.inlineCallbacks
def delete_push_rule(self, user_name, rule_id):
"""
Delete a push rule. Args specify the row to be deleted and can be
any of the columns in the push_rule table, but below are the
standard ones
Args:
user_name (str): The matrix ID of the push rule owner
rule_id (str): The rule_id of the rule to be deleted
"""
yield self._simple_delete_one(
PushRuleTable.table_name,
{'user_name': user_name, 'rule_id': rule_id}
)
class RuleNotFoundException(Exception):
pass
class InconsistentRuleException(Exception):
pass
class PushRuleTable(Table):
table_name = "push_rules"
fields = [
"id",
"user_name",
"rule_id",
"priority_class",
"priority",
"conditions",
"actions",
]
EntryType = collections.namedtuple("PushRuleEntry", fields)

Some files were not shown because too many files have changed in this diff Show More