Merge commit '1c9a85056' into anoa/dinsic_release_1_31_0
This commit is contained in:
259
INSTALL.md
259
INSTALL.md
@@ -1,19 +1,44 @@
|
||||
- [Choosing your server name](#choosing-your-server-name)
|
||||
- [Picking a database engine](#picking-a-database-engine)
|
||||
- [Installing Synapse](#installing-synapse)
|
||||
- [Installing from source](#installing-from-source)
|
||||
- [Platform-Specific Instructions](#platform-specific-instructions)
|
||||
- [Prebuilt packages](#prebuilt-packages)
|
||||
- [Setting up Synapse](#setting-up-synapse)
|
||||
- [TLS certificates](#tls-certificates)
|
||||
- [Client Well-Known URI](#client-well-known-uri)
|
||||
- [Email](#email)
|
||||
- [Registering a user](#registering-a-user)
|
||||
- [Setting up a TURN server](#setting-up-a-turn-server)
|
||||
- [URL previews](#url-previews)
|
||||
- [Troubleshooting Installation](#troubleshooting-installation)
|
||||
# Installation Instructions
|
||||
|
||||
# Choosing your server name
|
||||
There are 3 steps to follow under **Installation Instructions**.
|
||||
|
||||
- [Installation Instructions](#installation-instructions)
|
||||
- [Choosing your server name](#choosing-your-server-name)
|
||||
- [Installing Synapse](#installing-synapse)
|
||||
- [Installing from source](#installing-from-source)
|
||||
- [Platform-Specific Instructions](#platform-specific-instructions)
|
||||
- [Debian/Ubuntu/Raspbian](#debianubunturaspbian)
|
||||
- [ArchLinux](#archlinux)
|
||||
- [CentOS/Fedora](#centosfedora)
|
||||
- [macOS](#macos)
|
||||
- [OpenSUSE](#opensuse)
|
||||
- [OpenBSD](#openbsd)
|
||||
- [Windows](#windows)
|
||||
- [Prebuilt packages](#prebuilt-packages)
|
||||
- [Docker images and Ansible playbooks](#docker-images-and-ansible-playbooks)
|
||||
- [Debian/Ubuntu](#debianubuntu)
|
||||
- [Matrix.org packages](#matrixorg-packages)
|
||||
- [Downstream Debian packages](#downstream-debian-packages)
|
||||
- [Downstream Ubuntu packages](#downstream-ubuntu-packages)
|
||||
- [Fedora](#fedora)
|
||||
- [OpenSUSE](#opensuse-1)
|
||||
- [SUSE Linux Enterprise Server](#suse-linux-enterprise-server)
|
||||
- [ArchLinux](#archlinux-1)
|
||||
- [Void Linux](#void-linux)
|
||||
- [FreeBSD](#freebsd)
|
||||
- [OpenBSD](#openbsd-1)
|
||||
- [NixOS](#nixos)
|
||||
- [Setting up Synapse](#setting-up-synapse)
|
||||
- [Using PostgreSQL](#using-postgresql)
|
||||
- [TLS certificates](#tls-certificates)
|
||||
- [Client Well-Known URI](#client-well-known-uri)
|
||||
- [Email](#email)
|
||||
- [Registering a user](#registering-a-user)
|
||||
- [Setting up a TURN server](#setting-up-a-turn-server)
|
||||
- [URL previews](#url-previews)
|
||||
- [Troubleshooting Installation](#troubleshooting-installation)
|
||||
|
||||
## Choosing your server name
|
||||
|
||||
It is important to choose the name for your server before you install Synapse,
|
||||
because it cannot be changed later.
|
||||
@@ -29,28 +54,9 @@ that your email address is probably `user@example.com` rather than
|
||||
`user@email.example.com`) - but doing so may require more advanced setup: see
|
||||
[Setting up Federation](docs/federate.md).
|
||||
|
||||
# Picking a database engine
|
||||
## Installing Synapse
|
||||
|
||||
Synapse offers two database engines:
|
||||
* [PostgreSQL](https://www.postgresql.org)
|
||||
* [SQLite](https://sqlite.org/)
|
||||
|
||||
Almost all installations should opt to use PostgreSQL. Advantages include:
|
||||
|
||||
* significant performance improvements due to the superior threading and
|
||||
caching model, smarter query optimiser
|
||||
* allowing the DB to be run on separate hardware
|
||||
|
||||
For information on how to install and use PostgreSQL, please see
|
||||
[docs/postgres.md](docs/postgres.md)
|
||||
|
||||
By default Synapse uses SQLite and in doing so trades performance for convenience.
|
||||
SQLite is only recommended in Synapse for testing purposes or for servers with
|
||||
light workloads.
|
||||
|
||||
# Installing Synapse
|
||||
|
||||
## Installing from source
|
||||
### Installing from source
|
||||
|
||||
(Prebuilt packages are available for some platforms - see [Prebuilt packages](#prebuilt-packages).)
|
||||
|
||||
@@ -68,7 +74,7 @@ these on various platforms.
|
||||
|
||||
To install the Synapse homeserver run:
|
||||
|
||||
```
|
||||
```sh
|
||||
mkdir -p ~/synapse
|
||||
virtualenv -p python3 ~/synapse/env
|
||||
source ~/synapse/env/bin/activate
|
||||
@@ -85,7 +91,7 @@ prefer.
|
||||
This Synapse installation can then be later upgraded by using pip again with the
|
||||
update flag:
|
||||
|
||||
```
|
||||
```sh
|
||||
source ~/synapse/env/bin/activate
|
||||
pip install -U matrix-synapse
|
||||
```
|
||||
@@ -93,7 +99,7 @@ pip install -U matrix-synapse
|
||||
Before you can start Synapse, you will need to generate a configuration
|
||||
file. To do this, run (in your virtualenv, as before):
|
||||
|
||||
```
|
||||
```sh
|
||||
cd ~/synapse
|
||||
python -m synapse.app.homeserver \
|
||||
--server-name my.domain.name \
|
||||
@@ -111,45 +117,43 @@ wise to back them up somewhere safe. (If, for whatever reason, you do need to
|
||||
change your homeserver's keys, you may find that other homeserver have the
|
||||
old key cached. If you update the signing key, you should change the name of the
|
||||
key in the `<server name>.signing.key` file (the second word) to something
|
||||
different. See the
|
||||
[spec](https://matrix.org/docs/spec/server_server/latest.html#retrieving-server-keys)
|
||||
for more information on key management).
|
||||
different. See the [spec](https://matrix.org/docs/spec/server_server/latest.html#retrieving-server-keys) for more information on key management).
|
||||
|
||||
To actually run your new homeserver, pick a working directory for Synapse to
|
||||
run (e.g. `~/synapse`), and:
|
||||
|
||||
```
|
||||
```sh
|
||||
cd ~/synapse
|
||||
source env/bin/activate
|
||||
synctl start
|
||||
```
|
||||
|
||||
### Platform-Specific Instructions
|
||||
#### Platform-Specific Instructions
|
||||
|
||||
#### Debian/Ubuntu/Raspbian
|
||||
##### Debian/Ubuntu/Raspbian
|
||||
|
||||
Installing prerequisites on Ubuntu or Debian:
|
||||
|
||||
```
|
||||
sudo apt-get install build-essential python3-dev libffi-dev \
|
||||
```sh
|
||||
sudo apt install build-essential python3-dev libffi-dev \
|
||||
python3-pip python3-setuptools sqlite3 \
|
||||
libssl-dev virtualenv libjpeg-dev libxslt1-dev
|
||||
```
|
||||
|
||||
#### ArchLinux
|
||||
##### ArchLinux
|
||||
|
||||
Installing prerequisites on ArchLinux:
|
||||
|
||||
```
|
||||
```sh
|
||||
sudo pacman -S base-devel python python-pip \
|
||||
python-setuptools python-virtualenv sqlite3
|
||||
```
|
||||
|
||||
#### CentOS/Fedora
|
||||
##### CentOS/Fedora
|
||||
|
||||
Installing prerequisites on CentOS 8 or Fedora>26:
|
||||
|
||||
```
|
||||
```sh
|
||||
sudo dnf install libtiff-devel libjpeg-devel libzip-devel freetype-devel \
|
||||
libwebp-devel tk-devel redhat-rpm-config \
|
||||
python3-virtualenv libffi-devel openssl-devel
|
||||
@@ -158,7 +162,7 @@ sudo dnf groupinstall "Development Tools"
|
||||
|
||||
Installing prerequisites on CentOS 7 or Fedora<=25:
|
||||
|
||||
```
|
||||
```sh
|
||||
sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \
|
||||
lcms2-devel libwebp-devel tcl-devel tk-devel redhat-rpm-config \
|
||||
python3-virtualenv libffi-devel openssl-devel
|
||||
@@ -170,11 +174,11 @@ uses SQLite 3.7. You may be able to work around this by installing a more
|
||||
recent SQLite version, but it is recommended that you instead use a Postgres
|
||||
database: see [docs/postgres.md](docs/postgres.md).
|
||||
|
||||
#### macOS
|
||||
##### macOS
|
||||
|
||||
Installing prerequisites on macOS:
|
||||
|
||||
```
|
||||
```sh
|
||||
xcode-select --install
|
||||
sudo easy_install pip
|
||||
sudo pip install virtualenv
|
||||
@@ -184,22 +188,22 @@ brew install pkg-config libffi
|
||||
On macOS Catalina (10.15) you may need to explicitly install OpenSSL
|
||||
via brew and inform `pip` about it so that `psycopg2` builds:
|
||||
|
||||
```
|
||||
```sh
|
||||
brew install openssl@1.1
|
||||
export LDFLAGS=-L/usr/local/Cellar/openssl\@1.1/1.1.1d/lib/
|
||||
```
|
||||
|
||||
#### OpenSUSE
|
||||
##### OpenSUSE
|
||||
|
||||
Installing prerequisites on openSUSE:
|
||||
|
||||
```
|
||||
```sh
|
||||
sudo zypper in -t pattern devel_basis
|
||||
sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \
|
||||
python-devel libffi-devel libopenssl-devel libjpeg62-devel
|
||||
```
|
||||
|
||||
#### OpenBSD
|
||||
##### OpenBSD
|
||||
|
||||
A port of Synapse is available under `net/synapse`. The filesystem
|
||||
underlying the homeserver directory (defaults to `/var/synapse`) has to be
|
||||
@@ -213,73 +217,72 @@ mounted with `wxallowed` (cf. `mount(8)`).
|
||||
Creating a `WRKOBJDIR` for building python under `/usr/local` (which on a
|
||||
default OpenBSD installation is mounted with `wxallowed`):
|
||||
|
||||
```
|
||||
```sh
|
||||
doas mkdir /usr/local/pobj_wxallowed
|
||||
```
|
||||
|
||||
Assuming `PORTS_PRIVSEP=Yes` (cf. `bsd.port.mk(5)`) and `SUDO=doas` are
|
||||
configured in `/etc/mk.conf`:
|
||||
|
||||
```
|
||||
```sh
|
||||
doas chown _pbuild:_pbuild /usr/local/pobj_wxallowed
|
||||
```
|
||||
|
||||
Setting the `WRKOBJDIR` for building python:
|
||||
|
||||
```
|
||||
```sh
|
||||
echo WRKOBJDIR_lang/python/3.7=/usr/local/pobj_wxallowed \\nWRKOBJDIR_lang/python/2.7=/usr/local/pobj_wxallowed >> /etc/mk.conf
|
||||
```
|
||||
|
||||
Building Synapse:
|
||||
|
||||
```
|
||||
```sh
|
||||
cd /usr/ports/net/synapse
|
||||
make install
|
||||
```
|
||||
|
||||
#### Windows
|
||||
##### Windows
|
||||
|
||||
If you wish to run or develop Synapse on Windows, the Windows Subsystem For
|
||||
Linux provides a Linux environment on Windows 10 which is capable of using the
|
||||
Debian, Fedora, or source installation methods. More information about WSL can
|
||||
be found at https://docs.microsoft.com/en-us/windows/wsl/install-win10 for
|
||||
Windows 10 and https://docs.microsoft.com/en-us/windows/wsl/install-on-server
|
||||
be found at <https://docs.microsoft.com/en-us/windows/wsl/install-win10> for
|
||||
Windows 10 and <https://docs.microsoft.com/en-us/windows/wsl/install-on-server>
|
||||
for Windows Server.
|
||||
|
||||
## Prebuilt packages
|
||||
### Prebuilt packages
|
||||
|
||||
As an alternative to installing from source, prebuilt packages are available
|
||||
for a number of platforms.
|
||||
|
||||
### Docker images and Ansible playbooks
|
||||
#### Docker images and Ansible playbooks
|
||||
|
||||
There is an offical synapse image available at
|
||||
https://hub.docker.com/r/matrixdotorg/synapse which can be used with
|
||||
<https://hub.docker.com/r/matrixdotorg/synapse> which can be used with
|
||||
the docker-compose file available at [contrib/docker](contrib/docker). Further
|
||||
information on this including configuration options is available in the README
|
||||
on hub.docker.com.
|
||||
|
||||
Alternatively, Andreas Peters (previously Silvio Fricke) has contributed a
|
||||
Dockerfile to automate a synapse server in a single Docker image, at
|
||||
https://hub.docker.com/r/avhost/docker-matrix/tags/
|
||||
<https://hub.docker.com/r/avhost/docker-matrix/tags/>
|
||||
|
||||
Slavi Pantaleev has created an Ansible playbook,
|
||||
which installs the offical Docker image of Matrix Synapse
|
||||
along with many other Matrix-related services (Postgres database, Element, coturn,
|
||||
ma1sd, SSL support, etc.).
|
||||
For more details, see
|
||||
https://github.com/spantaleev/matrix-docker-ansible-deploy
|
||||
<https://github.com/spantaleev/matrix-docker-ansible-deploy>
|
||||
|
||||
#### Debian/Ubuntu
|
||||
|
||||
### Debian/Ubuntu
|
||||
|
||||
#### Matrix.org packages
|
||||
##### Matrix.org packages
|
||||
|
||||
Matrix.org provides Debian/Ubuntu packages of the latest stable version of
|
||||
Synapse via https://packages.matrix.org/debian/. They are available for Debian
|
||||
Synapse via <https://packages.matrix.org/debian/>. They are available for Debian
|
||||
9 (Stretch), Ubuntu 16.04 (Xenial), and later. To use them:
|
||||
|
||||
```
|
||||
```sh
|
||||
sudo apt install -y lsb-release wget apt-transport-https
|
||||
sudo wget -O /usr/share/keyrings/matrix-org-archive-keyring.gpg https://packages.matrix.org/debian/matrix-org-archive-keyring.gpg
|
||||
echo "deb [signed-by=/usr/share/keyrings/matrix-org-archive-keyring.gpg] https://packages.matrix.org/debian/ $(lsb_release -cs) main" |
|
||||
@@ -299,7 +302,7 @@ The fingerprint of the repository signing key (as shown by `gpg
|
||||
/usr/share/keyrings/matrix-org-archive-keyring.gpg`) is
|
||||
`AAF9AE843A7584B5A3E4CD2BCF45A512DE2DA058`.
|
||||
|
||||
#### Downstream Debian packages
|
||||
##### Downstream Debian packages
|
||||
|
||||
We do not recommend using the packages from the default Debian `buster`
|
||||
repository at this time, as they are old and suffer from known security
|
||||
@@ -311,49 +314,49 @@ for information on how to use backports.
|
||||
If you are using Debian `sid` or testing, Synapse is available in the default
|
||||
repositories and it should be possible to install it simply with:
|
||||
|
||||
```
|
||||
```sh
|
||||
sudo apt install matrix-synapse
|
||||
```
|
||||
|
||||
#### Downstream Ubuntu packages
|
||||
##### Downstream Ubuntu packages
|
||||
|
||||
We do not recommend using the packages in the default Ubuntu repository
|
||||
at this time, as they are old and suffer from known security vulnerabilities.
|
||||
The latest version of Synapse can be installed from [our repository](#matrixorg-packages).
|
||||
|
||||
### Fedora
|
||||
#### Fedora
|
||||
|
||||
Synapse is in the Fedora repositories as `matrix-synapse`:
|
||||
|
||||
```
|
||||
```sh
|
||||
sudo dnf install matrix-synapse
|
||||
```
|
||||
|
||||
Oleg Girko provides Fedora RPMs at
|
||||
https://obs.infoserver.lv/project/monitor/matrix-synapse
|
||||
<https://obs.infoserver.lv/project/monitor/matrix-synapse>
|
||||
|
||||
### OpenSUSE
|
||||
#### OpenSUSE
|
||||
|
||||
Synapse is in the OpenSUSE repositories as `matrix-synapse`:
|
||||
|
||||
```
|
||||
```sh
|
||||
sudo zypper install matrix-synapse
|
||||
```
|
||||
|
||||
### SUSE Linux Enterprise Server
|
||||
#### SUSE Linux Enterprise Server
|
||||
|
||||
Unofficial package are built for SLES 15 in the openSUSE:Backports:SLE-15 repository at
|
||||
https://download.opensuse.org/repositories/openSUSE:/Backports:/SLE-15/standard/
|
||||
<https://download.opensuse.org/repositories/openSUSE:/Backports:/SLE-15/standard/>
|
||||
|
||||
### ArchLinux
|
||||
#### ArchLinux
|
||||
|
||||
The quickest way to get up and running with ArchLinux is probably with the community package
|
||||
https://www.archlinux.org/packages/community/any/matrix-synapse/, which should pull in most of
|
||||
<https://www.archlinux.org/packages/community/any/matrix-synapse/>, which should pull in most of
|
||||
the necessary dependencies.
|
||||
|
||||
pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 ):
|
||||
|
||||
```
|
||||
```sh
|
||||
sudo pip install --upgrade pip
|
||||
```
|
||||
|
||||
@@ -362,28 +365,28 @@ ELFCLASS32 (x64 Systems), you may need to reinstall py-bcrypt to correctly
|
||||
compile it under the right architecture. (This should not be needed if
|
||||
installing under virtualenv):
|
||||
|
||||
```
|
||||
```sh
|
||||
sudo pip uninstall py-bcrypt
|
||||
sudo pip install py-bcrypt
|
||||
```
|
||||
|
||||
### Void Linux
|
||||
#### Void Linux
|
||||
|
||||
Synapse can be found in the void repositories as 'synapse':
|
||||
|
||||
```
|
||||
```sh
|
||||
xbps-install -Su
|
||||
xbps-install -S synapse
|
||||
```
|
||||
|
||||
### FreeBSD
|
||||
#### FreeBSD
|
||||
|
||||
Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Molloy from:
|
||||
|
||||
- Ports: `cd /usr/ports/net-im/py-matrix-synapse && make install clean`
|
||||
- Packages: `pkg install py37-matrix-synapse`
|
||||
- Ports: `cd /usr/ports/net-im/py-matrix-synapse && make install clean`
|
||||
- Packages: `pkg install py37-matrix-synapse`
|
||||
|
||||
### OpenBSD
|
||||
#### OpenBSD
|
||||
|
||||
As of OpenBSD 6.7 Synapse is available as a pre-compiled binary. The filesystem
|
||||
underlying the homeserver directory (defaults to `/var/synapse`) has to be
|
||||
@@ -392,20 +395,35 @@ and mounting it to `/var/synapse` should be taken into consideration.
|
||||
|
||||
Installing Synapse:
|
||||
|
||||
```
|
||||
```sh
|
||||
doas pkg_add synapse
|
||||
```
|
||||
|
||||
### NixOS
|
||||
#### NixOS
|
||||
|
||||
Robin Lambertz has packaged Synapse for NixOS at:
|
||||
https://github.com/NixOS/nixpkgs/blob/master/nixos/modules/services/misc/matrix-synapse.nix
|
||||
<https://github.com/NixOS/nixpkgs/blob/master/nixos/modules/services/misc/matrix-synapse.nix>
|
||||
|
||||
# Setting up Synapse
|
||||
## Setting up Synapse
|
||||
|
||||
Once you have installed synapse as above, you will need to configure it.
|
||||
|
||||
## TLS certificates
|
||||
### Using PostgreSQL
|
||||
|
||||
By default Synapse uses [SQLite](https://sqlite.org/) and in doing so trades performance for convenience.
|
||||
SQLite is only recommended in Synapse for testing purposes or for servers with
|
||||
very light workloads.
|
||||
|
||||
Almost all installations should opt to use [PostgreSQL](https://www.postgresql.org). Advantages include:
|
||||
|
||||
- significant performance improvements due to the superior threading and
|
||||
caching model, smarter query optimiser
|
||||
- allowing the DB to be run on separate hardware
|
||||
|
||||
For information on how to install and use PostgreSQL in Synapse, please see
|
||||
[docs/postgres.md](docs/postgres.md)
|
||||
|
||||
### TLS certificates
|
||||
|
||||
The default configuration exposes a single HTTP port on the local
|
||||
interface: `http://localhost:8008`. It is suitable for local testing,
|
||||
@@ -419,19 +437,19 @@ The recommended way to do so is to set up a reverse proxy on port
|
||||
Alternatively, you can configure Synapse to expose an HTTPS port. To do
|
||||
so, you will need to edit `homeserver.yaml`, as follows:
|
||||
|
||||
* First, under the `listeners` section, uncomment the configuration for the
|
||||
- First, under the `listeners` section, uncomment the configuration for the
|
||||
TLS-enabled listener. (Remove the hash sign (`#`) at the start of
|
||||
each line). The relevant lines are like this:
|
||||
|
||||
```
|
||||
- port: 8448
|
||||
type: http
|
||||
tls: true
|
||||
resources:
|
||||
- names: [client, federation]
|
||||
```yaml
|
||||
- port: 8448
|
||||
type: http
|
||||
tls: true
|
||||
resources:
|
||||
- names: [client, federation]
|
||||
```
|
||||
|
||||
* You will also need to uncomment the `tls_certificate_path` and
|
||||
- You will also need to uncomment the `tls_certificate_path` and
|
||||
`tls_private_key_path` lines under the `TLS` section. You will need to manage
|
||||
provisioning of these certificates yourself — Synapse had built-in ACME
|
||||
support, but the ACMEv1 protocol Synapse implements is deprecated, not
|
||||
@@ -446,7 +464,7 @@ so, you will need to edit `homeserver.yaml`, as follows:
|
||||
For a more detailed guide to configuring your server for federation, see
|
||||
[federate.md](docs/federate.md).
|
||||
|
||||
## Client Well-Known URI
|
||||
### Client Well-Known URI
|
||||
|
||||
Setting up the client Well-Known URI is optional but if you set it up, it will
|
||||
allow users to enter their full username (e.g. `@user:<server_name>`) into clients
|
||||
@@ -457,7 +475,7 @@ about the actual homeserver URL you are using.
|
||||
The URL `https://<server_name>/.well-known/matrix/client` should return JSON in
|
||||
the following format.
|
||||
|
||||
```
|
||||
```json
|
||||
{
|
||||
"m.homeserver": {
|
||||
"base_url": "https://<matrix.example.com>"
|
||||
@@ -467,7 +485,7 @@ the following format.
|
||||
|
||||
It can optionally contain identity server information as well.
|
||||
|
||||
```
|
||||
```json
|
||||
{
|
||||
"m.homeserver": {
|
||||
"base_url": "https://<matrix.example.com>"
|
||||
@@ -484,7 +502,8 @@ Cross-Origin Resource Sharing (CORS) headers. A recommended value would be
|
||||
view it.
|
||||
|
||||
In nginx this would be something like:
|
||||
```
|
||||
|
||||
```nginx
|
||||
location /.well-known/matrix/client {
|
||||
return 200 '{"m.homeserver": {"base_url": "https://<matrix.example.com>"}}';
|
||||
default_type application/json;
|
||||
@@ -497,11 +516,11 @@ correctly. `public_baseurl` should be set to the URL that clients will use to
|
||||
connect to your server. This is the same URL you put for the `m.homeserver`
|
||||
`base_url` above.
|
||||
|
||||
```
|
||||
```yaml
|
||||
public_baseurl: "https://<matrix.example.com>"
|
||||
```
|
||||
|
||||
## Email
|
||||
### Email
|
||||
|
||||
It is desirable for Synapse to have the capability to send email. This allows
|
||||
Synapse to send password reset emails, send verifications when an email address
|
||||
@@ -516,7 +535,7 @@ and `notif_from` fields filled out. You may also need to set `smtp_user`,
|
||||
If email is not configured, password reset, registration and notifications via
|
||||
email will be disabled.
|
||||
|
||||
## Registering a user
|
||||
### Registering a user
|
||||
|
||||
The easiest way to create a new user is to do so from a client like [Element](https://element.io/).
|
||||
|
||||
@@ -524,7 +543,7 @@ Alternatively you can do so from the command line if you have installed via pip.
|
||||
|
||||
This can be done as follows:
|
||||
|
||||
```
|
||||
```sh
|
||||
$ source ~/synapse/env/bin/activate
|
||||
$ synctl start # if not already running
|
||||
$ register_new_matrix_user -c homeserver.yaml http://localhost:8008
|
||||
@@ -542,12 +561,12 @@ value is generated by `--generate-config`), but it should be kept secret, as
|
||||
anyone with knowledge of it can register users, including admin accounts,
|
||||
on your server even if `enable_registration` is `false`.
|
||||
|
||||
## Setting up a TURN server
|
||||
### Setting up a TURN server
|
||||
|
||||
For reliable VoIP calls to be routed via this homeserver, you MUST configure
|
||||
a TURN server. See [docs/turn-howto.md](docs/turn-howto.md) for details.
|
||||
|
||||
## URL previews
|
||||
### URL previews
|
||||
|
||||
Synapse includes support for previewing URLs, which is disabled by default. To
|
||||
turn it on you must enable the `url_preview_enabled: True` config parameter
|
||||
@@ -561,14 +580,14 @@ This also requires the optional `lxml` python dependency to be installed. This
|
||||
in turn requires the `libxml2` library to be available - on Debian/Ubuntu this
|
||||
means `apt-get install libxml2-dev`, or equivalent for your OS.
|
||||
|
||||
# Troubleshooting Installation
|
||||
### Troubleshooting Installation
|
||||
|
||||
`pip` seems to leak *lots* of memory during installation. For instance, a Linux
|
||||
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
|
||||
happens, you will have to individually install the dependencies which are
|
||||
failing, e.g.:
|
||||
|
||||
```
|
||||
```sh
|
||||
pip install twisted
|
||||
```
|
||||
|
||||
|
||||
@@ -243,6 +243,8 @@ Then update the ``users`` table in the database::
|
||||
Synapse Development
|
||||
===================
|
||||
|
||||
Join our developer community on Matrix: [#synapse-dev:matrix.org](https://matrix.to/#/#synapse-dev:matrix.org)
|
||||
|
||||
Before setting up a development environment for synapse, make sure you have the
|
||||
system dependencies (such as the python header files) installed - see
|
||||
`Installing from source <INSTALL.md#installing-from-source>`_.
|
||||
|
||||
1
changelog.d/8856.misc
Normal file
1
changelog.d/8856.misc
Normal file
@@ -0,0 +1 @@
|
||||
Properly store the mapping of external ID to Matrix ID for CAS users.
|
||||
1
changelog.d/8977.bugfix
Normal file
1
changelog.d/8977.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Properly return 400 errors on invalid group IDs.
|
||||
1
changelog.d/8980.misc
Normal file
1
changelog.d/8980.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add type hints to the base storage code.
|
||||
1
changelog.d/8987.doc
Normal file
1
changelog.d/8987.doc
Normal file
@@ -0,0 +1 @@
|
||||
Moved instructions for database setup, adjusted heading levels and improved syntax highlighting in [INSTALL.md](../INSTALL.md). Contributed by fossterer.
|
||||
1
changelog.d/8998.misc
Normal file
1
changelog.d/8998.misc
Normal file
@@ -0,0 +1 @@
|
||||
Fix `tests.federation.transport.RoomDirectoryFederationTests` and ensure it runs in CI.
|
||||
1
changelog.d/8999.misc
Normal file
1
changelog.d/8999.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add type hints to the crypto module.
|
||||
1
changelog.d/9002.doc
Normal file
1
changelog.d/9002.doc
Normal file
@@ -0,0 +1 @@
|
||||
Link the Synapse developer room to the development section in the docs.
|
||||
12
mypy.ini
12
mypy.ini
@@ -17,6 +17,7 @@ files =
|
||||
synapse/api,
|
||||
synapse/appservice,
|
||||
synapse/config,
|
||||
synapse/crypto,
|
||||
synapse/event_auth.py,
|
||||
synapse/events/builder.py,
|
||||
synapse/events/validator.py,
|
||||
@@ -70,16 +71,27 @@ files =
|
||||
synapse/server_notices,
|
||||
synapse/spam_checker_api,
|
||||
synapse/state,
|
||||
synapse/storage/__init__.py,
|
||||
synapse/storage/_base.py,
|
||||
synapse/storage/background_updates.py,
|
||||
synapse/storage/databases/main/appservice.py,
|
||||
synapse/storage/databases/main/events.py,
|
||||
synapse/storage/databases/main/keys.py,
|
||||
synapse/storage/databases/main/pusher.py,
|
||||
synapse/storage/databases/main/registration.py,
|
||||
synapse/storage/databases/main/stream.py,
|
||||
synapse/storage/databases/main/ui_auth.py,
|
||||
synapse/storage/database.py,
|
||||
synapse/storage/engines,
|
||||
synapse/storage/keys.py,
|
||||
synapse/storage/persist_events.py,
|
||||
synapse/storage/prepare_database.py,
|
||||
synapse/storage/purge_events.py,
|
||||
synapse/storage/push_rule.py,
|
||||
synapse/storage/relations.py,
|
||||
synapse/storage/roommember.py,
|
||||
synapse/storage/state.py,
|
||||
synapse/storage/types.py,
|
||||
synapse/storage/util,
|
||||
synapse/streams,
|
||||
synapse/types.py,
|
||||
|
||||
@@ -227,7 +227,7 @@ class ConnectionVerifier:
|
||||
|
||||
# This code is based on twisted.internet.ssl.ClientTLSOptions.
|
||||
|
||||
def __init__(self, hostname: bytes, verify_certs):
|
||||
def __init__(self, hostname: bytes, verify_certs: bool):
|
||||
self._verify_certs = verify_certs
|
||||
|
||||
_decoded = hostname.decode("ascii")
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
import collections.abc
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Dict
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
from signedjson.sign import sign_json
|
||||
@@ -27,13 +27,18 @@ from unpaddedbase64 import decode_base64, encode_base64
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.api.room_versions import RoomVersion
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import prune_event, prune_event_dict
|
||||
from synapse.types import JsonDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Hasher = Callable[[bytes], "hashlib._Hash"]
|
||||
|
||||
def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
|
||||
|
||||
def check_event_content_hash(
|
||||
event: EventBase, hash_algorithm: Hasher = hashlib.sha256
|
||||
) -> bool:
|
||||
"""Check whether the hash for this PDU matches the contents"""
|
||||
name, expected_hash = compute_content_hash(event.get_pdu_json(), hash_algorithm)
|
||||
logger.debug(
|
||||
@@ -67,18 +72,19 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
|
||||
return message_hash_bytes == expected_hash
|
||||
|
||||
|
||||
def compute_content_hash(event_dict, hash_algorithm):
|
||||
def compute_content_hash(
|
||||
event_dict: Dict[str, Any], hash_algorithm: Hasher
|
||||
) -> Tuple[str, bytes]:
|
||||
"""Compute the content hash of an event, which is the hash of the
|
||||
unredacted event.
|
||||
|
||||
Args:
|
||||
event_dict (dict): The unredacted event as a dict
|
||||
event_dict: The unredacted event as a dict
|
||||
hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
|
||||
to hash the event
|
||||
|
||||
Returns:
|
||||
tuple[str, bytes]: A tuple of the name of hash and the hash as raw
|
||||
bytes.
|
||||
A tuple of the name of hash and the hash as raw bytes.
|
||||
"""
|
||||
event_dict = dict(event_dict)
|
||||
event_dict.pop("age_ts", None)
|
||||
@@ -94,18 +100,19 @@ def compute_content_hash(event_dict, hash_algorithm):
|
||||
return hashed.name, hashed.digest()
|
||||
|
||||
|
||||
def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256):
|
||||
def compute_event_reference_hash(
|
||||
event, hash_algorithm: Hasher = hashlib.sha256
|
||||
) -> Tuple[str, bytes]:
|
||||
"""Computes the event reference hash. This is the hash of the redacted
|
||||
event.
|
||||
|
||||
Args:
|
||||
event (FrozenEvent)
|
||||
event
|
||||
hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
|
||||
to hash the event
|
||||
|
||||
Returns:
|
||||
tuple[str, bytes]: A tuple of the name of hash and the hash as raw
|
||||
bytes.
|
||||
A tuple of the name of hash and the hash as raw bytes.
|
||||
"""
|
||||
tmp_event = prune_event(event)
|
||||
event_dict = tmp_event.get_pdu_json()
|
||||
@@ -156,7 +163,7 @@ def add_hashes_and_signatures(
|
||||
event_dict: JsonDict,
|
||||
signature_name: str,
|
||||
signing_key: SigningKey,
|
||||
):
|
||||
) -> None:
|
||||
"""Add content hash and sign the event
|
||||
|
||||
Args:
|
||||
|
||||
@@ -14,9 +14,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import urllib
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import attr
|
||||
from signedjson.key import (
|
||||
@@ -40,6 +42,7 @@ from synapse.api.errors import (
|
||||
RequestSendFailed,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.config.key import TrustedKeyServer
|
||||
from synapse.logging.context import (
|
||||
PreserveLoggingContext,
|
||||
make_deferred_yieldable,
|
||||
@@ -47,11 +50,15 @@ from synapse.logging.context import (
|
||||
run_in_background,
|
||||
)
|
||||
from synapse.storage.keys import FetchKeyResult
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async_helpers import yieldable_gather_results
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -61,16 +68,17 @@ class VerifyJsonRequest:
|
||||
A request to verify a JSON object.
|
||||
|
||||
Attributes:
|
||||
server_name(str): The name of the server to verify against.
|
||||
server_name: The name of the server to verify against.
|
||||
|
||||
key_ids(set[str]): The set of key_ids to that could be used to verify the
|
||||
JSON object
|
||||
json_object: The JSON object to verify.
|
||||
|
||||
json_object(dict): The JSON object to verify.
|
||||
|
||||
minimum_valid_until_ts (int): time at which we require the signing key to
|
||||
minimum_valid_until_ts: time at which we require the signing key to
|
||||
be valid. (0 implies we don't care)
|
||||
|
||||
request_name: The name of the request.
|
||||
|
||||
key_ids: The set of key_ids to that could be used to verify the JSON object
|
||||
|
||||
key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
|
||||
A deferred (server_name, key_id, verify_key) tuple that resolves when
|
||||
a verify key has been fetched. The deferreds' callbacks are run with no
|
||||
@@ -80,12 +88,12 @@ class VerifyJsonRequest:
|
||||
errbacks with an M_UNAUTHORIZED SynapseError.
|
||||
"""
|
||||
|
||||
server_name = attr.ib()
|
||||
json_object = attr.ib()
|
||||
minimum_valid_until_ts = attr.ib()
|
||||
request_name = attr.ib()
|
||||
key_ids = attr.ib(init=False)
|
||||
key_ready = attr.ib(default=attr.Factory(defer.Deferred))
|
||||
server_name = attr.ib(type=str)
|
||||
json_object = attr.ib(type=JsonDict)
|
||||
minimum_valid_until_ts = attr.ib(type=int)
|
||||
request_name = attr.ib(type=str)
|
||||
key_ids = attr.ib(init=False, type=List[str])
|
||||
key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self.key_ids = signature_ids(self.json_object, self.server_name)
|
||||
@@ -96,7 +104,9 @@ class KeyLookupError(ValueError):
|
||||
|
||||
|
||||
class Keyring:
|
||||
def __init__(self, hs, key_fetchers=None):
|
||||
def __init__(
|
||||
self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
|
||||
):
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
if key_fetchers is None:
|
||||
@@ -112,22 +122,26 @@ class Keyring:
|
||||
# completes.
|
||||
#
|
||||
# These are regular, logcontext-agnostic Deferreds.
|
||||
self.key_downloads = {}
|
||||
self.key_downloads = {} # type: Dict[str, defer.Deferred]
|
||||
|
||||
def verify_json_for_server(
|
||||
self, server_name, json_object, validity_time, request_name
|
||||
):
|
||||
self,
|
||||
server_name: str,
|
||||
json_object: JsonDict,
|
||||
validity_time: int,
|
||||
request_name: str,
|
||||
) -> defer.Deferred:
|
||||
"""Verify that a JSON object has been signed by a given server
|
||||
|
||||
Args:
|
||||
server_name (str): name of the server which must have signed this object
|
||||
server_name: name of the server which must have signed this object
|
||||
|
||||
json_object (dict): object to be checked
|
||||
json_object: object to be checked
|
||||
|
||||
validity_time (int): timestamp at which we require the signing key to
|
||||
validity_time: timestamp at which we require the signing key to
|
||||
be valid. (0 implies we don't care)
|
||||
|
||||
request_name (str): an identifier for this json object (eg, an event id)
|
||||
request_name: an identifier for this json object (eg, an event id)
|
||||
for logging.
|
||||
|
||||
Returns:
|
||||
@@ -138,12 +152,14 @@ class Keyring:
|
||||
requests = (req,)
|
||||
return make_deferred_yieldable(self._verify_objects(requests)[0])
|
||||
|
||||
def verify_json_objects_for_server(self, server_and_json):
|
||||
def verify_json_objects_for_server(
|
||||
self, server_and_json: Iterable[Tuple[str, dict, int, str]]
|
||||
) -> List[defer.Deferred]:
|
||||
"""Bulk verifies signatures of json objects, bulk fetching keys as
|
||||
necessary.
|
||||
|
||||
Args:
|
||||
server_and_json (iterable[Tuple[str, dict, int, str]):
|
||||
server_and_json:
|
||||
Iterable of (server_name, json_object, validity_time, request_name)
|
||||
tuples.
|
||||
|
||||
@@ -164,13 +180,14 @@ class Keyring:
|
||||
for server_name, json_object, validity_time, request_name in server_and_json
|
||||
)
|
||||
|
||||
def _verify_objects(self, verify_requests):
|
||||
def _verify_objects(
|
||||
self, verify_requests: Iterable[VerifyJsonRequest]
|
||||
) -> List[defer.Deferred]:
|
||||
"""Does the work of verify_json_[objects_]for_server
|
||||
|
||||
|
||||
Args:
|
||||
verify_requests (iterable[VerifyJsonRequest]):
|
||||
Iterable of verification requests.
|
||||
verify_requests: Iterable of verification requests.
|
||||
|
||||
Returns:
|
||||
List<Deferred[None]>: for each input item, a deferred indicating success
|
||||
@@ -182,7 +199,7 @@ class Keyring:
|
||||
key_lookups = []
|
||||
handle = preserve_fn(_handle_key_deferred)
|
||||
|
||||
def process(verify_request):
|
||||
def process(verify_request: VerifyJsonRequest) -> defer.Deferred:
|
||||
"""Process an entry in the request list
|
||||
|
||||
Adds a key request to key_lookups, and returns a deferred which
|
||||
@@ -222,18 +239,20 @@ class Keyring:
|
||||
|
||||
return results
|
||||
|
||||
async def _start_key_lookups(self, verify_requests):
|
||||
async def _start_key_lookups(
|
||||
self, verify_requests: List[VerifyJsonRequest]
|
||||
) -> None:
|
||||
"""Sets off the key fetches for each verify request
|
||||
|
||||
Once each fetch completes, verify_request.key_ready will be resolved.
|
||||
|
||||
Args:
|
||||
verify_requests (List[VerifyJsonRequest]):
|
||||
verify_requests:
|
||||
"""
|
||||
|
||||
try:
|
||||
# map from server name to a set of outstanding request ids
|
||||
server_to_request_ids = {}
|
||||
server_to_request_ids = {} # type: Dict[str, Set[int]]
|
||||
|
||||
for verify_request in verify_requests:
|
||||
server_name = verify_request.server_name
|
||||
@@ -275,11 +294,11 @@ class Keyring:
|
||||
except Exception:
|
||||
logger.exception("Error starting key lookups")
|
||||
|
||||
async def wait_for_previous_lookups(self, server_names) -> None:
|
||||
async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None:
|
||||
"""Waits for any previous key lookups for the given servers to finish.
|
||||
|
||||
Args:
|
||||
server_names (Iterable[str]): list of servers which we want to look up
|
||||
server_names: list of servers which we want to look up
|
||||
|
||||
Returns:
|
||||
Resolves once all key lookups for the given servers have
|
||||
@@ -304,7 +323,7 @@ class Keyring:
|
||||
|
||||
loop_count += 1
|
||||
|
||||
def _get_server_verify_keys(self, verify_requests):
|
||||
def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None:
|
||||
"""Tries to find at least one key for each verify request
|
||||
|
||||
For each verify_request, verify_request.key_ready is called back with
|
||||
@@ -312,7 +331,7 @@ class Keyring:
|
||||
with a SynapseError if none of the keys are found.
|
||||
|
||||
Args:
|
||||
verify_requests (list[VerifyJsonRequest]): list of verify requests
|
||||
verify_requests: list of verify requests
|
||||
"""
|
||||
|
||||
remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
|
||||
@@ -366,17 +385,19 @@ class Keyring:
|
||||
|
||||
run_in_background(do_iterations)
|
||||
|
||||
async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
|
||||
async def _attempt_key_fetches_with_fetcher(
|
||||
self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest]
|
||||
):
|
||||
"""Use a key fetcher to attempt to satisfy some key requests
|
||||
|
||||
Args:
|
||||
fetcher (KeyFetcher): fetcher to use to fetch the keys
|
||||
remaining_requests (set[VerifyJsonRequest]): outstanding key requests.
|
||||
fetcher: fetcher to use to fetch the keys
|
||||
remaining_requests: outstanding key requests.
|
||||
Any successfully-completed requests will be removed from the list.
|
||||
"""
|
||||
# dict[str, dict[str, int]]: keys to fetch.
|
||||
# The keys to fetch.
|
||||
# server_name -> key_id -> min_valid_ts
|
||||
missing_keys = defaultdict(dict)
|
||||
missing_keys = defaultdict(dict) # type: Dict[str, Dict[str, int]]
|
||||
|
||||
for verify_request in remaining_requests:
|
||||
# any completed requests should already have been removed
|
||||
@@ -438,16 +459,18 @@ class Keyring:
|
||||
remaining_requests.difference_update(completed)
|
||||
|
||||
|
||||
class KeyFetcher:
|
||||
async def get_keys(self, keys_to_fetch):
|
||||
class KeyFetcher(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
async def get_keys(
|
||||
self, keys_to_fetch: Dict[str, Dict[str, int]]
|
||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||
"""
|
||||
Args:
|
||||
keys_to_fetch (dict[str, dict[str, int]]):
|
||||
keys_to_fetch:
|
||||
the keys to be fetched. server_name -> key_id -> min_valid_ts
|
||||
|
||||
Returns:
|
||||
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
|
||||
map from server_name -> key_id -> FetchKeyResult
|
||||
Map from server_name -> key_id -> FetchKeyResult
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -455,31 +478,35 @@ class KeyFetcher:
|
||||
class StoreKeyFetcher(KeyFetcher):
|
||||
"""KeyFetcher impl which fetches keys from our data store"""
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def get_keys(self, keys_to_fetch):
|
||||
async def get_keys(
|
||||
self, keys_to_fetch: Dict[str, Dict[str, int]]
|
||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||
"""see KeyFetcher.get_keys"""
|
||||
|
||||
keys_to_fetch = (
|
||||
key_ids_to_fetch = (
|
||||
(server_name, key_id)
|
||||
for server_name, keys_for_server in keys_to_fetch.items()
|
||||
for key_id in keys_for_server.keys()
|
||||
)
|
||||
|
||||
res = await self.store.get_server_verify_keys(keys_to_fetch)
|
||||
keys = {}
|
||||
res = await self.store.get_server_verify_keys(key_ids_to_fetch)
|
||||
keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
|
||||
for (server_name, key_id), key in res.items():
|
||||
keys.setdefault(server_name, {})[key_id] = key
|
||||
return keys
|
||||
|
||||
|
||||
class BaseV2KeyFetcher:
|
||||
def __init__(self, hs):
|
||||
class BaseV2KeyFetcher(KeyFetcher):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
self.config = hs.get_config()
|
||||
|
||||
async def process_v2_response(self, from_server, response_json, time_added_ms):
|
||||
async def process_v2_response(
|
||||
self, from_server: str, response_json: JsonDict, time_added_ms: int
|
||||
) -> Dict[str, FetchKeyResult]:
|
||||
"""Parse a 'Server Keys' structure from the result of a /key request
|
||||
|
||||
This is used to parse either the entirety of the response from
|
||||
@@ -493,16 +520,16 @@ class BaseV2KeyFetcher:
|
||||
to /_matrix/key/v2/query.
|
||||
|
||||
Args:
|
||||
from_server (str): the name of the server producing this result: either
|
||||
from_server: the name of the server producing this result: either
|
||||
the origin server for a /_matrix/key/v2/server request, or the notary
|
||||
for a /_matrix/key/v2/query.
|
||||
|
||||
response_json (dict): the json-decoded Server Keys response object
|
||||
response_json: the json-decoded Server Keys response object
|
||||
|
||||
time_added_ms (int): the timestamp to record in server_keys_json
|
||||
time_added_ms: the timestamp to record in server_keys_json
|
||||
|
||||
Returns:
|
||||
Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
|
||||
Map from key_id to result object
|
||||
"""
|
||||
ts_valid_until_ms = response_json["valid_until_ts"]
|
||||
|
||||
@@ -575,21 +602,22 @@ class BaseV2KeyFetcher:
|
||||
class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||
"""KeyFetcher impl which fetches keys from the "perspectives" servers"""
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self.clock = hs.get_clock()
|
||||
self.client = hs.get_federation_http_client()
|
||||
self.key_servers = self.config.key_servers
|
||||
|
||||
async def get_keys(self, keys_to_fetch):
|
||||
async def get_keys(
|
||||
self, keys_to_fetch: Dict[str, Dict[str, int]]
|
||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||
"""see KeyFetcher.get_keys"""
|
||||
|
||||
async def get_key(key_server):
|
||||
async def get_key(key_server: TrustedKeyServer) -> Dict:
|
||||
try:
|
||||
result = await self.get_server_verify_key_v2_indirect(
|
||||
return await self.get_server_verify_key_v2_indirect(
|
||||
keys_to_fetch, key_server
|
||||
)
|
||||
return result
|
||||
except KeyLookupError as e:
|
||||
logger.warning(
|
||||
"Key lookup failed from %r: %s", key_server.server_name, e
|
||||
@@ -611,25 +639,25 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||
).addErrback(unwrapFirstError)
|
||||
)
|
||||
|
||||
union_of_keys = {}
|
||||
union_of_keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
|
||||
for result in results:
|
||||
for server_name, keys in result.items():
|
||||
union_of_keys.setdefault(server_name, {}).update(keys)
|
||||
|
||||
return union_of_keys
|
||||
|
||||
async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
|
||||
async def get_server_verify_key_v2_indirect(
|
||||
self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer
|
||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||
"""
|
||||
Args:
|
||||
keys_to_fetch (dict[str, dict[str, int]]):
|
||||
keys_to_fetch:
|
||||
the keys to be fetched. server_name -> key_id -> min_valid_ts
|
||||
|
||||
key_server (synapse.config.key.TrustedKeyServer): notary server to query for
|
||||
the keys
|
||||
key_server: notary server to query for the keys
|
||||
|
||||
Returns:
|
||||
dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map
|
||||
from server_name -> key_id -> FetchKeyResult
|
||||
Map from server_name -> key_id -> FetchKeyResult
|
||||
|
||||
Raises:
|
||||
KeyLookupError if there was an error processing the entire response from
|
||||
@@ -662,11 +690,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||
except HttpResponseException as e:
|
||||
raise KeyLookupError("Remote server returned an error: %s" % (e,))
|
||||
|
||||
keys = {}
|
||||
added_keys = []
|
||||
keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
|
||||
added_keys = [] # type: List[Tuple[str, str, FetchKeyResult]]
|
||||
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
assert isinstance(query_response, dict)
|
||||
for response in query_response["server_keys"]:
|
||||
# do this first, so that we can give useful errors thereafter
|
||||
server_name = response.get("server_name")
|
||||
@@ -704,14 +733,15 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||
|
||||
return keys
|
||||
|
||||
def _validate_perspectives_response(self, key_server, response):
|
||||
def _validate_perspectives_response(
|
||||
self, key_server: TrustedKeyServer, response: JsonDict
|
||||
) -> None:
|
||||
"""Optionally check the signature on the result of a /key/query request
|
||||
|
||||
Args:
|
||||
key_server (synapse.config.key.TrustedKeyServer): the notary server that
|
||||
produced this result
|
||||
key_server: the notary server that produced this result
|
||||
|
||||
response (dict): the json-decoded Server Keys response object
|
||||
response: the json-decoded Server Keys response object
|
||||
"""
|
||||
perspective_name = key_server.server_name
|
||||
perspective_keys = key_server.verify_keys
|
||||
@@ -745,25 +775,26 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
|
||||
class ServerKeyFetcher(BaseV2KeyFetcher):
|
||||
"""KeyFetcher impl which fetches keys from the origin servers"""
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self.clock = hs.get_clock()
|
||||
self.client = hs.get_federation_http_client()
|
||||
|
||||
async def get_keys(self, keys_to_fetch):
|
||||
async def get_keys(
|
||||
self, keys_to_fetch: Dict[str, Dict[str, int]]
|
||||
) -> Dict[str, Dict[str, FetchKeyResult]]:
|
||||
"""
|
||||
Args:
|
||||
keys_to_fetch (dict[str, iterable[str]]):
|
||||
keys_to_fetch:
|
||||
the keys to be fetched. server_name -> key_ids
|
||||
|
||||
Returns:
|
||||
dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]:
|
||||
map from server_name -> key_id -> FetchKeyResult
|
||||
Map from server_name -> key_id -> FetchKeyResult
|
||||
"""
|
||||
|
||||
results = {}
|
||||
|
||||
async def get_key(key_to_fetch_item):
|
||||
async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None:
|
||||
server_name, key_ids = key_to_fetch_item
|
||||
try:
|
||||
keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
|
||||
@@ -778,20 +809,22 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
||||
await yieldable_gather_results(get_key, keys_to_fetch.items())
|
||||
return results
|
||||
|
||||
async def get_server_verify_key_v2_direct(self, server_name, key_ids):
|
||||
async def get_server_verify_key_v2_direct(
|
||||
self, server_name: str, key_ids: Iterable[str]
|
||||
) -> Dict[str, FetchKeyResult]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
server_name (str):
|
||||
key_ids (iterable[str]):
|
||||
server_name:
|
||||
key_ids:
|
||||
|
||||
Returns:
|
||||
dict[str, FetchKeyResult]: map from key ID to lookup result
|
||||
Map from key ID to lookup result
|
||||
|
||||
Raises:
|
||||
KeyLookupError if there was a problem making the lookup
|
||||
"""
|
||||
keys = {} # type: dict[str, FetchKeyResult]
|
||||
keys = {} # type: Dict[str, FetchKeyResult]
|
||||
|
||||
for requested_key_id in key_ids:
|
||||
# we may have found this key as a side-effect of asking for another.
|
||||
@@ -825,6 +858,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
||||
except HttpResponseException as e:
|
||||
raise KeyLookupError("Remote server returned an error: %s" % (e,))
|
||||
|
||||
assert isinstance(response, dict)
|
||||
if response["server_name"] != server_name:
|
||||
raise KeyLookupError(
|
||||
"Expected a response for server %r not %r"
|
||||
@@ -846,11 +880,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
||||
return keys
|
||||
|
||||
|
||||
async def _handle_key_deferred(verify_request) -> None:
|
||||
async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
|
||||
"""Waits for the key to become available, and then performs a verification
|
||||
|
||||
Args:
|
||||
verify_request (VerifyJsonRequest):
|
||||
verify_request:
|
||||
|
||||
Raises:
|
||||
SynapseError if there was a problem performing the verification
|
||||
|
||||
@@ -146,7 +146,7 @@ class Authenticator:
|
||||
):
|
||||
raise FederationDeniedError(origin)
|
||||
|
||||
if not json_request["signatures"]:
|
||||
if origin is None or not json_request["signatures"]:
|
||||
raise NoAuthenticationError(
|
||||
401, "Missing Authorization headers", Codes.UNAUTHORIZED
|
||||
)
|
||||
|
||||
@@ -22,6 +22,7 @@ import attr
|
||||
from twisted.web.client import PartialDownloadError
|
||||
|
||||
from synapse.api.errors import HttpResponseException
|
||||
from synapse.handlers.sso import MappingException, UserAttributes
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import UserID, map_username_to_mxid_localpart
|
||||
|
||||
@@ -62,6 +63,7 @@ class CasHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self._hostname = hs.hostname
|
||||
self._store = hs.get_datastore()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._registration_handler = hs.get_registration_handler()
|
||||
|
||||
@@ -72,6 +74,9 @@ class CasHandler:
|
||||
|
||||
self._http_client = hs.get_proxied_http_client()
|
||||
|
||||
# identifier for the external_ids table
|
||||
self._auth_provider_id = "cas"
|
||||
|
||||
self._sso_handler = hs.get_sso_handler()
|
||||
|
||||
def _build_service_param(self, args: Dict[str, str]) -> str:
|
||||
@@ -267,6 +272,14 @@ class CasHandler:
|
||||
This should be the UI Auth session id.
|
||||
"""
|
||||
|
||||
# first check if we're doing a UIA
|
||||
if session:
|
||||
return await self._sso_handler.complete_sso_ui_auth_request(
|
||||
self._auth_provider_id, cas_response.username, session, request,
|
||||
)
|
||||
|
||||
# otherwise, we're handling a login request.
|
||||
|
||||
# Ensure that the attributes of the logged in user meet the required
|
||||
# attributes.
|
||||
for required_attribute, required_value in self._cas_required_attributes.items():
|
||||
@@ -293,54 +306,79 @@ class CasHandler:
|
||||
)
|
||||
return
|
||||
|
||||
# Pull out the user-agent and IP from the request.
|
||||
user_agent = request.get_user_agent("")
|
||||
ip_address = self.hs.get_ip_from_request(request)
|
||||
# Call the mapper to register/login the user
|
||||
|
||||
# Get the matrix ID from the CAS username.
|
||||
user_id = await self._map_cas_user_to_matrix_user(
|
||||
cas_response, user_agent, ip_address
|
||||
)
|
||||
# If this not a UI auth request than there must be a redirect URL.
|
||||
assert client_redirect_url is not None
|
||||
|
||||
if session:
|
||||
await self._auth_handler.complete_sso_ui_auth(
|
||||
user_id, session, request,
|
||||
)
|
||||
else:
|
||||
# If this not a UI auth request than there must be a redirect URL.
|
||||
assert client_redirect_url
|
||||
try:
|
||||
await self._complete_cas_login(cas_response, request, client_redirect_url)
|
||||
except MappingException as e:
|
||||
logger.exception("Could not map user")
|
||||
self._sso_handler.render_error(request, "mapping_error", str(e))
|
||||
|
||||
await self._auth_handler.complete_sso_login(
|
||||
user_id, request, client_redirect_url
|
||||
)
|
||||
|
||||
async def _map_cas_user_to_matrix_user(
|
||||
self, cas_response: CasResponse, user_agent: str, ip_address: str,
|
||||
) -> str:
|
||||
async def _complete_cas_login(
|
||||
self,
|
||||
cas_response: CasResponse,
|
||||
request: SynapseRequest,
|
||||
client_redirect_url: str,
|
||||
) -> None:
|
||||
"""
|
||||
Given a CAS username, retrieve the user ID for it and possibly register the user.
|
||||
Given a CAS response, complete the login flow
|
||||
|
||||
Retrieves the remote user ID, registers the user if necessary, and serves
|
||||
a redirect back to the client with a login-token.
|
||||
|
||||
Args:
|
||||
cas_response: The parsed CAS response.
|
||||
user_agent: The user agent of the client making the request.
|
||||
ip_address: The IP address of the client making the request.
|
||||
request: The request to respond to
|
||||
client_redirect_url: The redirect URL passed in by the client.
|
||||
|
||||
Returns:
|
||||
The user ID associated with this response.
|
||||
Raises:
|
||||
MappingException if there was a problem mapping the response to a user.
|
||||
RedirectException: some mapping providers may raise this if they need
|
||||
to redirect to an interstitial page.
|
||||
"""
|
||||
|
||||
# Note that CAS does not support a mapping provider, so the logic is hard-coded.
|
||||
localpart = map_username_to_mxid_localpart(cas_response.username)
|
||||
user_id = UserID(localpart, self._hostname).to_string()
|
||||
registered_user_id = await self._auth_handler.check_user_exists(user_id)
|
||||
|
||||
displayname = cas_response.attributes.get(self._cas_displayname_attribute, None)
|
||||
async def cas_response_to_user_attributes(failures: int) -> UserAttributes:
|
||||
"""
|
||||
Map from CAS attributes to user attributes.
|
||||
"""
|
||||
# Due to the grandfathering logic matching any previously registered
|
||||
# mxids it isn't expected for there to be any failures.
|
||||
if failures:
|
||||
raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs")
|
||||
|
||||
# If the user does not exist, register it.
|
||||
if not registered_user_id:
|
||||
registered_user_id = await self._registration_handler.register_user(
|
||||
localpart=localpart,
|
||||
default_display_name=displayname,
|
||||
user_agent_ips=[(user_agent, ip_address)],
|
||||
display_name = cas_response.attributes.get(
|
||||
self._cas_displayname_attribute, None
|
||||
)
|
||||
|
||||
return registered_user_id
|
||||
return UserAttributes(localpart=localpart, display_name=display_name)
|
||||
|
||||
async def grandfather_existing_users() -> Optional[str]:
|
||||
# Since CAS did not always use the user_external_ids table, always
|
||||
# to attempt to map to existing users.
|
||||
user_id = UserID(localpart, self._hostname).to_string()
|
||||
|
||||
logger.debug(
|
||||
"Looking for existing account based on mapped %s", user_id,
|
||||
)
|
||||
|
||||
users = await self._store.get_users_by_id_case_insensitive(user_id)
|
||||
if users:
|
||||
registered_user_id = list(users.keys())[0]
|
||||
logger.info("Grandfathering mapping to %s", registered_user_id)
|
||||
return registered_user_id
|
||||
|
||||
return None
|
||||
|
||||
await self._sso_handler.complete_sso_login_request(
|
||||
self._auth_provider_id,
|
||||
cas_response.username,
|
||||
request,
|
||||
client_redirect_url,
|
||||
cas_response_to_user_attributes,
|
||||
grandfather_existing_users,
|
||||
)
|
||||
|
||||
@@ -29,7 +29,7 @@ def _create_rerouter(func_name):
|
||||
|
||||
async def f(self, group_id, *args, **kwargs):
|
||||
if not GroupID.is_valid(group_id):
|
||||
raise SynapseError(400, "%s was not legal group ID" % (group_id,))
|
||||
raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
|
||||
|
||||
if self.is_mine_id(group_id):
|
||||
return await getattr(self.groups_server_handler, func_name)(
|
||||
|
||||
@@ -323,9 +323,7 @@ class InitialSyncHandler(BaseHandler):
|
||||
member_event_id: str,
|
||||
is_peeking: bool,
|
||||
) -> JsonDict:
|
||||
room_state = await self.state_store.get_state_for_events([member_event_id])
|
||||
|
||||
room_state = room_state[member_event_id]
|
||||
room_state = await self.state_store.get_state_for_event(member_event_id)
|
||||
|
||||
limit = pagin_config.limit if pagin_config else None
|
||||
if limit is None:
|
||||
|
||||
@@ -173,7 +173,7 @@ class SsoHandler:
|
||||
request: SynapseRequest,
|
||||
client_redirect_url: str,
|
||||
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
|
||||
grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
|
||||
grandfather_existing_users: Callable[[], Awaitable[Optional[str]]],
|
||||
extra_login_attributes: Optional[JsonDict] = None,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -241,7 +241,7 @@ class SsoHandler:
|
||||
)
|
||||
|
||||
# Check for grandfathering of users.
|
||||
if not user_id and grandfather_existing_users:
|
||||
if not user_id:
|
||||
user_id = await grandfather_existing_users()
|
||||
if user_id:
|
||||
# Future logins should also match this user ID.
|
||||
|
||||
@@ -568,7 +568,7 @@ class SyncHandler:
|
||||
event.event_id, state_filter=state_filter
|
||||
)
|
||||
if event.is_state():
|
||||
state_ids = state_ids.copy()
|
||||
state_ids = dict(state_ids)
|
||||
state_ids[(event.type, event.state_key)] = event.event_id
|
||||
return state_ids
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from functools import wraps
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
@@ -25,6 +26,22 @@ from ._base import client_patterns
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_group_id(f):
|
||||
"""Wrapper to validate the form of the group ID.
|
||||
|
||||
Can be applied to any on_FOO methods that accepts a group ID as a URL parameter.
|
||||
"""
|
||||
|
||||
@wraps(f)
|
||||
def wrapper(self, request, group_id, *args, **kwargs):
|
||||
if not GroupID.is_valid(group_id):
|
||||
raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
|
||||
|
||||
return f(self, request, group_id, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class GroupServlet(RestServlet):
|
||||
"""Get the group profile
|
||||
"""
|
||||
@@ -37,6 +54,7 @@ class GroupServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request, group_id):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -47,6 +65,7 @@ class GroupServlet(RestServlet):
|
||||
|
||||
return 200, group_description
|
||||
|
||||
@_validate_group_id
|
||||
async def on_POST(self, request, group_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -71,6 +90,7 @@ class GroupSummaryServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request, group_id):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -102,6 +122,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id, category_id, room_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -117,6 +138,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
|
||||
|
||||
return 200, resp
|
||||
|
||||
@_validate_group_id
|
||||
async def on_DELETE(self, request, group_id, category_id, room_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -142,6 +164,7 @@ class GroupCategoryServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request, group_id, category_id):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -152,6 +175,7 @@ class GroupCategoryServlet(RestServlet):
|
||||
|
||||
return 200, category
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id, category_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -163,6 +187,7 @@ class GroupCategoryServlet(RestServlet):
|
||||
|
||||
return 200, resp
|
||||
|
||||
@_validate_group_id
|
||||
async def on_DELETE(self, request, group_id, category_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -186,6 +211,7 @@ class GroupCategoriesServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request, group_id):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -209,6 +235,7 @@ class GroupRoleServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request, group_id, role_id):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -219,6 +246,7 @@ class GroupRoleServlet(RestServlet):
|
||||
|
||||
return 200, category
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id, role_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -230,6 +258,7 @@ class GroupRoleServlet(RestServlet):
|
||||
|
||||
return 200, resp
|
||||
|
||||
@_validate_group_id
|
||||
async def on_DELETE(self, request, group_id, role_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -253,6 +282,7 @@ class GroupRolesServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request, group_id):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -284,6 +314,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id, role_id, user_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -299,6 +330,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
|
||||
|
||||
return 200, resp
|
||||
|
||||
@_validate_group_id
|
||||
async def on_DELETE(self, request, group_id, role_id, user_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -322,13 +354,11 @@ class GroupRoomServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request, group_id):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
if not GroupID.is_valid(group_id):
|
||||
raise SynapseError(400, "%s was not legal group ID" % (group_id,))
|
||||
|
||||
result = await self.groups_handler.get_rooms_in_group(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
@@ -348,6 +378,7 @@ class GroupUsersServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request, group_id):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -371,6 +402,7 @@ class GroupInvitedUsersServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request, group_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -393,6 +425,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
|
||||
self.auth = hs.get_auth()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -449,6 +482,7 @@ class GroupAdminRoomsServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id, room_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -460,6 +494,7 @@ class GroupAdminRoomsServlet(RestServlet):
|
||||
|
||||
return 200, result
|
||||
|
||||
@_validate_group_id
|
||||
async def on_DELETE(self, request, group_id, room_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -486,6 +521,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id, room_id, config_key):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -514,6 +550,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
|
||||
self.store = hs.get_datastore()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id, user_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -541,6 +578,7 @@ class GroupAdminUsersKickServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id, user_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -565,6 +603,7 @@ class GroupSelfLeaveServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -589,6 +628,7 @@ class GroupSelfJoinServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -613,6 +653,7 @@ class GroupSelfAcceptInviteServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -637,6 +678,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request, group_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Set
|
||||
from typing import Dict
|
||||
|
||||
from signedjson.sign import sign_json
|
||||
|
||||
@@ -142,12 +142,13 @@ class RemoteKey(DirectServeJsonResource):
|
||||
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
cache_misses = {} # type: Dict[str, Set[str]]
|
||||
# Note that the value is unused.
|
||||
cache_misses = {} # type: Dict[str, Dict[str, int]]
|
||||
for (server_name, key_id, from_server), results in cached.items():
|
||||
results = [(result["ts_added_ms"], result) for result in results]
|
||||
|
||||
if not results and key_id is not None:
|
||||
cache_misses.setdefault(server_name, set()).add(key_id)
|
||||
cache_misses.setdefault(server_name, {})[key_id] = 0
|
||||
continue
|
||||
|
||||
if key_id is not None:
|
||||
@@ -201,7 +202,7 @@ class RemoteKey(DirectServeJsonResource):
|
||||
)
|
||||
|
||||
if miss:
|
||||
cache_misses.setdefault(server_name, set()).add(key_id)
|
||||
cache_misses.setdefault(server_name, {})[key_id] = 0
|
||||
# Cast to bytes since postgresql returns a memoryview.
|
||||
json_results.add(bytes(most_recent_result["key_json"]))
|
||||
else:
|
||||
|
||||
@@ -27,6 +27,7 @@ There are also schemas that get applied to every database, regardless of the
|
||||
data stores associated with them (e.g. the schema version tables), which are
|
||||
stored in `synapse.storage.schema`.
|
||||
"""
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from synapse.storage.databases import Databases
|
||||
from synapse.storage.databases.main import DataStore
|
||||
@@ -34,14 +35,18 @@ from synapse.storage.persist_events import EventsPersistenceStorage
|
||||
from synapse.storage.purge_events import PurgeEventsStorage
|
||||
from synapse.storage.state import StateGroupStorage
|
||||
|
||||
__all__ = ["DataStores", "DataStore"]
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
|
||||
__all__ = ["Databases", "DataStore"]
|
||||
|
||||
|
||||
class Storage:
|
||||
"""The high level interfaces for talking to various storage layers.
|
||||
"""
|
||||
|
||||
def __init__(self, hs, stores: Databases):
|
||||
def __init__(self, hs: "HomeServer", stores: Databases):
|
||||
# We include the main data store here mainly so that we don't have to
|
||||
# rewrite all the existing code to split it into high vs low level
|
||||
# interfaces.
|
||||
|
||||
@@ -17,14 +17,18 @@
|
||||
import logging
|
||||
import random
|
||||
from abc import ABCMeta
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Iterable, Optional, Union
|
||||
|
||||
from synapse.storage.database import LoggingTransaction # noqa: F401
|
||||
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.types import Collection, get_domain_from_id
|
||||
from synapse.storage.types import Connection
|
||||
from synapse.types import Collection, StreamToken, get_domain_from_id
|
||||
from synapse.util import json_decoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -36,24 +40,31 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||
per data store (and not one per physical database).
|
||||
"""
|
||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self._clock = hs.get_clock()
|
||||
self.database_engine = database.engine
|
||||
self.db_pool = database
|
||||
self.rand = random.SystemRandom()
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
def process_replication_rows(
|
||||
self,
|
||||
stream_name: str,
|
||||
instance_name: str,
|
||||
token: StreamToken,
|
||||
rows: Iterable[Any],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def _invalidate_state_caches(self, room_id, members_changed):
|
||||
def _invalidate_state_caches(
|
||||
self, room_id: str, members_changed: Iterable[str]
|
||||
) -> None:
|
||||
"""Invalidates caches that are based on the current state, but does
|
||||
not stream invalidations down replication.
|
||||
|
||||
Args:
|
||||
room_id (str): Room where state changed
|
||||
members_changed (iterable[str]): The user_ids of members that have
|
||||
changed
|
||||
room_id: Room where state changed
|
||||
members_changed: The user_ids of members that have changed
|
||||
"""
|
||||
for host in {get_domain_from_id(u) for u in members_changed}:
|
||||
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
|
||||
@@ -64,7 +75,7 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||
|
||||
def _attempt_to_invalidate_cache(
|
||||
self, cache_name: str, key: Optional[Collection[Any]]
|
||||
):
|
||||
) -> None:
|
||||
"""Attempts to invalidate the cache of the given name, ignoring if the
|
||||
cache doesn't exist. Mainly used for invalidating caches on workers,
|
||||
where they may not have the cache.
|
||||
@@ -88,12 +99,15 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||
cache.invalidate(tuple(key))
|
||||
|
||||
|
||||
def db_to_json(db_content):
|
||||
def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
|
||||
"""
|
||||
Take some data from a database row and return a JSON-decoded object.
|
||||
|
||||
Args:
|
||||
db_content (memoryview|buffer|bytes|bytearray|unicode)
|
||||
db_content: The JSON-encoded contents from the database.
|
||||
|
||||
Returns:
|
||||
The object decoded from JSON.
|
||||
"""
|
||||
# psycopg2 on Python 3 returns memoryview objects, which we need to
|
||||
# cast to bytes to decode
|
||||
|
||||
@@ -12,29 +12,34 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.types import Connection
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
|
||||
from . import engines
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BackgroundUpdatePerformance:
|
||||
"""Tracks the how long a background update is taking to update its items"""
|
||||
|
||||
def __init__(self, name):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.total_item_count = 0
|
||||
self.total_duration_ms = 0
|
||||
self.avg_item_count = 0
|
||||
self.avg_duration_ms = 0
|
||||
self.total_duration_ms = 0.0
|
||||
self.avg_item_count = 0.0
|
||||
self.avg_duration_ms = 0.0
|
||||
|
||||
def update(self, item_count, duration_ms):
|
||||
def update(self, item_count: int, duration_ms: float) -> None:
|
||||
"""Update the stats after doing an update"""
|
||||
self.total_item_count += item_count
|
||||
self.total_duration_ms += duration_ms
|
||||
@@ -44,7 +49,7 @@ class BackgroundUpdatePerformance:
|
||||
self.avg_item_count += 0.1 * (item_count - self.avg_item_count)
|
||||
self.avg_duration_ms += 0.1 * (duration_ms - self.avg_duration_ms)
|
||||
|
||||
def average_items_per_ms(self):
|
||||
def average_items_per_ms(self) -> Optional[float]:
|
||||
"""An estimate of how long it takes to do a single update.
|
||||
Returns:
|
||||
A duration in ms as a float
|
||||
@@ -58,7 +63,7 @@ class BackgroundUpdatePerformance:
|
||||
# changes in how long the update process takes.
|
||||
return float(self.avg_item_count) / float(self.avg_duration_ms)
|
||||
|
||||
def total_items_per_ms(self):
|
||||
def total_items_per_ms(self) -> Optional[float]:
|
||||
"""An estimate of how long it takes to do a single update.
|
||||
Returns:
|
||||
A duration in ms as a float
|
||||
@@ -83,21 +88,25 @@ class BackgroundUpdater:
|
||||
BACKGROUND_UPDATE_INTERVAL_MS = 1000
|
||||
BACKGROUND_UPDATE_DURATION_MS = 100
|
||||
|
||||
def __init__(self, hs, database):
|
||||
def __init__(self, hs: "HomeServer", database: "DatabasePool"):
|
||||
self._clock = hs.get_clock()
|
||||
self.db_pool = database
|
||||
|
||||
# if a background update is currently running, its name.
|
||||
self._current_background_update = None # type: Optional[str]
|
||||
|
||||
self._background_update_performance = {}
|
||||
self._background_update_handlers = {}
|
||||
self._background_update_performance = (
|
||||
{}
|
||||
) # type: Dict[str, BackgroundUpdatePerformance]
|
||||
self._background_update_handlers = (
|
||||
{}
|
||||
) # type: Dict[str, Callable[[JsonDict, int], Awaitable[int]]]
|
||||
self._all_done = False
|
||||
|
||||
def start_doing_background_updates(self):
|
||||
def start_doing_background_updates(self) -> None:
|
||||
run_as_background_process("background_updates", self.run_background_updates)
|
||||
|
||||
async def run_background_updates(self, sleep=True):
|
||||
async def run_background_updates(self, sleep: bool = True) -> None:
|
||||
logger.info("Starting background schema updates")
|
||||
while True:
|
||||
if sleep:
|
||||
@@ -148,7 +157,7 @@ class BackgroundUpdater:
|
||||
|
||||
return False
|
||||
|
||||
async def has_completed_background_update(self, update_name) -> bool:
|
||||
async def has_completed_background_update(self, update_name: str) -> bool:
|
||||
"""Check if the given background update has finished running.
|
||||
"""
|
||||
if self._all_done:
|
||||
@@ -173,8 +182,7 @@ class BackgroundUpdater:
|
||||
Returns once some amount of work is done.
|
||||
|
||||
Args:
|
||||
desired_duration_ms(float): How long we want to spend
|
||||
updating.
|
||||
desired_duration_ms: How long we want to spend updating.
|
||||
Returns:
|
||||
True if we have finished running all the background updates, otherwise False
|
||||
"""
|
||||
@@ -220,6 +228,7 @@ class BackgroundUpdater:
|
||||
return False
|
||||
|
||||
async def _do_background_update(self, desired_duration_ms: float) -> int:
|
||||
assert self._current_background_update is not None
|
||||
update_name = self._current_background_update
|
||||
logger.info("Starting update batch on background update '%s'", update_name)
|
||||
|
||||
@@ -273,7 +282,11 @@ class BackgroundUpdater:
|
||||
|
||||
return len(self._background_update_performance)
|
||||
|
||||
def register_background_update_handler(self, update_name, update_handler):
|
||||
def register_background_update_handler(
|
||||
self,
|
||||
update_name: str,
|
||||
update_handler: Callable[[JsonDict, int], Awaitable[int]],
|
||||
):
|
||||
"""Register a handler for doing a background update.
|
||||
|
||||
The handler should take two arguments:
|
||||
@@ -287,12 +300,12 @@ class BackgroundUpdater:
|
||||
The handler is responsible for updating the progress of the update.
|
||||
|
||||
Args:
|
||||
update_name(str): The name of the update that this code handles.
|
||||
update_handler(function): The function that does the update.
|
||||
update_name: The name of the update that this code handles.
|
||||
update_handler: The function that does the update.
|
||||
"""
|
||||
self._background_update_handlers[update_name] = update_handler
|
||||
|
||||
def register_noop_background_update(self, update_name):
|
||||
def register_noop_background_update(self, update_name: str) -> None:
|
||||
"""Register a noop handler for a background update.
|
||||
|
||||
This is useful when we previously did a background update, but no
|
||||
@@ -302,10 +315,10 @@ class BackgroundUpdater:
|
||||
also be called to clear the update.
|
||||
|
||||
Args:
|
||||
update_name (str): Name of update
|
||||
update_name: Name of update
|
||||
"""
|
||||
|
||||
async def noop_update(progress, batch_size):
|
||||
async def noop_update(progress: JsonDict, batch_size: int) -> int:
|
||||
await self._end_background_update(update_name)
|
||||
return 1
|
||||
|
||||
@@ -313,14 +326,14 @@ class BackgroundUpdater:
|
||||
|
||||
def register_background_index_update(
|
||||
self,
|
||||
update_name,
|
||||
index_name,
|
||||
table,
|
||||
columns,
|
||||
where_clause=None,
|
||||
unique=False,
|
||||
psql_only=False,
|
||||
):
|
||||
update_name: str,
|
||||
index_name: str,
|
||||
table: str,
|
||||
columns: Iterable[str],
|
||||
where_clause: Optional[str] = None,
|
||||
unique: bool = False,
|
||||
psql_only: bool = False,
|
||||
) -> None:
|
||||
"""Helper for store classes to do a background index addition
|
||||
|
||||
To use:
|
||||
@@ -332,19 +345,19 @@ class BackgroundUpdater:
|
||||
2. In the Store constructor, call this method
|
||||
|
||||
Args:
|
||||
update_name (str): update_name to register for
|
||||
index_name (str): name of index to add
|
||||
table (str): table to add index to
|
||||
columns (list[str]): columns/expressions to include in index
|
||||
unique (bool): true to make a UNIQUE index
|
||||
update_name: update_name to register for
|
||||
index_name: name of index to add
|
||||
table: table to add index to
|
||||
columns: columns/expressions to include in index
|
||||
unique: true to make a UNIQUE index
|
||||
psql_only: true to only create this index on psql databases (useful
|
||||
for virtual sqlite tables)
|
||||
"""
|
||||
|
||||
def create_index_psql(conn):
|
||||
def create_index_psql(conn: Connection) -> None:
|
||||
conn.rollback()
|
||||
# postgres insists on autocommit for the index
|
||||
conn.set_session(autocommit=True)
|
||||
conn.set_session(autocommit=True) # type: ignore
|
||||
|
||||
try:
|
||||
c = conn.cursor()
|
||||
@@ -371,9 +384,9 @@ class BackgroundUpdater:
|
||||
logger.debug("[SQL] %s", sql)
|
||||
c.execute(sql)
|
||||
finally:
|
||||
conn.set_session(autocommit=False)
|
||||
conn.set_session(autocommit=False) # type: ignore
|
||||
|
||||
def create_index_sqlite(conn):
|
||||
def create_index_sqlite(conn: Connection) -> None:
|
||||
# Sqlite doesn't support concurrent creation of indexes.
|
||||
#
|
||||
# We don't use partial indices on SQLite as it wasn't introduced
|
||||
@@ -399,7 +412,7 @@ class BackgroundUpdater:
|
||||
c.execute(sql)
|
||||
|
||||
if isinstance(self.db_pool.engine, engines.PostgresEngine):
|
||||
runner = create_index_psql
|
||||
runner = create_index_psql # type: Optional[Callable[[Connection], None]]
|
||||
elif psql_only:
|
||||
runner = None
|
||||
else:
|
||||
@@ -433,7 +446,9 @@ class BackgroundUpdater:
|
||||
"background_updates", keyvalues={"update_name": update_name}
|
||||
)
|
||||
|
||||
async def _background_update_progress(self, update_name: str, progress: dict):
|
||||
async def _background_update_progress(
|
||||
self, update_name: str, progress: dict
|
||||
) -> None:
|
||||
"""Update the progress of a background update
|
||||
|
||||
Args:
|
||||
@@ -441,20 +456,22 @@ class BackgroundUpdater:
|
||||
progress: The progress of the update.
|
||||
"""
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"background_update_progress",
|
||||
self._background_update_progress_txn,
|
||||
update_name,
|
||||
progress,
|
||||
)
|
||||
|
||||
def _background_update_progress_txn(self, txn, update_name, progress):
|
||||
def _background_update_progress_txn(
|
||||
self, txn: "LoggingTransaction", update_name: str, progress: JsonDict
|
||||
) -> None:
|
||||
"""Update the progress of a background update
|
||||
|
||||
Args:
|
||||
txn(cursor): The transaction.
|
||||
update_name(str): The name of the background update task
|
||||
progress(dict): The progress of the update.
|
||||
txn: The transaction.
|
||||
update_name: The name of the background update task
|
||||
progress: The progress of the update.
|
||||
"""
|
||||
|
||||
progress_json = json_encoder.encode(progress)
|
||||
|
||||
@@ -22,6 +22,7 @@ from signedjson.key import decode_verify_key_bytes
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.keys import FetchKeyResult
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
||||
@@ -44,7 +45,7 @@ class KeyStore(SQLBaseStore):
|
||||
)
|
||||
async def get_server_verify_keys(
|
||||
self, server_name_and_key_ids: Iterable[Tuple[str, str]]
|
||||
) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]:
|
||||
) -> Dict[Tuple[str, str], FetchKeyResult]:
|
||||
"""
|
||||
Args:
|
||||
server_name_and_key_ids:
|
||||
@@ -56,7 +57,7 @@ class KeyStore(SQLBaseStore):
|
||||
"""
|
||||
keys = {}
|
||||
|
||||
def _get_keys(txn, batch):
|
||||
def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str]]) -> None:
|
||||
"""Processes a batch of keys to fetch, and adds the result to `keys`."""
|
||||
|
||||
# batch_iter always returns tuples so it's safe to do len(batch)
|
||||
@@ -77,13 +78,12 @@ class KeyStore(SQLBaseStore):
|
||||
# `ts_valid_until_ms`.
|
||||
ts_valid_until_ms = 0
|
||||
|
||||
res = FetchKeyResult(
|
||||
keys[(server_name, key_id)] = FetchKeyResult(
|
||||
verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
|
||||
valid_until_ts=ts_valid_until_ms,
|
||||
)
|
||||
keys[(server_name, key_id)] = res
|
||||
|
||||
def _txn(txn):
|
||||
def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:
|
||||
for batch in batch_iter(server_name_and_key_ids, 50):
|
||||
_get_keys(txn, batch)
|
||||
return keys
|
||||
|
||||
@@ -17,11 +17,12 @@
|
||||
import logging
|
||||
|
||||
import attr
|
||||
from signedjson.types import VerifyKey
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class FetchKeyResult:
|
||||
verify_key = attr.ib() # VerifyKey: the key itself
|
||||
valid_until_ts = attr.ib() # int: how long we can use this key for
|
||||
verify_key = attr.ib(type=VerifyKey) # the key itself
|
||||
valid_until_ts = attr.ib(type=int) # how long we can use this key for
|
||||
|
||||
@@ -18,9 +18,10 @@ import logging
|
||||
import os
|
||||
import re
|
||||
from collections import Counter
|
||||
from typing import Optional, TextIO
|
||||
from typing import Generator, Iterable, List, Optional, TextIO, Tuple
|
||||
|
||||
import attr
|
||||
from typing_extensions import Counter as CounterType
|
||||
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.storage.database import LoggingDatabaseConnection
|
||||
@@ -70,7 +71,7 @@ def prepare_database(
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
database_engine: BaseDatabaseEngine,
|
||||
config: Optional[HomeServerConfig],
|
||||
databases: Collection[str] = ["main", "state"],
|
||||
databases: Collection[str] = ("main", "state"),
|
||||
):
|
||||
"""Prepares a physical database for usage. Will either create all necessary tables
|
||||
or upgrade from an older schema version.
|
||||
@@ -155,7 +156,9 @@ def prepare_database(
|
||||
raise
|
||||
|
||||
|
||||
def _setup_new_database(cur, database_engine, databases):
|
||||
def _setup_new_database(
|
||||
cur: Cursor, database_engine: BaseDatabaseEngine, databases: Collection[str]
|
||||
) -> None:
|
||||
"""Sets up the physical database by finding a base set of "full schemas" and
|
||||
then applying any necessary deltas, including schemas from the given data
|
||||
stores.
|
||||
@@ -188,10 +191,9 @@ def _setup_new_database(cur, database_engine, databases):
|
||||
folder as well those in the data stores specified.
|
||||
|
||||
Args:
|
||||
cur (Cursor): a database cursor
|
||||
database_engine (DatabaseEngine)
|
||||
databases (list[str]): The names of the databases to instantiate
|
||||
on the given physical database.
|
||||
cur: a database cursor
|
||||
database_engine
|
||||
databases: The names of the databases to instantiate on the given physical database.
|
||||
"""
|
||||
|
||||
# We're about to set up a brand new database so we check that its
|
||||
@@ -199,12 +201,11 @@ def _setup_new_database(cur, database_engine, databases):
|
||||
database_engine.check_new_database(cur)
|
||||
|
||||
current_dir = os.path.join(dir_path, "schema", "full_schemas")
|
||||
directory_entries = os.listdir(current_dir)
|
||||
|
||||
# First we find the highest full schema version we have
|
||||
valid_versions = []
|
||||
|
||||
for filename in directory_entries:
|
||||
for filename in os.listdir(current_dir):
|
||||
try:
|
||||
ver = int(filename)
|
||||
except ValueError:
|
||||
@@ -237,7 +238,7 @@ def _setup_new_database(cur, database_engine, databases):
|
||||
for database in databases
|
||||
)
|
||||
|
||||
directory_entries = []
|
||||
directory_entries = [] # type: List[_DirectoryListing]
|
||||
for directory in directories:
|
||||
directory_entries.extend(
|
||||
_DirectoryListing(file_name, os.path.join(directory, file_name))
|
||||
@@ -275,15 +276,15 @@ def _setup_new_database(cur, database_engine, databases):
|
||||
|
||||
|
||||
def _upgrade_existing_database(
|
||||
cur,
|
||||
current_version,
|
||||
applied_delta_files,
|
||||
upgraded,
|
||||
database_engine,
|
||||
config,
|
||||
databases,
|
||||
is_empty=False,
|
||||
):
|
||||
cur: Cursor,
|
||||
current_version: int,
|
||||
applied_delta_files: List[str],
|
||||
upgraded: bool,
|
||||
database_engine: BaseDatabaseEngine,
|
||||
config: Optional[HomeServerConfig],
|
||||
databases: Collection[str],
|
||||
is_empty: bool = False,
|
||||
) -> None:
|
||||
"""Upgrades an existing physical database.
|
||||
|
||||
Delta files can either be SQL stored in *.sql files, or python modules
|
||||
@@ -323,21 +324,20 @@ def _upgrade_existing_database(
|
||||
for a version before applying those in the next version.
|
||||
|
||||
Args:
|
||||
cur (Cursor)
|
||||
current_version (int): The current version of the schema.
|
||||
applied_delta_files (list): A list of deltas that have already been
|
||||
applied.
|
||||
upgraded (bool): Whether the current version was generated by having
|
||||
cur
|
||||
current_version: The current version of the schema.
|
||||
applied_delta_files: A list of deltas that have already been applied.
|
||||
upgraded: Whether the current version was generated by having
|
||||
applied deltas or from full schema file. If `True` the function
|
||||
will never apply delta files for the given `current_version`, since
|
||||
the current_version wasn't generated by applying those delta files.
|
||||
database_engine (DatabaseEngine)
|
||||
config (synapse.config.homeserver.HomeServerConfig|None):
|
||||
database_engine
|
||||
config:
|
||||
None if we are initialising a blank database, otherwise the application
|
||||
config
|
||||
databases (list[str]): The names of the databases to instantiate
|
||||
databases: The names of the databases to instantiate
|
||||
on the given physical database.
|
||||
is_empty (bool): Is this a blank database? I.e. do we need to run the
|
||||
is_empty: Is this a blank database? I.e. do we need to run the
|
||||
upgrade portions of the delta scripts.
|
||||
"""
|
||||
if is_empty:
|
||||
@@ -358,6 +358,7 @@ def _upgrade_existing_database(
|
||||
if not is_empty and "main" in databases:
|
||||
from synapse.storage.databases.main import check_database_before_upgrade
|
||||
|
||||
assert config is not None
|
||||
check_database_before_upgrade(cur, database_engine, config)
|
||||
|
||||
start_ver = current_version
|
||||
@@ -388,10 +389,10 @@ def _upgrade_existing_database(
|
||||
)
|
||||
|
||||
# Used to check if we have any duplicate file names
|
||||
file_name_counter = Counter()
|
||||
file_name_counter = Counter() # type: CounterType[str]
|
||||
|
||||
# Now find which directories have anything of interest.
|
||||
directory_entries = []
|
||||
directory_entries = [] # type: List[_DirectoryListing]
|
||||
for directory in directories:
|
||||
logger.debug("Looking for schema deltas in %s", directory)
|
||||
try:
|
||||
@@ -445,11 +446,11 @@ def _upgrade_existing_database(
|
||||
|
||||
module_name = "synapse.storage.v%d_%s" % (v, root_name)
|
||||
with open(absolute_path) as python_file:
|
||||
module = imp.load_source(module_name, absolute_path, python_file)
|
||||
module = imp.load_source(module_name, absolute_path, python_file) # type: ignore
|
||||
logger.info("Running script %s", relative_path)
|
||||
module.run_create(cur, database_engine)
|
||||
module.run_create(cur, database_engine) # type: ignore
|
||||
if not is_empty:
|
||||
module.run_upgrade(cur, database_engine, config=config)
|
||||
module.run_upgrade(cur, database_engine, config=config) # type: ignore
|
||||
elif ext == ".pyc" or file_name == "__pycache__":
|
||||
# Sometimes .pyc files turn up anyway even though we've
|
||||
# disabled their generation; e.g. from distribution package
|
||||
@@ -497,14 +498,15 @@ def _upgrade_existing_database(
|
||||
logger.info("Schema now up to date")
|
||||
|
||||
|
||||
def _apply_module_schemas(txn, database_engine, config):
|
||||
def _apply_module_schemas(
|
||||
txn: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig
|
||||
) -> None:
|
||||
"""Apply the module schemas for the dynamic modules, if any
|
||||
|
||||
Args:
|
||||
cur: database cursor
|
||||
database_engine: synapse database engine class
|
||||
config (synapse.config.homeserver.HomeServerConfig):
|
||||
application config
|
||||
database_engine:
|
||||
config: application config
|
||||
"""
|
||||
for (mod, _config) in config.password_providers:
|
||||
if not hasattr(mod, "get_db_schema_files"):
|
||||
@@ -515,15 +517,19 @@ def _apply_module_schemas(txn, database_engine, config):
|
||||
)
|
||||
|
||||
|
||||
def _apply_module_schema_files(cur, database_engine, modname, names_and_streams):
|
||||
def _apply_module_schema_files(
|
||||
cur: Cursor,
|
||||
database_engine: BaseDatabaseEngine,
|
||||
modname: str,
|
||||
names_and_streams: Iterable[Tuple[str, TextIO]],
|
||||
) -> None:
|
||||
"""Apply the module schemas for a single module
|
||||
|
||||
Args:
|
||||
cur: database cursor
|
||||
database_engine: synapse database engine class
|
||||
modname (str): fully qualified name of the module
|
||||
names_and_streams (Iterable[(str, file)]): the names and streams of
|
||||
schemas to be applied
|
||||
modname: fully qualified name of the module
|
||||
names_and_streams: the names and streams of schemas to be applied
|
||||
"""
|
||||
cur.execute(
|
||||
"SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
|
||||
@@ -549,7 +555,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
|
||||
)
|
||||
|
||||
|
||||
def get_statements(f):
|
||||
def get_statements(f: Iterable[str]) -> Generator[str, None, None]:
|
||||
statement_buffer = ""
|
||||
in_comment = False # If we're in a /* ... */ style comment
|
||||
|
||||
@@ -594,17 +600,19 @@ def get_statements(f):
|
||||
statement_buffer = statements[-1].strip()
|
||||
|
||||
|
||||
def executescript(txn, schema_path):
|
||||
def executescript(txn: Cursor, schema_path: str) -> None:
|
||||
with open(schema_path, "r") as f:
|
||||
execute_statements_from_stream(txn, f)
|
||||
|
||||
|
||||
def execute_statements_from_stream(cur: Cursor, f: TextIO):
|
||||
def execute_statements_from_stream(cur: Cursor, f: TextIO) -> None:
|
||||
for statement in get_statements(f):
|
||||
cur.execute(statement)
|
||||
|
||||
|
||||
def _get_or_create_schema_state(txn, database_engine):
|
||||
def _get_or_create_schema_state(
|
||||
txn: Cursor, database_engine: BaseDatabaseEngine
|
||||
) -> Optional[Tuple[int, List[str], bool]]:
|
||||
# Bluntly try creating the schema_version tables.
|
||||
schema_path = os.path.join(dir_path, "schema", "schema_version.sql")
|
||||
executescript(txn, schema_path)
|
||||
@@ -612,7 +620,6 @@ def _get_or_create_schema_state(txn, database_engine):
|
||||
txn.execute("SELECT version, upgraded FROM schema_version")
|
||||
row = txn.fetchone()
|
||||
current_version = int(row[0]) if row else None
|
||||
upgraded = bool(row[1]) if row else None
|
||||
|
||||
if current_version:
|
||||
txn.execute(
|
||||
@@ -620,6 +627,7 @@ def _get_or_create_schema_state(txn, database_engine):
|
||||
(current_version,),
|
||||
)
|
||||
applied_deltas = [d for d, in txn]
|
||||
upgraded = bool(row[1])
|
||||
return current_version, applied_deltas, upgraded
|
||||
|
||||
return None
|
||||
@@ -634,5 +642,5 @@ class _DirectoryListing:
|
||||
`file_name` attr is kept first.
|
||||
"""
|
||||
|
||||
file_name = attr.ib()
|
||||
absolute_path = attr.ib()
|
||||
file_name = attr.ib(type=str)
|
||||
absolute_path = attr.ib(type=str)
|
||||
|
||||
@@ -15,7 +15,12 @@
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Set
|
||||
from typing import TYPE_CHECKING, Set
|
||||
|
||||
from synapse.storage.databases import Databases
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,10 +29,10 @@ class PurgeEventsStorage:
|
||||
"""High level interface for purging rooms and event history.
|
||||
"""
|
||||
|
||||
def __init__(self, hs, stores):
|
||||
def __init__(self, hs: "HomeServer", stores: Databases):
|
||||
self.stores = stores
|
||||
|
||||
async def purge_room(self, room_id: str):
|
||||
async def purge_room(self, room_id: str) -> None:
|
||||
"""Deletes all record of a room
|
||||
"""
|
||||
|
||||
|
||||
@@ -14,10 +14,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.types import JsonDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -27,18 +29,18 @@ class PaginationChunk:
|
||||
"""Returned by relation pagination APIs.
|
||||
|
||||
Attributes:
|
||||
chunk (list): The rows returned by pagination
|
||||
next_batch (Any|None): Token to fetch next set of results with, if
|
||||
chunk: The rows returned by pagination
|
||||
next_batch: Token to fetch next set of results with, if
|
||||
None then there are no more results.
|
||||
prev_batch (Any|None): Token to fetch previous set of results with, if
|
||||
prev_batch: Token to fetch previous set of results with, if
|
||||
None then there are no previous results.
|
||||
"""
|
||||
|
||||
chunk = attr.ib()
|
||||
next_batch = attr.ib(default=None)
|
||||
prev_batch = attr.ib(default=None)
|
||||
chunk = attr.ib(type=List[JsonDict])
|
||||
next_batch = attr.ib(type=Optional[Any], default=None)
|
||||
prev_batch = attr.ib(type=Optional[Any], default=None)
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d = {"chunk": self.chunk}
|
||||
|
||||
if self.next_batch:
|
||||
@@ -59,25 +61,25 @@ class RelationPaginationToken:
|
||||
boundaries of the chunk as pagination tokens.
|
||||
|
||||
Attributes:
|
||||
topological (int): The topological ordering of the boundary event
|
||||
stream (int): The stream ordering of the boundary event.
|
||||
topological: The topological ordering of the boundary event
|
||||
stream: The stream ordering of the boundary event.
|
||||
"""
|
||||
|
||||
topological = attr.ib()
|
||||
stream = attr.ib()
|
||||
topological = attr.ib(type=int)
|
||||
stream = attr.ib(type=int)
|
||||
|
||||
@staticmethod
|
||||
def from_string(string):
|
||||
def from_string(string: str) -> "RelationPaginationToken":
|
||||
try:
|
||||
t, s = string.split("-")
|
||||
return RelationPaginationToken(int(t), int(s))
|
||||
except ValueError:
|
||||
raise SynapseError(400, "Invalid token")
|
||||
|
||||
def to_string(self):
|
||||
def to_string(self) -> str:
|
||||
return "%d-%d" % (self.topological, self.stream)
|
||||
|
||||
def as_tuple(self):
|
||||
def as_tuple(self) -> Tuple[Any, ...]:
|
||||
return attr.astuple(self)
|
||||
|
||||
|
||||
@@ -89,23 +91,23 @@ class AggregationPaginationToken:
|
||||
aggregation groups, we can just use them as our pagination token.
|
||||
|
||||
Attributes:
|
||||
count (int): The count of relations in the boundar group.
|
||||
stream (int): The MAX stream ordering in the boundary group.
|
||||
count: The count of relations in the boundary group.
|
||||
stream: The MAX stream ordering in the boundary group.
|
||||
"""
|
||||
|
||||
count = attr.ib()
|
||||
stream = attr.ib()
|
||||
count = attr.ib(type=int)
|
||||
stream = attr.ib(type=int)
|
||||
|
||||
@staticmethod
|
||||
def from_string(string):
|
||||
def from_string(string: str) -> "AggregationPaginationToken":
|
||||
try:
|
||||
c, s = string.split("-")
|
||||
return AggregationPaginationToken(int(c), int(s))
|
||||
except ValueError:
|
||||
raise SynapseError(400, "Invalid token")
|
||||
|
||||
def to_string(self):
|
||||
def to_string(self) -> str:
|
||||
return "%d-%d" % (self.count, self.stream)
|
||||
|
||||
def as_tuple(self):
|
||||
def as_tuple(self) -> Tuple[Any, ...]:
|
||||
return attr.astuple(self)
|
||||
|
||||
@@ -12,9 +12,18 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Awaitable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import attr
|
||||
|
||||
@@ -22,6 +31,10 @@ from synapse.api.constants import EventTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import MutableStateMap, StateMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
from synapse.storage.databases import Databases
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Used for generic functions below
|
||||
@@ -330,10 +343,12 @@ class StateGroupStorage:
|
||||
"""High level interface to fetching state for event.
|
||||
"""
|
||||
|
||||
def __init__(self, hs, stores):
|
||||
def __init__(self, hs: "HomeServer", stores: "Databases"):
|
||||
self.stores = stores
|
||||
|
||||
async def get_state_group_delta(self, state_group: int):
|
||||
async def get_state_group_delta(
|
||||
self, state_group: int
|
||||
) -> Tuple[Optional[int], Optional[StateMap[str]]]:
|
||||
"""Given a state group try to return a previous group and a delta between
|
||||
the old and the new.
|
||||
|
||||
@@ -341,8 +356,8 @@ class StateGroupStorage:
|
||||
state_group: The state group used to retrieve state deltas.
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[int], Optional[StateMap[str]]]:
|
||||
(prev_group, delta_ids)
|
||||
A tuple of the previous group and a state map of the event IDs which
|
||||
make up the delta between the old and new state groups.
|
||||
"""
|
||||
|
||||
return await self.stores.state.get_state_group_delta(state_group)
|
||||
@@ -436,7 +451,7 @@ class StateGroupStorage:
|
||||
|
||||
async def get_state_for_events(
|
||||
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
|
||||
):
|
||||
) -> Dict[str, StateMap[EventBase]]:
|
||||
"""Given a list of event_ids and type tuples, return a list of state
|
||||
dicts for each event.
|
||||
|
||||
@@ -472,7 +487,7 @@ class StateGroupStorage:
|
||||
|
||||
async def get_state_ids_for_events(
|
||||
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
|
||||
):
|
||||
) -> Dict[str, StateMap[str]]:
|
||||
"""
|
||||
Get the state dicts corresponding to a list of events, containing the event_ids
|
||||
of the state events (as opposed to the events themselves)
|
||||
@@ -500,7 +515,7 @@ class StateGroupStorage:
|
||||
|
||||
async def get_state_for_event(
|
||||
self, event_id: str, state_filter: StateFilter = StateFilter.all()
|
||||
):
|
||||
) -> StateMap[EventBase]:
|
||||
"""
|
||||
Get the state dict corresponding to a particular event
|
||||
|
||||
@@ -516,7 +531,7 @@ class StateGroupStorage:
|
||||
|
||||
async def get_state_ids_for_event(
|
||||
self, event_id: str, state_filter: StateFilter = StateFilter.all()
|
||||
):
|
||||
) -> StateMap[str]:
|
||||
"""
|
||||
Get the state dict corresponding to a particular event
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||
return val
|
||||
|
||||
def test_verify_json_objects_for_server_awaits_previous_requests(self):
|
||||
mock_fetcher = keyring.KeyFetcher()
|
||||
mock_fetcher = Mock()
|
||||
mock_fetcher.get_keys = Mock()
|
||||
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
|
||||
|
||||
@@ -195,7 +195,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||
"""Tests that we correctly handle key requests for keys we've stored
|
||||
with a null `ts_valid_until_ms`
|
||||
"""
|
||||
mock_fetcher = keyring.KeyFetcher()
|
||||
mock_fetcher = Mock()
|
||||
mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
|
||||
|
||||
kr = keyring.Keyring(
|
||||
@@ -249,7 +249,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||
}
|
||||
}
|
||||
|
||||
mock_fetcher = keyring.KeyFetcher()
|
||||
mock_fetcher = Mock()
|
||||
mock_fetcher.get_keys = Mock(side_effect=get_keys)
|
||||
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
|
||||
|
||||
@@ -288,9 +288,9 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||
}
|
||||
}
|
||||
|
||||
mock_fetcher1 = keyring.KeyFetcher()
|
||||
mock_fetcher1 = Mock()
|
||||
mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
|
||||
mock_fetcher2 = keyring.KeyFetcher()
|
||||
mock_fetcher2 = Mock()
|
||||
mock_fetcher2.get_keys = Mock(side_effect=get_keys2)
|
||||
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2))
|
||||
|
||||
|
||||
0
tests/federation/transport/__init__.py
Normal file
0
tests/federation/transport/__init__.py
Normal file
@@ -1,5 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -13,34 +13,31 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.config.ratelimiting import FederationRateLimitConfig
|
||||
from synapse.federation.transport import server
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
|
||||
from tests import unittest
|
||||
from tests.unittest import override_config
|
||||
|
||||
|
||||
class RoomDirectoryFederationTests(unittest.HomeserverTestCase):
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
class Authenticator:
|
||||
def authenticate_request(self, request, content):
|
||||
return defer.succeed("otherserver.nottld")
|
||||
|
||||
ratelimiter = FederationRateLimiter(clock, FederationRateLimitConfig())
|
||||
server.register_servlets(
|
||||
homeserver, self.resource, Authenticator(), ratelimiter
|
||||
)
|
||||
|
||||
class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
|
||||
@override_config({"allow_public_rooms_over_federation": False})
|
||||
def test_blocked_public_room_list_over_federation(self):
|
||||
channel = self.make_request("GET", "/_matrix/federation/v1/publicRooms")
|
||||
"""Test that unauthenticated requests to the public rooms directory 403 when
|
||||
allow_public_rooms_over_federation is False.
|
||||
"""
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/federation/v1/publicRooms",
|
||||
federation_auth_origin=b"example.com",
|
||||
)
|
||||
self.assertEquals(403, channel.code)
|
||||
|
||||
@override_config({"allow_public_rooms_over_federation": True})
|
||||
def test_open_public_room_list_over_federation(self):
|
||||
channel = self.make_request("GET", "/_matrix/federation/v1/publicRooms")
|
||||
"""Test that unauthenticated requests to the public rooms directory 200 when
|
||||
allow_public_rooms_over_federation is True.
|
||||
"""
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/federation/v1/publicRooms",
|
||||
federation_auth_origin=b"example.com",
|
||||
)
|
||||
self.assertEquals(200, channel.code)
|
||||
|
||||
121
tests/handlers/test_cas.py
Normal file
121
tests/handlers/test_cas.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from mock import Mock
|
||||
|
||||
from synapse.handlers.cas_handler import CasResponse
|
||||
|
||||
from tests.test_utils import simple_async_mock
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
# These are a few constants that are used as config parameters in the tests.
|
||||
BASE_URL = "https://synapse/"
|
||||
SERVER_URL = "https://issuer/"
|
||||
|
||||
|
||||
class CasHandlerTestCase(HomeserverTestCase):
|
||||
def default_config(self):
|
||||
config = super().default_config()
|
||||
config["public_baseurl"] = BASE_URL
|
||||
cas_config = {
|
||||
"enabled": True,
|
||||
"server_url": SERVER_URL,
|
||||
"service_url": BASE_URL,
|
||||
}
|
||||
config["cas_config"] = cas_config
|
||||
|
||||
return config
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
hs = self.setup_test_homeserver()
|
||||
|
||||
self.handler = hs.get_cas_handler()
|
||||
|
||||
# Reduce the number of attempts when generating MXIDs.
|
||||
sso_handler = hs.get_sso_handler()
|
||||
sso_handler._MAP_USERNAME_RETRIES = 3
|
||||
|
||||
return hs
|
||||
|
||||
def test_map_cas_user_to_user(self):
|
||||
"""Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
|
||||
|
||||
# stub out the auth handler
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
|
||||
cas_response = CasResponse("test_user", {})
|
||||
request = _mock_request()
|
||||
self.get_success(
|
||||
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
|
||||
)
|
||||
|
||||
# check that the auth handler got called as expected
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@test_user:test", request, "redirect_uri", None
|
||||
)
|
||||
|
||||
def test_map_cas_user_to_existing_user(self):
|
||||
"""Existing users can log in with CAS account."""
|
||||
store = self.hs.get_datastore()
|
||||
self.get_success(
|
||||
store.register_user(user_id="@test_user:test", password_hash=None)
|
||||
)
|
||||
|
||||
# stub out the auth handler
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
|
||||
# Map a user via SSO.
|
||||
cas_response = CasResponse("test_user", {})
|
||||
request = _mock_request()
|
||||
self.get_success(
|
||||
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
|
||||
)
|
||||
|
||||
# check that the auth handler got called as expected
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@test_user:test", request, "redirect_uri", None
|
||||
)
|
||||
|
||||
# Subsequent calls should map to the same mxid.
|
||||
auth_handler.complete_sso_login.reset_mock()
|
||||
self.get_success(
|
||||
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
|
||||
)
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@test_user:test", request, "redirect_uri", None
|
||||
)
|
||||
|
||||
def test_map_cas_user_to_invalid_localpart(self):
|
||||
"""CAS automaps invalid characters to base-64 encoding."""
|
||||
|
||||
# stub out the auth handler
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
auth_handler.complete_sso_login = simple_async_mock()
|
||||
|
||||
cas_response = CasResponse("föö", {})
|
||||
request = _mock_request()
|
||||
self.get_success(
|
||||
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
|
||||
)
|
||||
|
||||
# check that the auth handler got called as expected
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@f=c3=b6=c3=b6:test", request, "redirect_uri", None
|
||||
)
|
||||
|
||||
|
||||
def _mock_request():
|
||||
"""Returns a mock which will stand in as a SynapseRequest"""
|
||||
return Mock(spec=["getClientIP", "get_user_agent"])
|
||||
Reference in New Issue
Block a user