Merge commit '22db45bd4' into anoa/dinsic_release_1_31_0
This commit is contained in:
@@ -1,9 +1,18 @@
|
||||
Removal warning
|
||||
---------------
|
||||
|
||||
Note that this release deprecates the ability for appservices to call `POST /_matrix/client/r0/register` without the body parameter `type`. Appservice developers should use a `type` value of `m.login.application_service` as per the spec. In future releases, calling this endpoint with an access token but
|
||||
without a valid type will fail.
|
||||
|
||||
Synapse 1.29.0 (2021-03-08)
|
||||
===========================
|
||||
|
||||
Note that synapse now expects an `X-Forwarded-Proto` header when used with a reverse proxy. Please see [UPGRADE.rst](UPGRADE.rst#upgrading-to-v1290) for more details on this change.
|
||||
|
||||
|
||||
No significant changes.
|
||||
|
||||
|
||||
Synapse 1.29.0rc1 (2021-03-04)
|
||||
==============================
|
||||
|
||||
|
||||
+3
-2
@@ -183,8 +183,9 @@ Using a reverse proxy with Synapse
|
||||
It is recommended to put a reverse proxy such as
|
||||
`nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_,
|
||||
`Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_,
|
||||
`Caddy <https://caddyserver.com/docs/quick-starts/reverse-proxy>`_ or
|
||||
`HAProxy <https://www.haproxy.org/>`_ in front of Synapse. One advantage of
|
||||
`Caddy <https://caddyserver.com/docs/quick-starts/reverse-proxy>`_,
|
||||
`HAProxy <https://www.haproxy.org/>`_ or
|
||||
`relayd <https://man.openbsd.org/relayd.8>`_ in front of Synapse. One advantage of
|
||||
doing so is that it means that you can expose the default https port (443) to
|
||||
Matrix clients without needing to run Synapse with root privileges.
|
||||
|
||||
|
||||
@@ -124,6 +124,13 @@ This version changes the URI used for callbacks from OAuth2 and SAML2 identity p
|
||||
need to add ``[synapse public baseurl]/_synapse/client/saml2/authn_response`` as a permitted
|
||||
"ACS location" (also known as "allowed callback URLs") at the identity provider.
|
||||
|
||||
The "Issuer" in the "AuthnRequest" to the SAML2 identity provider is also updated to
|
||||
``[synapse public baseurl]/_synapse/client/saml2/metadata.xml``. If your SAML2 identity
|
||||
provider uses this property to validate or otherwise identify Synapse, its configuration
|
||||
will need to be updated to use the new URL. Alternatively you could create a new, separate
|
||||
"EntityDescriptor" in your SAML2 identity provider with the new URLs and leave the URLs in
|
||||
the existing "EntityDescriptor" as they were.
|
||||
|
||||
Changes to HTML templates
|
||||
-------------------------
|
||||
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
Add tests to ResponseCache.
|
||||
@@ -0,0 +1 @@
|
||||
Add relayd entry to reverse proxy example configurations.
|
||||
@@ -0,0 +1 @@
|
||||
Add prometheus metrics for number of users successfully registering and logging in.
|
||||
@@ -0,0 +1 @@
|
||||
Add prometheus metrics for number of users successfully registering and logging in.
|
||||
@@ -0,0 +1 @@
|
||||
Add type hints to purge room and server notice admin API.
|
||||
@@ -0,0 +1 @@
|
||||
Add extra logging to ObservableDeferred when callbacks throw exceptions.
|
||||
@@ -0,0 +1 @@
|
||||
Fix incorrect type hints.
|
||||
@@ -0,0 +1 @@
|
||||
Add `synapse_federation_last_sent_pdu_time` and `synapse_federation_last_received_pdu_time` prometheus metrics, which monitor federation delays by reporting the timestamps of messages sent and received to a set of remote servers.
|
||||
@@ -0,0 +1 @@
|
||||
The `synapse_federation_last_sent_pdu_age` and `synapse_federation_last_received_pdu_age` prometheus metrics have been removed. They are replaced by `synapse_federation_last_sent_pdu_time` and `synapse_federation_last_received_pdu_time`.
|
||||
@@ -0,0 +1 @@
|
||||
Add an additional test for purging a room.
|
||||
@@ -0,0 +1 @@
|
||||
Improve the SAML2 upgrade notes for 1.27.0.
|
||||
@@ -0,0 +1 @@
|
||||
Registering an Application Service user without using the `m.login.application_service` login type will be unsupported in an upcoming Synapse release.
|
||||
@@ -0,0 +1 @@
|
||||
Fix spurious errors reported by the `config-lint.sh` script.
|
||||
+49
-2
@@ -3,8 +3,9 @@
|
||||
It is recommended to put a reverse proxy such as
|
||||
[nginx](https://nginx.org/en/docs/http/ngx_http_proxy_module.html),
|
||||
[Apache](https://httpd.apache.org/docs/current/mod/mod_proxy_http.html),
|
||||
[Caddy](https://caddyserver.com/docs/quick-starts/reverse-proxy) or
|
||||
[HAProxy](https://www.haproxy.org/) in front of Synapse. One advantage
|
||||
[Caddy](https://caddyserver.com/docs/quick-starts/reverse-proxy),
|
||||
[HAProxy](https://www.haproxy.org/) or
|
||||
[relayd](https://man.openbsd.org/relayd.8) in front of Synapse. One advantage
|
||||
of doing so is that it means that you can expose the default https port
|
||||
(443) to Matrix clients without needing to run Synapse with root
|
||||
privileges.
|
||||
@@ -162,6 +163,52 @@ backend matrix
|
||||
server matrix 127.0.0.1:8008
|
||||
```
|
||||
|
||||
### Relayd
|
||||
|
||||
```
|
||||
table <webserver> { 127.0.0.1 }
|
||||
table <matrixserver> { 127.0.0.1 }
|
||||
|
||||
http protocol "https" {
|
||||
tls { no tlsv1.0, ciphers "HIGH" }
|
||||
tls keypair "example.com"
|
||||
match header set "X-Forwarded-For" value "$REMOTE_ADDR"
|
||||
match header set "X-Forwarded-Proto" value "https"
|
||||
|
||||
# set CORS header for .well-known/matrix/server, .well-known/matrix/client
|
||||
# httpd does not support setting headers, so do it here
|
||||
match request path "/.well-known/matrix/*" tag "matrix-cors"
|
||||
match response tagged "matrix-cors" header set "Access-Control-Allow-Origin" value "*"
|
||||
|
||||
pass quick path "/_matrix/*" forward to <matrixserver>
|
||||
pass quick path "/_synapse/client/*" forward to <matrixserver>
|
||||
|
||||
# pass on non-matrix traffic to webserver
|
||||
pass forward to <webserver>
|
||||
}
|
||||
|
||||
relay "https_traffic" {
|
||||
listen on egress port 443 tls
|
||||
protocol "https"
|
||||
forward to <matrixserver> port 8008 check tcp
|
||||
forward to <webserver> port 8080 check tcp
|
||||
}
|
||||
|
||||
http protocol "matrix" {
|
||||
tls { no tlsv1.0, ciphers "HIGH" }
|
||||
tls keypair "example.com"
|
||||
block
|
||||
pass quick path "/_matrix/*" forward to <matrixserver>
|
||||
pass quick path "/_synapse/client/*" forward to <matrixserver>
|
||||
}
|
||||
|
||||
relay "matrix_federation" {
|
||||
listen on egress port 8448 tls
|
||||
protocol "matrix"
|
||||
forward to <matrixserver> port 8008 check tcp
|
||||
}
|
||||
```
|
||||
|
||||
## Homeserver Configuration
|
||||
|
||||
You will also want to set `bind_addresses: ['127.0.0.1']` and
|
||||
|
||||
@@ -69,6 +69,7 @@ files =
|
||||
synapse/util/async_helpers.py,
|
||||
synapse/util/caches,
|
||||
synapse/util/metrics.py,
|
||||
synapse/util/macaroons.py,
|
||||
synapse/util/stringutils.py,
|
||||
tests/replication,
|
||||
tests/test_utils,
|
||||
|
||||
@@ -2,9 +2,14 @@
|
||||
# Find linting errors in Synapse's default config file.
|
||||
# Exits with 0 if there are no problems, or another code otherwise.
|
||||
|
||||
# cd to the root of the repository
|
||||
cd `dirname $0`/..
|
||||
|
||||
# Restore backup of sample config upon script exit
|
||||
trap "mv docs/sample_config.yaml.bak docs/sample_config.yaml" EXIT
|
||||
|
||||
# Fix non-lowercase true/false values
|
||||
sed -i.bak -E "s/: +True/: true/g; s/: +False/: false/g;" docs/sample_config.yaml
|
||||
rm docs/sample_config.yaml.bak
|
||||
|
||||
# Check if anything changed
|
||||
git diff --exit-code docs/sample_config.yaml
|
||||
diff docs/sample_config.yaml docs/sample_config.yaml.bak
|
||||
|
||||
+9
-32
@@ -39,6 +39,7 @@ from synapse.logging import opentracing as opentracing
|
||||
from synapse.storage.databases.main.registration import TokenLookupResult
|
||||
from synapse.types import StateMap, UserID
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -413,7 +414,7 @@ class Auth:
|
||||
raise _InvalidMacaroonException()
|
||||
|
||||
try:
|
||||
user_id = self.get_user_id_from_macaroon(macaroon)
|
||||
user_id = get_value_from_macaroon(macaroon, "user_id")
|
||||
|
||||
guest = False
|
||||
for caveat in macaroon.caveats:
|
||||
@@ -421,7 +422,12 @@ class Auth:
|
||||
guest = True
|
||||
|
||||
self.validate_macaroon(macaroon, rights, user_id=user_id)
|
||||
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
|
||||
except (
|
||||
pymacaroons.exceptions.MacaroonException,
|
||||
KeyError,
|
||||
TypeError,
|
||||
ValueError,
|
||||
):
|
||||
raise InvalidClientTokenError("Invalid macaroon passed.")
|
||||
|
||||
if rights == "access":
|
||||
@@ -429,27 +435,6 @@ class Auth:
|
||||
|
||||
return user_id, guest
|
||||
|
||||
def get_user_id_from_macaroon(self, macaroon):
|
||||
"""Retrieve the user_id given by the caveats on the macaroon.
|
||||
|
||||
Does *not* validate the macaroon.
|
||||
|
||||
Args:
|
||||
macaroon (pymacaroons.Macaroon): The macaroon to validate
|
||||
|
||||
Returns:
|
||||
(str) user id
|
||||
|
||||
Raises:
|
||||
InvalidClientCredentialsError if there is no user_id caveat in the
|
||||
macaroon
|
||||
"""
|
||||
user_prefix = "user_id = "
|
||||
for caveat in macaroon.caveats:
|
||||
if caveat.caveat_id.startswith(user_prefix):
|
||||
return caveat.caveat_id[len(user_prefix) :]
|
||||
raise InvalidClientTokenError("No user caveat in macaroon")
|
||||
|
||||
def validate_macaroon(self, macaroon, type_string, user_id):
|
||||
"""
|
||||
validate that a Macaroon is understood by and was signed by this server.
|
||||
@@ -470,21 +455,13 @@ class Auth:
|
||||
v.satisfy_exact("type = " + type_string)
|
||||
v.satisfy_exact("user_id = %s" % user_id)
|
||||
v.satisfy_exact("guest = true")
|
||||
v.satisfy_general(self._verify_expiry)
|
||||
satisfy_expiry(v, self.clock.time_msec)
|
||||
|
||||
# access_tokens include a nonce for uniqueness: any value is acceptable
|
||||
v.satisfy_general(lambda c: c.startswith("nonce = "))
|
||||
|
||||
v.verify(macaroon, self._macaroon_secret_key)
|
||||
|
||||
def _verify_expiry(self, caveat):
|
||||
prefix = "time < "
|
||||
if not caveat.startswith(prefix):
|
||||
return False
|
||||
expiry = int(caveat[len(prefix) :])
|
||||
now = self.hs.get_clock().time_msec()
|
||||
return now < expiry
|
||||
|
||||
def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
|
||||
token = self.get_access_token_from_request(request)
|
||||
service = self.store.get_app_service_by_token(token)
|
||||
|
||||
@@ -90,7 +90,7 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
self.protocol_meta_cache = ResponseCache(
|
||||
hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
|
||||
hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
|
||||
) # type: ResponseCache[Tuple[str, str]]
|
||||
|
||||
async def query_user(self, service, user_id):
|
||||
|
||||
@@ -847,8 +847,7 @@ class ServerConfig(Config):
|
||||
# Whether to require authentication to retrieve profile data (avatars,
|
||||
# display names) of other users through the client API. Defaults to
|
||||
# 'false'. Note that profile data is also available via the federation
|
||||
# API, so this setting is of limited value if federation is enabled on
|
||||
# the server.
|
||||
# API, unless allow_profile_lookup_over_federation is set to false.
|
||||
#
|
||||
#require_auth_for_profile_requests: true
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ from typing import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
@@ -91,16 +92,15 @@ pdu_process_time = Histogram(
|
||||
"Time taken to process an event",
|
||||
)
|
||||
|
||||
|
||||
last_pdu_age_metric = Gauge(
|
||||
"synapse_federation_last_received_pdu_age",
|
||||
"The age (in seconds) of the last PDU successfully received from the given domain",
|
||||
last_pdu_ts_metric = Gauge(
|
||||
"synapse_federation_last_received_pdu_time",
|
||||
"The timestamp of the last PDU which was successfully received from the given domain",
|
||||
labelnames=("server_name",),
|
||||
)
|
||||
|
||||
|
||||
class FederationServer(FederationBase):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
@@ -120,7 +120,7 @@ class FederationServer(FederationBase):
|
||||
|
||||
# We cache results for transaction with the same ID
|
||||
self._transaction_resp_cache = ResponseCache(
|
||||
hs, "fed_txn_handler", timeout_ms=30000
|
||||
hs.get_clock(), "fed_txn_handler", timeout_ms=30000
|
||||
) # type: ResponseCache[Tuple[str, str]]
|
||||
|
||||
self.transaction_actions = TransactionActions(self.store)
|
||||
@@ -130,10 +130,10 @@ class FederationServer(FederationBase):
|
||||
# We cache responses to state queries, as they take a while and often
|
||||
# come in waves.
|
||||
self._state_resp_cache = ResponseCache(
|
||||
hs, "state_resp", timeout_ms=30000
|
||||
hs.get_clock(), "state_resp", timeout_ms=30000
|
||||
) # type: ResponseCache[Tuple[str, str]]
|
||||
self._state_ids_resp_cache = ResponseCache(
|
||||
hs, "state_ids_resp", timeout_ms=30000
|
||||
hs.get_clock(), "state_ids_resp", timeout_ms=30000
|
||||
) # type: ResponseCache[Tuple[str, str]]
|
||||
|
||||
self._federation_metrics_domains = (
|
||||
@@ -370,8 +370,7 @@ class FederationServer(FederationBase):
|
||||
)
|
||||
|
||||
if newest_pdu_ts and origin in self._federation_metrics_domains:
|
||||
newest_pdu_age = self._clock.time_msec() - newest_pdu_ts
|
||||
last_pdu_age_metric.labels(server_name=origin).set(newest_pdu_age / 1000)
|
||||
last_pdu_ts_metric.labels(server_name=origin).set(newest_pdu_ts / 1000)
|
||||
|
||||
return pdu_results
|
||||
|
||||
@@ -456,7 +455,9 @@ class FederationServer(FederationBase):
|
||||
self, room_id: str, event_id: str
|
||||
) -> Dict[str, list]:
|
||||
if event_id:
|
||||
pdus = await self.handler.get_state_for_pdu(room_id, event_id)
|
||||
pdus = await self.handler.get_state_for_pdu(
|
||||
room_id, event_id
|
||||
) # type: Iterable[EventBase]
|
||||
else:
|
||||
pdus = (await self.state.get_current_state(room_id)).values()
|
||||
|
||||
|
||||
@@ -36,9 +36,9 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
last_pdu_age_metric = Gauge(
|
||||
"synapse_federation_last_sent_pdu_age",
|
||||
"The age (in seconds) of the last PDU successfully sent to the given domain",
|
||||
last_pdu_ts_metric = Gauge(
|
||||
"synapse_federation_last_sent_pdu_time",
|
||||
"The timestamp of the last PDU which was successfully sent to the given domain",
|
||||
labelnames=("server_name",),
|
||||
)
|
||||
|
||||
@@ -187,9 +187,8 @@ class TransactionManager:
|
||||
|
||||
if success and pdus and destination in self._federation_metrics_domains:
|
||||
last_pdu = pdus[-1]
|
||||
last_pdu_age = self.clock.time_msec() - last_pdu.origin_server_ts
|
||||
last_pdu_age_metric.labels(server_name=destination).set(
|
||||
last_pdu_age / 1000
|
||||
last_pdu_ts_metric.labels(server_name=destination).set(
|
||||
last_pdu.origin_server_ts / 1000
|
||||
)
|
||||
|
||||
set_tag(tags.ERROR, not success)
|
||||
|
||||
@@ -73,7 +73,9 @@ class AcmeHandler:
|
||||
"Listening for ACME requests on %s:%i", host, self.hs.config.acme_port
|
||||
)
|
||||
try:
|
||||
self.reactor.listenTCP(self.hs.config.acme_port, srv, interface=host)
|
||||
self.reactor.listenTCP(
|
||||
self.hs.config.acme_port, srv, backlog=50, interface=host
|
||||
)
|
||||
except twisted.internet.error.CannotListenError as e:
|
||||
check_bind_error(e, host, bind_addresses)
|
||||
|
||||
|
||||
+58
-10
@@ -65,6 +65,7 @@ from synapse.storage.roommember import ProfileInfo
|
||||
from synapse.types import JsonDict, Requester, UserID
|
||||
from synapse.util import stringutils as stringutils
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
from synapse.util.threepids import canonicalise_email
|
||||
|
||||
@@ -170,6 +171,16 @@ class SsoLoginExtraAttributes:
|
||||
extra_attributes = attr.ib(type=JsonDict)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class LoginTokenAttributes:
|
||||
"""Data we store in a short-term login token"""
|
||||
|
||||
user_id = attr.ib(type=str)
|
||||
|
||||
# the SSO Identity Provider that the user authenticated with, to get this token
|
||||
auth_provider_id = attr.ib(type=str)
|
||||
|
||||
|
||||
class AuthHandler(BaseHandler):
|
||||
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
|
||||
|
||||
@@ -1164,18 +1175,16 @@ class AuthHandler(BaseHandler):
|
||||
return None
|
||||
return user_id
|
||||
|
||||
async def validate_short_term_login_token_and_get_user_id(self, login_token: str):
|
||||
auth_api = self.hs.get_auth()
|
||||
user_id = None
|
||||
async def validate_short_term_login_token(
|
||||
self, login_token: str
|
||||
) -> LoginTokenAttributes:
|
||||
try:
|
||||
macaroon = pymacaroons.Macaroon.deserialize(login_token)
|
||||
user_id = auth_api.get_user_id_from_macaroon(macaroon)
|
||||
auth_api.validate_macaroon(macaroon, "login", user_id)
|
||||
res = self.macaroon_gen.verify_short_term_login_token(login_token)
|
||||
except Exception:
|
||||
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
|
||||
|
||||
await self.auth.check_auth_blocking(user_id)
|
||||
return user_id
|
||||
await self.auth.check_auth_blocking(res.user_id)
|
||||
return res
|
||||
|
||||
async def delete_access_token(self, access_token: str):
|
||||
"""Invalidate a single access token
|
||||
@@ -1397,6 +1406,7 @@ class AuthHandler(BaseHandler):
|
||||
async def complete_sso_login(
|
||||
self,
|
||||
registered_user_id: str,
|
||||
auth_provider_id: str,
|
||||
request: Request,
|
||||
client_redirect_url: str,
|
||||
extra_attributes: Optional[JsonDict] = None,
|
||||
@@ -1406,6 +1416,9 @@ class AuthHandler(BaseHandler):
|
||||
|
||||
Args:
|
||||
registered_user_id: The registered user ID to complete SSO login for.
|
||||
auth_provider_id: The id of the SSO Identity provider that was used for
|
||||
login. This will be stored in the login token for future tracking in
|
||||
prometheus metrics.
|
||||
request: The request to complete.
|
||||
client_redirect_url: The URL to which to redirect the user at the end of the
|
||||
process.
|
||||
@@ -1427,6 +1440,7 @@ class AuthHandler(BaseHandler):
|
||||
|
||||
self._complete_sso_login(
|
||||
registered_user_id,
|
||||
auth_provider_id,
|
||||
request,
|
||||
client_redirect_url,
|
||||
extra_attributes,
|
||||
@@ -1437,6 +1451,7 @@ class AuthHandler(BaseHandler):
|
||||
def _complete_sso_login(
|
||||
self,
|
||||
registered_user_id: str,
|
||||
auth_provider_id: str,
|
||||
request: Request,
|
||||
client_redirect_url: str,
|
||||
extra_attributes: Optional[JsonDict] = None,
|
||||
@@ -1463,7 +1478,7 @@ class AuthHandler(BaseHandler):
|
||||
|
||||
# Create a login token
|
||||
login_token = self.macaroon_gen.generate_short_term_login_token(
|
||||
registered_user_id
|
||||
registered_user_id, auth_provider_id=auth_provider_id
|
||||
)
|
||||
|
||||
# Append the login token to the original redirect URL (i.e. with its query
|
||||
@@ -1569,15 +1584,48 @@ class MacaroonGenerator:
|
||||
return macaroon.serialize()
|
||||
|
||||
def generate_short_term_login_token(
|
||||
self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
|
||||
self,
|
||||
user_id: str,
|
||||
auth_provider_id: str,
|
||||
duration_in_ms: int = (2 * 60 * 1000),
|
||||
) -> str:
|
||||
macaroon = self._generate_base_macaroon(user_id)
|
||||
macaroon.add_first_party_caveat("type = login")
|
||||
now = self.hs.get_clock().time_msec()
|
||||
expiry = now + duration_in_ms
|
||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||
macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
|
||||
return macaroon.serialize()
|
||||
|
||||
def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
|
||||
"""Verify a short-term-login macaroon
|
||||
|
||||
Checks that the given token is a valid, unexpired short-term-login token
|
||||
minted by this server.
|
||||
|
||||
Args:
|
||||
token: the login token to verify
|
||||
|
||||
Returns:
|
||||
the user_id that this token is valid for
|
||||
|
||||
Raises:
|
||||
MacaroonVerificationFailedException if the verification failed
|
||||
"""
|
||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||
user_id = get_value_from_macaroon(macaroon, "user_id")
|
||||
auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
|
||||
|
||||
v = pymacaroons.Verifier()
|
||||
v.satisfy_exact("gen = 1")
|
||||
v.satisfy_exact("type = login")
|
||||
v.satisfy_general(lambda c: c.startswith("user_id = "))
|
||||
v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
|
||||
satisfy_expiry(v, self.hs.get_clock().time_msec)
|
||||
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
|
||||
|
||||
return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
|
||||
|
||||
def generate_delete_pusher_token(self, user_id: str) -> str:
|
||||
macaroon = self._generate_base_macaroon(user_id)
|
||||
macaroon.add_first_party_caveat("type = delete_pusher")
|
||||
|
||||
@@ -48,7 +48,7 @@ class InitialSyncHandler(BaseHandler):
|
||||
self.clock = hs.get_clock()
|
||||
self.validator = EventValidator()
|
||||
self.snapshot_cache = ResponseCache(
|
||||
hs, "initial_sync_cache"
|
||||
hs.get_clock(), "initial_sync_cache"
|
||||
) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]]
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self.storage = hs.get_storage()
|
||||
|
||||
@@ -42,6 +42,7 @@ from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
|
||||
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@@ -211,7 +212,7 @@ class OidcHandler:
|
||||
session_data = self._token_generator.verify_oidc_session_token(
|
||||
session, state
|
||||
)
|
||||
except (MacaroonDeserializationException, ValueError) as e:
|
||||
except (MacaroonDeserializationException, KeyError) as e:
|
||||
logger.exception("Invalid session for OIDC callback")
|
||||
self._sso_handler.render_error(request, "invalid_session", str(e))
|
||||
return
|
||||
@@ -745,7 +746,7 @@ class OidcProvider:
|
||||
idp_id=self.idp_id,
|
||||
nonce=nonce,
|
||||
client_redirect_url=client_redirect_url.decode(),
|
||||
ui_auth_session_id=ui_auth_session_id,
|
||||
ui_auth_session_id=ui_auth_session_id or "",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1020,10 +1021,9 @@ class OidcSessionTokenGenerator:
|
||||
macaroon.add_first_party_caveat(
|
||||
"client_redirect_url = %s" % (session_data.client_redirect_url,)
|
||||
)
|
||||
if session_data.ui_auth_session_id:
|
||||
macaroon.add_first_party_caveat(
|
||||
"ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
|
||||
)
|
||||
macaroon.add_first_party_caveat(
|
||||
"ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
|
||||
)
|
||||
now = self._clock.time_msec()
|
||||
expiry = now + duration_in_ms
|
||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||
@@ -1046,7 +1046,7 @@ class OidcSessionTokenGenerator:
|
||||
The data extracted from the session cookie
|
||||
|
||||
Raises:
|
||||
ValueError if an expected caveat is missing from the macaroon.
|
||||
KeyError if an expected caveat is missing from the macaroon.
|
||||
"""
|
||||
macaroon = pymacaroons.Macaroon.deserialize(session)
|
||||
|
||||
@@ -1057,26 +1057,16 @@ class OidcSessionTokenGenerator:
|
||||
v.satisfy_general(lambda c: c.startswith("nonce = "))
|
||||
v.satisfy_general(lambda c: c.startswith("idp_id = "))
|
||||
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
|
||||
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
|
||||
# to always satisfy this.
|
||||
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
|
||||
v.satisfy_general(self._verify_expiry)
|
||||
satisfy_expiry(v, self._clock.time_msec)
|
||||
|
||||
v.verify(macaroon, self._macaroon_secret_key)
|
||||
|
||||
# Extract the session data from the token.
|
||||
nonce = self._get_value_from_macaroon(macaroon, "nonce")
|
||||
idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
|
||||
client_redirect_url = self._get_value_from_macaroon(
|
||||
macaroon, "client_redirect_url"
|
||||
)
|
||||
try:
|
||||
ui_auth_session_id = self._get_value_from_macaroon(
|
||||
macaroon, "ui_auth_session_id"
|
||||
) # type: Optional[str]
|
||||
except ValueError:
|
||||
ui_auth_session_id = None
|
||||
|
||||
nonce = get_value_from_macaroon(macaroon, "nonce")
|
||||
idp_id = get_value_from_macaroon(macaroon, "idp_id")
|
||||
client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url")
|
||||
ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id")
|
||||
return OidcSessionData(
|
||||
nonce=nonce,
|
||||
idp_id=idp_id,
|
||||
@@ -1084,33 +1074,6 @@ class OidcSessionTokenGenerator:
|
||||
ui_auth_session_id=ui_auth_session_id,
|
||||
)
|
||||
|
||||
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
|
||||
"""Extracts a caveat value from a macaroon token.
|
||||
|
||||
Args:
|
||||
macaroon: the token
|
||||
key: the key of the caveat to extract
|
||||
|
||||
Returns:
|
||||
The extracted value
|
||||
|
||||
Raises:
|
||||
ValueError: if the caveat was not in the macaroon
|
||||
"""
|
||||
prefix = key + " = "
|
||||
for caveat in macaroon.caveats:
|
||||
if caveat.caveat_id.startswith(prefix):
|
||||
return caveat.caveat_id[len(prefix) :]
|
||||
raise ValueError("No %s caveat in macaroon" % (key,))
|
||||
|
||||
def _verify_expiry(self, caveat: str) -> bool:
|
||||
prefix = "time < "
|
||||
if not caveat.startswith(prefix):
|
||||
return False
|
||||
expiry = int(caveat[len(prefix) :])
|
||||
now = self._clock.time_msec()
|
||||
return now < expiry
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True)
|
||||
class OidcSessionData:
|
||||
@@ -1125,8 +1088,8 @@ class OidcSessionData:
|
||||
# The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
|
||||
client_redirect_url = attr.ib(type=str)
|
||||
|
||||
# The session ID of the ongoing UI Auth (None if this is a login)
|
||||
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
|
||||
# The session ID of the ongoing UI Auth ("" if this is a login)
|
||||
ui_auth_session_id = attr.ib(type=str)
|
||||
|
||||
|
||||
UserAttributeDict = TypedDict(
|
||||
|
||||
@@ -18,6 +18,8 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
from synapse import types
|
||||
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
|
||||
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
|
||||
@@ -41,6 +43,19 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
registration_counter = Counter(
|
||||
"synapse_user_registrations_total",
|
||||
"Number of new users registered (since restart)",
|
||||
["guest", "shadow_banned", "auth_provider"],
|
||||
)
|
||||
|
||||
login_counter = Counter(
|
||||
"synapse_user_logins_total",
|
||||
"Number of user logins (since restart)",
|
||||
["guest", "auth_provider"],
|
||||
)
|
||||
|
||||
|
||||
class RegistrationHandler(BaseHandler):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
@@ -171,6 +186,7 @@ class RegistrationHandler(BaseHandler):
|
||||
bind_emails: Iterable[str] = [],
|
||||
by_admin: bool = False,
|
||||
user_agent_ips: Optional[List[Tuple[str, str]]] = None,
|
||||
auth_provider_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Registers a new client on the server.
|
||||
|
||||
@@ -196,8 +212,10 @@ class RegistrationHandler(BaseHandler):
|
||||
admin api, otherwise False.
|
||||
user_agent_ips: Tuples of IP addresses and user-agents used
|
||||
during the registration process.
|
||||
auth_provider_id: The SSO IdP the user used, if any (just used for the
|
||||
prometheus metrics).
|
||||
Returns:
|
||||
The registere user_id.
|
||||
The registered user_id.
|
||||
Raises:
|
||||
SynapseError if there was a problem registering.
|
||||
"""
|
||||
@@ -304,6 +322,12 @@ class RegistrationHandler(BaseHandler):
|
||||
# if user id is taken, just generate another
|
||||
fail_count += 1
|
||||
|
||||
registration_counter.labels(
|
||||
guest=make_guest,
|
||||
shadow_banned=shadow_banned,
|
||||
auth_provider=(auth_provider_id or ""),
|
||||
).inc()
|
||||
|
||||
if not self.hs.config.user_consent_at_registration:
|
||||
if not self.hs.config.auto_join_rooms_for_guests and make_guest:
|
||||
logger.info(
|
||||
@@ -718,6 +742,7 @@ class RegistrationHandler(BaseHandler):
|
||||
initial_display_name: Optional[str],
|
||||
is_guest: bool = False,
|
||||
is_appservice_ghost: bool = False,
|
||||
auth_provider_id: Optional[str] = None,
|
||||
) -> Tuple[str, str]:
|
||||
"""Register a device for a user and generate an access token.
|
||||
|
||||
@@ -728,7 +753,8 @@ class RegistrationHandler(BaseHandler):
|
||||
device_id: The device ID to check, or None to generate a new one.
|
||||
initial_display_name: An optional display name for the device.
|
||||
is_guest: Whether this is a guest account
|
||||
|
||||
auth_provider_id: The SSO IdP the user used, if any (just used for the
|
||||
prometheus metrics).
|
||||
Returns:
|
||||
Tuple of device ID and access token
|
||||
"""
|
||||
@@ -767,6 +793,11 @@ class RegistrationHandler(BaseHandler):
|
||||
is_appservice_ghost=is_appservice_ghost,
|
||||
)
|
||||
|
||||
login_counter.labels(
|
||||
guest=is_guest,
|
||||
auth_provider=(auth_provider_id or ""),
|
||||
).inc()
|
||||
|
||||
return (registered_device_id, access_token)
|
||||
|
||||
async def post_registration_actions(
|
||||
|
||||
@@ -121,7 +121,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
# succession, only process the first attempt and return its result to
|
||||
# subsequent requests
|
||||
self._upgrade_response_cache = ResponseCache(
|
||||
hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
|
||||
hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
|
||||
) # type: ResponseCache[Tuple[str, str]]
|
||||
self._server_notices_mxid = hs.config.server_notices_mxid
|
||||
|
||||
|
||||
@@ -44,10 +44,10 @@ class RoomListHandler(BaseHandler):
|
||||
super().__init__(hs)
|
||||
self.enable_room_list_search = hs.config.enable_room_list_search
|
||||
self.response_cache = ResponseCache(
|
||||
hs, "room_list"
|
||||
hs.get_clock(), "room_list"
|
||||
) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]]
|
||||
self.remote_response_cache = ResponseCache(
|
||||
hs, "remote_room_list", timeout_ms=30 * 1000
|
||||
hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000
|
||||
) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
|
||||
|
||||
async def get_local_public_room_list(
|
||||
|
||||
@@ -456,6 +456,7 @@ class SsoHandler:
|
||||
|
||||
await self._auth_handler.complete_sso_login(
|
||||
user_id,
|
||||
auth_provider_id,
|
||||
request,
|
||||
client_redirect_url,
|
||||
extra_login_attributes,
|
||||
@@ -605,6 +606,7 @@ class SsoHandler:
|
||||
default_display_name=attributes.display_name,
|
||||
bind_emails=attributes.emails,
|
||||
user_agent_ips=[(user_agent, ip_address)],
|
||||
auth_provider_id=auth_provider_id,
|
||||
)
|
||||
|
||||
await self._store.record_user_external_id(
|
||||
@@ -886,6 +888,7 @@ class SsoHandler:
|
||||
|
||||
await self._auth_handler.complete_sso_login(
|
||||
user_id,
|
||||
session.auth_provider_id,
|
||||
request,
|
||||
session.client_redirect_url,
|
||||
session.extra_login_attributes,
|
||||
|
||||
@@ -258,7 +258,7 @@ class SyncHandler:
|
||||
self.event_sources = hs.get_event_sources()
|
||||
self.clock = hs.get_clock()
|
||||
self.response_cache = ResponseCache(
|
||||
hs, "sync"
|
||||
hs.get_clock(), "sync"
|
||||
) # type: ResponseCache[Tuple[Any, ...]]
|
||||
self.state = hs.get_state_handler()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
@@ -63,6 +63,7 @@ from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_u
|
||||
from synapse.http.proxyagent import ProxyAgent
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.logging.opentracing import set_tag, start_active_span, tags
|
||||
from synapse.types import ISynapseReactor
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.async_helpers import timeout_deferred
|
||||
|
||||
@@ -199,7 +200,7 @@ class _IPBlacklistingResolver:
|
||||
return r
|
||||
|
||||
|
||||
@implementer(IReactorPluggableNameResolver)
|
||||
@implementer(ISynapseReactor)
|
||||
class BlacklistingReactorWrapper:
|
||||
"""
|
||||
A Reactor wrapper which will prevent DNS resolution to blacklisted IP
|
||||
@@ -324,7 +325,7 @@ class SimpleHttpClient:
|
||||
# filters out blacklisted IP addresses, to prevent DNS rebinding.
|
||||
self.reactor = BlacklistingReactorWrapper(
|
||||
hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
|
||||
)
|
||||
) # type: ISynapseReactor
|
||||
else:
|
||||
self.reactor = hs.get_reactor()
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ from synapse.http.client import BlacklistingAgentWrapper
|
||||
from synapse.http.federation.srv_resolver import Server, SrvResolver
|
||||
from synapse.http.federation.well_known_resolver import WellKnownResolver
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.types import ISynapseReactor
|
||||
from synapse.util import Clock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -68,7 +69,7 @@ class MatrixFederationAgent:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reactor: IReactorCore,
|
||||
reactor: ISynapseReactor,
|
||||
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
|
||||
user_agent: bytes,
|
||||
ip_blacklist: IPSet,
|
||||
|
||||
@@ -59,7 +59,7 @@ from synapse.logging.opentracing import (
|
||||
start_active_span,
|
||||
tags,
|
||||
)
|
||||
from synapse.types import JsonDict
|
||||
from synapse.types import ISynapseReactor, JsonDict
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.async_helpers import timeout_deferred
|
||||
from synapse.util.metrics import Measure
|
||||
@@ -237,14 +237,14 @@ class MatrixFederationHttpClient:
|
||||
# addresses, to prevent DNS rebinding.
|
||||
self.reactor = BlacklistingReactorWrapper(
|
||||
hs.get_reactor(), None, hs.config.federation_ip_range_blacklist
|
||||
)
|
||||
) # type: ISynapseReactor
|
||||
|
||||
user_agent = hs.version_string
|
||||
if hs.config.user_agent_suffix:
|
||||
user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix)
|
||||
user_agent = user_agent.encode("ascii")
|
||||
|
||||
self.agent = MatrixFederationAgent(
|
||||
federation_agent = MatrixFederationAgent(
|
||||
self.reactor,
|
||||
tls_client_options_factory,
|
||||
user_agent,
|
||||
@@ -254,7 +254,7 @@ class MatrixFederationHttpClient:
|
||||
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
|
||||
# blacklist via IP literals in server names
|
||||
self.agent = BlacklistingAgentWrapper(
|
||||
self.agent,
|
||||
federation_agent,
|
||||
ip_blacklist=hs.config.federation_ip_range_blacklist,
|
||||
)
|
||||
|
||||
|
||||
@@ -203,11 +203,26 @@ class ModuleApi:
|
||||
)
|
||||
|
||||
def generate_short_term_login_token(
|
||||
self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
|
||||
self,
|
||||
user_id: str,
|
||||
duration_in_ms: int = (2 * 60 * 1000),
|
||||
auth_provider_id: str = "",
|
||||
) -> str:
|
||||
"""Generate a login token suitable for m.login.token authentication"""
|
||||
"""Generate a login token suitable for m.login.token authentication
|
||||
|
||||
Args:
|
||||
user_id: gives the ID of the user that the token is for
|
||||
|
||||
duration_in_ms: the time that the token will be valid for
|
||||
|
||||
auth_provider_id: the ID of the SSO IdP that the user used to authenticate
|
||||
to get this token, if any. This is encoded in the token so that
|
||||
/login can report stats on number of successful logins by IdP.
|
||||
"""
|
||||
return self._hs.get_macaroon_generator().generate_short_term_login_token(
|
||||
user_id, duration_in_ms
|
||||
user_id,
|
||||
auth_provider_id,
|
||||
duration_in_ms,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -276,6 +291,7 @@ class ModuleApi:
|
||||
"""
|
||||
self._auth_handler._complete_sso_login(
|
||||
registered_user_id,
|
||||
"<unknown>",
|
||||
request,
|
||||
client_redirect_url,
|
||||
)
|
||||
@@ -286,6 +302,7 @@ class ModuleApi:
|
||||
request: SynapseRequest,
|
||||
client_redirect_url: str,
|
||||
new_user: bool = False,
|
||||
auth_provider_id: str = "<unknown>",
|
||||
):
|
||||
"""Complete a SSO login by redirecting the user to a page to confirm whether they
|
||||
want their access token sent to `client_redirect_url`, or redirect them to that
|
||||
@@ -299,9 +316,15 @@ class ModuleApi:
|
||||
redirect them directly if whitelisted).
|
||||
new_user: set to true to use wording for the consent appropriate to a user
|
||||
who has just registered.
|
||||
auth_provider_id: the ID of the SSO IdP which was used to log in. This
|
||||
is used to track counts of sucessful logins by IdP.
|
||||
"""
|
||||
await self._auth_handler.complete_sso_login(
|
||||
registered_user_id, request, client_redirect_url, new_user=new_user
|
||||
registered_user_id,
|
||||
auth_provider_id,
|
||||
request,
|
||||
client_redirect_url,
|
||||
new_user=new_user,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
||||
@@ -18,7 +18,7 @@ import logging
|
||||
import re
|
||||
import urllib
|
||||
from inspect import signature
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple
|
||||
|
||||
from prometheus_client import Counter, Gauge
|
||||
|
||||
@@ -28,6 +28,9 @@ from synapse.logging.opentracing import inject_active_span_byte_dict, trace
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_pending_outgoing_requests = Gauge(
|
||||
@@ -88,10 +91,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||
CACHE = True
|
||||
RETRY_ON_TIMEOUT = True
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
if self.CACHE:
|
||||
self.response_cache = ResponseCache(
|
||||
hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
|
||||
hs.get_clock(), "repl." + self.NAME, timeout_ms=30 * 60 * 1000
|
||||
) # type: ResponseCache[str]
|
||||
|
||||
# We reserve `instance_name` as a parameter to sending requests, so we
|
||||
|
||||
@@ -328,6 +328,6 @@ def lazyConnection(
|
||||
factory.continueTrying = reconnect
|
||||
|
||||
reactor = hs.get_reactor()
|
||||
reactor.connectTCP(host, port, factory, 30)
|
||||
reactor.connectTCP(host, port, factory, timeout=30, bindAddress=None)
|
||||
|
||||
return factory.handler
|
||||
|
||||
@@ -12,13 +12,20 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
parse_json_object_from_request,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.admin import assert_requester_is_admin
|
||||
from synapse.rest.admin._base import admin_patterns
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
class PurgeRoomServlet(RestServlet):
|
||||
@@ -36,16 +43,12 @@ class PurgeRoomServlet(RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/purge_room$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.pagination_handler = hs.get_pagination_handler()
|
||||
|
||||
async def on_POST(self, request):
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
@@ -12,17 +12,24 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
parse_json_object_from_request,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.admin import assert_requester_is_admin
|
||||
from synapse.rest.admin._base import admin_patterns
|
||||
from synapse.rest.client.transactions import HttpTransactionCache
|
||||
from synapse.types import UserID
|
||||
from synapse.types import JsonDict, UserID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
class SendServerNoticeServlet(RestServlet):
|
||||
@@ -44,17 +51,13 @@ class SendServerNoticeServlet(RestServlet):
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.txns = HttpTransactionCache(hs)
|
||||
self.snm = hs.get_server_notices_manager()
|
||||
|
||||
def register(self, json_resource):
|
||||
def register(self, json_resource: HttpServer):
|
||||
PATTERN = "/send_server_notice"
|
||||
json_resource.register_paths(
|
||||
"POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__
|
||||
@@ -66,7 +69,9 @@ class SendServerNoticeServlet(RestServlet):
|
||||
self.__class__.__name__,
|
||||
)
|
||||
|
||||
async def on_POST(self, request, txn_id=None):
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, txn_id: Optional[str] = None
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
body = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(body, ("user_id", "content"))
|
||||
@@ -90,7 +95,7 @@ class SendServerNoticeServlet(RestServlet):
|
||||
|
||||
return 200, {"event_id": event.event_id}
|
||||
|
||||
def on_PUT(self, request, txn_id):
|
||||
def on_PUT(self, request: SynapseRequest, txn_id: str) -> Tuple[int, JsonDict]:
|
||||
return self.txns.fetch_or_execute_request(
|
||||
request, self.on_POST, request, txn_id
|
||||
)
|
||||
|
||||
@@ -219,6 +219,7 @@ class LoginRestServlet(RestServlet):
|
||||
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
|
||||
create_non_existent_users: bool = False,
|
||||
ratelimit: bool = True,
|
||||
auth_provider_id: Optional[str] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Called when we've successfully authed the user and now need to
|
||||
actually login them in (e.g. create devices). This gets called on
|
||||
@@ -234,6 +235,8 @@ class LoginRestServlet(RestServlet):
|
||||
create_non_existent_users: Whether to create the user if they don't
|
||||
exist. Defaults to False.
|
||||
ratelimit: Whether to ratelimit the login request.
|
||||
auth_provider_id: The SSO IdP the user used, if any (just used for the
|
||||
prometheus metrics).
|
||||
|
||||
Returns:
|
||||
result: Dictionary of account information after successful login.
|
||||
@@ -256,7 +259,7 @@ class LoginRestServlet(RestServlet):
|
||||
device_id = login_submission.get("device_id")
|
||||
initial_display_name = login_submission.get("initial_device_display_name")
|
||||
device_id, access_token = await self.registration_handler.register_device(
|
||||
user_id, device_id, initial_display_name
|
||||
user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id
|
||||
)
|
||||
|
||||
result = {
|
||||
@@ -283,12 +286,13 @@ class LoginRestServlet(RestServlet):
|
||||
"""
|
||||
token = login_submission["token"]
|
||||
auth_handler = self.auth_handler
|
||||
user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
token
|
||||
)
|
||||
res = await auth_handler.validate_short_term_login_token(token)
|
||||
|
||||
return await self._complete_login(
|
||||
user_id, login_submission, self.auth_handler._sso_login_callback
|
||||
res.user_id,
|
||||
login_submission,
|
||||
self.auth_handler._sso_login_callback,
|
||||
auth_provider_id=res.auth_provider_id,
|
||||
)
|
||||
|
||||
async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
|
||||
|
||||
+2
-3
@@ -36,7 +36,6 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
import twisted.internet.base
|
||||
import twisted.internet.tcp
|
||||
from twisted.internet import defer
|
||||
from twisted.mail.smtp import sendmail
|
||||
@@ -130,7 +129,7 @@ from synapse.server_notices.worker_server_notices_sender import (
|
||||
from synapse.state import StateHandler, StateResolutionHandler
|
||||
from synapse.storage import Databases, DataStore, Storage
|
||||
from synapse.streams.events import EventSources
|
||||
from synapse.types import DomainSpecificString
|
||||
from synapse.types import DomainSpecificString, ISynapseReactor
|
||||
from synapse.util import Clock
|
||||
from synapse.util.distributor import Distributor
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
@@ -291,7 +290,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||
for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP:
|
||||
getattr(self, "get_" + i + "_handler")()
|
||||
|
||||
def get_reactor(self) -> twisted.internet.base.ReactorBase:
|
||||
def get_reactor(self) -> ISynapseReactor:
|
||||
"""
|
||||
Fetch the Twisted reactor in use by this HomeServer.
|
||||
"""
|
||||
|
||||
@@ -36,6 +36,14 @@ import attr
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
from six.moves import filter
|
||||
from unpaddedbase64 import decode_base64
|
||||
from zope.interface import Interface
|
||||
|
||||
from twisted.internet.interfaces import (
|
||||
IReactorCore,
|
||||
IReactorPluggableNameResolver,
|
||||
IReactorTCP,
|
||||
IReactorTime,
|
||||
)
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.util.stringutils import parse_and_validate_server_name
|
||||
@@ -68,6 +76,14 @@ MutableStateMap = MutableMapping[StateKey, T]
|
||||
JsonDict = Dict[str, Any]
|
||||
|
||||
|
||||
# Note that this seems to require inheriting *directly* from Interface in order
|
||||
# for mypy-zope to realize it is an interface.
|
||||
class ISynapseReactor(
|
||||
IReactorTCP, IReactorPluggableNameResolver, IReactorTime, IReactorCore, Interface
|
||||
):
|
||||
"""The interfaces necessary for Synapse to function."""
|
||||
|
||||
|
||||
class Requester(
|
||||
namedtuple(
|
||||
"Requester",
|
||||
|
||||
@@ -76,11 +76,16 @@ class ObservableDeferred:
|
||||
def callback(r):
|
||||
object.__setattr__(self, "_result", (True, r))
|
||||
while self._observers:
|
||||
observer = self._observers.pop()
|
||||
try:
|
||||
# TODO: Handle errors here.
|
||||
self._observers.pop().callback(r)
|
||||
except Exception:
|
||||
pass
|
||||
observer.callback(r)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"%r threw an exception on .callback(%r), ignoring...",
|
||||
observer,
|
||||
r,
|
||||
exc_info=e,
|
||||
)
|
||||
return r
|
||||
|
||||
def errback(f):
|
||||
@@ -90,11 +95,16 @@ class ObservableDeferred:
|
||||
# traces when we `await` on one of the observer deferreds.
|
||||
f.value.__failure__ = f
|
||||
|
||||
observer = self._observers.pop()
|
||||
try:
|
||||
# TODO: Handle errors here.
|
||||
self._observers.pop().errback(f)
|
||||
except Exception:
|
||||
pass
|
||||
observer.errback(f)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"%r threw an exception on .errback(%r), ignoring...",
|
||||
observer,
|
||||
f,
|
||||
exc_info=e,
|
||||
)
|
||||
|
||||
if consumeErrors:
|
||||
return None
|
||||
|
||||
@@ -13,17 +13,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar
|
||||
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.util import Clock
|
||||
from synapse.util.async_helpers import ObservableDeferred
|
||||
from synapse.util.caches import register_cache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -37,11 +35,11 @@ class ResponseCache(Generic[T]):
|
||||
used rather than trying to compute a new response.
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
|
||||
def __init__(self, clock: Clock, name: str, timeout_ms: float = 0):
|
||||
# Requests that haven't finished yet.
|
||||
self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
self.clock = clock
|
||||
self.timeout_sec = timeout_ms / 1000.0
|
||||
|
||||
self._name = name
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 Quentin Gliech
|
||||
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utilities for manipulating macaroons"""
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import pymacaroons
|
||||
from pymacaroons.exceptions import MacaroonVerificationFailedException
|
||||
|
||||
|
||||
def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
|
||||
"""Extracts a caveat value from a macaroon token.
|
||||
|
||||
Checks that there is exactly one caveat of the form "key = <val>" in the macaroon,
|
||||
and returns the extracted value.
|
||||
|
||||
Args:
|
||||
macaroon: the token
|
||||
key: the key of the caveat to extract
|
||||
|
||||
Returns:
|
||||
The extracted value
|
||||
|
||||
Raises:
|
||||
MacaroonVerificationFailedException: if there are conflicting values for the
|
||||
caveat in the macaroon, or if the caveat was not found in the macaroon.
|
||||
"""
|
||||
prefix = key + " = "
|
||||
result = None # type: Optional[str]
|
||||
for caveat in macaroon.caveats:
|
||||
if not caveat.caveat_id.startswith(prefix):
|
||||
continue
|
||||
|
||||
val = caveat.caveat_id[len(prefix) :]
|
||||
|
||||
if result is None:
|
||||
# first time we found this caveat: record the value
|
||||
result = val
|
||||
elif val != result:
|
||||
# on subsequent occurrences, raise if the value is different.
|
||||
raise MacaroonVerificationFailedException(
|
||||
"Conflicting values for caveat " + key
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# If the caveat is not there, we raise a MacaroonVerificationFailedException.
|
||||
# Note that it is insecure to generate a macaroon without all the caveats you
|
||||
# might need (because there is nothing stopping people from adding extra caveats),
|
||||
# so if the caveat isn't there, something odd must be going on.
|
||||
raise MacaroonVerificationFailedException("No %s caveat in macaroon" % (key,))
|
||||
|
||||
|
||||
def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> None:
|
||||
"""Make a macaroon verifier which accepts 'time' caveats
|
||||
|
||||
Builds a caveat verifier which will accept unexpired 'time' caveats, and adds it to
|
||||
the given macaroon verifier.
|
||||
|
||||
Args:
|
||||
v: the macaroon verifier
|
||||
get_time_ms: a callable which will return the timestamp after which the caveat
|
||||
should be considered expired. Normally the current time.
|
||||
"""
|
||||
|
||||
def verify_expiry_caveat(caveat: str):
|
||||
time_msec = get_time_ms()
|
||||
prefix = "time < "
|
||||
if not caveat.startswith(prefix):
|
||||
return False
|
||||
expiry = int(caveat[len(prefix) :])
|
||||
return time_msec < expiry
|
||||
|
||||
v.satisfy_general(verify_expiry_caveat)
|
||||
+29
-20
@@ -68,38 +68,45 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
||||
|
||||
def test_short_term_login_token_gives_user_id(self):
|
||||
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
|
||||
user_id = self.get_success(
|
||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
||||
token = self.macaroon_generator.generate_short_term_login_token(
|
||||
"a_user", "", 5000
|
||||
)
|
||||
self.assertEqual("a_user", user_id)
|
||||
res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
|
||||
self.assertEqual("a_user", res.user_id)
|
||||
self.assertEqual("", res.auth_provider_id)
|
||||
|
||||
# when we advance the clock, the token should be rejected
|
||||
self.reactor.advance(6)
|
||||
self.get_failure(
|
||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(token),
|
||||
self.auth_handler.validate_short_term_login_token(token),
|
||||
AuthError,
|
||||
)
|
||||
|
||||
def test_short_term_login_token_gives_auth_provider(self):
|
||||
token = self.macaroon_generator.generate_short_term_login_token(
|
||||
"a_user", auth_provider_id="my_idp"
|
||||
)
|
||||
res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
|
||||
self.assertEqual("a_user", res.user_id)
|
||||
self.assertEqual("my_idp", res.auth_provider_id)
|
||||
|
||||
def test_short_term_login_token_cannot_replace_user_id(self):
|
||||
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
|
||||
token = self.macaroon_generator.generate_short_term_login_token(
|
||||
"a_user", "", 5000
|
||||
)
|
||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||
|
||||
user_id = self.get_success(
|
||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
macaroon.serialize()
|
||||
)
|
||||
res = self.get_success(
|
||||
self.auth_handler.validate_short_term_login_token(macaroon.serialize())
|
||||
)
|
||||
self.assertEqual("a_user", user_id)
|
||||
self.assertEqual("a_user", res.user_id)
|
||||
|
||||
# add another "user_id" caveat, which might allow us to override the
|
||||
# user_id.
|
||||
macaroon.add_first_party_caveat("user_id = b_user")
|
||||
|
||||
self.get_failure(
|
||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
macaroon.serialize()
|
||||
),
|
||||
self.auth_handler.validate_short_term_login_token(macaroon.serialize()),
|
||||
AuthError,
|
||||
)
|
||||
|
||||
@@ -113,7 +120,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.get_success(
|
||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
self.auth_handler.validate_short_term_login_token(
|
||||
self._get_macaroon().serialize()
|
||||
)
|
||||
)
|
||||
@@ -135,7 +142,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
return_value=make_awaitable(self.large_number_of_users)
|
||||
)
|
||||
self.get_failure(
|
||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
self.auth_handler.validate_short_term_login_token(
|
||||
self._get_macaroon().serialize()
|
||||
),
|
||||
ResourceLimitError,
|
||||
@@ -159,7 +166,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
ResourceLimitError,
|
||||
)
|
||||
self.get_failure(
|
||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
self.auth_handler.validate_short_term_login_token(
|
||||
self._get_macaroon().serialize()
|
||||
),
|
||||
ResourceLimitError,
|
||||
@@ -175,7 +182,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
self.get_success(
|
||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
self.auth_handler.validate_short_term_login_token(
|
||||
self._get_macaroon().serialize()
|
||||
)
|
||||
)
|
||||
@@ -197,11 +204,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
return_value=make_awaitable(self.small_number_of_users)
|
||||
)
|
||||
self.get_success(
|
||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
self.auth_handler.validate_short_term_login_token(
|
||||
self._get_macaroon().serialize()
|
||||
)
|
||||
)
|
||||
|
||||
def _get_macaroon(self):
|
||||
token = self.macaroon_generator.generate_short_term_login_token("user_a", 5000)
|
||||
token = self.macaroon_generator.generate_short_term_login_token(
|
||||
"user_a", "", 5000
|
||||
)
|
||||
return pymacaroons.Macaroon.deserialize(token)
|
||||
|
||||
@@ -66,7 +66,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||
|
||||
# check that the auth handler got called as expected
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@test_user:test", request, "redirect_uri", None, new_user=True
|
||||
"@test_user:test", "cas", request, "redirect_uri", None, new_user=True
|
||||
)
|
||||
|
||||
def test_map_cas_user_to_existing_user(self):
|
||||
@@ -89,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||
|
||||
# check that the auth handler got called as expected
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@test_user:test", request, "redirect_uri", None, new_user=False
|
||||
"@test_user:test", "cas", request, "redirect_uri", None, new_user=False
|
||||
)
|
||||
|
||||
# Subsequent calls should map to the same mxid.
|
||||
@@ -98,7 +98,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
|
||||
)
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@test_user:test", request, "redirect_uri", None, new_user=False
|
||||
"@test_user:test", "cas", request, "redirect_uri", None, new_user=False
|
||||
)
|
||||
|
||||
def test_map_cas_user_to_invalid_localpart(self):
|
||||
@@ -116,7 +116,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||
|
||||
# check that the auth handler got called as expected
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
|
||||
"@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True
|
||||
)
|
||||
|
||||
@override_config(
|
||||
@@ -160,7 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||
|
||||
# check that the auth handler got called as expected
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@test_user:test", request, "redirect_uri", None, new_user=True
|
||||
"@test_user:test", "cas", request, "redirect_uri", None, new_user=True
|
||||
)
|
||||
|
||||
|
||||
|
||||
+16
-20
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
from typing import Optional
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from mock import ANY, Mock, patch
|
||||
@@ -23,6 +22,7 @@ import pymacaroons
|
||||
from synapse.handlers.sso import MappingException
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import UserID
|
||||
from synapse.util.macaroons import get_value_from_macaroon
|
||||
|
||||
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
|
||||
from tests.unittest import HomeserverTestCase, override_config
|
||||
@@ -360,15 +360,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
self.assertEqual(name, b"oidc_session")
|
||||
|
||||
macaroon = pymacaroons.Macaroon.deserialize(cookie)
|
||||
state = self.handler._token_generator._get_value_from_macaroon(
|
||||
macaroon, "state"
|
||||
)
|
||||
nonce = self.handler._token_generator._get_value_from_macaroon(
|
||||
macaroon, "nonce"
|
||||
)
|
||||
redirect = self.handler._token_generator._get_value_from_macaroon(
|
||||
macaroon, "client_redirect_url"
|
||||
)
|
||||
state = get_value_from_macaroon(macaroon, "state")
|
||||
nonce = get_value_from_macaroon(macaroon, "nonce")
|
||||
redirect = get_value_from_macaroon(macaroon, "client_redirect_url")
|
||||
|
||||
self.assertEqual(params["state"], [state])
|
||||
self.assertEqual(params["nonce"], [nonce])
|
||||
@@ -434,7 +428,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
expected_user_id, request, client_redirect_url, None, new_user=True
|
||||
expected_user_id, "oidc", request, client_redirect_url, None, new_user=True
|
||||
)
|
||||
self.provider._exchange_code.assert_called_once_with(code)
|
||||
self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
|
||||
@@ -465,7 +459,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
self.get_success(self.handler.handle_oidc_callback(request))
|
||||
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
expected_user_id, request, client_redirect_url, None, new_user=False
|
||||
expected_user_id, "oidc", request, client_redirect_url, None, new_user=False
|
||||
)
|
||||
self.provider._exchange_code.assert_called_once_with(code)
|
||||
self.provider._parse_id_token.assert_not_called()
|
||||
@@ -651,6 +645,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@foo:test",
|
||||
"oidc",
|
||||
request,
|
||||
client_redirect_url,
|
||||
{"phone": "1234567"},
|
||||
@@ -668,7 +663,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
}
|
||||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@test_user:test", ANY, ANY, None, new_user=True
|
||||
"@test_user:test", "oidc", ANY, ANY, None, new_user=True
|
||||
)
|
||||
auth_handler.complete_sso_login.reset_mock()
|
||||
|
||||
@@ -679,7 +674,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
}
|
||||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@test_user_2:test", ANY, ANY, None, new_user=True
|
||||
"@test_user_2:test", "oidc", ANY, ANY, None, new_user=True
|
||||
)
|
||||
auth_handler.complete_sso_login.reset_mock()
|
||||
|
||||
@@ -716,14 +711,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
}
|
||||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
user.to_string(), ANY, ANY, None, new_user=False
|
||||
user.to_string(), "oidc", ANY, ANY, None, new_user=False
|
||||
)
|
||||
auth_handler.complete_sso_login.reset_mock()
|
||||
|
||||
# Subsequent calls should map to the same mxid.
|
||||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
user.to_string(), ANY, ANY, None, new_user=False
|
||||
user.to_string(), "oidc", ANY, ANY, None, new_user=False
|
||||
)
|
||||
auth_handler.complete_sso_login.reset_mock()
|
||||
|
||||
@@ -738,7 +733,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
}
|
||||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
user.to_string(), ANY, ANY, None, new_user=False
|
||||
user.to_string(), "oidc", ANY, ANY, None, new_user=False
|
||||
)
|
||||
auth_handler.complete_sso_login.reset_mock()
|
||||
|
||||
@@ -774,7 +769,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
|
||||
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@TEST_USER_2:test", ANY, ANY, None, new_user=False
|
||||
"@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False
|
||||
)
|
||||
|
||||
def test_map_userinfo_to_invalid_localpart(self):
|
||||
@@ -810,7 +805,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
|
||||
# test_user is already taken, so test_user1 gets registered instead.
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@test_user1:test", ANY, ANY, None, new_user=True
|
||||
"@test_user1:test", "oidc", ANY, ANY, None, new_user=True
|
||||
)
|
||||
auth_handler.complete_sso_login.reset_mock()
|
||||
|
||||
@@ -866,7 +861,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
state: str,
|
||||
nonce: str,
|
||||
client_redirect_url: str,
|
||||
ui_auth_session_id: Optional[str] = None,
|
||||
ui_auth_session_id: str = "",
|
||||
) -> str:
|
||||
from synapse.handlers.oidc_handler import OidcSessionData
|
||||
|
||||
@@ -909,6 +904,7 @@ async def _make_callback_with_userinfo(
|
||||
idp_id="oidc",
|
||||
nonce="nonce",
|
||||
client_redirect_url=client_redirect_url,
|
||||
ui_auth_session_id="",
|
||||
),
|
||||
)
|
||||
request = _build_callback_request("code", state, session)
|
||||
|
||||
@@ -131,7 +131,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||
|
||||
# check that the auth handler got called as expected
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@test_user:test", request, "redirect_uri", None, new_user=True
|
||||
"@test_user:test", "saml", request, "redirect_uri", None, new_user=True
|
||||
)
|
||||
|
||||
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
|
||||
@@ -157,7 +157,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||
|
||||
# check that the auth handler got called as expected
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@test_user:test", request, "", None, new_user=False
|
||||
"@test_user:test", "saml", request, "", None, new_user=False
|
||||
)
|
||||
|
||||
# Subsequent calls should map to the same mxid.
|
||||
@@ -166,7 +166,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||
self.handler._handle_authn_response(request, saml_response, "")
|
||||
)
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@test_user:test", request, "", None, new_user=False
|
||||
"@test_user:test", "saml", request, "", None, new_user=False
|
||||
)
|
||||
|
||||
def test_map_saml_response_to_invalid_localpart(self):
|
||||
@@ -214,7 +214,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||
|
||||
# test_user is already taken, so test_user1 gets registered instead.
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@test_user1:test", request, "", None, new_user=True
|
||||
"@test_user1:test", "saml", request, "", None, new_user=True
|
||||
)
|
||||
auth_handler.complete_sso_login.reset_mock()
|
||||
|
||||
@@ -310,7 +310,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||
|
||||
# check that the auth handler got called as expected
|
||||
auth_handler.complete_sso_login.assert_called_once_with(
|
||||
"@test_user:test", request, "redirect_uri", None, new_user=True
|
||||
"@test_user:test", "saml", request, "redirect_uri", None, new_user=True
|
||||
)
|
||||
|
||||
|
||||
|
||||
+45
-26
@@ -13,9 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import NotFoundError
|
||||
from synapse.api.errors import NotFoundError, SynapseError
|
||||
from synapse.rest.client.v1 import room
|
||||
|
||||
from tests.unittest import HomeserverTestCase
|
||||
@@ -33,9 +31,12 @@ class PurgeTests(HomeserverTestCase):
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.room_id = self.helper.create_room_as(self.user_id)
|
||||
|
||||
def test_purge(self):
|
||||
self.store = hs.get_datastore()
|
||||
self.storage = self.hs.get_storage()
|
||||
|
||||
def test_purge_history(self):
|
||||
"""
|
||||
Purging a room will delete everything before the topological point.
|
||||
Purging a room history will delete everything before the topological point.
|
||||
"""
|
||||
# Send four messages to the room
|
||||
first = self.helper.send(self.room_id, body="test1")
|
||||
@@ -43,30 +44,27 @@ class PurgeTests(HomeserverTestCase):
|
||||
third = self.helper.send(self.room_id, body="test3")
|
||||
last = self.helper.send(self.room_id, body="test4")
|
||||
|
||||
store = self.hs.get_datastore()
|
||||
storage = self.hs.get_storage()
|
||||
|
||||
# Get the topological token
|
||||
token = self.get_success(
|
||||
store.get_topological_token_for_event(last["event_id"])
|
||||
self.store.get_topological_token_for_event(last["event_id"])
|
||||
)
|
||||
token_str = self.get_success(token.to_string(self.hs.get_datastore()))
|
||||
|
||||
# Purge everything before this topological token
|
||||
self.get_success(
|
||||
storage.purge_events.purge_history(self.room_id, token_str, True)
|
||||
self.storage.purge_events.purge_history(self.room_id, token_str, True)
|
||||
)
|
||||
|
||||
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
|
||||
# and last is not.
|
||||
self.get_failure(store.get_event(first["event_id"]), NotFoundError)
|
||||
self.get_failure(store.get_event(second["event_id"]), NotFoundError)
|
||||
self.get_failure(store.get_event(third["event_id"]), NotFoundError)
|
||||
self.get_success(store.get_event(last["event_id"]))
|
||||
self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
|
||||
self.get_failure(self.store.get_event(second["event_id"]), NotFoundError)
|
||||
self.get_failure(self.store.get_event(third["event_id"]), NotFoundError)
|
||||
self.get_success(self.store.get_event(last["event_id"]))
|
||||
|
||||
def test_purge_wont_delete_extrems(self):
|
||||
def test_purge_history_wont_delete_extrems(self):
|
||||
"""
|
||||
Purging a room will delete everything before the topological point.
|
||||
Purging a room history will delete everything before the topological point.
|
||||
"""
|
||||
# Send four messages to the room
|
||||
first = self.helper.send(self.room_id, body="test1")
|
||||
@@ -74,22 +72,43 @@ class PurgeTests(HomeserverTestCase):
|
||||
third = self.helper.send(self.room_id, body="test3")
|
||||
last = self.helper.send(self.room_id, body="test4")
|
||||
|
||||
storage = self.hs.get_datastore()
|
||||
|
||||
# Set the topological token higher than it should be
|
||||
token = self.get_success(
|
||||
storage.get_topological_token_for_event(last["event_id"])
|
||||
self.store.get_topological_token_for_event(last["event_id"])
|
||||
)
|
||||
event = "t{}-{}".format(token.topological + 1, token.stream + 1)
|
||||
|
||||
# Purge everything before this topological token
|
||||
purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
|
||||
self.pump()
|
||||
f = self.failureResultOf(purge)
|
||||
f = self.get_failure(
|
||||
self.storage.purge_events.purge_history(self.room_id, event, True),
|
||||
SynapseError,
|
||||
)
|
||||
self.assertIn("greater than forward", f.value.args[0])
|
||||
|
||||
# Try and get the events
|
||||
self.get_success(storage.get_event(first["event_id"]))
|
||||
self.get_success(storage.get_event(second["event_id"]))
|
||||
self.get_success(storage.get_event(third["event_id"]))
|
||||
self.get_success(storage.get_event(last["event_id"]))
|
||||
self.get_success(self.store.get_event(first["event_id"]))
|
||||
self.get_success(self.store.get_event(second["event_id"]))
|
||||
self.get_success(self.store.get_event(third["event_id"]))
|
||||
self.get_success(self.store.get_event(last["event_id"]))
|
||||
|
||||
def test_purge_room(self):
|
||||
"""
|
||||
Purging a room will delete everything about it.
|
||||
"""
|
||||
# Send four messages to the room
|
||||
first = self.helper.send(self.room_id, body="test1")
|
||||
|
||||
# Get the current room state.
|
||||
state_handler = self.hs.get_state_handler()
|
||||
create_event = self.get_success(
|
||||
state_handler.get_current_state(self.room_id, "m.room.create", "")
|
||||
)
|
||||
self.assertIsNotNone(create_event)
|
||||
|
||||
# Purge everything before this topological token
|
||||
self.get_success(self.storage.purge_events.purge_room(self.room_id))
|
||||
|
||||
# The events aren't found.
|
||||
self.store._invalidate_get_event_cache(create_event.event_id)
|
||||
self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
|
||||
self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
|
||||
from tests.server import get_clock
|
||||
from tests.unittest import TestCase
|
||||
|
||||
|
||||
class DeferredCacheTestCase(TestCase):
|
||||
"""
|
||||
A TestCase class for ResponseCache.
|
||||
|
||||
The test-case function naming has some logic to it in it's parts, here's some notes about it:
|
||||
wait: Denotes tests that have an element of "waiting" before its wrapped result becomes available
|
||||
(Generally these just use .delayed_return instead of .instant_return in it's wrapped call.)
|
||||
expire: Denotes tests that test expiry after assured existence.
|
||||
(These have cache with a short timeout_ms=, shorter than will be tested through advancing the clock)
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.reactor, self.clock = get_clock()
|
||||
|
||||
def with_cache(self, name: str, ms: int = 0) -> ResponseCache:
|
||||
return ResponseCache(self.clock, name, timeout_ms=ms)
|
||||
|
||||
@staticmethod
|
||||
async def instant_return(o: str) -> str:
|
||||
return o
|
||||
|
||||
async def delayed_return(self, o: str) -> str:
|
||||
await self.clock.sleep(1)
|
||||
return o
|
||||
|
||||
def test_cache_hit(self):
|
||||
cache = self.with_cache("keeping_cache", ms=9001)
|
||||
|
||||
expected_result = "howdy"
|
||||
|
||||
wrap_d = cache.wrap(0, self.instant_return, expected_result)
|
||||
|
||||
self.assertEqual(
|
||||
expected_result,
|
||||
self.successResultOf(wrap_d),
|
||||
"initial wrap result should be the same",
|
||||
)
|
||||
self.assertEqual(
|
||||
expected_result,
|
||||
self.successResultOf(cache.get(0)),
|
||||
"cache should have the result",
|
||||
)
|
||||
|
||||
def test_cache_miss(self):
|
||||
cache = self.with_cache("trashing_cache", ms=0)
|
||||
|
||||
expected_result = "howdy"
|
||||
|
||||
wrap_d = cache.wrap(0, self.instant_return, expected_result)
|
||||
|
||||
self.assertEqual(
|
||||
expected_result,
|
||||
self.successResultOf(wrap_d),
|
||||
"initial wrap result should be the same",
|
||||
)
|
||||
self.assertIsNone(cache.get(0), "cache should not have the result now")
|
||||
|
||||
def test_cache_expire(self):
|
||||
cache = self.with_cache("short_cache", ms=1000)
|
||||
|
||||
expected_result = "howdy"
|
||||
|
||||
wrap_d = cache.wrap(0, self.instant_return, expected_result)
|
||||
|
||||
self.assertEqual(expected_result, self.successResultOf(wrap_d))
|
||||
self.assertEqual(
|
||||
expected_result,
|
||||
self.successResultOf(cache.get(0)),
|
||||
"cache should still have the result",
|
||||
)
|
||||
|
||||
# cache eviction timer is handled
|
||||
self.reactor.pump((2,))
|
||||
|
||||
self.assertIsNone(cache.get(0), "cache should not have the result now")
|
||||
|
||||
def test_cache_wait_hit(self):
|
||||
cache = self.with_cache("neutral_cache")
|
||||
|
||||
expected_result = "howdy"
|
||||
|
||||
wrap_d = cache.wrap(0, self.delayed_return, expected_result)
|
||||
self.assertNoResult(wrap_d)
|
||||
|
||||
# function wakes up, returns result
|
||||
self.reactor.pump((2,))
|
||||
|
||||
self.assertEqual(expected_result, self.successResultOf(wrap_d))
|
||||
|
||||
def test_cache_wait_expire(self):
|
||||
cache = self.with_cache("medium_cache", ms=3000)
|
||||
|
||||
expected_result = "howdy"
|
||||
|
||||
wrap_d = cache.wrap(0, self.delayed_return, expected_result)
|
||||
self.assertNoResult(wrap_d)
|
||||
|
||||
# stop at 1 second to callback cache eviction callLater at that time, then another to set time at 2
|
||||
self.reactor.pump((1, 1))
|
||||
|
||||
self.assertEqual(expected_result, self.successResultOf(wrap_d))
|
||||
self.assertEqual(
|
||||
expected_result,
|
||||
self.successResultOf(cache.get(0)),
|
||||
"cache should still have the result",
|
||||
)
|
||||
|
||||
# (1 + 1 + 2) > 3.0, cache eviction timer is handled
|
||||
self.reactor.pump((2,))
|
||||
|
||||
self.assertIsNone(cache.get(0), "cache should not have the result now")
|
||||
Reference in New Issue
Block a user